Skip to content

Commit

Permalink
Enable subsampling of displayed images in ShowImages.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 701298818
  • Loading branch information
The kauldron Authors committed Nov 29, 2024
1 parent 2ba436d commit 9e7f94f
Showing 1 changed file with 30 additions and 1 deletion.
31 changes: 30 additions & 1 deletion kauldron/summaries/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class ShowImages(metrics.Metric):
Attributes:
images: Key to the images to display.
num_images: Number of images to collect and display. Default 5.
vrange: Optional value range of the input images. Used to clip aand then
vrange: Optional value range of the input images. Used to clip and then
rescale the images to [0, 1].
rearrange: Optional einops string to reshape the images.
rearrange_kwargs: Optional keyword arguments for the einops reshape.
Expand All @@ -55,6 +55,9 @@ class ShowImages(metrics.Metric):
rearrange: Optional[str] = None
rearrange_kwargs: Mapping[str, Any] | None = None

subsample_dim: int | None = None
subsample_step: int | None = None

@struct.dataclass
class State(metrics.AutoState["ShowImages"]):
"""Collects the first num_images images."""
Expand All @@ -75,6 +78,9 @@ def get_state(
images: Float["..."],
) -> ShowImages.State:
# maybe rearrange and then check shape
images = _maybe_subsample(
images, self.subsample_dim, step=self.subsample_step
)
images = _maybe_rearrange(images, self.rearrange, self.rearrange_kwargs)
check_type(images, Float["n h w #3"])

Expand Down Expand Up @@ -315,6 +321,29 @@ def get_state(
return self.State(diff_images=diff_images)


def _maybe_subsample(
array: Array["..."] | None,
dimension: int | None = None,
step: int | None = None,
) -> Array["..."] | None:
"""Subsamples the array along the given dimension with the given step.
Args:
array: The array to subsample.
dimension: The dimension to subsample along.
step: The subsampling step.
Returns:
The subsampled array.
"""
if array is None or step is None:
return array

slices = [slice(None)] * array.ndim
slices[dimension] = slice(None, None, step)
return array[tuple(slices)]


def _maybe_rearrange(
array: Array["..."] | None,
rearrange: Optional[str] = None,
Expand Down

0 comments on commit 9e7f94f

Please sign in to comment.