From 9e7f94f7644e1aac739c1a8750463a387933d535 Mon Sep 17 00:00:00 2001 From: The kauldron Authors Date: Fri, 29 Nov 2024 08:47:17 -0800 Subject: [PATCH] Enable subsampling of displayed images in ShowImages. PiperOrigin-RevId: 701298818 --- kauldron/summaries/images.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/kauldron/summaries/images.py b/kauldron/summaries/images.py index b4297dbf..8976a095 100644 --- a/kauldron/summaries/images.py +++ b/kauldron/summaries/images.py @@ -41,7 +41,7 @@ class ShowImages(metrics.Metric): Attributes: images: Key to the images to display. num_images: Number of images to collect and display. Default 5. - vrange: Optional value range of the input images. Used to clip aand then + vrange: Optional value range of the input images. Used to clip and then rescale the images to [0, 1]. rearrange: Optional einops string to reshape the images. rearrange_kwargs: Optional keyword arguments for the einops reshape. @@ -55,6 +55,9 @@ class ShowImages(metrics.Metric): rearrange: Optional[str] = None rearrange_kwargs: Mapping[str, Any] | None = None + subsample_dim: int | None = None + subsample_step: int | None = None + @struct.dataclass class State(metrics.AutoState["ShowImages"]): """Collects the first num_images images.""" @@ -75,6 +78,9 @@ def get_state( images: Float["..."], ) -> ShowImages.State: # maybe rearrange and then check shape + images = _maybe_subsample( + images, self.subsample_dim, step=self.subsample_step + ) images = _maybe_rearrange(images, self.rearrange, self.rearrange_kwargs) check_type(images, Float["n h w #3"]) @@ -315,6 +321,29 @@ def get_state( return self.State(diff_images=diff_images) +def _maybe_subsample( + array: Array["..."] | None, + dimension: int | None = None, + step: int | None = None, +) -> Array["..."] | None: + """Subsamples the array along the given dimension with the given step. + + Args: + array: The array to subsample. + dimension: The dimension to subsample along. + step: The subsampling step. + + Returns: + The subsampled array. + """ + if array is None or step is None: + return array + + slices = [slice(None)] * array.ndim + slices[dimension] = slice(None, None, step) + return array[tuple(slices)] + + def _maybe_rearrange( array: Array["..."] | None, rearrange: Optional[str] = None,