diff --git a/tests/test_subsamplers.py b/tests/test_subsamplers.py index 32f708fa..e6a5b5f0 100644 --- a/tests/test_subsamplers.py +++ b/tests/test_subsamplers.py @@ -188,7 +188,7 @@ def test_audio_rate_subsampler(sample_rate, n_audio_channels): audio_bytes = aud_f.read() streams = {"audio": [audio_bytes]} - subsampler = AudioRateSubsampler(sample_rate, {"audio": "mp3"}, n_audio_channels) + subsampler = AudioRateSubsampler(sample_rate, "mp3", n_audio_channels) subsampled_streams, _, error_message = subsampler(streams) assert error_message is None diff --git a/video2dataset/subsamplers/audio_rate_subsampler.py b/video2dataset/subsamplers/audio_rate_subsampler.py index 48d94038..ad8b12e1 100644 --- a/video2dataset/subsamplers/audio_rate_subsampler.py +++ b/video2dataset/subsamplers/audio_rate_subsampler.py @@ -12,12 +12,13 @@ class AudioRateSubsampler: """ Adjusts the frame rate of the videos to the specified frame rate. Args: - frame_rate (int): Target frame rate of the videos. + sample_rate (int): Target sample rate of the audio. + encode_format (str): Format to encode in (i.e. m4a) """ - def __init__(self, sample_rate, encode_formats, n_audio_channels=None): + def __init__(self, sample_rate, encode_format, n_audio_channels=None): self.sample_rate = sample_rate - self.encode_formats = encode_formats + self.encode_format = encode_format self.n_audio_channels = n_audio_channels def __call__(self, streams, metadata=None): @@ -27,7 +28,7 @@ def __call__(self, streams, metadata=None): with tempfile.TemporaryDirectory() as tmpdir: with open(os.path.join(tmpdir, "input.m4a"), "wb") as f: f.write(aud_bytes) - ext = self.encode_formats["audio"] + ext = self.encode_format try: # TODO: for now assuming m4a, change this ffmpeg_args = {"ar": str(self.sample_rate), "f": ext} diff --git a/video2dataset/subsamplers/frame_subsampler.py b/video2dataset/subsamplers/frame_subsampler.py index 845c6677..e3c1b572 100644 --- a/video2dataset/subsamplers/frame_subsampler.py +++ b/video2dataset/subsamplers/frame_subsampler.py @@ -22,6 +22,7 @@ class FrameSubsampler(Subsampler): yt_subtitle: temporary special case where you want a frame at the beginning of each yt_subtitle we will want to turn this into something like frame_timestamps and introduce this as a fusing option with clipping_subsampler + encode_format (str): Format to encode in (i.e. mp4) TODO: n_frame TODO: generalize interface, should be like (frame_rate, n_frames, sampler, output_format) @@ -31,10 +32,11 @@ class FrameSubsampler(Subsampler): # output_format - save as video, or images """ - def __init__(self, frame_rate, downsample_method="fps"): + def __init__(self, frame_rate, downsample_method="fps", encode_format="mp4"): self.frame_rate = frame_rate self.downsample_method = downsample_method self.output_modality = "video" if downsample_method == "fps" else "jpg" + self.encode_format = encode_format def __call__(self, streams, metadata=None): # TODO: you might not want to pop it (f.e. in case of other subsamplers) diff --git a/video2dataset/subsamplers/resolution_subsampler.py b/video2dataset/subsamplers/resolution_subsampler.py index 59224546..46545f92 100644 --- a/video2dataset/subsamplers/resolution_subsampler.py +++ b/video2dataset/subsamplers/resolution_subsampler.py @@ -23,20 +23,23 @@ class ResolutionSubsampler(Subsampler): height (int): Height of video. width (int): Width of video. video_size (int): Both height and width. + encode_format (str): Format to encode in (i.e. mp4) """ - def __init__( self, resize_mode: Literal["scale", "crop", "pad"], height: int = -1, width: int = -1, video_size: int = -1, + encode_format: str = "mp4", ): if video_size > 0 and (height > 0 or width > 0): raise Exception("Either set video_size, or set height and/or width") + self.resize_mode = resize_mode self.height = height if video_size < 0 else video_size self.width = width if video_size < 0 else video_size - self.resize_mode = resize_mode + self.video_size = video_size + self.encode_format = encode_format def __call__(self, streams, metadata=None): video_bytes = streams["video"] diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index 9e6e142a..20d77eb8 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -114,9 +114,15 @@ def process_shard( # The subsamplers might change the output format, so we need to update the writer writer_encode_formats = self.encode_formats.copy() if self.subsamplers["audio"]: - writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_formats["audio"] + assert ( + len({s.encode_format for s in self.subsamplers["audio"]}) == 1 + ) # assert that all audio subsamplers have the same output format + writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_format if self.subsamplers["video"]: - writer_encode_formats["video"] = self.subsamplers["video"][0].encode_formats["video"] + assert ( + len({s.encode_format for s in self.subsamplers["video"]}) == 1 + ) # assert that all video subsamplers have the same output format + writer_encode_formats["video"] = self.subsamplers["video"][0].encode_format # give schema to writer sample_writer = self.sample_writer_class(