diff --git a/pytorchvideo/transforms/functional.py b/pytorchvideo/transforms/functional.py index e848a159..d16ecd50 100644 --- a/pytorchvideo/transforms/functional.py +++ b/pytorchvideo/transforms/functional.py @@ -41,6 +41,33 @@ def uniform_temporal_subsample( return torch.index_select(x, temporal_dim, indices) +def random_temporal_subsample( + x: torch.Tensor, num_samples: int, temporal_dim: int = 1 +) -> torch.Tensor: + """ + Random subsamples num_samples indices from the temporal dimension of the video. + When num_samples is larger than the size of temporal dimension of the video, it + will randomly sample the same frames multiple times. + + Args: + x (torch.Tensor): A video tensor with dimension larger than one with torch + tensor type includes int, long, float, complex, etc. + num_samples (int): The number of equispaced samples to be selected + temporal_dim (int): dimension of temporal to perform temporal subsample. + + Returns: + An x-like Tensor with subsampled temporal dimension. + """ + t = x.shape[temporal_dim] + assert num_samples > 0 and t > 0 + indices = torch.randperm(t) + # Repeat indices ntimes if num_samples > t. + ntimes = math.ceil(num_samples / t) + indices = indices.repeat(ntimes)[:num_samples] + indices, _ = torch.sort(indices) + return torch.index_select(x, temporal_dim, indices.long()) + + @torch.jit.ignore def _interpolate_opencv( x: torch.Tensor, size: Tuple[int, int], interpolation: str diff --git a/pytorchvideo/transforms/transforms.py b/pytorchvideo/transforms/transforms.py index 786e66cb..f528d2b1 100644 --- a/pytorchvideo/transforms/transforms.py +++ b/pytorchvideo/transforms/transforms.py @@ -89,6 +89,23 @@ def forward(self, x: torch.Tensor): ) +class RandomTemporalSubsample(torch.nn.Module): + """ + ``nn.Module`` wrapper for ``pytorchvideo.transforms.functional.random_temporal_subsample``. + """ + + def __init__(self, num_samples: int): + super().__init__() + self._num_samples = num_samples + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x (torch.Tensor): video tensor with shape (C, T, H, W). + """ + return pytorchvideo.transforms.functional.random_temporal_subsample(x, self._num_samples) + + class ShortSideScale(torch.nn.Module): """ ``nn.Module`` wrapper for ``pytorchvideo.transforms.functional.short_side_scale``. diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 1b35077b..a9b1810a 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -21,6 +21,7 @@ ShortSideScale, UniformCropVideo, UniformTemporalSubsample, + RandomTemporalSubsample, RandomResizedCrop, ) from pytorchvideo.transforms import create_video_transform @@ -1227,3 +1228,25 @@ def test_div_255(self): self.assertEqual(output_tensor.shape, video_tensor.shape) self.assertTrue(bool(torch.all(torch.eq(output_tensor, expect_tensor)))) + + def test_random_temporal_subsample(self): + subsample = RandomTemporalSubsample(10) + c, t, h, w = 3, 100, 200, 200 + in_tensor = torch.rand(c, t, h, w) + out_tensor = subsample(in_tensor) + co, to, ho, wo = out_tensor.shape + self.assertEqual(co, 3) + self.assertEqual(to, 10) + self.assertEqual(ho, 200) + self.assertEqual(wo, 200) + + # num subsampled > num frames + subsample = RandomTemporalSubsample(95) + c, t, h, w = 3, 10, 200, 200 + in_tensor = torch.rand(c, t, h, w) + out_tensor = subsample(in_tensor) + co, to, ho, wo = out_tensor.shape + self.assertEqual(co, 3) + self.assertEqual(to, 95) + self.assertEqual(ho, 200) + self.assertEqual(wo, 200)