From 2efa849602b6b6040d7f93690068a1b9842eff43 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Thu, 18 Jan 2024 10:03:29 -0500 Subject: [PATCH 01/23] ClippingSubsampler rewrite and bug fixes --- .../subsamplers/clipping_subsampler.py | 201 ++++++++++-------- 1 file changed, 115 insertions(+), 86 deletions(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 466fd8e8..6424eaf9 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -7,12 +7,26 @@ import ffmpeg import tempfile from collections.abc import Iterable +from typing import Annotated, TypedDict, Literal, cast import datetime from .subsampler import Subsampler -def _get_seconds(t): +ClipTimes = Annotated[list[float], 2] + + +class EncodeFormats(TypedDict): + video: str + audio: str + + +class Streams(TypedDict): + video: bytes + audio: bytes + + +def _get_seconds(t: str | float) -> float: if not isinstance(t, str): return float(t) # already seconds time_format = "%H:%M:%S.%f" # TODO: maybe parameterize this? @@ -20,7 +34,7 @@ def _get_seconds(t): return t_obj.second + t_obj.microsecond / 1e6 + t_obj.minute * 60 + t_obj.hour * 3600 -def _get_strtime(t_sec): +def _get_strtime(t_sec: float) -> str: hour = int(t_sec // 3600) minute = int((t_sec // 60) % 60) second = int(t_sec % 60) @@ -29,24 +43,20 @@ def _get_strtime(t_sec): return f"{hour:02d}:{minute:02d}:{second:02d}.{microsecond:03d}" -def _split_time_frame(s, e, min_length, max_length): +def _split_time_frame(s: float, e: float, min_length: float, max_length: float) -> list[ClipTimes]: """Filters out cuts by min and max length""" time_d = e - s - time_frames = [ - (s + i * max_length, min(s + (i + 1) * max_length, e)) - for i in range(int(time_d // max_length) + (1 if time_d % max_length > 0 else 0)) - ] - if len(time_frames) == 0: - return [] - last_time_d = time_frames[-1][1] - time_frames[-1][0] - time_frames = time_frames if last_time_d >= min_length else time_frames[:-1] - return time_frames - - -def _adjust_ranges_to_keyframes(ranges, keyframes): - """Translates ranges into keyframe vocab""" + n_full_clips = int(time_d // max_length) + clip_times = [[s + i * max_length, s + (i + 1) * max_length] for i in range(n_full_clips)] + ( + [[s + (n_full_clips - 1) * max_length, e]] if time_d % max_length > min_length else [] + ) + return clip_times + + +def _adjust_clip_times_to_keyframes(clip_times: list[ClipTimes], keyframes: list[float]) -> list[ClipTimes]: + """Translates clip_times into keyframe vocab""" adjusted_ranges = [] - for start, end in ranges: + for start, end in clip_times: keyframes_in_range = [k for k in keyframes if start <= k <= end] if keyframes_in_range: adjusted_start = min(keyframes_in_range) @@ -56,6 +66,52 @@ def _adjust_ranges_to_keyframes(ranges, keyframes): return adjusted_ranges +def _adjust_clip_times( + clip_times: list[ClipTimes], + keyframe_timestamps: list[float] | None, + min_length: float, + max_length: float, + max_length_strategy: str, +) -> list[ClipTimes]: + if not isinstance(clip_times[0], Iterable): # make sure clip_times looks like [[start, end]] and not [start, end] + clip_times = cast(list[ClipTimes], [clip_times]) + clip_times = [[_get_seconds(s), _get_seconds(e)] for [s, e] in clip_times] + + if keyframe_timestamps: + clip_times = _adjust_clip_times_to_keyframes(clip_times, keyframe_timestamps) + + filtered_clip_times = [] + for s, e in clip_times: + max_len_clip_times = _split_time_frame(s, e, min_length, max_length) + if max_length_strategy == "first": + max_len_clip_times = max_len_clip_times[:1] + filtered_clip_times += max_len_clip_times + return filtered_clip_times + + +def _get_clip_intervals(clip_times: list[ClipTimes]) -> tuple[str, list[int]]: + s_clip, e_clip = clip_times[0] + skip_first_interval = int(s_clip > 0.0) + + # which timestamp intervals to take, used to discard non-contiguous sections + intervals = [skip_first_interval] + timestamps = [0.0] + skip_first_interval * [s_clip] + [e_clip] + interval = 1 + skip_first_interval + for s, e in clip_times[1:]: + if s == e_clip: # situations like [0, 1], [1, 2], [2, 3] -> 1, 2 + timestamps += [e] + intervals.append(interval) + interval += 1 + else: + timestamps += [s, e] + intervals.append(interval + 1) + interval += 2 + e_clip = e + + timestamps = ",".join([str(time) for time in timestamps]) + return timestamps, intervals + + class ClippingSubsampler(Subsampler): """ Cuts videos up into segments according to the 'clips' metadata @@ -85,72 +141,59 @@ class ClippingSubsampler(Subsampler): def __init__( self, - oom_clip_count, - encode_formats, - min_length=0.0, - max_length=999999.0, - max_length_strategy="all", - precision="low", + oom_clip_count: int, + encode_formats: EncodeFormats, + min_length: float = 0.0, + max_length: float = 999999.0, + max_length_strategy: Literal["all", "first"] = "all", + precision: Literal["low", "keyframe_adjusted", "exact"] = "low", ): + assert max_length_strategy in ["all", "first"] + assert precision in ["exact", "low", "keyframe_adjusted"] self.oom_clip_count = oom_clip_count self.encode_formats = encode_formats self.min_length = min_length - self.max_length, self.max_length_strategy = max_length, max_length_strategy - assert precision in ["exact", "low", "keyframe_adjusted"] + self.max_length = max_length + self.max_length_strategy = max_length_strategy self.precision = precision def __call__(self, streams, metadata): - clips = metadata.pop("clips") - - if not isinstance(clips[0], Iterable): # make sure clips looks like [[start, end]] and not [start, end] - clips = [clips] + strtime_formatting = isinstance(metadata["clips"][0][0], str) - is_strtime = isinstance(clips[0][0], str) + clip_times = _adjust_clip_times( + clip_times=metadata.pop("clips"), + keyframe_timestamps=( + # TODO: make it so if keyframe timestamps not present, get it yourself + metadata["video_metadata"].pop("keyframe_timestamps") + if self.precision == "keyframe_adjusted" + else None + ), + min_length=self.min_length, + max_length=self.max_length, + max_length_strategy=self.max_length_strategy, + ) + if len(clip_times) == 0: + return {}, [], f"Video had no clip_times longer than {self.min_length}" - if self.precision == "keyframe_adjusted": - # TODO: make it so if not present, get it yourself - keyframe_timestamps = metadata["video_metadata"].pop("keyframe_timestamps") - s_clips = [[_get_seconds(s), _get_seconds(e)] for (s, e) in clips] - clips = _adjust_ranges_to_keyframes(s_clips, keyframe_timestamps) + timestamps, intervals = _get_clip_intervals(clip_times) - filtered_clips = [] - for s, e in clips: - max_len_clips = _split_time_frame(_get_seconds(s), _get_seconds(e), self.min_length, self.max_length) + ffmpeg_kwargs = { + "map": 0, + "f": "segment", + "segment_times": timestamps, + "reset_timestamps": 1, + } + if self.precision == "exact": + ffmpeg_kwargs["force_key_frames"] = timestamps + else: + ffmpeg_kwargs["c"] = "copy" - if self.max_length_strategy == "first": - max_len_clips = max_len_clips[:1] - filtered_clips += max_len_clips - clips = filtered_clips - if len(clips) == 0: - # return an error - return {}, [], f"Video had no clips longer than {self.min_length}" - start_0 = _get_seconds(clips[0][0]) == 0.0 - ind = 1 + int(not start_0) - s_p, e_p = clips[0] - s_p, e_p = _get_seconds(s_p), _get_seconds(e_p) - splits = (not start_0) * [s_p] + [e_p] - # list of indicies of clips to take, used to discard non-contiguous sections - take_inds = [int(not start_0)] - # TODO: make nicer - for s, e in clips[1:]: - s, e = _get_seconds(s), _get_seconds(e) - if s == e_p: # situations like [0, 1], [1, 2], [2, 3] -> 1, 2 - splits += [e] - take_inds.append(ind) - ind += 1 - else: - splits += [s, e] - take_inds.append(ind + 1) - ind += 2 - e_p = e - - segment_times = ",".join([str(spl) for spl in splits]) streams_clips = {} for k in streams.keys(): @@ -165,25 +208,11 @@ def __call__(self, streams, metadata): with open(os.path.join(tmpdir, f"input.{encode_format}"), "wb") as f: f.write(stream_bytes) try: - kwargs = { - "map": 0, - "f": "segment", - "segment_times": segment_times, - "reset_timestamps": 1, - } - - # Precision things, tradeoff for speed - if self.precision != "exact": - kwargs["c"] = "copy" - else: - kwargs["force_key_frames"] = segment_times - - _ = ( + ( ffmpeg.input(f"{tmpdir}/input.{encode_format}") - .output(f"{tmpdir}/clip_%d.{encode_format}", **kwargs) + .output(f"{tmpdir}/clip_%d.{encode_format}", **ffmpeg_kwargs) .run(capture_stdout=True, quiet=True) ) - except Exception as err: # pylint: disable=broad-except return {}, [], str(err) @@ -191,10 +220,10 @@ def __call__(self, streams, metadata): stream_clips.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) correct_clips = [] - for clip_id, (clip, ind) in enumerate(zip(clips, take_inds)): + for clip_id, (clip, ind) in enumerate(zip(clip_times, intervals)): if ind < len(stream_clips): correct_clips.append((clip_id, clip, stream_clips[ind])) - # clips_lost = len(take_inds) - len(correct_clips) # TODO report this somehow + # clips_lost = len(intervals) - len(correct_clips) # TODO report this somehow stream_clips, metadata_clips = [], [] for clip_id, clip_span, clip_pth in correct_clips: @@ -207,8 +236,8 @@ def __call__(self, streams, metadata): ) meta_clip = copy.deepcopy(metadata) # set the timeframe of this clip - if is_strtime: - # Keep clips in the original format to be compatible with the data schema. + if strtime_formatting: + # Keep clip_times in the original format to be compatible with the data schema. meta_clip["clips"] = [(_get_strtime(clip_span[0]), _get_strtime(clip_span[1]))] else: meta_clip["clips"] = [clip_span] From a5c9649b32af7541e7887a0460e4ecf46e855f4f Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Thu, 18 Jan 2024 11:59:45 -0500 Subject: [PATCH 02/23] More refactoring of ClippingSubsampler, plus a fix to _get_clip_intervals --- .../subsamplers/clipping_subsampler.py | 52 +++++++++---------- 1 file changed, 25 insertions(+), 27 deletions(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 6424eaf9..4d2e578e 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -55,15 +55,15 @@ def _split_time_frame(s: float, e: float, min_length: float, max_length: float) def _adjust_clip_times_to_keyframes(clip_times: list[ClipTimes], keyframes: list[float]) -> list[ClipTimes]: """Translates clip_times into keyframe vocab""" - adjusted_ranges = [] + adjusted_clip_times = [] for start, end in clip_times: keyframes_in_range = [k for k in keyframes if start <= k <= end] if keyframes_in_range: adjusted_start = min(keyframes_in_range) adjusted_end = max(keyframes_in_range) if adjusted_start != adjusted_end: - adjusted_ranges.append((adjusted_start, adjusted_end)) - return adjusted_ranges + adjusted_clip_times.append((adjusted_start, adjusted_end)) + return adjusted_clip_times def _adjust_clip_times( @@ -89,27 +89,25 @@ def _adjust_clip_times( return filtered_clip_times -def _get_clip_intervals(clip_times: list[ClipTimes]) -> tuple[str, list[int]]: - s_clip, e_clip = clip_times[0] - skip_first_interval = int(s_clip > 0.0) +def _get_clip_times(clip_times: list[ClipTimes]) -> tuple[str, list[int]]: + all_clip_times = [0.0] + clip_idxs = [] + e_prev = 0.0 + clip_idx = 0 - # which timestamp intervals to take, used to discard non-contiguous sections - intervals = [skip_first_interval] - timestamps = [0.0] + skip_first_interval * [s_clip] + [e_clip] - interval = 1 + skip_first_interval - for s, e in clip_times[1:]: - if s == e_clip: # situations like [0, 1], [1, 2], [2, 3] -> 1, 2 - timestamps += [e] - intervals.append(interval) - interval += 1 - else: - timestamps += [s, e] - intervals.append(interval + 1) - interval += 2 - e_clip = e + for s, e in clip_times: + if s == e_prev: # clip starts where last one left off + all_clip_times += [e] + clip_idxs.append(clip_idx) + clip_idx += 1 + else: # next clip skips over some time + all_clip_times += [s, e] + clip_idxs.append(clip_idx + 1) + clip_idx += 2 + e_prev = e - timestamps = ",".join([str(time) for time in timestamps]) - return timestamps, intervals + all_clip_times = ",".join([str(time) for time in all_clip_times]) + return all_clip_times, clip_idxs class ClippingSubsampler(Subsampler): @@ -175,16 +173,16 @@ def __call__(self, streams, metadata): if len(clip_times) == 0: return {}, [], f"Video had no clip_times longer than {self.min_length}" - timestamps, intervals = _get_clip_intervals(clip_times) + all_clip_times, clip_idxs = _get_clip_times(clip_times) ffmpeg_kwargs = { "map": 0, "f": "segment", - "segment_times": timestamps, + "segment_times": all_clip_times, "reset_timestamps": 1, } if self.precision == "exact": - ffmpeg_kwargs["force_key_frames"] = timestamps + ffmpeg_kwargs["force_key_frames"] = all_clip_times else: ffmpeg_kwargs["c"] = "copy" @@ -220,10 +218,10 @@ def __call__(self, streams, metadata): stream_clips.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) correct_clips = [] - for clip_id, (clip, ind) in enumerate(zip(clip_times, intervals)): + for clip_id, (clip, ind) in enumerate(zip(clip_times, clip_idxs)): if ind < len(stream_clips): correct_clips.append((clip_id, clip, stream_clips[ind])) - # clips_lost = len(intervals) - len(correct_clips) # TODO report this somehow + # clips_lost = len(clip_idxs) - len(correct_clips) # TODO report this somehow stream_clips, metadata_clips = [], [] for clip_id, clip_span, clip_pth in correct_clips: From 2cb5854b03760d4d75404c25ac3f553586c43876 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Thu, 18 Jan 2024 12:48:17 -0500 Subject: [PATCH 03/23] Finished refactoring ClippingSubsampler --- .../subsamplers/clipping_subsampler.py | 281 ++++++++++-------- 1 file changed, 159 insertions(+), 122 deletions(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 4d2e578e..9ae4ee60 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -13,7 +13,7 @@ from .subsampler import Subsampler -ClipTimes = Annotated[list[float], 2] +ClipSpans = Annotated[list[float], 2] class EncodeFormats(TypedDict): @@ -43,71 +43,180 @@ def _get_strtime(t_sec: float) -> str: return f"{hour:02d}:{minute:02d}:{second:02d}.{microsecond:03d}" -def _split_time_frame(s: float, e: float, min_length: float, max_length: float) -> list[ClipTimes]: +def _split_time_frame(s: float, e: float, min_length: float, max_length: float) -> list[ClipSpans]: """Filters out cuts by min and max length""" time_d = e - s n_full_clips = int(time_d // max_length) - clip_times = [[s + i * max_length, s + (i + 1) * max_length] for i in range(n_full_clips)] + ( + clip_spans = [[s + i * max_length, s + (i + 1) * max_length] for i in range(n_full_clips)] + ( [[s + (n_full_clips - 1) * max_length, e]] if time_d % max_length > min_length else [] ) - return clip_times + return clip_spans -def _adjust_clip_times_to_keyframes(clip_times: list[ClipTimes], keyframes: list[float]) -> list[ClipTimes]: - """Translates clip_times into keyframe vocab""" - adjusted_clip_times = [] - for start, end in clip_times: +def _adjust_clip_spans_to_keyframes(clip_spans: list[ClipSpans], keyframes: list[float]) -> list[ClipSpans]: + """Translates clip_spans into keyframe vocab""" + adjusted_clip_spans = [] + for start, end in clip_spans: keyframes_in_range = [k for k in keyframes if start <= k <= end] if keyframes_in_range: adjusted_start = min(keyframes_in_range) adjusted_end = max(keyframes_in_range) if adjusted_start != adjusted_end: - adjusted_clip_times.append((adjusted_start, adjusted_end)) - return adjusted_clip_times + adjusted_clip_spans.append((adjusted_start, adjusted_end)) + return adjusted_clip_spans -def _adjust_clip_times( - clip_times: list[ClipTimes], +def _adjust_clip_spans( + clip_spans: list[ClipSpans], keyframe_timestamps: list[float] | None, min_length: float, max_length: float, max_length_strategy: str, -) -> list[ClipTimes]: - if not isinstance(clip_times[0], Iterable): # make sure clip_times looks like [[start, end]] and not [start, end] - clip_times = cast(list[ClipTimes], [clip_times]) - clip_times = [[_get_seconds(s), _get_seconds(e)] for [s, e] in clip_times] +) -> list[ClipSpans]: + if not isinstance(clip_spans[0], Iterable): # make sure clip_spans looks like [[start, end]] and not [start, end] + clip_spans = cast(list[ClipSpans], [clip_spans]) + clip_spans = [[_get_seconds(s), _get_seconds(e)] for [s, e] in clip_spans] if keyframe_timestamps: - clip_times = _adjust_clip_times_to_keyframes(clip_times, keyframe_timestamps) + clip_spans = _adjust_clip_spans_to_keyframes(clip_spans, keyframe_timestamps) - filtered_clip_times = [] - for s, e in clip_times: - max_len_clip_times = _split_time_frame(s, e, min_length, max_length) + filtered_clip_spans = [] + for s, e in clip_spans: + max_len_clip_spans = _split_time_frame(s, e, min_length, max_length) if max_length_strategy == "first": - max_len_clip_times = max_len_clip_times[:1] - filtered_clip_times += max_len_clip_times - return filtered_clip_times + max_len_clip_spans = max_len_clip_spans[:1] + filtered_clip_spans += max_len_clip_spans + return filtered_clip_spans -def _get_clip_times(clip_times: list[ClipTimes]) -> tuple[str, list[int]]: - all_clip_times = [0.0] +def _get_clip_spans(clip_spans: list[ClipSpans]) -> tuple[str, list[int]]: + segment_times = [0.0] clip_idxs = [] e_prev = 0.0 clip_idx = 0 - for s, e in clip_times: + for s, e in clip_spans: if s == e_prev: # clip starts where last one left off - all_clip_times += [e] + segment_times += [e] clip_idxs.append(clip_idx) clip_idx += 1 else: # next clip skips over some time - all_clip_times += [s, e] + segment_times += [s, e] clip_idxs.append(clip_idx + 1) clip_idx += 2 e_prev = e - all_clip_times = ",".join([str(time) for time in all_clip_times]) - return all_clip_times, clip_idxs + segment_times = ",".join([str(time) for time in segment_times]) + return segment_times, clip_idxs + + +def _process_stream(stream_bytes: bytes, encode_format: str, ffmpeg_kwargs: dict) -> list[str]: + with tempfile.TemporaryDirectory() as tmpdir: + # TODO: we need to put the extension into the metadata + # TODO: This can be done better using pipes I just don't feel like sinking too much time into this rn + with open(os.path.join(tmpdir, f"input.{encode_format}"), "wb") as f: + f.write(stream_bytes) + try: + ( + ffmpeg.input(f"{tmpdir}/input.{encode_format}") + .output(f"{tmpdir}/clip_%d.{encode_format}", **ffmpeg_kwargs) + .run(capture_stdout=True, quiet=True) + ) + except Exception as err: # pylint: disable=broad-except + raise err + stream_clips = glob.glob(f"{tmpdir}/clip*.{encode_format}") + stream_clips.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) + return stream_clips + + +def _get_clip_metadata( + clip_spans: list[ClipSpans], + clip_idxs: list[int], + metadata: dict, + oom_clip_count: int, + strtime_formatting: bool, +) -> list[dict]: + metadata_clips = [] + for clip_id, (clip_span, _) in enumerate(zip(clip_spans, clip_idxs)): + clip_key = "{clip_id:0{oom_clip_count}d}".format( # pylint: disable=consider-using-f-string + clip_id=clip_id, oom_clip_count=oom_clip_count + ) + meta_clip = copy.deepcopy(metadata) + # set the timeframe of this clip + if strtime_formatting: + # Keep clip_spans in the original format to be compatible with the data schema. + meta_clip["clips"] = [(_get_strtime(clip_span[0]), _get_strtime(clip_span[1]))] + else: + meta_clip["clips"] = [clip_span] + meta_clip["key"] = f"{meta_clip['key']}_{clip_key}" + + yt_md_dict = meta_clip.get("yt_meta_dict", {}) + if (yt_md_dict is not None) and (yt_md_dict.get("subtitles", None) is not None): + clip_subtitles = [] + s_c, e_c = _get_seconds(clip_span[0]), _get_seconds(clip_span[1]) + for line in meta_clip["yt_meta_dict"]["subtitles"]: + s, e = _get_seconds(line["start"]), _get_seconds(line["end"]) + if max(s_c, s) < min(e_c, e): + clip_subtitles.append(line) + elif s > e_c: + break + # full video subtitles might still be useful for context + meta_clip["clip_subtitles"] = clip_subtitles + metadata_clips.append(meta_clip) + return metadata_clips + + +def _get_clips( + streams: Streams, + encode_formats: EncodeFormats, + precision: str, + clip_spans: list[ClipSpans], + metadata: dict, + oom_clip_count: int, + strtime_formatting: bool, +) -> tuple[dict[str, list[str]], list[dict]]: + segment_times, clip_idxs = _get_clip_spans(clip_spans) + + ffmpeg_kwargs = { + "map": 0, + "f": "segment", + "segment_times": segment_times, + "reset_timestamps": 1, + } + if precision == "exact": + ffmpeg_kwargs["force_key_frames"] = segment_times + else: + ffmpeg_kwargs["c"] = "copy" + + clips = {} + for k in streams.keys(): + stream_bytes = streams[k][0] # pre-broadcast so only one + if stream_bytes is None: + continue + try: + stream_clips = _process_stream( + stream_bytes=stream_bytes, + encode_format=encode_formats[k], + ffmpeg_kwargs=ffmpeg_kwargs, + ) + except Exception as err: + raise err + + clips[k] = [] + for _, (_, clip_idx) in enumerate(zip(clip_spans, clip_idxs)): + with open(stream_clips[clip_idx], "rb") as vid_f: + clip_bytes = vid_f.read() + clips[k].append(clip_bytes) + + clip_metadata = _get_clip_metadata( + clip_spans=clip_spans, + clip_idxs=clip_idxs, + metadata=metadata, + oom_clip_count=oom_clip_count, + strtime_formatting=strtime_formatting, + ) + + return clips, clip_metadata class ClippingSubsampler(Subsampler): @@ -155,11 +264,11 @@ def __init__( self.max_length_strategy = max_length_strategy self.precision = precision - def __call__(self, streams, metadata): + def __call__(self, streams: Streams, metadata: dict): strtime_formatting = isinstance(metadata["clips"][0][0], str) - clip_times = _adjust_clip_times( - clip_times=metadata.pop("clips"), + clip_spans = _adjust_clip_spans( + clip_spans=metadata.pop("clips"), keyframe_timestamps=( # TODO: make it so if keyframe timestamps not present, get it yourself metadata["video_metadata"].pop("keyframe_timestamps") @@ -170,92 +279,20 @@ def __call__(self, streams, metadata): max_length=self.max_length, max_length_strategy=self.max_length_strategy, ) - if len(clip_times) == 0: - return {}, [], f"Video had no clip_times longer than {self.min_length}" - - all_clip_times, clip_idxs = _get_clip_times(clip_times) - - ffmpeg_kwargs = { - "map": 0, - "f": "segment", - "segment_times": all_clip_times, - "reset_timestamps": 1, - } - if self.precision == "exact": - ffmpeg_kwargs["force_key_frames"] = all_clip_times - else: - ffmpeg_kwargs["c"] = "copy" - - - - - - - - streams_clips = {} - - for k in streams.keys(): - stream_bytes = streams[k][0] # pre-broadcast so only one - if stream_bytes is None: - continue - encode_format = self.encode_formats[k] - - with tempfile.TemporaryDirectory() as tmpdir: - # TODO: we need to put the extension into the metadata - # TODO: This can be done better using pipes I just don't feel like sinking too much time into this rn - with open(os.path.join(tmpdir, f"input.{encode_format}"), "wb") as f: - f.write(stream_bytes) - try: - ( - ffmpeg.input(f"{tmpdir}/input.{encode_format}") - .output(f"{tmpdir}/clip_%d.{encode_format}", **ffmpeg_kwargs) - .run(capture_stdout=True, quiet=True) - ) - except Exception as err: # pylint: disable=broad-except - return {}, [], str(err) - - stream_clips = glob.glob(f"{tmpdir}/clip*.{encode_format}") - stream_clips.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) - - correct_clips = [] - for clip_id, (clip, ind) in enumerate(zip(clip_times, clip_idxs)): - if ind < len(stream_clips): - correct_clips.append((clip_id, clip, stream_clips[ind])) - # clips_lost = len(clip_idxs) - len(correct_clips) # TODO report this somehow - - stream_clips, metadata_clips = [], [] - for clip_id, clip_span, clip_pth in correct_clips: - with open(clip_pth, "rb") as vid_f: - clip_bytes = vid_f.read() - stream_clips.append(clip_bytes) - - clip_key = "{clip_id:0{oom_clip_count}d}".format( # pylint: disable=consider-using-f-string - clip_id=clip_id, oom_clip_count=self.oom_clip_count - ) - meta_clip = copy.deepcopy(metadata) - # set the timeframe of this clip - if strtime_formatting: - # Keep clip_times in the original format to be compatible with the data schema. - meta_clip["clips"] = [(_get_strtime(clip_span[0]), _get_strtime(clip_span[1]))] - else: - meta_clip["clips"] = [clip_span] - meta_clip["key"] = f"{meta_clip['key']}_{clip_key}" - - yt_md_dict = meta_clip.get("yt_meta_dict", {}) - if (yt_md_dict is not None) and (yt_md_dict.get("subtitles", None) is not None): - clip_subtitles = [] - s_c, e_c = _get_seconds(clip_span[0]), _get_seconds(clip_span[1]) - for line in meta_clip["yt_meta_dict"]["subtitles"]: - s, e = _get_seconds(line["start"]), _get_seconds(line["end"]) - if max(s_c, s) < min(e_c, e): - clip_subtitles.append(line) - elif s > e_c: - break - # full video subtitles might still be useful for context - meta_clip["clip_subtitles"] = clip_subtitles - - metadata_clips.append(meta_clip) - - streams_clips[k] = stream_clips - - return streams_clips, metadata_clips, None + if len(clip_spans) == 0: + return {}, [], f"Video had no clip_spans longer than {self.min_length}" + + try: + clips, clip_metadata = _get_clips( + streams=streams, + encode_formats=self.encode_formats, + precision=self.precision, + clip_spans=clip_spans, + metadata=metadata, + oom_clip_count=self.oom_clip_count, + strtime_formatting=strtime_formatting, + ) + except Exception as err: + return {}, [], str(err) + + return clips, clip_metadata, None From 5d03b720e10345a5492de4d0c8529a354c3a48fc Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Thu, 18 Jan 2024 20:48:19 -0500 Subject: [PATCH 04/23] Final code changes --- .../subsamplers/clipping_subsampler.py | 94 ++++++++++--------- 1 file changed, 50 insertions(+), 44 deletions(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 9ae4ee60..a0960ff0 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -7,7 +7,7 @@ import ffmpeg import tempfile from collections.abc import Iterable -from typing import Annotated, TypedDict, Literal, cast +from typing import Any, Annotated, TypedDict, Literal, cast import datetime from .subsampler import Subsampler @@ -89,42 +89,46 @@ def _adjust_clip_spans( return filtered_clip_spans -def _get_clip_spans(clip_spans: list[ClipSpans]) -> tuple[str, list[int]]: - segment_times = [0.0] +def _collate_clip_spans(clip_spans: list[ClipSpans]) -> tuple[str, list[int]]: + clip_times = [0.0] clip_idxs = [] e_prev = 0.0 clip_idx = 0 for s, e in clip_spans: if s == e_prev: # clip starts where last one left off - segment_times += [e] + clip_times += [e] clip_idxs.append(clip_idx) clip_idx += 1 else: # next clip skips over some time - segment_times += [s, e] + clip_times += [s, e] clip_idxs.append(clip_idx + 1) clip_idx += 2 e_prev = e - segment_times = ",".join([str(time) for time in segment_times]) - return segment_times, clip_idxs - - -def _process_stream(stream_bytes: bytes, encode_format: str, ffmpeg_kwargs: dict) -> list[str]: - with tempfile.TemporaryDirectory() as tmpdir: - # TODO: we need to put the extension into the metadata - # TODO: This can be done better using pipes I just don't feel like sinking too much time into this rn - with open(os.path.join(tmpdir, f"input.{encode_format}"), "wb") as f: - f.write(stream_bytes) - try: - ( - ffmpeg.input(f"{tmpdir}/input.{encode_format}") - .output(f"{tmpdir}/clip_%d.{encode_format}", **ffmpeg_kwargs) - .run(capture_stdout=True, quiet=True) - ) - except Exception as err: # pylint: disable=broad-except - raise err - stream_clips = glob.glob(f"{tmpdir}/clip*.{encode_format}") + clip_times = ",".join([str(time) for time in clip_times]) + return clip_times, clip_idxs + + +def _process_stream( + tmpdir: Any, # BytesPath + stream_bytes: bytes, + encode_format: str, + ffmpeg_kwargs: dict, +) -> list[str]: + # TODO: we need to put the extension into the metadata + # TODO: This can be done better using pipes I just don't feel like sinking too much time into this rn + with open(os.path.join(tmpdir, f"input.{encode_format}"), "wb") as f: + f.write(stream_bytes) + try: + ( + ffmpeg.input(f"{tmpdir}/input.{encode_format}") + .output(f"{tmpdir}/clip_%d.{encode_format}", **ffmpeg_kwargs) + .run(capture_stdout=True, quiet=True) + ) + except Exception as err: # pylint: disable=broad-except + raise err + stream_clips = glob.glob(f"{tmpdir}/clip*.{encode_format}") stream_clips.sort(key=lambda x: int(x.split("_")[-1].split(".")[0])) return stream_clips @@ -175,38 +179,40 @@ def _get_clips( oom_clip_count: int, strtime_formatting: bool, ) -> tuple[dict[str, list[str]], list[dict]]: - segment_times, clip_idxs = _get_clip_spans(clip_spans) + clip_times, clip_idxs = _collate_clip_spans(clip_spans) ffmpeg_kwargs = { "map": 0, "f": "segment", - "segment_times": segment_times, + "segment_times": clip_times, "reset_timestamps": 1, } if precision == "exact": - ffmpeg_kwargs["force_key_frames"] = segment_times + ffmpeg_kwargs["force_key_frames"] = clip_times else: ffmpeg_kwargs["c"] = "copy" clips = {} for k in streams.keys(): - stream_bytes = streams[k][0] # pre-broadcast so only one - if stream_bytes is None: - continue - try: - stream_clips = _process_stream( - stream_bytes=stream_bytes, - encode_format=encode_formats[k], - ffmpeg_kwargs=ffmpeg_kwargs, - ) - except Exception as err: - raise err - - clips[k] = [] - for _, (_, clip_idx) in enumerate(zip(clip_spans, clip_idxs)): - with open(stream_clips[clip_idx], "rb") as vid_f: - clip_bytes = vid_f.read() - clips[k].append(clip_bytes) + with tempfile.TemporaryDirectory() as tmpdir: + stream_bytes = streams[k][0] # pre-broadcast so only one + if stream_bytes is None: + continue + try: + stream_clips = _process_stream( + tmpdir=tmpdir, + stream_bytes=stream_bytes, + encode_format=encode_formats[k], + ffmpeg_kwargs=ffmpeg_kwargs, + ) + except Exception as err: + raise err + + clips[k] = [] + for _, (_, clip_idx) in enumerate(zip(clip_spans, clip_idxs)): + with open(stream_clips[clip_idx], "rb") as vid_f: + clip_bytes = vid_f.read() + clips[k].append(clip_bytes) clip_metadata = _get_clip_metadata( clip_spans=clip_spans, From 47c7d647a5fc36cd11953aaa5baaf000f200458e Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Thu, 18 Jan 2024 20:53:16 -0500 Subject: [PATCH 05/23] Added docstrings --- video2dataset/subsamplers/clipping_subsampler.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index a0960ff0..3f18d703 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -27,6 +27,7 @@ class Streams(TypedDict): def _get_seconds(t: str | float) -> float: + """Converts time to seconds""" if not isinstance(t, str): return float(t) # already seconds time_format = "%H:%M:%S.%f" # TODO: maybe parameterize this? @@ -35,6 +36,7 @@ def _get_seconds(t: str | float) -> float: def _get_strtime(t_sec: float) -> str: + """Converts time to string""" hour = int(t_sec // 3600) minute = int((t_sec // 60) % 60) second = int(t_sec % 60) @@ -73,6 +75,7 @@ def _adjust_clip_spans( max_length: float, max_length_strategy: str, ) -> list[ClipSpans]: + """Adjusts cut times around keyframes, filtering by min and max length""" if not isinstance(clip_spans[0], Iterable): # make sure clip_spans looks like [[start, end]] and not [start, end] clip_spans = cast(list[ClipSpans], [clip_spans]) clip_spans = [[_get_seconds(s), _get_seconds(e)] for [s, e] in clip_spans] @@ -90,6 +93,7 @@ def _adjust_clip_spans( def _collate_clip_spans(clip_spans: list[ClipSpans]) -> tuple[str, list[int]]: + """Collates clip spans into a single string for ffmpeg and a list of clip idxs""" clip_times = [0.0] clip_idxs = [] e_prev = 0.0 @@ -116,6 +120,7 @@ def _process_stream( encode_format: str, ffmpeg_kwargs: dict, ) -> list[str]: + """Processes a stream into clips using ffmpeg""" # TODO: we need to put the extension into the metadata # TODO: This can be done better using pipes I just don't feel like sinking too much time into this rn with open(os.path.join(tmpdir, f"input.{encode_format}"), "wb") as f: @@ -140,6 +145,7 @@ def _get_clip_metadata( oom_clip_count: int, strtime_formatting: bool, ) -> list[dict]: + """Gets metadata for each clip""" metadata_clips = [] for clip_id, (clip_span, _) in enumerate(zip(clip_spans, clip_idxs)): clip_key = "{clip_id:0{oom_clip_count}d}".format( # pylint: disable=consider-using-f-string @@ -179,6 +185,7 @@ def _get_clips( oom_clip_count: int, strtime_formatting: bool, ) -> tuple[dict[str, list[str]], list[dict]]: + """Gets clips from streams""" clip_times, clip_idxs = _collate_clip_spans(clip_spans) ffmpeg_kwargs = { From 5aa84d49d95535fd3db4f109ef94ba8238973e87 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Thu, 18 Jan 2024 21:35:15 -0500 Subject: [PATCH 06/23] Passed tests and linting --- tests/test_subsamplers.py | 10 +++++----- video2dataset/subsamplers/__init__.py | 2 +- video2dataset/subsamplers/clipping_subsampler.py | 10 +++++----- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/test_subsamplers.py b/tests/test_subsamplers.py index e6a5b5f0..28ace480 100644 --- a/tests/test_subsamplers.py +++ b/tests/test_subsamplers.py @@ -11,6 +11,7 @@ ClippingSubsampler, _get_seconds, _split_time_frame, + Streams, FFProbeSubsampler, ResolutionSubsampler, FrameSubsampler, @@ -45,8 +46,8 @@ def test_clipping_subsampler(clips): min_length = 5.0 if clips == MULTI else 2.0 max_length = 999999.0 if clips == MULTI else 3.0 subsampler = ClippingSubsampler( - 3, - {"video": "mp4", "audio": "mp3"}, + oom_clip_count=3, + encode_formats={"video": "mp4", "audio": "mp3"}, min_length=min_length, max_length=max_length, max_length_strategy="all", @@ -58,7 +59,7 @@ def test_clipping_subsampler(clips): "clips": clips, } - streams = {"video": [video_bytes], "audio": [audio_bytes]} + streams: Streams = {"video": [video_bytes], "audio": [audio_bytes]} stream_fragments, meta_fragments, error_message = subsampler(streams, metadata) video_fragments = stream_fragments["video"] audio_fragments = stream_fragments["audio"] @@ -84,7 +85,7 @@ def test_clipping_subsampler(clips): s_target, e_target = clips[key_ind] s_target, e_target = _get_seconds(s_target), _get_seconds(e_target) expected_clips = _split_time_frame(s_target, e_target, min_length, max_length) - assert (_get_seconds(s), _get_seconds(e)) in expected_clips + assert [_get_seconds(s), _get_seconds(e)] in expected_clips assert _get_seconds(e) - _get_seconds(s) >= min_length s_s, e_s = _get_seconds(s), _get_seconds(e) @@ -92,7 +93,6 @@ def test_clipping_subsampler(clips): video_stream = [stream for stream in probe["streams"] if stream["codec_type"] == "video"][0] frag_len = float(video_stream["duration"]) - # currently some segments can be pretty innacurate assert abs(frag_len - (e_s - s_s)) < 5.0 diff --git a/video2dataset/subsamplers/__init__.py b/video2dataset/subsamplers/__init__.py index 5d4741f8..90e4cd58 100644 --- a/video2dataset/subsamplers/__init__.py +++ b/video2dataset/subsamplers/__init__.py @@ -3,7 +3,7 @@ """ from .audio_rate_subsampler import AudioRateSubsampler -from .clipping_subsampler import ClippingSubsampler, _get_seconds, _split_time_frame +from .clipping_subsampler import ClippingSubsampler, _get_seconds, _split_time_frame, Streams from .frame_subsampler import FrameSubsampler from .ffprobe_subsampler import FFProbeSubsampler from .noop_subsampler import NoOpSubsampler diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 3f18d703..b3ae717a 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -22,8 +22,8 @@ class EncodeFormats(TypedDict): class Streams(TypedDict): - video: bytes - audio: bytes + video: list[bytes] + audio: list[bytes] def _get_seconds(t: str | float) -> float: @@ -50,7 +50,7 @@ def _split_time_frame(s: float, e: float, min_length: float, max_length: float) time_d = e - s n_full_clips = int(time_d // max_length) clip_spans = [[s + i * max_length, s + (i + 1) * max_length] for i in range(n_full_clips)] + ( - [[s + (n_full_clips - 1) * max_length, e]] if time_d % max_length > min_length else [] + [[s + (n_full_clips) * max_length, e]] if time_d % max_length > min_length else [] ) return clip_spans @@ -94,7 +94,7 @@ def _adjust_clip_spans( def _collate_clip_spans(clip_spans: list[ClipSpans]) -> tuple[str, list[int]]: """Collates clip spans into a single string for ffmpeg and a list of clip idxs""" - clip_times = [0.0] + clip_times = [] clip_idxs = [] e_prev = 0.0 clip_idx = 0 @@ -216,7 +216,7 @@ def _get_clips( raise err clips[k] = [] - for _, (_, clip_idx) in enumerate(zip(clip_spans, clip_idxs)): + for clip_idx in clip_idxs: with open(stream_clips[clip_idx], "rb") as vid_f: clip_bytes = vid_f.read() clips[k].append(clip_bytes) From 140e1abbe4445916b5f81347673adba1f7e9ebbe Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Thu, 18 Jan 2024 21:48:17 -0500 Subject: [PATCH 07/23] Made type annotations consistent with Python 3.8 --- .../subsamplers/clipping_subsampler.py | 36 +++++++++---------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index b3ae717a..25c7f665 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -7,13 +7,13 @@ import ffmpeg import tempfile from collections.abc import Iterable -from typing import Any, Annotated, TypedDict, Literal, cast +from typing import Any, Union, List, TypedDict, Literal, cast import datetime from .subsampler import Subsampler -ClipSpans = Annotated[list[float], 2] +ClipSpans = List[float] # [start, end] class EncodeFormats(TypedDict): @@ -22,11 +22,11 @@ class EncodeFormats(TypedDict): class Streams(TypedDict): - video: list[bytes] - audio: list[bytes] + video: List[bytes] + audio: List[bytes] -def _get_seconds(t: str | float) -> float: +def _get_seconds(t: Union[str, float]) -> float: """Converts time to seconds""" if not isinstance(t, str): return float(t) # already seconds @@ -45,7 +45,7 @@ def _get_strtime(t_sec: float) -> str: return f"{hour:02d}:{minute:02d}:{second:02d}.{microsecond:03d}" -def _split_time_frame(s: float, e: float, min_length: float, max_length: float) -> list[ClipSpans]: +def _split_time_frame(s: float, e: float, min_length: float, max_length: float) -> List[ClipSpans]: """Filters out cuts by min and max length""" time_d = e - s n_full_clips = int(time_d // max_length) @@ -55,7 +55,7 @@ def _split_time_frame(s: float, e: float, min_length: float, max_length: float) return clip_spans -def _adjust_clip_spans_to_keyframes(clip_spans: list[ClipSpans], keyframes: list[float]) -> list[ClipSpans]: +def _adjust_clip_spans_to_keyframes(clip_spans: List[ClipSpans], keyframes: List[float]) -> List[ClipSpans]: """Translates clip_spans into keyframe vocab""" adjusted_clip_spans = [] for start, end in clip_spans: @@ -69,15 +69,15 @@ def _adjust_clip_spans_to_keyframes(clip_spans: list[ClipSpans], keyframes: list def _adjust_clip_spans( - clip_spans: list[ClipSpans], - keyframe_timestamps: list[float] | None, + clip_spans: List[ClipSpans], + keyframe_timestamps: List[float] | None, min_length: float, max_length: float, max_length_strategy: str, -) -> list[ClipSpans]: +) -> List[ClipSpans]: """Adjusts cut times around keyframes, filtering by min and max length""" if not isinstance(clip_spans[0], Iterable): # make sure clip_spans looks like [[start, end]] and not [start, end] - clip_spans = cast(list[ClipSpans], [clip_spans]) + clip_spans = cast(List[ClipSpans], [clip_spans]) clip_spans = [[_get_seconds(s), _get_seconds(e)] for [s, e] in clip_spans] if keyframe_timestamps: @@ -92,7 +92,7 @@ def _adjust_clip_spans( return filtered_clip_spans -def _collate_clip_spans(clip_spans: list[ClipSpans]) -> tuple[str, list[int]]: +def _collate_clip_spans(clip_spans: List[ClipSpans]) -> tuple[str, List[int]]: """Collates clip spans into a single string for ffmpeg and a list of clip idxs""" clip_times = [] clip_idxs = [] @@ -119,7 +119,7 @@ def _process_stream( stream_bytes: bytes, encode_format: str, ffmpeg_kwargs: dict, -) -> list[str]: +) -> List[str]: """Processes a stream into clips using ffmpeg""" # TODO: we need to put the extension into the metadata # TODO: This can be done better using pipes I just don't feel like sinking too much time into this rn @@ -139,12 +139,12 @@ def _process_stream( def _get_clip_metadata( - clip_spans: list[ClipSpans], - clip_idxs: list[int], + clip_spans: List[ClipSpans], + clip_idxs: List[int], metadata: dict, oom_clip_count: int, strtime_formatting: bool, -) -> list[dict]: +) -> List[dict]: """Gets metadata for each clip""" metadata_clips = [] for clip_id, (clip_span, _) in enumerate(zip(clip_spans, clip_idxs)): @@ -180,11 +180,11 @@ def _get_clips( streams: Streams, encode_formats: EncodeFormats, precision: str, - clip_spans: list[ClipSpans], + clip_spans: List[ClipSpans], metadata: dict, oom_clip_count: int, strtime_formatting: bool, -) -> tuple[dict[str, list[str]], list[dict]]: +) -> tuple[dict[str, List[str]], List[dict]]: """Gets clips from streams""" clip_times, clip_idxs = _collate_clip_spans(clip_spans) From 077ca27e78d713ff5b13c69692a66ed2dc95381d Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Thu, 18 Jan 2024 21:59:26 -0500 Subject: [PATCH 08/23] More annotation fixes --- .../subsamplers/clipping_subsampler.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 25c7f665..317c6f92 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -7,7 +7,7 @@ import ffmpeg import tempfile from collections.abc import Iterable -from typing import Any, Union, List, TypedDict, Literal, cast +from typing import Any, Union, List, Tuple, Dict, TypedDict, Literal, cast import datetime from .subsampler import Subsampler @@ -64,13 +64,13 @@ def _adjust_clip_spans_to_keyframes(clip_spans: List[ClipSpans], keyframes: List adjusted_start = min(keyframes_in_range) adjusted_end = max(keyframes_in_range) if adjusted_start != adjusted_end: - adjusted_clip_spans.append((adjusted_start, adjusted_end)) + adjusted_clip_spans.append([adjusted_start, adjusted_end]) return adjusted_clip_spans def _adjust_clip_spans( clip_spans: List[ClipSpans], - keyframe_timestamps: List[float] | None, + keyframe_timestamps: Union[List[float], None], min_length: float, max_length: float, max_length_strategy: str, @@ -92,7 +92,7 @@ def _adjust_clip_spans( return filtered_clip_spans -def _collate_clip_spans(clip_spans: List[ClipSpans]) -> tuple[str, List[int]]: +def _collate_clip_spans(clip_spans: List[ClipSpans]) -> Tuple[str, List[int]]: """Collates clip spans into a single string for ffmpeg and a list of clip idxs""" clip_times = [] clip_idxs = [] @@ -110,8 +110,8 @@ def _collate_clip_spans(clip_spans: List[ClipSpans]) -> tuple[str, List[int]]: clip_idx += 2 e_prev = e - clip_times = ",".join([str(time) for time in clip_times]) - return clip_times, clip_idxs + clip_times_str = ",".join([str(time) for time in clip_times]) + return clip_times_str, clip_idxs def _process_stream( @@ -184,7 +184,7 @@ def _get_clips( metadata: dict, oom_clip_count: int, strtime_formatting: bool, -) -> tuple[dict[str, List[str]], List[dict]]: +) -> Tuple[Dict[str, List[bytes]], List[dict]]: """Gets clips from streams""" clip_times, clip_idxs = _collate_clip_spans(clip_spans) @@ -199,8 +199,10 @@ def _get_clips( else: ffmpeg_kwargs["c"] = "copy" - clips = {} - for k in streams.keys(): + clips: Dict[str, List[bytes]] = {} + for k in Streams.__annotations__.keys(): + if k not in streams: + continue with tempfile.TemporaryDirectory() as tmpdir: stream_bytes = streams[k][0] # pre-broadcast so only one if stream_bytes is None: From 32fa4eaf760302a011bf263876c9f7bb17313205 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Thu, 18 Jan 2024 22:05:04 -0500 Subject: [PATCH 09/23] The Python 3.8 annotation needs a lot of hand-holding, it seems --- video2dataset/subsamplers/clipping_subsampler.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 317c6f92..3c07e2de 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -200,9 +200,8 @@ def _get_clips( ffmpeg_kwargs["c"] = "copy" clips: Dict[str, List[bytes]] = {} - for k in Streams.__annotations__.keys(): - if k not in streams: - continue + for k in streams.keys(): + k = cast(Literal["audio", "video"], k) with tempfile.TemporaryDirectory() as tmpdir: stream_bytes = streams[k][0] # pre-broadcast so only one if stream_bytes is None: From 5a8957fce3285632bbf566c5c577f498b407415f Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Fri, 19 Jan 2024 00:00:31 -0500 Subject: [PATCH 10/23] Pylint has to cut it out, I swear to God --- video2dataset/subsamplers/clipping_subsampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 3c07e2de..2af9a93c 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -213,7 +213,7 @@ def _get_clips( encode_format=encode_formats[k], ffmpeg_kwargs=ffmpeg_kwargs, ) - except Exception as err: + except Exception as err: # pylint: disable=broad-except raise err clips[k] = [] @@ -306,7 +306,7 @@ def __call__(self, streams: Streams, metadata: dict): oom_clip_count=self.oom_clip_count, strtime_formatting=strtime_formatting, ) - except Exception as err: + except Exception as err: # pylint: disable=broad-except return {}, [], str(err) return clips, clip_metadata, None From f0f01688fe3d60069d51af3fe61565d8e35eda04 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Fri, 19 Jan 2024 08:15:29 -0500 Subject: [PATCH 11/23] No real change, just relauching unit tests which failed due to connection timeouts --- video2dataset/subsamplers/clipping_subsampler.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 2af9a93c..439fd7b9 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -1,13 +1,13 @@ """ clipping subsampler turns full videos into clips of videos according to clip_col """ -import os +from collections.abc import Iterable +from typing import Any, Union, List, Tuple, Dict, TypedDict, Literal, cast import copy -import glob import ffmpeg +import glob +import os import tempfile -from collections.abc import Iterable -from typing import Any, Union, List, Tuple, Dict, TypedDict, Literal, cast import datetime from .subsampler import Subsampler From 1df88dd6b1bc1a8f3236418a863703a200aaa019 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Sun, 21 Jan 2024 22:46:56 -0500 Subject: [PATCH 12/23] Linting issue --- video2dataset/subsamplers/clipping_subsampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index df68fb46..508c6ed8 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -140,7 +140,7 @@ def _process_stream( def _extract_subtitles(clip_span: ClipSpan, meta_clip: dict) -> List[dict]: """Extracts subtitles and groups them by language""" - clip_subtitles = [] + clip_subtitles: list[dict] = [] s_c, e_c = _get_seconds(clip_span[0]), _get_seconds(clip_span[1]) for lang_id, (lang, subtitles) in enumerate(meta_clip["yt_meta_dict"]["subtitles"].items()): idx = 0 From 226fba3bbf5ae98c689dc1f95f911a0532b6fe5c Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Sun, 21 Jan 2024 22:51:59 -0500 Subject: [PATCH 13/23] Another linting issue --- video2dataset/subsamplers/clipping_subsampler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 508c6ed8..73eae18f 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -140,7 +140,7 @@ def _process_stream( def _extract_subtitles(clip_span: ClipSpan, meta_clip: dict) -> List[dict]: """Extracts subtitles and groups them by language""" - clip_subtitles: list[dict] = [] + clip_subtitles: List[dict] = [] s_c, e_c = _get_seconds(clip_span[0]), _get_seconds(clip_span[1]) for lang_id, (lang, subtitles) in enumerate(meta_clip["yt_meta_dict"]["subtitles"].items()): idx = 0 From 8ed5074a2a1ded55c19f4048be0569e1d001ec10 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Tue, 23 Jan 2024 21:52:25 -0500 Subject: [PATCH 14/23] Separated per-shard code from code that should only be executed once --- .../subsamplers/clipping_subsampler.py | 17 +- video2dataset/types.py | 11 + video2dataset/workers/subset_worker.py | 191 +++++++++--------- 3 files changed, 116 insertions(+), 103 deletions(-) create mode 100644 video2dataset/types.py diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 73eae18f..2e7c7d96 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -2,30 +2,21 @@ clipping subsampler turns full videos into clips of videos according to clip_col """ from collections.abc import Iterable -from typing import Any, Union, List, Tuple, Dict, TypedDict, Literal, cast import copy +import datetime import ffmpeg import glob import os import tempfile +from typing import Any, Union, List, Tuple, Dict, Literal, cast -import datetime -from .subsampler import Subsampler +from video2dataset.subsamplers.subsampler import Subsampler +from video2dataset.types import EncodeFormats, Streams ClipSpan = List[float] # [start, end] -class EncodeFormats(TypedDict): - video: str - audio: str - - -class Streams(TypedDict): - video: List[bytes] - audio: List[bytes] - - def _get_seconds(t: Union[str, float]) -> float: """Converts time to seconds""" if not isinstance(t, str): diff --git a/video2dataset/types.py b/video2dataset/types.py new file mode 100644 index 00000000..40fdf24e --- /dev/null +++ b/video2dataset/types.py @@ -0,0 +1,11 @@ +from typing import List, TypedDict + + +class EncodeFormats(TypedDict): + video: str + audio: str + + +class Streams(TypedDict): + video: List[bytes] + audio: List[bytes] diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index 06383074..d3dcafb3 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -1,4 +1,5 @@ """creates a subset of an existing dataset inside the sample dimension""" +from dataclasses import dataclass import time import json import pyarrow as pa @@ -7,7 +8,7 @@ import fsspec import numpy as np import webdataset as wds -from typing import List, Any +from typing import List, Any, Union from video2dataset.dataloader import get_video_dataset from video2dataset.logger import CappedCounter, write_stats @@ -20,6 +21,56 @@ ResolutionSubsampler, AudioRateSubsampler, ) +from video2dataset.types import EncodeFormats, Streams + + +@dataclass +class Subsamplers: + broadcast_subsampler: Union[ClippingSubsampler, NoOpSubsampler] + + +def get_subsamplers(config: dict, encode_formats: EncodeFormats): + clipping_subsampler = ClippingSubsampler( + 5, # oom_clip_count + encode_formats, + **config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], + ) + + need_keyframes = clipping_subsampler.precision == "keyframe_adjusted" + ffprobe_subsampler = None + if "FFProbeSubsampler" in config["subsampling"] or need_keyframes: + ffprobe_subsampler = FFProbeSubsampler( + **config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"] + ) + ffprobe_subsampler.extract_keyframes |= need_keyframes + noop_subsampler = NoOpSubsampler() + video_subsamplers: List[Any] = [] + if "ResolutionSubsampler" in config["subsampling"]: + video_subsamplers.append(ResolutionSubsampler(**config["subsampling"]["ResolutionSubsampler"]["args"])) + if "FrameSubsampler" in config["subsampling"]: + video_subsamplers.append(FrameSubsampler(**config["subsampling"]["FrameSubsampler"]["args"])) + + audio_subsamplers: List[Any] = [] + if "AudioRateSubsampler" in config["subsampling"]: + audio_subsamplers.append(AudioRateSubsampler(**config["subsampling"]["AudioRateSubsampler"]["args"])) + subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} + + cut_detection_subsampler = None + cuts_are_clips = False + if "CutDetectionSubsampler" in config["subsampling"]: + if "args" in config["subsampling"]["CutDetectionSubsampler"]: + cut_detection_subsampler = CutDetectionSubsampler( + **config["subsampling"]["CutDetectionSubsampler"]["args"] + ) + cuts_are_clips = config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) + + broadcast_subsampler = ( + clipping_subsampler + if (config["storage"]["captions_are_subtitles"] or cuts_are_clips) + else noop_subsampler + ) + + return ffprobe_subsampler, subsamplers, cut_detection_subsampler, cuts_are_clips, broadcast_subsampler class SubsetWorker: @@ -29,72 +80,43 @@ def __init__( self, sample_writer_class, output_folder, - encode_formats, + encode_formats: EncodeFormats, config, ) -> None: self.sample_writer_class = sample_writer_class - self.save_caption = True self.output_folder = output_folder - self.encode_formats = encode_formats self.config = config + self.ffprobe_subsampler, self.subsamplers, self.cut_detection_subsampler, self.cuts_are_clips, self.broadcast_subsampler = get_subsamplers(config, encode_formats) - self.clipping_subsampler = ClippingSubsampler( - 5, # oom_clip_count - encode_formats, - **self.config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], - ) - need_keyframes = self.clipping_subsampler.precision == "keyframe_adjusted" + # set encoding formats + self.input_encode_formats = encode_formats + self.output_encode_formats = self.input_encode_formats.copy() + if self.subsamplers["audio"]: + assert ( + len({s.encode_format for s in self.subsamplers["audio"]}) == 1 + ) # assert that all audio subsamplers have the same output format + self.output_encode_formats["audio"] = self.subsamplers["audio"][0].encode_format + if self.subsamplers["video"]: + assert ( + len({s.encode_format for s in self.subsamplers["video"]}) == 1 + ) # assert that all video subsamplers have the same output format + self.output_encode_formats["video"] = self.subsamplers["video"][0].encode_format - self.ffprobe_subsampler = None - if "FFProbeSubsampler" in self.config["subsampling"] or need_keyframes: - self.ffprobe_subsampler = FFProbeSubsampler( - **self.config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"] - ) - self.ffprobe_subsampler.extract_keyframes |= need_keyframes - - self.cut_detector = None - self.cuts_are_clips = False - if "CutDetectionSubsampler" in self.config["subsampling"]: - if "args" in self.config["subsampling"]["CutDetectionSubsampler"]: - self.cut_detector = CutDetectionSubsampler( - **self.config["subsampling"]["CutDetectionSubsampler"]["args"] - ) - self.cuts_are_clips = self.config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) - - self.noop_subsampler = NoOpSubsampler() - - video_subsamplers: List[Any] = [] - if "ResolutionSubsampler" in self.config["subsampling"]: - video_subsamplers.append(ResolutionSubsampler(**self.config["subsampling"]["ResolutionSubsampler"]["args"])) - if "FrameSubsampler" in self.config["subsampling"]: - video_subsamplers.append(FrameSubsampler(**self.config["subsampling"]["FrameSubsampler"]["args"])) - - audio_subsamplers: List[Any] = [] - if "AudioRateSubsampler" in self.config["subsampling"]: - audio_subsamplers.append(AudioRateSubsampler(**self.config["subsampling"]["AudioRateSubsampler"]["args"])) - self.subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} def __call__( self, row, ): try: - self.process_shard(row) + shard, shard_id = row + self.process_shard(shard, shard_id) return (True, row) except Exception as err: # pylint: disable=broad-except traceback.print_exc() print(f"shard {row[0]} failed with error {err}") return (False, row) - def process_shard( - self, - row, - ): - """Function to start an video processing in one process""" - - shard, shard_id = row - start_time = time.time() - + def get_shard_processors(self, shard, shard_id): try: fs, shard_path = fsspec.core.url_to_fs(shard[: -len(".tar")] + ".parquet") @@ -108,55 +130,49 @@ def process_shard( pa.field("error_message", pa.string()), ] schema = pa.schema(fields) - - status_dict = CappedCounter() - - # The subsamplers might change the output format, so we need to update the writer - writer_encode_formats = self.encode_formats.copy() - if self.subsamplers["audio"]: - assert ( - len({s.encode_format for s in self.subsamplers["audio"]}) == 1 - ) # assert that all audio subsamplers have the same output format - writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_format - if self.subsamplers["video"]: - assert ( - len({s.encode_format for s in self.subsamplers["video"]}) == 1 - ) # assert that all video subsamplers have the same output format - writer_encode_formats["video"] = self.subsamplers["video"][0].encode_format - - # give schema to writer - sample_writer = self.sample_writer_class( + shard_sample_writer = self.sample_writer_class( shard_id, self.output_folder, - self.save_caption, + True, # save_caption self.config["storage"]["oom_shard_count"], schema, - writer_encode_formats, + self.output_encode_formats, + ) + shard_dataloader = get_video_dataset( + urls=shard, + batch_size=1, + decoder_kwargs={}, + enforce_additional_keys=[], + handler=wds.warn_and_continue, ) + return shard_sample_writer, shard_dataloader + def process_shard( + self, + shard, + shard_id, + ): + """Function to start an video processing in one process""" + + start_time = time.time() + + shard_sample_writer, shard_dataloader = self.get_shard_processors(shard, shard_id) successes = 0 failed = { "failed_to_download": 0, "failed_to_subsample": 0, } + status_dict = CappedCounter() error_message = None - - dataloader = get_video_dataset( - urls=shard, - batch_size=1, - decoder_kwargs={}, - enforce_additional_keys=[], - handler=wds.warn_and_continue, - ) count = 0 - for sample in dataloader: + for sample in shard_dataloader: try: count += 1 key = sample["__key__"] caption = sample.get("txt", b"").decode("utf-8") meta = json.loads(sample.get("json", b"{}").decode("utf-8")) streams = {} - for mod, fmt in self.encode_formats.items(): + for mod, fmt in self.input_encode_formats.items(): streams[mod] = [sample[fmt]] if self.ffprobe_subsampler is not None: @@ -167,8 +183,8 @@ def process_shard( if self.config["storage"]["captions_are_subtitles"]: # create clips subtitles = meta["yt_meta_dict"]["subtitles"] meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] - elif self.cut_detector is not None: # apply cut detection to get clips - streams, cuts, error_message = self.cut_detector(streams) + elif self.cut_detection_subsampler is not None: # apply cut detection to get clips + streams, cuts, error_message = self.cut_detection_subsampler(streams) if error_message is not None: raise ValueError("failed_to_subsample") @@ -180,12 +196,7 @@ def process_shard( meta["clips"] = (np.array(cuts["cuts_original_fps"]) / native_fps).tolist() # 1 video -> many videos (either clipping or noop which does identity broadcasting) - broadcast_subsampler = ( - self.clipping_subsampler - if (self.config["storage"]["captions_are_subtitles"] or self.cuts_are_clips) - else self.noop_subsampler - ) - subsampled_streams, metas, error_message = broadcast_subsampler(streams, meta) + subsampled_streams, metas, error_message = self.broadcast_subsampler(streams, meta) if error_message is not None: meta["clips"] = [] raise ValueError("failed_to_subsample") @@ -203,7 +214,7 @@ def process_shard( subsampled_streams_list = [dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())] if len(subsampled_streams_list) == 0: # no audio or video, just write meta meta["status"] = status - sample_writer.write( + shard_sample_writer.write( {}, key, caption, @@ -218,7 +229,7 @@ def process_shard( if self.config["storage"]["captions_are_subtitles"]: text_caption = meta.get("clip_subtitles")[0]["lines"][0] - sample_writer.write( + shard_sample_writer.write( subsampled_streams, meta["key"], text_caption, @@ -232,7 +243,7 @@ def process_shard( status_dict.increment(error_message) meta["status"] = status meta["error_message"] = error_message - sample_writer.write( + shard_sample_writer.write( {}, key, caption, @@ -242,7 +253,7 @@ def process_shard( traceback.print_exc() print(f"Sample {key} failed to download: {err}") - sample_writer.close() + shard_sample_writer.close() end_time = time.time() write_stats( From e862eaacbc6363bbc75837270284cd42d82752be Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Tue, 23 Jan 2024 22:01:03 -0500 Subject: [PATCH 15/23] Pulled ShardStatus parameters into their own data type --- video2dataset/workers/subset_worker.py | 67 +++++++++++++++----------- 1 file changed, 38 insertions(+), 29 deletions(-) diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index d3dcafb3..350426a0 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -1,5 +1,5 @@ """creates a subset of an existing dataset inside the sample dimension""" -from dataclasses import dataclass +from dataclasses import dataclass, field import time import json import pyarrow as pa @@ -8,7 +8,7 @@ import fsspec import numpy as np import webdataset as wds -from typing import List, Any, Union +from typing import List, Any, Union, Optional from video2dataset.dataloader import get_video_dataset from video2dataset.logger import CappedCounter, write_stats @@ -73,6 +73,22 @@ def get_subsamplers(config: dict, encode_formats: EncodeFormats): return ffprobe_subsampler, subsamplers, cut_detection_subsampler, cuts_are_clips, broadcast_subsampler +@dataclass +class ShardStatus: + successes: int = 0 + failed: dict = field( + default_factory=lambda: { + "failed_to_download": 0, + "failed_to_subsample": 0, + } + ) + status_dict: CappedCounter = field( + default_factory=CappedCounter + ) + error_message: Optional[str] = None + count: int = 0 + + class SubsetWorker: """The loader class reads the shards, then the selected data is chosen and writen by the writer""" @@ -155,19 +171,12 @@ def process_shard( """Function to start an video processing in one process""" start_time = time.time() - shard_sample_writer, shard_dataloader = self.get_shard_processors(shard, shard_id) - successes = 0 - failed = { - "failed_to_download": 0, - "failed_to_subsample": 0, - } - status_dict = CappedCounter() - error_message = None - count = 0 + shard_status = ShardStatus() + for sample in shard_dataloader: try: - count += 1 + shard_status.count += 1 key = sample["__key__"] caption = sample.get("txt", b"").decode("utf-8") meta = json.loads(sample.get("json", b"{}").decode("utf-8")) @@ -176,16 +185,16 @@ def process_shard( streams[mod] = [sample[fmt]] if self.ffprobe_subsampler is not None: - streams, meta, error_message = self.ffprobe_subsampler(streams, meta) - if error_message is not None: + streams, meta, shard_status.error_message = self.ffprobe_subsampler(streams, meta) + if shard_status.error_message is not None: raise ValueError("failed_to_subsample") if self.config["storage"]["captions_are_subtitles"]: # create clips subtitles = meta["yt_meta_dict"]["subtitles"] meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] elif self.cut_detection_subsampler is not None: # apply cut detection to get clips - streams, cuts, error_message = self.cut_detection_subsampler(streams) - if error_message is not None: + streams, cuts, shard_status.error_message = self.cut_detection_subsampler(streams) + if shard_status.error_message is not None: raise ValueError("failed_to_subsample") meta["cuts"] = cuts @@ -196,21 +205,21 @@ def process_shard( meta["clips"] = (np.array(cuts["cuts_original_fps"]) / native_fps).tolist() # 1 video -> many videos (either clipping or noop which does identity broadcasting) - subsampled_streams, metas, error_message = self.broadcast_subsampler(streams, meta) - if error_message is not None: + subsampled_streams, metas, shard_status.error_message = self.broadcast_subsampler(streams, meta) + if shard_status.error_message is not None: meta["clips"] = [] raise ValueError("failed_to_subsample") for modality in list(subsampled_streams.keys()): for modality_subsampler in self.subsamplers[modality]: - subsampled_streams, metas, error_message = modality_subsampler(subsampled_streams, metas) + subsampled_streams, metas, shard_status.error_message = modality_subsampler(subsampled_streams, metas) - if error_message is not None: + if shard_status.error_message is not None: raise ValueError("failed_to_subsample") - successes += 1 + shard_status.successes += 1 status = "success" - status_dict.increment(status) + shard_status.status_dict.increment(status) subsampled_streams_list = [dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())] if len(subsampled_streams_list) == 0: # no audio or video, just write meta meta["status"] = status @@ -239,10 +248,10 @@ def process_shard( except Exception as err: # pylint: disable=broad-except status = str(err) if status.startswith("failed_to_"): - failed[status] += 1 - status_dict.increment(error_message) + shard_status.failed[status] += 1 + shard_status.status_dict.increment(shard_status.error_message) meta["status"] = status - meta["error_message"] = error_message + meta["error_message"] = shard_status.error_message shard_sample_writer.write( {}, key, @@ -259,13 +268,13 @@ def process_shard( write_stats( self.output_folder, shard_id, - count, - successes, + shard_status.count, + shard_status.successes, 0, # failed to download - failed["failed_to_subsample"], + shard_status.failed["failed_to_subsample"], 0, # bytes downloaded start_time, end_time, - status_dict, + shard_status.status_dict, self.config["storage"]["oom_shard_count"], ) From d158106ca8640797dd5549e5c6bc9dfb5c57aa62 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Tue, 23 Jan 2024 22:20:05 -0500 Subject: [PATCH 16/23] Cleaned up shard processing error handling --- video2dataset/workers/subset_worker.py | 74 +++++++++++--------------- 1 file changed, 31 insertions(+), 43 deletions(-) diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index 350426a0..7c265860 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -76,12 +76,7 @@ def get_subsamplers(config: dict, encode_formats: EncodeFormats): @dataclass class ShardStatus: successes: int = 0 - failed: dict = field( - default_factory=lambda: { - "failed_to_download": 0, - "failed_to_subsample": 0, - } - ) + failed_to_subsample: int = 0 status_dict: CappedCounter = field( default_factory=CappedCounter ) @@ -175,51 +170,53 @@ def process_shard( shard_status = ShardStatus() for sample in shard_dataloader: + shard_status.count += 1 + key = sample["__key__"] try: - shard_status.count += 1 - key = sample["__key__"] caption = sample.get("txt", b"").decode("utf-8") meta = json.loads(sample.get("json", b"{}").decode("utf-8")) - streams = {} - for mod, fmt in self.input_encode_formats.items(): - streams[mod] = [sample[fmt]] + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"Sample {key} failed to download: {err}") + return + + try: + streams: Streams = {"video": [], "audio": []} + for modality, format in self.input_encode_formats.items(): + streams[modality] = [sample[format]] if self.ffprobe_subsampler is not None: streams, meta, shard_status.error_message = self.ffprobe_subsampler(streams, meta) - if shard_status.error_message is not None: - raise ValueError("failed_to_subsample") + assert shard_status.error_message is None if self.config["storage"]["captions_are_subtitles"]: # create clips subtitles = meta["yt_meta_dict"]["subtitles"] meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] + elif self.cut_detection_subsampler is not None: # apply cut detection to get clips streams, cuts, shard_status.error_message = self.cut_detection_subsampler(streams) - if shard_status.error_message is not None: - raise ValueError("failed_to_subsample") - + assert shard_status.error_message is None meta["cuts"] = cuts if self.cuts_are_clips: cuts = meta["cuts"] - native_fps = cuts["original_fps"] - meta["clips"] = (np.array(cuts["cuts_original_fps"]) / native_fps).tolist() + meta["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() # 1 video -> many videos (either clipping or noop which does identity broadcasting) subsampled_streams, metas, shard_status.error_message = self.broadcast_subsampler(streams, meta) if shard_status.error_message is not None: meta["clips"] = [] - raise ValueError("failed_to_subsample") + assert False for modality in list(subsampled_streams.keys()): for modality_subsampler in self.subsamplers[modality]: subsampled_streams, metas, shard_status.error_message = modality_subsampler(subsampled_streams, metas) - - if shard_status.error_message is not None: - raise ValueError("failed_to_subsample") + assert shard_status.error_message is None shard_status.successes += 1 status = "success" shard_status.status_dict.increment(status) + subsampled_streams_list = [dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())] if len(subsampled_streams_list) == 0: # no audio or video, just write meta meta["status"] = status @@ -230,37 +227,28 @@ def process_shard( meta, ) continue - for subsampled_streams, meta in zip(subsampled_streams_list, metas): meta["status"] = status - text_caption = caption if self.config["storage"]["captions_are_subtitles"]: text_caption = meta.get("clip_subtitles")[0]["lines"][0] - shard_sample_writer.write( subsampled_streams, meta["key"], text_caption, meta, ) - - except Exception as err: # pylint: disable=broad-except - status = str(err) - if status.startswith("failed_to_"): - shard_status.failed[status] += 1 - shard_status.status_dict.increment(shard_status.error_message) - meta["status"] = status - meta["error_message"] = shard_status.error_message - shard_sample_writer.write( - {}, - key, - caption, - meta, - ) - else: - traceback.print_exc() - print(f"Sample {key} failed to download: {err}") + except Exception: # pylint: disable=broad-except + shard_status.failed_to_subsample += 1 + shard_status.status_dict.increment(shard_status.error_message) + meta["status"] = "failed_to_subsample" + meta["error_message"] = shard_status.error_message + shard_sample_writer.write( + {}, + key, + caption, + meta, + ) shard_sample_writer.close() end_time = time.time() @@ -271,7 +259,7 @@ def process_shard( shard_status.count, shard_status.successes, 0, # failed to download - shard_status.failed["failed_to_subsample"], + shard_status.failed_to_subsample, 0, # bytes downloaded start_time, end_time, From 5cd53a97a3aa4e4e8dfe2613715a705a558f08eb Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Tue, 23 Jan 2024 22:28:10 -0500 Subject: [PATCH 17/23] Cleaned up code --- video2dataset/workers/subset_worker.py | 66 ++++++++++++++------------ 1 file changed, 35 insertions(+), 31 deletions(-) diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index 7c265860..323e4a94 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -35,15 +35,30 @@ def get_subsamplers(config: dict, encode_formats: EncodeFormats): encode_formats, **config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], ) - need_keyframes = clipping_subsampler.precision == "keyframe_adjusted" + + cut_detection_subsampler = None + cuts_are_clips = False + if "CutDetectionSubsampler" in config["subsampling"]: + if "args" in config["subsampling"]["CutDetectionSubsampler"]: + cut_detection_subsampler = CutDetectionSubsampler( + **config["subsampling"]["CutDetectionSubsampler"]["args"] + ) + cuts_are_clips = config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) + + broadcast_subsampler = ( + clipping_subsampler + if (config["storage"]["captions_are_subtitles"] or cuts_are_clips) + else NoOpSubsampler() + ) + ffprobe_subsampler = None if "FFProbeSubsampler" in config["subsampling"] or need_keyframes: ffprobe_subsampler = FFProbeSubsampler( **config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"] ) ffprobe_subsampler.extract_keyframes |= need_keyframes - noop_subsampler = NoOpSubsampler() + video_subsamplers: List[Any] = [] if "ResolutionSubsampler" in config["subsampling"]: video_subsamplers.append(ResolutionSubsampler(**config["subsampling"]["ResolutionSubsampler"]["args"])) @@ -53,24 +68,10 @@ def get_subsamplers(config: dict, encode_formats: EncodeFormats): audio_subsamplers: List[Any] = [] if "AudioRateSubsampler" in config["subsampling"]: audio_subsamplers.append(AudioRateSubsampler(**config["subsampling"]["AudioRateSubsampler"]["args"])) - subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} - cut_detection_subsampler = None - cuts_are_clips = False - if "CutDetectionSubsampler" in config["subsampling"]: - if "args" in config["subsampling"]["CutDetectionSubsampler"]: - cut_detection_subsampler = CutDetectionSubsampler( - **config["subsampling"]["CutDetectionSubsampler"]["args"] - ) - cuts_are_clips = config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) + modal_subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} - broadcast_subsampler = ( - clipping_subsampler - if (config["storage"]["captions_are_subtitles"] or cuts_are_clips) - else noop_subsampler - ) - - return ffprobe_subsampler, subsamplers, cut_detection_subsampler, cuts_are_clips, broadcast_subsampler + return ffprobe_subsampler, modal_subsamplers, cut_detection_subsampler, cuts_are_clips, broadcast_subsampler @dataclass @@ -97,21 +98,21 @@ def __init__( self.sample_writer_class = sample_writer_class self.output_folder = output_folder self.config = config - self.ffprobe_subsampler, self.subsamplers, self.cut_detection_subsampler, self.cuts_are_clips, self.broadcast_subsampler = get_subsamplers(config, encode_formats) + self.ffprobe_subsampler, self.modal_subsamplers, self.cut_detection_subsampler, self.cuts_are_clips, self.broadcast_subsampler = get_subsamplers(config, encode_formats) # set encoding formats self.input_encode_formats = encode_formats self.output_encode_formats = self.input_encode_formats.copy() - if self.subsamplers["audio"]: + if self.modal_subsamplers["audio"]: assert ( - len({s.encode_format for s in self.subsamplers["audio"]}) == 1 + len({s.encode_format for s in self.modal_subsamplers["audio"]}) == 1 ) # assert that all audio subsamplers have the same output format - self.output_encode_formats["audio"] = self.subsamplers["audio"][0].encode_format - if self.subsamplers["video"]: + self.output_encode_formats["audio"] = self.modal_subsamplers["audio"][0].encode_format + if self.modal_subsamplers["video"]: assert ( - len({s.encode_format for s in self.subsamplers["video"]}) == 1 + len({s.encode_format for s in self.modal_subsamplers["video"]}) == 1 ) # assert that all video subsamplers have the same output format - self.output_encode_formats["video"] = self.subsamplers["video"][0].encode_format + self.output_encode_formats["video"] = self.modal_subsamplers["video"][0].encode_format def __call__( @@ -127,14 +128,17 @@ def __call__( print(f"shard {row[0]} failed with error {err}") return (False, row) - def get_shard_processors(self, shard, shard_id): + def get_shard_processors( + self, + shard: Union[str, List[str]], + shard_id: int, + ): try: fs, shard_path = fsspec.core.url_to_fs(shard[: -len(".tar")] + ".parquet") - with fs.open(shard_path, "rb") as f: df = pa.parquet.read_table(f) schema = df.schema - except Exception as e: # pylint: disable=broad-except,unused-variable + except Exception: # pylint: disable=broad-except fields = [ pa.field("key", pa.string()), pa.field("status", pa.string()), @@ -160,8 +164,8 @@ def get_shard_processors(self, shard, shard_id): def process_shard( self, - shard, - shard_id, + shard: Union[str, List[str]], + shard_id: int, ): """Function to start an video processing in one process""" @@ -209,7 +213,7 @@ def process_shard( assert False for modality in list(subsampled_streams.keys()): - for modality_subsampler in self.subsamplers[modality]: + for modality_subsampler in self.modal_subsamplers[modality]: subsampled_streams, metas, shard_status.error_message = modality_subsampler(subsampled_streams, metas) assert shard_status.error_message is None From ffe0e716e601c140de206f371438338012a4ef8a Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Tue, 23 Jan 2024 22:42:58 -0500 Subject: [PATCH 18/23] Bug fixes --- video2dataset/types.py | 4 ++-- video2dataset/workers/subset_worker.py | 10 ++++------ 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/video2dataset/types.py b/video2dataset/types.py index 40fdf24e..a648c86a 100644 --- a/video2dataset/types.py +++ b/video2dataset/types.py @@ -1,11 +1,11 @@ from typing import List, TypedDict -class EncodeFormats(TypedDict): +class EncodeFormats(TypedDict, total=False): video: str audio: str -class Streams(TypedDict): +class Streams(TypedDict, total=False): video: List[bytes] audio: List[bytes] diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index 323e4a94..d9ee5690 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -185,7 +185,7 @@ def process_shard( return try: - streams: Streams = {"video": [], "audio": []} + streams: Streams = {} for modality, format in self.input_encode_formats.items(): streams[modality] = [sample[format]] @@ -196,15 +196,13 @@ def process_shard( if self.config["storage"]["captions_are_subtitles"]: # create clips subtitles = meta["yt_meta_dict"]["subtitles"] meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] - elif self.cut_detection_subsampler is not None: # apply cut detection to get clips streams, cuts, shard_status.error_message = self.cut_detection_subsampler(streams) assert shard_status.error_message is None meta["cuts"] = cuts - - if self.cuts_are_clips: - cuts = meta["cuts"] - meta["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() + assert cuts is not None + if self.cuts_are_clips: + meta["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() # 1 video -> many videos (either clipping or noop which does identity broadcasting) subsampled_streams, metas, shard_status.error_message = self.broadcast_subsampler(streams, meta) From 2c7daf8c2f13c4d8d4e7039e9e7dea4dfee22225 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Tue, 23 Jan 2024 22:43:42 -0500 Subject: [PATCH 19/23] Formatting --- video2dataset/workers/subset_worker.py | 29 +++++++++++++------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index d9ee5690..b25c4f6d 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -41,22 +41,16 @@ def get_subsamplers(config: dict, encode_formats: EncodeFormats): cuts_are_clips = False if "CutDetectionSubsampler" in config["subsampling"]: if "args" in config["subsampling"]["CutDetectionSubsampler"]: - cut_detection_subsampler = CutDetectionSubsampler( - **config["subsampling"]["CutDetectionSubsampler"]["args"] - ) + cut_detection_subsampler = CutDetectionSubsampler(**config["subsampling"]["CutDetectionSubsampler"]["args"]) cuts_are_clips = config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) broadcast_subsampler = ( - clipping_subsampler - if (config["storage"]["captions_are_subtitles"] or cuts_are_clips) - else NoOpSubsampler() + clipping_subsampler if (config["storage"]["captions_are_subtitles"] or cuts_are_clips) else NoOpSubsampler() ) ffprobe_subsampler = None if "FFProbeSubsampler" in config["subsampling"] or need_keyframes: - ffprobe_subsampler = FFProbeSubsampler( - **config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"] - ) + ffprobe_subsampler = FFProbeSubsampler(**config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"]) ffprobe_subsampler.extract_keyframes |= need_keyframes video_subsamplers: List[Any] = [] @@ -78,9 +72,7 @@ def get_subsamplers(config: dict, encode_formats: EncodeFormats): class ShardStatus: successes: int = 0 failed_to_subsample: int = 0 - status_dict: CappedCounter = field( - default_factory=CappedCounter - ) + status_dict: CappedCounter = field(default_factory=CappedCounter) error_message: Optional[str] = None count: int = 0 @@ -98,7 +90,13 @@ def __init__( self.sample_writer_class = sample_writer_class self.output_folder = output_folder self.config = config - self.ffprobe_subsampler, self.modal_subsamplers, self.cut_detection_subsampler, self.cuts_are_clips, self.broadcast_subsampler = get_subsamplers(config, encode_formats) + ( + self.ffprobe_subsampler, + self.modal_subsamplers, + self.cut_detection_subsampler, + self.cuts_are_clips, + self.broadcast_subsampler, + ) = get_subsamplers(config, encode_formats) # set encoding formats self.input_encode_formats = encode_formats @@ -114,7 +112,6 @@ def __init__( ) # assert that all video subsamplers have the same output format self.output_encode_formats["video"] = self.modal_subsamplers["video"][0].encode_format - def __call__( self, row, @@ -212,7 +209,9 @@ def process_shard( for modality in list(subsampled_streams.keys()): for modality_subsampler in self.modal_subsamplers[modality]: - subsampled_streams, metas, shard_status.error_message = modality_subsampler(subsampled_streams, metas) + subsampled_streams, metas, shard_status.error_message = modality_subsampler( + subsampled_streams, metas + ) assert shard_status.error_message is None shard_status.successes += 1 From ac5a35b88a1ef14999717aea59ff8fea8783bd10 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Tue, 23 Jan 2024 22:58:03 -0500 Subject: [PATCH 20/23] Fixed linting issues --- video2dataset/main.py | 17 +++++++++-------- video2dataset/workers/subset_worker.py | 12 ++++-------- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/video2dataset/main.py b/video2dataset/main.py index 56fc197f..aa1d998b 100644 --- a/video2dataset/main.py +++ b/video2dataset/main.py @@ -9,25 +9,26 @@ from typing import List, Optional, Any import numpy as np # pylint: disable=unused-import -from .logger import LoggerProcess -from .data_writer import ( +from video2dataset.logger import LoggerProcess +from video2dataset.data_writer import ( WebDatasetSampleWriter, FilesSampleWriter, ParquetSampleWriter, TFRecordSampleWriter, DummySampleWriter, ) -from .input_sharder import InputSharder -from .output_sharder import OutputSharder -from .distributor import ( +from video2dataset.input_sharder import InputSharder +from video2dataset.output_sharder import OutputSharder +from video2dataset.distributor import ( no_distributor, multiprocessing_distributor, pyspark_distributor, SlurmDistributor, SlurmShardSampler, ) -from .workers import DownloadWorker, SubsetWorker, OpticalFlowWorker, CaptionWorker, WhisperWorker -from .configs import CONFIGS +from video2dataset.workers import DownloadWorker, SubsetWorker, OpticalFlowWorker, CaptionWorker, WhisperWorker +from video2dataset.configs import CONFIGS +from video2dataset.types import EncodeFormats def identity(x): @@ -42,7 +43,7 @@ def video2dataset( output_folder: str = "dataset", output_format: str = "files", input_format: str = "csv", - encode_formats: Optional[dict] = None, + encode_formats: Optional[EncodeFormats] = None, stage: str = "download", url_col: str = "url", caption_col: Optional[str] = None, diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index b25c4f6d..7be9cb65 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -8,7 +8,7 @@ import fsspec import numpy as np import webdataset as wds -from typing import List, Any, Union, Optional +from typing import List, Any, Union, Optional, Literal, cast from video2dataset.dataloader import get_video_dataset from video2dataset.logger import CappedCounter, write_stats @@ -24,11 +24,6 @@ from video2dataset.types import EncodeFormats, Streams -@dataclass -class Subsamplers: - broadcast_subsampler: Union[ClippingSubsampler, NoOpSubsampler] - - def get_subsamplers(config: dict, encode_formats: EncodeFormats): clipping_subsampler = ClippingSubsampler( 5, # oom_clip_count @@ -127,7 +122,7 @@ def __call__( def get_shard_processors( self, - shard: Union[str, List[str]], + shard: str, shard_id: int, ): try: @@ -161,7 +156,7 @@ def get_shard_processors( def process_shard( self, - shard: Union[str, List[str]], + shard: str, shard_id: int, ): """Function to start an video processing in one process""" @@ -184,6 +179,7 @@ def process_shard( try: streams: Streams = {} for modality, format in self.input_encode_formats.items(): + modality = cast(Literal["audio", "video"], modality) streams[modality] = [sample[format]] if self.ffprobe_subsampler is not None: From 5222f39761462def598d18961a4c2e2fdc5744cd Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Tue, 23 Jan 2024 23:51:03 -0500 Subject: [PATCH 21/23] Fixing more damn linting --- video2dataset/workers/subset_worker.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index 7be9cb65..ad4c9ebc 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -8,7 +8,7 @@ import fsspec import numpy as np import webdataset as wds -from typing import List, Any, Union, Optional, Literal, cast +from typing import List, Any, Optional, Literal, cast from video2dataset.dataloader import get_video_dataset from video2dataset.logger import CappedCounter, write_stats @@ -25,6 +25,8 @@ def get_subsamplers(config: dict, encode_formats: EncodeFormats): + """Initialize all subsamplers using config""" + clipping_subsampler = ClippingSubsampler( 5, # oom_clip_count encode_formats, @@ -125,6 +127,8 @@ def get_shard_processors( shard: str, shard_id: int, ): + """Get objects for loading and writing data""" + try: fs, shard_path = fsspec.core.url_to_fs(shard[: -len(".tar")] + ".parquet") with fs.open(shard_path, "rb") as f: @@ -178,9 +182,9 @@ def process_shard( try: streams: Streams = {} - for modality, format in self.input_encode_formats.items(): + for modality, encode_format in self.input_encode_formats.items(): modality = cast(Literal["audio", "video"], modality) - streams[modality] = [sample[format]] + streams[modality] = [sample[encode_format]] if self.ffprobe_subsampler is not None: streams, meta, shard_status.error_message = self.ffprobe_subsampler(streams, meta) From 6dc899170e4b204b9cd5171db8f0fd335aa74e35 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Tue, 23 Jan 2024 23:58:01 -0500 Subject: [PATCH 22/23] Added a missing docstring --- video2dataset/types.py | 1 + 1 file changed, 1 insertion(+) diff --git a/video2dataset/types.py b/video2dataset/types.py index a648c86a..77240d32 100644 --- a/video2dataset/types.py +++ b/video2dataset/types.py @@ -1,3 +1,4 @@ +"""Type definitions for video2dataset.""" from typing import List, TypedDict From cd0d27c3b160b86241de6d9ea55cf410dc519c66 Mon Sep 17 00:00:00 2001 From: Matt Zhang Date: Wed, 24 Jan 2024 03:18:58 -0500 Subject: [PATCH 23/23] Removed git worktree folder (ugh) --- download_worker_refactoring | 1 - 1 file changed, 1 deletion(-) delete mode 160000 download_worker_refactoring diff --git a/download_worker_refactoring b/download_worker_refactoring deleted file mode 160000 index d5f3b19a..00000000 --- a/download_worker_refactoring +++ /dev/null @@ -1 +0,0 @@ -Subproject commit d5f3b19a2827f4f76be0bd2e5ec534a51eca21f7