Skip to content

Commit

Permalink
Merge branch 'main' into width_and_height
Browse files Browse the repository at this point in the history
  • Loading branch information
MattUnderscoreZhang authored Jan 18, 2024
2 parents 2636951 + 0f18b90 commit 50fc2a8
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 10 deletions.
2 changes: 1 addition & 1 deletion tests/test_subsamplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def test_audio_rate_subsampler(sample_rate, n_audio_channels):
audio_bytes = aud_f.read()

streams = {"audio": [audio_bytes]}
subsampler = AudioRateSubsampler(sample_rate, {"audio": "mp3"}, n_audio_channels)
subsampler = AudioRateSubsampler(sample_rate, "mp3", n_audio_channels)

subsampled_streams, _, error_message = subsampler(streams)
assert error_message is None
Expand Down
9 changes: 5 additions & 4 deletions video2dataset/subsamplers/audio_rate_subsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,13 @@ class AudioRateSubsampler:
"""
Adjusts the frame rate of the videos to the specified frame rate.
Args:
frame_rate (int): Target frame rate of the videos.
sample_rate (int): Target sample rate of the audio.
encode_format (str): Format to encode in (i.e. m4a)
"""

def __init__(self, sample_rate, encode_formats, n_audio_channels=None):
def __init__(self, sample_rate, encode_format, n_audio_channels=None):
self.sample_rate = sample_rate
self.encode_formats = encode_formats
self.encode_format = encode_format
self.n_audio_channels = n_audio_channels

def __call__(self, streams, metadata=None):
Expand All @@ -27,7 +28,7 @@ def __call__(self, streams, metadata=None):
with tempfile.TemporaryDirectory() as tmpdir:
with open(os.path.join(tmpdir, "input.m4a"), "wb") as f:
f.write(aud_bytes)
ext = self.encode_formats["audio"]
ext = self.encode_format
try:
# TODO: for now assuming m4a, change this
ffmpeg_args = {"ar": str(self.sample_rate), "f": ext}
Expand Down
4 changes: 3 additions & 1 deletion video2dataset/subsamplers/frame_subsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class FrameSubsampler(Subsampler):
yt_subtitle: temporary special case where you want a frame at the beginning of each yt_subtitle
we will want to turn this into something like frame_timestamps and introduce
this as a fusing option with clipping_subsampler
encode_format (str): Format to encode in (i.e. mp4)
TODO: n_frame
TODO: generalize interface, should be like (frame_rate, n_frames, sampler, output_format)
Expand All @@ -31,10 +32,11 @@ class FrameSubsampler(Subsampler):
# output_format - save as video, or images
"""

def __init__(self, frame_rate, downsample_method="fps"):
def __init__(self, frame_rate, downsample_method="fps", encode_format="mp4"):
self.frame_rate = frame_rate
self.downsample_method = downsample_method
self.output_modality = "video" if downsample_method == "fps" else "jpg"
self.encode_format = encode_format

def __call__(self, streams, metadata=None):
# TODO: you might not want to pop it (f.e. in case of other subsamplers)
Expand Down
7 changes: 5 additions & 2 deletions video2dataset/subsamplers/resolution_subsampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,23 @@ class ResolutionSubsampler(Subsampler):
height (int): Height of video.
width (int): Width of video.
video_size (int): Both height and width.
encode_format (str): Format to encode in (i.e. mp4)
"""

def __init__(
self,
resize_mode: Literal["scale", "crop", "pad"],
height: int = -1,
width: int = -1,
video_size: int = -1,
encode_format: str = "mp4",
):
if video_size > 0 and (height > 0 or width > 0):
raise Exception("Either set video_size, or set height and/or width")
self.resize_mode = resize_mode
self.height = height if video_size < 0 else video_size
self.width = width if video_size < 0 else video_size
self.resize_mode = resize_mode
self.video_size = video_size
self.encode_format = encode_format

def __call__(self, streams, metadata=None):
video_bytes = streams["video"]
Expand Down
10 changes: 8 additions & 2 deletions video2dataset/workers/subset_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,15 @@ def process_shard(
# The subsamplers might change the output format, so we need to update the writer
writer_encode_formats = self.encode_formats.copy()
if self.subsamplers["audio"]:
writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_formats["audio"]
assert (
len({s.encode_format for s in self.subsamplers["audio"]}) == 1
) # assert that all audio subsamplers have the same output format
writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_format
if self.subsamplers["video"]:
writer_encode_formats["video"] = self.subsamplers["video"][0].encode_formats["video"]
assert (
len({s.encode_format for s in self.subsamplers["video"]}) == 1
) # assert that all video subsamplers have the same output format
writer_encode_formats["video"] = self.subsamplers["video"][0].encode_format

# give schema to writer
sample_writer = self.sample_writer_class(
Expand Down

0 comments on commit 50fc2a8

Please sign in to comment.