diff --git a/video2dataset/main.py b/video2dataset/main.py index 56fc197f..aa1d998b 100644 --- a/video2dataset/main.py +++ b/video2dataset/main.py @@ -9,25 +9,26 @@ from typing import List, Optional, Any import numpy as np # pylint: disable=unused-import -from .logger import LoggerProcess -from .data_writer import ( +from video2dataset.logger import LoggerProcess +from video2dataset.data_writer import ( WebDatasetSampleWriter, FilesSampleWriter, ParquetSampleWriter, TFRecordSampleWriter, DummySampleWriter, ) -from .input_sharder import InputSharder -from .output_sharder import OutputSharder -from .distributor import ( +from video2dataset.input_sharder import InputSharder +from video2dataset.output_sharder import OutputSharder +from video2dataset.distributor import ( no_distributor, multiprocessing_distributor, pyspark_distributor, SlurmDistributor, SlurmShardSampler, ) -from .workers import DownloadWorker, SubsetWorker, OpticalFlowWorker, CaptionWorker, WhisperWorker -from .configs import CONFIGS +from video2dataset.workers import DownloadWorker, SubsetWorker, OpticalFlowWorker, CaptionWorker, WhisperWorker +from video2dataset.configs import CONFIGS +from video2dataset.types import EncodeFormats def identity(x): @@ -42,7 +43,7 @@ def video2dataset( output_folder: str = "dataset", output_format: str = "files", input_format: str = "csv", - encode_formats: Optional[dict] = None, + encode_formats: Optional[EncodeFormats] = None, stage: str = "download", url_col: str = "url", caption_col: Optional[str] = None, diff --git a/video2dataset/subsamplers/clipping_subsampler.py b/video2dataset/subsamplers/clipping_subsampler.py index 73eae18f..2e7c7d96 100644 --- a/video2dataset/subsamplers/clipping_subsampler.py +++ b/video2dataset/subsamplers/clipping_subsampler.py @@ -2,30 +2,21 @@ clipping subsampler turns full videos into clips of videos according to clip_col """ from collections.abc import Iterable -from typing import Any, Union, List, Tuple, Dict, TypedDict, Literal, cast import copy +import datetime import ffmpeg import glob import os import tempfile +from typing import Any, Union, List, Tuple, Dict, Literal, cast -import datetime -from .subsampler import Subsampler +from video2dataset.subsamplers.subsampler import Subsampler +from video2dataset.types import EncodeFormats, Streams ClipSpan = List[float] # [start, end] -class EncodeFormats(TypedDict): - video: str - audio: str - - -class Streams(TypedDict): - video: List[bytes] - audio: List[bytes] - - def _get_seconds(t: Union[str, float]) -> float: """Converts time to seconds""" if not isinstance(t, str): diff --git a/video2dataset/types.py b/video2dataset/types.py new file mode 100644 index 00000000..77240d32 --- /dev/null +++ b/video2dataset/types.py @@ -0,0 +1,12 @@ +"""Type definitions for video2dataset.""" +from typing import List, TypedDict + + +class EncodeFormats(TypedDict, total=False): + video: str + audio: str + + +class Streams(TypedDict, total=False): + video: List[bytes] + audio: List[bytes] diff --git a/video2dataset/workers/subset_worker.py b/video2dataset/workers/subset_worker.py index 06383074..ad4c9ebc 100644 --- a/video2dataset/workers/subset_worker.py +++ b/video2dataset/workers/subset_worker.py @@ -1,4 +1,5 @@ """creates a subset of an existing dataset inside the sample dimension""" +from dataclasses import dataclass, field import time import json import pyarrow as pa @@ -7,7 +8,7 @@ import fsspec import numpy as np import webdataset as wds -from typing import List, Any +from typing import List, Any, Optional, Literal, cast from video2dataset.dataloader import get_video_dataset from video2dataset.logger import CappedCounter, write_stats @@ -20,6 +21,57 @@ ResolutionSubsampler, AudioRateSubsampler, ) +from video2dataset.types import EncodeFormats, Streams + + +def get_subsamplers(config: dict, encode_formats: EncodeFormats): + """Initialize all subsamplers using config""" + + clipping_subsampler = ClippingSubsampler( + 5, # oom_clip_count + encode_formats, + **config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], + ) + need_keyframes = clipping_subsampler.precision == "keyframe_adjusted" + + cut_detection_subsampler = None + cuts_are_clips = False + if "CutDetectionSubsampler" in config["subsampling"]: + if "args" in config["subsampling"]["CutDetectionSubsampler"]: + cut_detection_subsampler = CutDetectionSubsampler(**config["subsampling"]["CutDetectionSubsampler"]["args"]) + cuts_are_clips = config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) + + broadcast_subsampler = ( + clipping_subsampler if (config["storage"]["captions_are_subtitles"] or cuts_are_clips) else NoOpSubsampler() + ) + + ffprobe_subsampler = None + if "FFProbeSubsampler" in config["subsampling"] or need_keyframes: + ffprobe_subsampler = FFProbeSubsampler(**config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"]) + ffprobe_subsampler.extract_keyframes |= need_keyframes + + video_subsamplers: List[Any] = [] + if "ResolutionSubsampler" in config["subsampling"]: + video_subsamplers.append(ResolutionSubsampler(**config["subsampling"]["ResolutionSubsampler"]["args"])) + if "FrameSubsampler" in config["subsampling"]: + video_subsamplers.append(FrameSubsampler(**config["subsampling"]["FrameSubsampler"]["args"])) + + audio_subsamplers: List[Any] = [] + if "AudioRateSubsampler" in config["subsampling"]: + audio_subsamplers.append(AudioRateSubsampler(**config["subsampling"]["AudioRateSubsampler"]["args"])) + + modal_subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} + + return ffprobe_subsampler, modal_subsamplers, cut_detection_subsampler, cuts_are_clips, broadcast_subsampler + + +@dataclass +class ShardStatus: + successes: int = 0 + failed_to_subsample: int = 0 + status_dict: CappedCounter = field(default_factory=CappedCounter) + error_message: Optional[str] = None + count: int = 0 class SubsetWorker: @@ -29,232 +81,189 @@ def __init__( self, sample_writer_class, output_folder, - encode_formats, + encode_formats: EncodeFormats, config, ) -> None: self.sample_writer_class = sample_writer_class - self.save_caption = True self.output_folder = output_folder - self.encode_formats = encode_formats self.config = config - - self.clipping_subsampler = ClippingSubsampler( - 5, # oom_clip_count - encode_formats, - **self.config["subsampling"].get("ClippingSubsampler", {"args": {}})["args"], - ) - need_keyframes = self.clipping_subsampler.precision == "keyframe_adjusted" - - self.ffprobe_subsampler = None - if "FFProbeSubsampler" in self.config["subsampling"] or need_keyframes: - self.ffprobe_subsampler = FFProbeSubsampler( - **self.config["subsampling"].get("FFProbeSubsampler", {"args": {}})["args"] - ) - self.ffprobe_subsampler.extract_keyframes |= need_keyframes - - self.cut_detector = None - self.cuts_are_clips = False - if "CutDetectionSubsampler" in self.config["subsampling"]: - if "args" in self.config["subsampling"]["CutDetectionSubsampler"]: - self.cut_detector = CutDetectionSubsampler( - **self.config["subsampling"]["CutDetectionSubsampler"]["args"] - ) - self.cuts_are_clips = self.config["subsampling"]["CutDetectionSubsampler"].get("cuts_are_clips", False) - - self.noop_subsampler = NoOpSubsampler() - - video_subsamplers: List[Any] = [] - if "ResolutionSubsampler" in self.config["subsampling"]: - video_subsamplers.append(ResolutionSubsampler(**self.config["subsampling"]["ResolutionSubsampler"]["args"])) - if "FrameSubsampler" in self.config["subsampling"]: - video_subsamplers.append(FrameSubsampler(**self.config["subsampling"]["FrameSubsampler"]["args"])) - - audio_subsamplers: List[Any] = [] - if "AudioRateSubsampler" in self.config["subsampling"]: - audio_subsamplers.append(AudioRateSubsampler(**self.config["subsampling"]["AudioRateSubsampler"]["args"])) - self.subsamplers = {"video": video_subsamplers, "audio": audio_subsamplers} + ( + self.ffprobe_subsampler, + self.modal_subsamplers, + self.cut_detection_subsampler, + self.cuts_are_clips, + self.broadcast_subsampler, + ) = get_subsamplers(config, encode_formats) + + # set encoding formats + self.input_encode_formats = encode_formats + self.output_encode_formats = self.input_encode_formats.copy() + if self.modal_subsamplers["audio"]: + assert ( + len({s.encode_format for s in self.modal_subsamplers["audio"]}) == 1 + ) # assert that all audio subsamplers have the same output format + self.output_encode_formats["audio"] = self.modal_subsamplers["audio"][0].encode_format + if self.modal_subsamplers["video"]: + assert ( + len({s.encode_format for s in self.modal_subsamplers["video"]}) == 1 + ) # assert that all video subsamplers have the same output format + self.output_encode_formats["video"] = self.modal_subsamplers["video"][0].encode_format def __call__( self, row, ): try: - self.process_shard(row) + shard, shard_id = row + self.process_shard(shard, shard_id) return (True, row) except Exception as err: # pylint: disable=broad-except traceback.print_exc() print(f"shard {row[0]} failed with error {err}") return (False, row) - def process_shard( + def get_shard_processors( self, - row, + shard: str, + shard_id: int, ): - """Function to start an video processing in one process""" - - shard, shard_id = row - start_time = time.time() + """Get objects for loading and writing data""" try: fs, shard_path = fsspec.core.url_to_fs(shard[: -len(".tar")] + ".parquet") - with fs.open(shard_path, "rb") as f: df = pa.parquet.read_table(f) schema = df.schema - except Exception as e: # pylint: disable=broad-except,unused-variable + except Exception: # pylint: disable=broad-except fields = [ pa.field("key", pa.string()), pa.field("status", pa.string()), pa.field("error_message", pa.string()), ] schema = pa.schema(fields) - - status_dict = CappedCounter() - - # The subsamplers might change the output format, so we need to update the writer - writer_encode_formats = self.encode_formats.copy() - if self.subsamplers["audio"]: - assert ( - len({s.encode_format for s in self.subsamplers["audio"]}) == 1 - ) # assert that all audio subsamplers have the same output format - writer_encode_formats["audio"] = self.subsamplers["audio"][0].encode_format - if self.subsamplers["video"]: - assert ( - len({s.encode_format for s in self.subsamplers["video"]}) == 1 - ) # assert that all video subsamplers have the same output format - writer_encode_formats["video"] = self.subsamplers["video"][0].encode_format - - # give schema to writer - sample_writer = self.sample_writer_class( + shard_sample_writer = self.sample_writer_class( shard_id, self.output_folder, - self.save_caption, + True, # save_caption self.config["storage"]["oom_shard_count"], schema, - writer_encode_formats, + self.output_encode_formats, ) - - successes = 0 - failed = { - "failed_to_download": 0, - "failed_to_subsample": 0, - } - error_message = None - - dataloader = get_video_dataset( + shard_dataloader = get_video_dataset( urls=shard, batch_size=1, decoder_kwargs={}, enforce_additional_keys=[], handler=wds.warn_and_continue, ) - count = 0 - for sample in dataloader: + return shard_sample_writer, shard_dataloader + + def process_shard( + self, + shard: str, + shard_id: int, + ): + """Function to start an video processing in one process""" + + start_time = time.time() + shard_sample_writer, shard_dataloader = self.get_shard_processors(shard, shard_id) + shard_status = ShardStatus() + + for sample in shard_dataloader: + shard_status.count += 1 + key = sample["__key__"] try: - count += 1 - key = sample["__key__"] caption = sample.get("txt", b"").decode("utf-8") meta = json.loads(sample.get("json", b"{}").decode("utf-8")) - streams = {} - for mod, fmt in self.encode_formats.items(): - streams[mod] = [sample[fmt]] + except Exception as err: # pylint: disable=broad-except + traceback.print_exc() + print(f"Sample {key} failed to download: {err}") + return + + try: + streams: Streams = {} + for modality, encode_format in self.input_encode_formats.items(): + modality = cast(Literal["audio", "video"], modality) + streams[modality] = [sample[encode_format]] if self.ffprobe_subsampler is not None: - streams, meta, error_message = self.ffprobe_subsampler(streams, meta) - if error_message is not None: - raise ValueError("failed_to_subsample") + streams, meta, shard_status.error_message = self.ffprobe_subsampler(streams, meta) + assert shard_status.error_message is None if self.config["storage"]["captions_are_subtitles"]: # create clips subtitles = meta["yt_meta_dict"]["subtitles"] meta["clips"] = [[line_dict["start"], line_dict["end"]] for line_dict in subtitles] - elif self.cut_detector is not None: # apply cut detection to get clips - streams, cuts, error_message = self.cut_detector(streams) - if error_message is not None: - raise ValueError("failed_to_subsample") - + elif self.cut_detection_subsampler is not None: # apply cut detection to get clips + streams, cuts, shard_status.error_message = self.cut_detection_subsampler(streams) + assert shard_status.error_message is None meta["cuts"] = cuts - - if self.cuts_are_clips: - cuts = meta["cuts"] - native_fps = cuts["original_fps"] - meta["clips"] = (np.array(cuts["cuts_original_fps"]) / native_fps).tolist() + assert cuts is not None + if self.cuts_are_clips: + meta["clips"] = (np.array(cuts["cuts_original_fps"]) / cuts["original_fps"]).tolist() # 1 video -> many videos (either clipping or noop which does identity broadcasting) - broadcast_subsampler = ( - self.clipping_subsampler - if (self.config["storage"]["captions_are_subtitles"] or self.cuts_are_clips) - else self.noop_subsampler - ) - subsampled_streams, metas, error_message = broadcast_subsampler(streams, meta) - if error_message is not None: + subsampled_streams, metas, shard_status.error_message = self.broadcast_subsampler(streams, meta) + if shard_status.error_message is not None: meta["clips"] = [] - raise ValueError("failed_to_subsample") + assert False for modality in list(subsampled_streams.keys()): - for modality_subsampler in self.subsamplers[modality]: - subsampled_streams, metas, error_message = modality_subsampler(subsampled_streams, metas) + for modality_subsampler in self.modal_subsamplers[modality]: + subsampled_streams, metas, shard_status.error_message = modality_subsampler( + subsampled_streams, metas + ) + assert shard_status.error_message is None - if error_message is not None: - raise ValueError("failed_to_subsample") - - successes += 1 + shard_status.successes += 1 status = "success" - status_dict.increment(status) + shard_status.status_dict.increment(status) + subsampled_streams_list = [dict(zip(subsampled_streams, s)) for s in zip(*subsampled_streams.values())] if len(subsampled_streams_list) == 0: # no audio or video, just write meta meta["status"] = status - sample_writer.write( + shard_sample_writer.write( {}, key, caption, meta, ) continue - for subsampled_streams, meta in zip(subsampled_streams_list, metas): meta["status"] = status - text_caption = caption if self.config["storage"]["captions_are_subtitles"]: text_caption = meta.get("clip_subtitles")[0]["lines"][0] - - sample_writer.write( + shard_sample_writer.write( subsampled_streams, meta["key"], text_caption, meta, ) + except Exception: # pylint: disable=broad-except + shard_status.failed_to_subsample += 1 + shard_status.status_dict.increment(shard_status.error_message) + meta["status"] = "failed_to_subsample" + meta["error_message"] = shard_status.error_message + shard_sample_writer.write( + {}, + key, + caption, + meta, + ) - except Exception as err: # pylint: disable=broad-except - status = str(err) - if status.startswith("failed_to_"): - failed[status] += 1 - status_dict.increment(error_message) - meta["status"] = status - meta["error_message"] = error_message - sample_writer.write( - {}, - key, - caption, - meta, - ) - else: - traceback.print_exc() - print(f"Sample {key} failed to download: {err}") - - sample_writer.close() + shard_sample_writer.close() end_time = time.time() write_stats( self.output_folder, shard_id, - count, - successes, + shard_status.count, + shard_status.successes, 0, # failed to download - failed["failed_to_subsample"], + shard_status.failed_to_subsample, 0, # bytes downloaded start_time, end_time, - status_dict, + shard_status.status_dict, self.config["storage"]["oom_shard_count"], )