Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add random temporal subsample transform #93

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions pytorchvideo/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,33 @@ def uniform_temporal_subsample(
return torch.index_select(x, temporal_dim, indices)


def random_temporal_subsample(
x: torch.Tensor, num_samples: int, temporal_dim: int = 1
) -> torch.Tensor:
"""
Random subsamples num_samples indices from the temporal dimension of the video.
When num_samples is larger than the size of temporal dimension of the video, it
will randomly sample the same frames multiple times.

Args:
x (torch.Tensor): A video tensor with dimension larger than one with torch
tensor type includes int, long, float, complex, etc.
num_samples (int): The number of equispaced samples to be selected
temporal_dim (int): dimension of temporal to perform temporal subsample.

Returns:
An x-like Tensor with subsampled temporal dimension.
"""
t = x.shape[temporal_dim]
assert num_samples > 0 and t > 0
indices = torch.randperm(t)
# Repeat indices ntimes if num_samples > t.
ntimes = math.ceil(num_samples / t)
indices = indices.repeat(ntimes)[:num_samples]
indices, _ = torch.sort(indices)
return torch.index_select(x, temporal_dim, indices.long())


@torch.jit.ignore
def _interpolate_opencv(
x: torch.Tensor, size: Tuple[int, int], interpolation: str
Expand Down
17 changes: 17 additions & 0 deletions pytorchvideo/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ def forward(self, x: torch.Tensor):
)


class RandomTemporalSubsample(torch.nn.Module):
"""
``nn.Module`` wrapper for ``pytorchvideo.transforms.functional.random_temporal_subsample``.
"""

def __init__(self, num_samples: int):
super().__init__()
self._num_samples = num_samples

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x (torch.Tensor): video tensor with shape (C, T, H, W).
"""
return pytorchvideo.transforms.functional.random_temporal_subsample(x, self._num_samples)


class ShortSideScale(torch.nn.Module):
"""
``nn.Module`` wrapper for ``pytorchvideo.transforms.functional.short_side_scale``.
Expand Down
23 changes: 23 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
ShortSideScale,
UniformCropVideo,
UniformTemporalSubsample,
RandomTemporalSubsample,
RandomResizedCrop,
)
from pytorchvideo.transforms import create_video_transform
Expand Down Expand Up @@ -1227,3 +1228,25 @@ def test_div_255(self):

self.assertEqual(output_tensor.shape, video_tensor.shape)
self.assertTrue(bool(torch.all(torch.eq(output_tensor, expect_tensor))))

def test_random_temporal_subsample(self):
subsample = RandomTemporalSubsample(10)
c, t, h, w = 3, 100, 200, 200
in_tensor = torch.rand(c, t, h, w)
out_tensor = subsample(in_tensor)
co, to, ho, wo = out_tensor.shape
self.assertEqual(co, 3)
self.assertEqual(to, 10)
self.assertEqual(ho, 200)
self.assertEqual(wo, 200)

# num subsampled > num frames
subsample = RandomTemporalSubsample(95)
c, t, h, w = 3, 10, 200, 200
in_tensor = torch.rand(c, t, h, w)
out_tensor = subsample(in_tensor)
co, to, ho, wo = out_tensor.shape
self.assertEqual(co, 3)
self.assertEqual(to, 95)
self.assertEqual(ho, 200)
self.assertEqual(wo, 200)