From 0f18b900ad33153aa4e63b1db2490720ea36bbaa Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Thu, 18 Jan 2024 11:47:10 -0500 Subject: [PATCH] Fix for issue #263 (#271) * Fix for issue #263 * Fixed unit test * Fixed linting * Black formatting on just subset_worker.py --------- Co-authored-by: iejMac --- tests/test_subsamplers.py | 2 +- video2dataset/subsamplers/audio_rate_subsampler.py | 9 +++++---- video2dataset/subsamplers/frame_subsampler.py | 4 +++- video2dataset/subsamplers/resolution_subsampler.py | 4 +++- video2dataset/workers/subset_worker.py | 10 ++++++++-- 5 files changed, 20 insertions(+), 9 deletions(-) diff --git a/tests/test_subsamplers.py b/tests/test_subsamplers.py index 511539e1..2c78e258 100644 --- a/tests/test_subsamplers.py +++ b/tests/test_subsamplers.py @@ -158,7 +158,7 @@ def test_audio_rate_subsampler(sample_rate, n_audio_channels): audio_bytes = aud_f.read() streams = {"audio": [audio_bytes]} - subsampler = AudioRateSubsampler(sample_rate, {"audio": "mp3"}, n_audio_channels) + subsampler = AudioRateSubsampler(sample_rate, "mp3", n_audio_channels) subsampled_streams, _, error_message = subsampler(streams) assert error_message is None diff --git a/video2dataset/subsamplers/audio_rate_subsampler.py b/video2dataset/subsamplers/audio_rate_subsampler.py index 48d94038..ad8b12e1 100644 --- a/video2dataset/subsamplers/audio_rate_subsampler.py +++ b/video2dataset/subsamplers/audio_rate_subsampler.py @@ -12,12 +12,13 @@ class AudioRateSubsampler: """ Adjusts the frame rate of the videos to the specified frame rate. Args: - frame_rate (int): Target frame rate of the videos. + sample_rate (int): Target sample rate of the audio. + encode_format (str): Format to encode in (i.e. m4a) """ - def __init__(self, sample_rate, encode_formats, n_audio_channels=None): + def __init__(self, sample_rate, encode_format, n_audio_channels=None): self.sample_rate = sample_rate - self.encode_formats = encode_formats + self.encode_format = encode_format self.n_audio_channels = n_audio_channels def __call__(self, streams, metadata=None): @@ -27,7 +28,7 @@ def __call__(self, streams, metadata=None): with tempfile.TemporaryDirectory() as tmpdir: with open(os.path.join(tmpdir, "input.m4a"), "wb") as f: f.write(aud_bytes) - ext = self.encode_formats["audio"] + ext = self.encode_format try: # TODO: for now assuming m4a, change this ffmpeg_args = {"ar": str(self.sample_rate), "f": ext} diff --git a/video2dataset/subsamplers/frame_subsampler.py b/video2dataset/subsamplers/frame_subsampler.py index 845c6677..e3c1b572 100644 --- a/video2dataset/subsamplers/frame_subsampler.py +++ b/video2dataset/subsamplers/frame_subsampler.py @@ -22,6 +22,7 @@ class FrameSubsampler(Subsampler): yt_subtitle: temporary special case where you want a frame at the beginning of each yt_subtitle we will want to turn this into something like frame_timestamps and introduce this as a fusing option with clipping_subsampler + encode_format (str): Format to encode in (i.e. mp4) TODO: n_frame TODO: generalize interface, should be like (frame_rate, n_frames, sampler, output_format) @@ -31,10 +32,11 @@ class FrameSubsampler(Subsampler): # output_format - save as video, or images """ - def __init__(self, frame_rate, downsample_method="fps"): + def __init__(self, frame_rate, downsample_method="fps", encode_format="mp4"): self.frame_rate = frame_rate self.downsample_method = downsample_method self.output_modality = "video" if downsample_method == "fps" else "jpg" + self.encode_format = encode_format def __call__(self, streams, metadata=None): # TODO: you might not want to pop it (f.e. in case of other subsamplers) diff --git a/video2dataset/subsamplers/resolution_subsampler.py b/video2dataset/subsamplers/resolution_subsampler.py index 365bd6a6..d38f4f7a 100644 --- a/video2dataset/subsamplers/resolution_subsampler.py +++ b/video2dataset/subsamplers/resolution_subsampler.py @@ -18,11 +18,13 @@ class ResolutionSubsampler(Subsampler): scale: scale video keeping aspect ratios (currently always picks video height) crop: center crop to video_size x video_size pad: center pad to video_size x video_size + encode_format (str): Format to encode in (i.e. mp4) """ - def __init__(self, video_size, resize_mode): + def __init__(self, video_size, resize_mode, encode_format="mp4"): self.video_size = video_size self.resize_mode = resize_mode + self.encode_format = encode_format def __call__(self, streams, metadata=None): video_bytes = streams["video"] diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index 9e6e142a..20d77eb8 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -114,9 +114,15 @@ def process_shard( # The subsamplers might change the output format, so we need to update the writer writer_encode_formats = self.encode_formats.copy() if self.subsamplers["audio"]: - writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_formats["audio"] + assert ( + len({s.encode_format for s in self.subsamplers["audio"]}) == 1 + ) # assert that all audio subsamplers have the same output format + writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_format if self.subsamplers["video"]: - writer_encode_formats["video"] = self.subsamplers["video"][0].encode_formats["video"] + assert ( + len({s.encode_format for s in self.subsamplers["video"]}) == 1 + ) # assert that all video subsamplers have the same output format + writer_encode_formats["video"] = self.subsamplers["video"][0].encode_format # give schema to writer sample_writer = self.sample_writer_class(