diff --git a/aana/api/api_generation.py b/aana/api/api_generation.py index 0f656308..18363555 100644 --- a/aana/api/api_generation.py +++ b/aana/api/api_generation.py @@ -10,6 +10,7 @@ from aana.exceptions.general import MultipleFileUploadNotAllowed from aana.models.pydantic.exception_response import ExceptionResponseModel +from pydantic import ValidationError async def run_pipeline( @@ -93,6 +94,20 @@ class FileUploadField: description: str +@dataclass +class EndpointOutput: + """ + Class used to represent an endpoint output. + + Attributes: + name (str): Name of the output that should be returned by the endpoint. + output (str): Output of the pipeline that should be returned by the endpoint. + """ + + name: str + output: str + + @dataclass class Endpoint: """ @@ -102,8 +117,7 @@ class Endpoint: name (str): Name of the endpoint. path (str): Path of the endpoint. summary (str): Description of the endpoint that will be shown in the API documentation. - outputs (List[str]): List of required outputs from the pipeline that should be returned - by the endpoint. + outputs (List[EndpointOutput]): List of outputs that should be returned by the endpoint. output_filter (Optional[OutputFilter]): The parameter will be added to the request and will allow to choose subset of `outputs` to return. streaming (bool): Whether the endpoint outputs a stream of data. @@ -112,10 +126,19 @@ class Endpoint: name: str path: str summary: str - outputs: List[str] + outputs: list[EndpointOutput] output_filter: Optional[OutputFilter] = None streaming: bool = False + def __post_init__(self): + """ + Post init method. + + Creates dictionaries for fast lookup of outputs. + """ + self.name_to_output = {output.name: output.output for output in self.outputs} + self.output_to_name = {output.output: output.name for output in self.outputs} + def generate_model_name(self, suffix: str) -> str: """ Generate a Pydantic model name based on a given suffix. @@ -145,21 +168,25 @@ def socket_to_field(self, socket: Socket) -> Tuple[Any, Any]: data_model = Any return (data_model, Field(None)) - # check if any of the fields are required - if any(field.required for field in data_model.__fields__.values()): + # try to instantiate the data model + # to see if any of the fields are required + try: + data_model_instance = data_model() + return (data_model, data_model_instance) + except ValidationError: + # if we can't instantiate the data model + # it means that it has required fields return (data_model, ...) - return (data_model, data_model()) - - def get_fields(self, sockets: List[Socket]) -> Dict[str, Tuple[Any, Any]]: + def get_input_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]: """ - Generate fields for the Pydantic model based on the provided sockets. + Generate fields for the request Pydantic model based on the provided sockets. Parameters: - sockets (List[Socket]): List of sockets. + sockets (list[Socket]): List of sockets. Returns: - Dict[str, Tuple[Any, Field]]: Dictionary of fields for the Pydantic model. + dict[str, tuple[Any, Field]]: Dictionary of fields for the request Pydantic model. """ fields = {} for socket in sockets: @@ -167,6 +194,23 @@ def get_fields(self, sockets: List[Socket]) -> Dict[str, Tuple[Any, Any]]: fields[socket.name] = field return fields + def get_output_fields(self, sockets: list[Socket]) -> dict[str, tuple[Any, Any]]: + """ + Generate fields for the response Pydantic model based on the provided sockets. + + Parameters: + sockets (list[Socket]): List of sockets. + + Returns: + dict[str, tuple[Any, Field]]: Dictionary of fields for the response Pydantic model. + """ + fields = {} + for socket in sockets: + field = self.socket_to_field(socket) + name = self.output_to_name[socket.name] + fields[name] = field + return fields + def get_file_upload_field( self, input_sockets: List[Socket] ) -> Optional[FileUploadField]: @@ -219,7 +263,9 @@ def get_output_filter_field(self) -> Optional[Tuple[Any, Any]]: description = self.output_filter.description outputs_enum_name = self.generate_model_name("Outputs") outputs_enum = Enum( # type: ignore - outputs_enum_name, [(output, output) for output in self.outputs], type=str + outputs_enum_name, + [(output.name, output.name) for output in self.outputs], + type=str, ) field = (Optional[List[outputs_enum]], Field(None, description=description)) return field @@ -235,7 +281,7 @@ def get_request_model(self, input_sockets: List[Socket]) -> Type[BaseModel]: Type[BaseModel]: Pydantic model for the request. """ model_name = self.generate_model_name("Request") - input_fields = self.get_fields(input_sockets) + input_fields = self.get_input_fields(input_sockets) output_filter_field = self.get_output_filter_field() if output_filter_field and self.output_filter: input_fields[self.output_filter.name] = output_filter_field @@ -253,10 +299,30 @@ def get_response_model(self, output_sockets: List[Socket]) -> Type[BaseModel]: Type[BaseModel]: Pydantic model for the response. """ model_name = self.generate_model_name("Response") - output_fields = self.get_fields(output_sockets) + output_fields = self.get_output_fields(output_sockets) ResponseModel = create_model(model_name, **output_fields) return ResponseModel + def process_output(self, output: dict[str, Any]) -> dict[str, Any]: + """ + Process the output of the pipeline. + + Maps the output names of the pipeline to the names defined in the endpoint outputs. + + For example, maps videos_captions_hf_blip2_opt_2_7b to captions. + + Args: + output (dict): The output of the pipeline. + + Returns: + dict: The processed output. + """ + output = { + self.output_to_name.get(output_name, output_name): output_value + for output_name, output_value in output.items() + } + return output + def create_endpoint_func( self, pipeline: Pipeline, @@ -293,10 +359,12 @@ async def route_func_body(body: str, files: Optional[List[UploadFile]] = None): if requested_outputs: # get values for requested outputs because it's a list of enums requested_outputs = [output.value for output in requested_outputs] - outputs = requested_outputs + # map the requested outputs to the actual outputs + # for example, videos_captions_hf_blip2_opt_2_7b to captions + outputs = [self.name_to_output[output] for output in requested_outputs] # otherwise use the required outputs from the config (all outputs endpoint provides) else: - outputs = self.outputs + outputs = [output.output for output in self.outputs] # remove the output filter parameter from the data if self.output_filter and self.output_filter.name in data_dict: @@ -304,15 +372,23 @@ async def route_func_body(body: str, files: Optional[List[UploadFile]] = None): # run the pipeline if self.streaming: + async def generator_wrapper(): """ Serializes the output of the generator using ORJSONResponseCustom """ - async for output in run_pipeline_streaming(pipeline, data_dict, outputs): + async for output in run_pipeline_streaming( + pipeline, data_dict, outputs + ): + output = self.process_output(output) yield AanaJSONResponse(content=output).body - return StreamingResponse(generator_wrapper(), media_type="application/json") + + return StreamingResponse( + generator_wrapper(), media_type="application/json" + ) else: output = await run_pipeline(pipeline, data_dict, outputs) + output = self.process_output(output) return AanaJSONResponse(content=output) if file_upload_field: @@ -336,7 +412,9 @@ def register( pipeline (Pipeline): Pipeline to register the endpoint to. custom_schemas (Dict[str, Dict]): Dictionary of custom schemas. """ - input_sockets, output_sockets = pipeline.get_sockets(self.outputs) + input_sockets, output_sockets = pipeline.get_sockets( + [output.output for output in self.outputs] + ) RequestModel = self.get_request_model(input_sockets) ResponseModel = self.get_response_model(output_sockets) file_upload_field = self.get_file_upload_field(input_sockets) diff --git a/aana/configs/build.py b/aana/configs/build.py index 8cd724db..50b8857c 100644 --- a/aana/configs/build.py +++ b/aana/configs/build.py @@ -43,7 +43,7 @@ def get_configuration(target: str, endpoints, nodes, deployments) -> Dict: # Target endpoints require the following outputs endpoint_outputs = [] for endpoint in target_endpoints: - endpoint_outputs += endpoint.outputs + endpoint_outputs += [output.output for output in endpoint.outputs] # Build the output graph for the whole pipeline node_definitions = [NodeDefinition.from_dict(node_dict) for node_dict in nodes] diff --git a/aana/configs/endpoints.py b/aana/configs/endpoints.py index e0d1b98e..cd62d22d 100644 --- a/aana/configs/endpoints.py +++ b/aana/configs/endpoints.py @@ -1,4 +1,4 @@ -from aana.api.api_generation import Endpoint +from aana.api.api_generation import Endpoint, EndpointOutput endpoints = { @@ -7,13 +7,19 @@ name="llm_generate", path="/llm/generate", summary="Generate text using LLaMa2 7B Chat", - outputs=["vllm_llama2_7b_chat_output"], + outputs=[ + EndpointOutput(name="completion", output="vllm_llama2_7b_chat_output") + ], ), Endpoint( name="llm_generate_stream", path="/llm/generate_stream", summary="Generate text using LLaMa2 7B Chat (streaming)", - outputs=["vllm_llama2_7b_chat_output_stream"], + outputs=[ + EndpointOutput( + name="completion", output="vllm_llama2_7b_chat_output_stream" + ) + ], streaming=True, ), ], @@ -22,7 +28,9 @@ name="zephyr_generate", path="/llm/generate", summary="Generate text using Zephyr 7B Beta", - outputs=["vllm_zephyr_7b_beta_output"], + outputs=[ + EndpointOutput(name="completion", output="vllm_zephyr_7b_beta_output") + ], ) ], "blip2": [ @@ -30,13 +38,20 @@ name="blip2_generate", path="/image/generate_captions", summary="Generate captions for images using BLIP2 OPT-2.7B", - outputs=["captions_hf_blip2_opt_2_7b"], + outputs=[ + EndpointOutput(name="captions", output="captions_hf_blip2_opt_2_7b") + ], ), Endpoint( name="blip2_video_generate", path="/video/generate_captions", summary="Generate captions for videos using BLIP2 OPT-2.7B", - outputs=["video_captions_hf_blip2_opt_2_7b", "timestamps"], + outputs=[ + EndpointOutput( + name="captions", output="videos_captions_hf_blip2_opt_2_7b" + ), + EndpointOutput(name="timestamps", output="timestamps"), + ], ), ], "video": [ @@ -44,7 +59,10 @@ name="video_extract_frames", path="/video/extract_frames", summary="Extract frames from a video", - outputs=["timestamps", "duration"], + outputs=[ + EndpointOutput(name="timestamps", output="timestamps"), + EndpointOutput(name="duration", output="duration"), + ], ) ], "whisper": [ @@ -53,10 +71,68 @@ path="/video/transcribe", summary="Transcribe a video using Whisper Medium", outputs=[ - "video_transcriptions_whisper_medium", - "video_transcriptions_segments_whisper_medium", - "video_transcriptions_info_whisper_medium", + EndpointOutput( + name="transcription", output="videos_transcriptions_whisper_medium" + ), + EndpointOutput( + name="segments", + output="videos_transcriptions_segments_whisper_medium", + ), + EndpointOutput( + name="info", output="videos_transcriptions_info_whisper_medium" + ), ], ) ], + "chat_with_video": [ + Endpoint( + name="blip2_video_generate", + path="/video/generate_captions", + summary="Generate captions for videos using BLIP2 OPT-2.7B", + outputs=[ + EndpointOutput( + name="captions", output="video_captions_hf_blip2_opt_2_7b" + ), + EndpointOutput(name="timestamps", output="video_timestamps"), + ], + streaming=True, + ), + Endpoint( + name="whisper_transcribe", + path="/video/transcribe", + summary="Transcribe a video using Whisper Medium", + outputs=[ + EndpointOutput( + name="transcription", output="video_transcriptions_whisper_medium" + ), + EndpointOutput( + name="segments", + output="video_transcriptions_segments_whisper_medium", + ), + EndpointOutput( + name="info", output="video_transcriptions_info_whisper_medium" + ), + ], + streaming=True, + ), + Endpoint( + name="llm_generate", + path="/llm/generate", + summary="Generate text using LLaMa2 7B Chat", + outputs=[ + EndpointOutput(name="completion", output="vllm_llama2_7b_chat_output") + ], + ), + Endpoint( + name="llm_generate_stream", + path="/llm/generate_stream", + summary="Generate text using LLaMa2 7B Chat (streaming)", + outputs=[ + EndpointOutput( + name="completion", output="vllm_llama2_7b_chat_output_stream" + ) + ], + streaming=True, + ), + ], } diff --git a/aana/configs/pipeline.py b/aana/configs/pipeline.py index 8ad8465a..99b1d397 100644 --- a/aana/configs/pipeline.py +++ b/aana/configs/pipeline.py @@ -12,7 +12,7 @@ from aana.models.pydantic.image_input import ImageInputList from aana.models.pydantic.prompt import Prompt from aana.models.pydantic.sampling_params import SamplingParams -from aana.models.pydantic.video_input import VideoInputList +from aana.models.pydantic.video_input import VideoInput, VideoInputList from aana.models.pydantic.video_params import VideoParams from aana.models.pydantic.whisper_params import WhisperParams @@ -211,7 +211,7 @@ ], }, { - "name": "download_video", + "name": "download_videos", "type": "ray_task", "function": "aana.utils.video.download_video", "batched": True, @@ -278,7 +278,7 @@ ], }, { - "name": "hf_blip2_opt_2_7b_video", + "name": "hf_blip2_opt_2_7b_videos", "type": "ray_deployment", "deployment_name": "hf_blip2_deployment_opt_2_7b", "method": "generate_batch", @@ -292,7 +292,7 @@ ], "outputs": [ { - "name": "video_captions_hf_blip2_opt_2_7b", + "name": "videos_captions_hf_blip2_opt_2_7b", "key": "captions", "path": "video_batch.videos.[*].frames.[*].caption_hf_blip2_opt_2_7b", "data_model": VideoCaptionsList, @@ -330,6 +330,134 @@ "data_model": WhisperParams, }, ], + "outputs": [ + { + "name": "videos_transcriptions_segments_whisper_medium", + "key": "segments", + "path": "video_batch.videos.[*].segments", + "data_model": AsrSegmentsList, + }, + { + "name": "videos_transcriptions_info_whisper_medium", + "key": "transcription_info", + "path": "video_batch.videos.[*].transcription_info", + "data_model": AsrTranscriptionInfoList, + }, + { + "name": "videos_transcriptions_whisper_medium", + "key": "transcription", + "path": "video_batch.videos.[*].transcription", + "data_model": AsrTranscriptionList, + }, + ], + }, + { + "name": "video", + "type": "input", + "inputs": [], + "outputs": [ + { + "name": "video", + "key": "video", + "path": "video.video_input", + "data_model": VideoInput, + } + ], + }, + { + "name": "download_video", + "type": "ray_task", + "function": "aana.utils.video.download_video", + "dict_output": False, + "inputs": [ + { + "name": "video", + "key": "video_input", + "path": "video.video_input", + }, + ], + "outputs": [ + { + "name": "video_object", + "key": "output", + "path": "video.video", + }, + ], + }, + { + "name": "generate_frames_for_video", + "type": "ray_task", + "function": "aana.utils.video.generate_frames_decord", + "data_type": "generator", + "generator_path": "video", + "inputs": [ + { + "name": "video_object", + "key": "video", + "path": "video.video", + }, + {"name": "video_params", "key": "params", "path": "video_batch.params"}, + ], + "outputs": [ + { + "name": "video_frames", + "key": "frames", + "path": "video.frames.[*].image", + }, + { + "name": "video_timestamps", + "key": "timestamps", + "path": "video.timestamps", + }, + { + "name": "video_duration", + "key": "duration", + "path": "video.duration", + }, + ], + }, + { + "name": "hf_blip2_opt_2_7b_video", + "type": "ray_deployment", + "deployment_name": "hf_blip2_deployment_opt_2_7b", + "method": "generate_batch", + "flatten_by": "video.frames.[*]", + "inputs": [ + { + "name": "video_frames", + "key": "images", + "path": "video.frames.[*].image", + } + ], + "outputs": [ + { + "name": "video_captions_hf_blip2_opt_2_7b", + "key": "captions", + "path": "video.frames.[*].caption_hf_blip2_opt_2_7b", + "data_model": VideoCaptionsList, + } + ], + }, + { + "name": "whisper_medium_transcribe_video", + "type": "ray_deployment", + "deployment_name": "whisper_deployment_medium", + "data_type": "generator", + "generator_path": "video", + "method": "transcribe_stream", + "inputs": [ + { + "name": "video_object", + "key": "media", + "path": "video.video", + }, + { + "name": "whisper_params", + "key": "params", + "path": "video_batch.whisper_params", + "data_model": WhisperParams, + }, + ], "outputs": [ { "name": "video_transcriptions_segments_whisper_medium", diff --git a/aana/deployments/hf_blip2_deployment.py b/aana/deployments/hf_blip2_deployment.py index f5b5c677..af0b225d 100644 --- a/aana/deployments/hf_blip2_deployment.py +++ b/aana/deployments/hf_blip2_deployment.py @@ -27,26 +27,6 @@ class HFBlip2Config(BaseModel): batch_size: int = Field(default=1) num_processing_threads: int = Field(default=1) - @validator("dtype", pre=True, always=True) - def validate_dtype(cls, value: Dtype) -> Dtype: - """ - Validate the data type. For BLIP2 only "float32" and "float16" are supported. - - Args: - value (Dtype): the data type - - Returns: - Dtype: the validated data type - - Raises: - ValueError: if the data type is not supported - """ - if value not in {Dtype.AUTO, Dtype.FLOAT32, Dtype.FLOAT16}: - raise ValueError( - f"Invalid dtype: {value}. BLIP2 only supports 'auto', 'float32', and 'float16'." - ) - return value - class CaptioningOutput(TypedDict): """ @@ -105,10 +85,17 @@ async def apply_config(self, config: Dict[str, Any]): # Load the model and processor for BLIP2 from HuggingFace self.model_id = config_obj.model self.dtype = config_obj.dtype - self.torch_dtype = self.dtype.to_torch() + if self.dtype == Dtype.INT8: + load_in_8bit = True + self.torch_dtype = Dtype.FLOAT16.to_torch() + else: + load_in_8bit = False + self.torch_dtype = self.dtype.to_torch() self.model = Blip2ForConditionalGeneration.from_pretrained( - self.model_id, torch_dtype=self.torch_dtype + self.model_id, torch_dtype=self.torch_dtype, load_in_8bit=load_in_8bit ) + self.model = torch.compile(self.model) + self.model.eval() self.processor = Blip2Processor.from_pretrained(self.model_id) self.device = "cuda" if torch.cuda.is_available() else "cpu" diff --git a/aana/deployments/whisper_deployment.py b/aana/deployments/whisper_deployment.py index bf661912..4ee6927f 100644 --- a/aana/deployments/whisper_deployment.py +++ b/aana/deployments/whisper_deployment.py @@ -147,7 +147,7 @@ async def apply_config(self, config: Dict[str, Any]): # TODO: add audio support async def transcribe( - self, media: Video, params: WhisperParams = WhisperParams() + self, media: Video, params: WhisperParams | None = None ) -> WhisperOutput: """ Transcribe the media with the whisper model. @@ -165,6 +165,8 @@ async def transcribe( Raises: InferenceException: If the inference fails. """ + if params is None: + params = WhisperParams() media_path: str = str(media.path) try: @@ -183,6 +185,29 @@ async def transcribe( transcription=asr_transcription, ) + async def transcribe_stream( + self, media: Video, params: WhisperParams | None = None + ) -> WhisperOutput: + """ + Transcribe the media with the whisper model in a streaming fashion. + + Right now this is the same as transcribe, but we will add support for + streaming in the future to support larger media and to make the ASR more responsive. + + Args: + media (Video): The media to transcribe. + params (WhisperParams): The parameters for the whisper model. + + Yields: + WhisperOutput: The transcription output as a dictionary: + segments (List[AsrSegment]): The ASR segments. + transcription_info (AsrTranscriptionInfo): The ASR transcription info. + transcription (AsrTranscription): The ASR transcription. + """ + # TODO: add streaming support + output = await self.transcribe(media, params) + yield output + async def transcribe_batch( self, media_batch: List[Video], params: WhisperParams = WhisperParams() ) -> WhisperBatchOutput: diff --git a/aana/tests/test_api_generation.py b/aana/tests/test_api_generation.py index ccfb2b03..edba374e 100644 --- a/aana/tests/test_api_generation.py +++ b/aana/tests/test_api_generation.py @@ -6,7 +6,7 @@ from pydantic import BaseModel, Field, Extra -from aana.api.api_generation import Endpoint +from aana.api.api_generation import Endpoint, EndpointOutput from aana.exceptions.general import MultipleFileUploadNotAllowed @@ -49,7 +49,7 @@ def test_get_request_model(): name="test_endpoint", summary="Test endpoint", path="/test_endpoint", - outputs=["output"], + outputs=[EndpointOutput(name="output", output="output")], ) input_sockets = [ @@ -81,7 +81,12 @@ def test_get_response_model(): name="test_endpoint", summary="Test endpoint", path="/test_endpoint", - outputs=["output", "output_without_datamodel"], + outputs=[ + EndpointOutput(name="output", output="output"), + EndpointOutput( + name="output_without_datamodel", output="output_without_datamodel" + ), + ], ) output_sockets = [ @@ -109,7 +114,7 @@ def test_get_response_model(): name="test_endpoint", summary="Test endpoint", path="/test_endpoint", - outputs=["output"], + outputs=[EndpointOutput(name="output", output="output")], ) output_sockets = [ @@ -135,7 +140,7 @@ def test_get_file_upload_field(): name="test_endpoint", summary="Test endpoint", path="/test_endpoint", - outputs=["output"], + outputs=[EndpointOutput(name="output", output="output")], ) input_sockets = [ @@ -155,6 +160,7 @@ def test_get_file_upload_field(): # Check that the file upload field has the correct description assert file_upload_field.description == "Upload image files." + def test_get_file_upload_field_multiple_file_uploads(): """Test the get_file_upload_field function with multiple file uploads.""" @@ -162,7 +168,7 @@ def test_get_file_upload_field_multiple_file_uploads(): name="test_endpoint", summary="Test endpoint", path="/test_endpoint", - outputs=["output"], + outputs=[EndpointOutput(name="output", output="output")], ) input_sockets = [ diff --git a/aana/tests/test_app.py b/aana/tests/test_app.py index 7e9ef534..0def3694 100644 --- a/aana/tests/test_app.py +++ b/aana/tests/test_app.py @@ -4,7 +4,7 @@ from ray import serve import requests -from aana.api.api_generation import Endpoint +from aana.api.api_generation import Endpoint, EndpointOutput from aana.api.request_handler import RequestHandler @@ -62,7 +62,7 @@ async def lower(self, text: str) -> dict: name="lowercase", path="/lowercase", summary="Lowercase text", - outputs=["lowercase_text"], + outputs=[EndpointOutput(name="text", output="lowercase_text")], ) ] @@ -103,5 +103,5 @@ def test_app(ray_setup): data={"body": json.dumps(data)}, ) assert response.status_code == 200 - lowercase_text = response.json().get("lowercase_text") + lowercase_text = response.json().get("text") assert lowercase_text == ["hello world!", "this is a test."] diff --git a/aana/tests/test_app_streaming.py b/aana/tests/test_app_streaming.py index 386c3c64..aed1cdd7 100644 --- a/aana/tests/test_app_streaming.py +++ b/aana/tests/test_app_streaming.py @@ -7,7 +7,7 @@ from ray import serve import requests -from aana.api.api_generation import Endpoint +from aana.api.api_generation import Endpoint, EndpointOutput from aana.api.request_handler import RequestHandler @@ -71,7 +71,7 @@ async def lower_stream(self, text: str) -> AsyncGenerator[dict, None]: name="lowercase", path="/lowercase", summary="Lowercase text", - outputs=["lowercase_text"], + outputs=[EndpointOutput(name="text", output="lowercase_text")], streaming=True, ) ] @@ -121,7 +121,7 @@ def test_app_streaming(ray_setup): offset = 0 for chunk in response.iter_content(chunk_size=None): json_data = json.loads(chunk) - lowercase_text_chunk = json_data["lowercase_text"] + lowercase_text_chunk = json_data["text"] lowercase_text += lowercase_text_chunk chunk_size = len(lowercase_text_chunk) diff --git a/aana/tests/test_build.py b/aana/tests/test_build.py index 0e069f3d..0e8f7ef9 100644 --- a/aana/tests/test_build.py +++ b/aana/tests/test_build.py @@ -1,7 +1,7 @@ from mobius_pipeline.exceptions import OutputNotFoundException import pytest -from aana.api.api_generation import Endpoint +from aana.api.api_generation import Endpoint, EndpointOutput from aana.configs.build import get_configuration nodes = [ @@ -70,7 +70,7 @@ name="lowercase", path="/lowercase", summary="Lowercase text", - outputs=["lowercase_text"], + outputs=[EndpointOutput(name="lowercase_text", output="lowercase_text")], ) ], "uppercase": [ @@ -78,7 +78,7 @@ name="uppercase", path="/uppercase", summary="Uppercase text", - outputs=["uppercase_text"], + outputs=[EndpointOutput(name="uppercase_text", output="uppercase_text")], ) ], "both": [ @@ -86,13 +86,13 @@ name="lowercase", path="/lowercase", summary="Lowercase text", - outputs=["lowercase_text"], + outputs=[EndpointOutput(name="lowercase_text", output="lowercase_text")], ), Endpoint( name="uppercase", path="/uppercase", summary="Uppercase text", - outputs=["uppercase_text"], + outputs=[EndpointOutput(name="uppercase_text", output="uppercase_text")], ), ], "non_existent": [ @@ -100,7 +100,7 @@ name="non_existent", path="/non_existent", summary="Non existent endpoint", - outputs=["non_existent"], + outputs=[EndpointOutput(name="non_existent", output="non_existent")], ) ], "capitalize": [ @@ -108,7 +108,7 @@ name="capitalize", path="/capitalize", summary="Capitalize text", - outputs=["capitalize_text"], + outputs=[EndpointOutput(name="capitalize_text", output="capitalize_text")], ) ], } diff --git a/aana/utils/video.py b/aana/utils/video.py index 234ca3d5..4c904d88 100644 --- a/aana/utils/video.py +++ b/aana/utils/video.py @@ -3,7 +3,7 @@ import numpy as np import yt_dlp from yt_dlp.utils import DownloadError -from typing import List, TypedDict +from typing import Generator, List, TypedDict from aana.configs.settings import settings from aana.exceptions.general import DownloadException, VideoReadingException from aana.models.core.image import Image @@ -62,7 +62,59 @@ def extract_frames_decord(video: Video, params: VideoParams) -> FramesDict: return FramesDict(frames=frames, timestamps=timestamps, duration=duration) -def download_video(video_input: VideoInput) -> Video: +def generate_frames_decord( + video: Video, params: VideoParams, batch_size: int = 8 +) -> Generator[FramesDict, None, None]: + """ + Generate frames from a video using decord. + + Args: + video (Video): the video to extract frames from + params (VideoParams): the parameters of the video extraction + batch_size (int): the number of frames to yield at each iteration + + Yields: + FramesDict: a dictionary containing the extracted frames, timestamps, + and duration for each batch + """ + + device = decord.cpu(0) + num_threads = 1 # TODO: see if we can use more threads + + num_fps: float = params.extract_fps + try: + video_reader = decord.VideoReader( + str(video.path), ctx=device, num_threads=num_threads + ) + except Exception as e: + raise VideoReadingException(video) from e + + video_fps = video_reader.get_avg_fps() + num_frames = len(video_reader) + duration = num_frames / video_fps + + if params.fast_mode_enabled: + indexes = video_reader.get_key_indices() + else: + # num_fps can be smaller than 1 (e.g. 0.5 means 1 frame every 2 seconds) + indexes = np.arange(0, num_frames, int(video_fps / num_fps)) + timestamps = video_reader.get_frame_timestamp(indexes)[:, 0].tolist() + + for i in range(0, len(indexes), batch_size): + batch = indexes[i : i + batch_size] + batch_frames_array = video_reader.get_batch(batch).asnumpy() + batch_frames = [] + for _, frame in enumerate(batch_frames_array): + img = Image(numpy=frame) + batch_frames.append(img) + + batch_timestamps = timestamps[i : i + batch_size] + yield FramesDict( + frames=batch_frames, timestamps=batch_timestamps, duration=duration + ) + + +def download_video(video_input: VideoInput | Video) -> Video: """ Downloads videos for a VideoInput object. @@ -72,6 +124,8 @@ def download_video(video_input: VideoInput) -> Video: Returns: Video: the video object """ + if isinstance(video_input, Video): + return video_input if video_input.url is not None: video_source: VideoSource = VideoSource.from_url(video_input.url) if video_source == VideoSource.YOUTUBE: diff --git a/mobius-pipeline b/mobius-pipeline index f455c656..386943bd 160000 --- a/mobius-pipeline +++ b/mobius-pipeline @@ -1 +1 @@ -Subproject commit f455c656da566866f031d74d8676bef0f558b4f3 +Subproject commit 386943bd78d8c3617013ac52bd18a92be0e19c5e diff --git a/pyproject.toml b/pyproject.toml index 7f1fc97d..42f759c3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,9 @@ faster-whisper = "^0.9.0" onnxruntime = "1.16.1" deepdiff = "^6.7.0" yt-dlp = "^2023.10.13" +qdrant-client = "^1.6.9" +bitsandbytes = "^0.41.2.post2" +accelerate = "^0.24.1" [tool.poetry.group.dev.dependencies] ipykernel = "^6.25.2"