diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 1e7e8816..06eccb77 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,5 +1,6 @@ { "name": "Ubuntu", + "runArgs": ["--name", "${localEnv:USER}_dev_container"], "build": { "dockerfile": "Dockerfile" }, diff --git a/README.md b/README.md index 1dae7f24..e332240f 100644 --- a/README.md +++ b/README.md @@ -109,7 +109,7 @@ If you are using Visual Studio Code, you can run this repository in a run everything you need for the repo in an isolated environment via docker on a host system. Running it somewhere other than a Mobius dev server may cause issues due to the mounts of `/nas` and `/nas2` inside the container, but you can specify the environment variables for VS Code `PATH_NAS` and -`PATH_NAS2` which will override the default locations used for these mount points (otherise they default +`PATH_NAS2` which will override the default locations used for these mount points (otherwise they default to look for `/nas` and `/nas2`). You can read more about environment variables for dev containers [here](https://containers.dev/implementors/json_reference/). @@ -195,7 +195,7 @@ This will run the tests normally using GPU and save the deployment cache after r ## Databases -The project uses two databases: a vector database as well as a tradtional SQL database, +The project uses two databases: a vector database as well as a traditional SQL database, referred to internally as vectorstore and datastore, respectively. ### Vectorstore @@ -219,10 +219,10 @@ Higher level code for interacting with the ORM is available in `aana.repository. ## Settings -Here are the environment variables that can be used to configure the Aaana SDK: +Here are the environment variables that can be used to configure the Aana SDK: - TMP_DATA_DIR: The directory to store temporary data. Default: `/tmp/aana`. - NUM_WORKERS: The number of request workers. Default: `2`. -- DB_CONFIG: The database configuration in the format `{"datastore_type": "sqlite", "datastore_config": {"path": "/path/to/sqlite.db"}}`. Currently only SQLite and PostgreSQL are supported. Default: `{"datastore_type": "sqlite", "datastore_config": {"path": "/var/lib/aana_data"}}`. +- DB_CONFIG: The database configuration in the format `{"datastore_type": "sqlite", "datastore_config": {"path": "/path/to/sqlite.db"}}`. Currently, only SQLite and PostgreSQL are supported. Default: `{"datastore_type": "sqlite", "datastore_config": {"path": "/var/lib/aana_data"}}`. - USE_DEPLOYMENT_CACHE (testing only): If set to `true`, the tests will use the deployment cache to avoid downloading the models and running the deployments. Default: `false`. - SAVE_DEPLOYMENT_CACHE (testing only): If set to `true`, the tests will save the deployment cache after running the deployments. Default: `false`. - HF_HUB_ENABLE_HF_TRANSFER: If set to `1`, the HuggingFace Transformers will use the HF Transfer library to download the models from HuggingFace Hub to speed up the process. Recommended to always set to it `1`. Default: `0`. diff --git a/aana/core/models/image.py b/aana/core/models/image.py index 25697b3f..4fe20b83 100644 --- a/aana/core/models/image.py +++ b/aana/core/models/image.py @@ -3,15 +3,17 @@ import uuid from dataclasses import dataclass from pathlib import Path +from typing import Annotated import numpy as np import PIL.Image from pydantic import ( + AfterValidator, + AnyUrl, BaseModel, ConfigDict, Field, ValidationError, - field_validator, model_validator, ) from pydantic_core import InitErrorDetails @@ -255,15 +257,17 @@ class ImageInput(BaseModel): Attributes: path (str): the file path of the image - url (str): the URL of the image + url (AnyUrl): the URL of the image content (bytes): the content of the image in bytes numpy (bytes): the image as a numpy array """ path: str | None = Field(None, description="The file path of the image.") - url: str | None = Field( - None, description="The URL of the image." - ) # TODO: validate url + url: Annotated[ + AnyUrl | None, + AfterValidator(lambda x: str(x) if x else None), + Field(None, description="The URL of the image."), + ] content: bytes | None = Field( None, description=( @@ -283,24 +287,6 @@ class ImageInput(BaseModel): description="The ID of the image. If not provided, it will be generated automatically.", ) - @field_validator("media_id") - @classmethod - def media_id_must_not_be_empty(cls, media_id): - """Validates that the media_id is not an empty string. - - Args: - media_id (MediaId): The value of the media_id field. - - Raises: - ValueError: If the media_id is an empty string. - - Returns: - str: The non-empty media_id value. - """ - if media_id == "": - raise ValueError("media_id cannot be an empty string") # noqa: TRY003 - return media_id - def set_file(self, file: bytes): """Sets the instance internal file data. diff --git a/aana/core/models/stream.py b/aana/core/models/stream.py new file mode 100644 index 00000000..7fb2c940 --- /dev/null +++ b/aana/core/models/stream.py @@ -0,0 +1,58 @@ +import uuid +from typing import Annotated + +from pydantic import ( + AfterValidator, + AnyUrl, + BaseModel, + ConfigDict, + Field, +) + +from aana.core.models.media import MediaId + + +class StreamInput(BaseModel): + """A video stream input. + + The 'url' must be provided. + + Attributes: + media_id (MediaId): the ID of the video stream. If not provided, it will be generated automatically. + url (AnyUrl): the URL of the video stream + channel_number (int): the desired channel of stream to be processed + extract_fps (float): the number of frames to extract per second + """ + + url: Annotated[ + AnyUrl, + Field(description="The URL of the video stream."), + AfterValidator(lambda x: str(x)), + ] + channel_number: int = Field( + default=0, + ge=0, + description=("the desired channel of stream"), + ) + + extract_fps: float = Field( + default=3.0, + gt=0.0, + description=( + "The number of frames to extract per second. " + "Can be smaller than 1. For example, 0.5 means 1 frame every 2 seconds." + ), + ) + + media_id: MediaId = Field( + default_factory=lambda: str(uuid.uuid4()), + description="The ID of the video. If not provided, it will be generated automatically.", + ) + + model_config = ConfigDict( + json_schema_extra={ + "description": ("A video Stream. \n" "The 'url' must be provided. \n") + }, + validate_assignment=True, + file_upload=False, + ) diff --git a/aana/core/models/video.py b/aana/core/models/video.py index d19caaec..093a359a 100644 --- a/aana/core/models/video.py +++ b/aana/core/models/video.py @@ -3,18 +3,19 @@ from pathlib import Path import torch, decord # noqa: F401 # See https://github.com/dmlc/decord/issues/263 from decord import DECORDError - +from typing import Annotated from aana.configs.settings import settings from aana.exceptions.io import VideoReadingException from aana.core.models.media import Media import uuid from pydantic import ( + AfterValidator, + AnyUrl, BaseModel, ConfigDict, Field, ValidationError, - field_validator, model_validator, ) from pydantic_core import InitErrorDetails @@ -144,7 +145,7 @@ class VideoParams(BaseModel): """A pydantic model for video parameters. Attributes: - extract_fps (int): the number of frames to extract per second + extract_fps (float): the number of frames to extract per second fast_mode_enabled (bool): whether to use fast mode (keyframes only) """ @@ -179,14 +180,16 @@ class VideoInput(BaseModel): Attributes: media_id (MediaId): the ID of the video. If not provided, it will be generated automatically. path (str): the file path of the video - url (str): the URL of the video (supports YouTube videos) + url (AnyUrl): the URL of the video (supports YouTube videos) content (bytes): the content of the video in bytes """ path: str | None = Field(None, description="The file path of the video.") - url: str | None = Field( - None, description="The URL of the video (supports YouTube videos)." - ) + url: Annotated[ + AnyUrl | None, + AfterValidator(lambda x: str(x) if x else None), + Field(None, description="The URL of the video (supports YouTube videos)."), + ] content: bytes | None = Field( None, description=( @@ -199,43 +202,6 @@ class VideoInput(BaseModel): description="The ID of the video. If not provided, it will be generated automatically.", ) - @field_validator("url") - @classmethod - def check_url(cls, url: str) -> str: - """Check that the URL is valid and supported. - - Right now, we support normal URLs and youtube URLs. - - Args: - url (str): the URL - - Returns: - str: the valid URL - - Raises: - ValueError: if the URL is invalid or unsupported - """ - # TODO: implement the youtube URL validation - return url - - @field_validator("media_id") - @classmethod - def media_id_must_not_be_empty(cls, media_id): - """Validates that the media_id is not an empty string. - - Args: - media_id (MediaId): The value of the media_id field. - - Raises: - ValueError: If the media_id is an empty string. - - Returns: - str: The non-empty media_id value. - """ - if media_id == "": - raise ValueError("media_id cannot be an empty string") # noqa: TRY003 - return media_id - @model_validator(mode="after") def check_only_one_field(self) -> Self: """Check that exactly one of 'path', 'url', or 'content' is provided. diff --git a/aana/exceptions/io.py b/aana/exceptions/io.py index 306254fd..bff90c1e 100644 --- a/aana/exceptions/io.py +++ b/aana/exceptions/io.py @@ -102,3 +102,25 @@ class VideoReadingException(VideoException): """ pass + +class StreamReadingException(BaseException): + """Exception raised when there is an error reading an stream. + + Attributes: + stream (Stream): the stream that caused the exception + """ + + def __init__(self, url: str, msg: str = ""): + """Initialize the exception. + + Args: + url (str): the URL of the stream that caused the exception + msg (str): the error message + """ + super().__init__(url=url) + self.url = url + self.msg = msg + + def __reduce__(self): + """Used for pickling.""" + return (self.__class__, (self.url, self.msg)) diff --git a/aana/integrations/external/av.py b/aana/integrations/external/av.py index d09e1445..9cc37fb5 100644 --- a/aana/integrations/external/av.py +++ b/aana/integrations/external/av.py @@ -4,11 +4,22 @@ import wave from collections.abc import Generator from pathlib import Path +from typing import TypedDict import av import numpy as np from aana.core.libraries.audio import AbstractAudioLibrary +from aana.core.models.image import Image +from aana.core.models.stream import StreamInput +from aana.exceptions.io import StreamReadingException + + +class FramesDict(TypedDict): + """Represents a set of frames with ids, timestamps.""" + frames: list[Image] + timestamps: list[float] + frame_ids: list[int] def load_audio(file: Path | None, sample_rate: int = 16000) -> bytes: @@ -120,6 +131,69 @@ def resample_frames(frames: Generator, resampler) -> Generator: yield from resampler.resample(frame) +def fetch_stream_frames( + stream_input: StreamInput, batch_size: int = 2 +) -> Generator[FramesDict, None, None]: + """Generate frames from a video using decord. + + Args: + stream_input (StreamInput): the video stream to fetch frames from + batch_size (int): the number of frames to yield at each iteration + Yields: + FramesDict: a dictionary containing the extracted frames, frame ids, timestamps, and duration for each batch + """ + stream_url = stream_input.url + channel = stream_input.channel_number + extraction_fps = stream_input.extract_fps + + try: + stream_container = av.open(stream_url) + except Exception as e: + raise StreamReadingException(stream_url) from e + + available_streams = [s for s in stream_container.streams if s.type == "video"] + + # Check the stream channel be valid + if len(available_streams) == 0 or channel >= len(available_streams): + raise StreamReadingException( + stream_url, + msg=f"selected channel does not exist: {channel + 1} from {len(available_streams)}", + ) + video_stream = available_streams[channel] + + avg_rate = float(video_stream.average_rate) + + if extraction_fps > avg_rate: + extraction_fps = avg_rate + + frame_rate = int(avg_rate / extraction_fps) + + # read frames from the stream + frame_number = 0 + batch_frames = [] + batch_timestamps = [] + num_batches = 0 + + for packet in stream_container.demux(video_stream): + for frame in packet.decode(): + if frame_number % frame_rate == 0: + img = Image(numpy=frame.to_rgb().to_ndarray()) + packet_timestamp = float(frame.pts * frame.time_base) # in seconds + batch_frames.append(img) + batch_timestamps.append(packet_timestamp) + frame_number += 1 + if len(batch_frames) == batch_size: + num_batches += 1 + yield FramesDict( + frames=batch_frames, + frame_ids=list( + range(num_batches * batch_size, (num_batches + 1) * batch_size) + ), + timestamps=batch_timestamps, + ) + batch_frames = [] + batch_timestamps = [] + class pyAVWrapper(AbstractAudioLibrary): """Class for audio handling using PyAV library.""" diff --git a/aana/projects/chat_with_video/app.py b/aana/projects/chat_with_video/app.py index 3aece801..aea78ae6 100644 --- a/aana/projects/chat_with_video/app.py +++ b/aana/projects/chat_with_video/app.py @@ -1,6 +1,6 @@ from aana.configs.deployments import ( - meta_llama3_8b_instruct_deployment, hf_blip2_opt_2_7b_deployment, + meta_llama3_8b_instruct_deployment, vad_deployment, whisper_medium_deployment, ) diff --git a/aana/projects/process_stream/__init__.py b/aana/projects/process_stream/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aana/projects/process_stream/app.py b/aana/projects/process_stream/app.py new file mode 100644 index 00000000..1e7315e9 --- /dev/null +++ b/aana/projects/process_stream/app.py @@ -0,0 +1,48 @@ +import argparse + +from aana.configs.deployments import hf_blip2_opt_2_7b_deployment +from aana.projects.process_stream.endpoints import ( + CaptionStreamEndpoint, +) +from aana.sdk import AanaSDK + +deployments = [ + { + "name": "captioning_deployment", + "instance": hf_blip2_opt_2_7b_deployment, + }, +] + +endpoints = [ + { + "name": "caption_live_stream", + "path": "/stream/caption_stream", + "summary": "Process a live stream and return the captions", + "endpoint_cls": CaptionStreamEndpoint, + }, +] + +if __name__ == "__main__": + """Runs the application.""" + arg_parser = argparse.ArgumentParser() + arg_parser.add_argument("--port", type=int, default=8000) + arg_parser.add_argument("--host", type=str, default="127.0.0.1") + args = arg_parser.parse_args() + + aana_app = AanaSDK(port=args.port, host=args.host, show_logs=True) + + for deployment in deployments: + aana_app.register_deployment( + name=deployment["name"], + instance=deployment["instance"], + ) + + for endpoint in endpoints: + aana_app.register_endpoint( + name=endpoint["name"], + path=endpoint["path"], + summary=endpoint["summary"], + endpoint_cls=endpoint["endpoint_cls"], + ) + + aana_app.deploy(blocking=True) \ No newline at end of file diff --git a/aana/projects/process_stream/const.py b/aana/projects/process_stream/const.py new file mode 100644 index 00000000..46d039ed --- /dev/null +++ b/aana/projects/process_stream/const.py @@ -0,0 +1 @@ +captioning_model_name = "hf_blip2_opt_2_7b" diff --git a/aana/projects/process_stream/endpoints.py b/aana/projects/process_stream/endpoints.py new file mode 100644 index 00000000..5ee2af19 --- /dev/null +++ b/aana/projects/process_stream/endpoints.py @@ -0,0 +1,46 @@ +from collections.abc import AsyncGenerator +from typing import Annotated, TypedDict + +from pydantic import Field + +from aana.api.api_generation import Endpoint +from aana.core.models.stream import StreamInput +from aana.deployments.aana_deployment_handle import AanaDeploymentHandle +from aana.integrations.external.av import fetch_stream_frames +from aana.processors.remote import run_remote + + +class CaptionStreamOutput(TypedDict): + """The output of the transcribe video endpoint.""" + + captions: Annotated[list[str], Field(..., description="Captions")] + timestamps: Annotated[ + list[float], Field(..., description="Timestamps for each caption in seconds") + ] + + +class CaptionStreamEndpoint(Endpoint): + """Transcribe video in chunks endpoint.""" + + async def initialize(self): + """Initialize the endpoint.""" + self.captioning_handle = await AanaDeploymentHandle.create( + "captioning_deployment" + ) + + async def run( + self, + stream: StreamInput, + ) -> AsyncGenerator[CaptionStreamOutput, None]: + """Transcribe video in chunks.""" + async for frames_dict in run_remote(fetch_stream_frames)( + stream_input=stream, batch_size=2 + ): + captioning_output = await self.captioning_handle.generate_batch( + images=frames_dict["frames"] + ) + + yield { + "captions": captioning_output["captions"], + "timestamps": frames_dict["timestamps"], + } \ No newline at end of file diff --git a/aana/tests/units/test_frame_extraction.py b/aana/tests/units/test_frame_extraction.py index 03d5b3dd..c485729d 100644 --- a/aana/tests/units/test_frame_extraction.py +++ b/aana/tests/units/test_frame_extraction.py @@ -4,8 +4,10 @@ import pytest from aana.core.models.image import Image +from aana.core.models.stream import StreamInput from aana.core.models.video import Video, VideoParams -from aana.exceptions.io import VideoReadingException +from aana.exceptions.io import StreamReadingException, VideoReadingException +from aana.integrations.external.av import fetch_stream_frames from aana.integrations.external.decord import extract_frames, generate_frames @@ -89,3 +91,64 @@ def test_extract_frames_failure(): invalid_video = Video(path=path) params = VideoParams(extract_fps=1.0, fast_mode_enabled=False) extract_frames(video=invalid_video, params=params) + +@pytest.mark.parametrize( + "mode, url, channel_number, extract_fps", + [ + ( + "hls", + "https://live-par-2-cdn-alt.livepush.io/live/bigbuckbunnyclip/index.m3u8", + 0, + 3, + ), + ( + "dash", + "https://live-par-2-cdn-alt.livepush.io/live/bigbuckbunnyclip/index.mpd", + 0, + 3, + ), + ( + "mp4", + "https://live-par-2-abr.livepush.io/vod/bigbuckbunnyclip.mp4", + 0, + 3, + ), + ], +) +def test_fetch_stream_frames(mode, url, channel_number, extract_fps): + """Test fetch_stream_frames. + + fetch_stream_frames is a generator function that yields a dictionary + containing the frames, timestamps and frame_ids of the stream. + """ + stream_input = StreamInput( + url=url, channel_number=channel_number, extract_fps=extract_fps + ) + gen_frame = fetch_stream_frames(stream_input, batch_size=1) + total_frames = 0 + for result in gen_frame: + assert "frames" in result + assert "frame_ids" in result + assert "timestamps" in result + assert isinstance(result["frames"], list) + assert isinstance(result["frame_ids"], list) + assert isinstance(result["timestamps"], list) + + assert isinstance(result["frames"][0], Image) + assert len(result["frames"]) == 1 # batch_size = 1 + assert len(result["timestamps"]) == 1 # batch_size = 1 + + total_frames += 1 + if total_frames > 10: + return + print(f"{mode} is supported") + + +def test_fetch_stream_frames_failure(): + """Test that frames cannot be extracted from a youtube video.""" + url = "https://www.youtube.com/watch?v=T98dnE2vPdY" + stream_input = StreamInput(url=url, channel_number=0, extract_fps=3) + with pytest.raises(StreamReadingException): + gen_frame = fetch_stream_frames(stream_input, batch_size=1) + for _ in gen_frame: + return diff --git a/aana/tests/units/test_stream_input.py b/aana/tests/units/test_stream_input.py new file mode 100644 index 00000000..ab4cefe0 --- /dev/null +++ b/aana/tests/units/test_stream_input.py @@ -0,0 +1,37 @@ +# ruff: noqa: S101 +import pytest +from pydantic import ValidationError + +from aana.core.models.stream import StreamInput + + +def test_new_stream_input_success(): + """Test that StreamInput can be created successfully.""" + stream_input = StreamInput(url="http://example.com/stream.m3u8") + assert stream_input.url == "http://example.com/stream.m3u8" + + +def test_stream_input_invalid_media_id(): + """Test that StreamInput can't be created if media_id is invalid.""" + with pytest.raises(ValidationError): + StreamInput(url="http://example.com/stream.m3u8", media_id="") + + +@pytest.mark.parametrize( + "url, extract_fps", + [("http://example.com/stream.m3u8", 0), ("http://example.com/stream.m3u8", -1)], +) +def test_stream_input_invalid_extract_fps(url, extract_fps): + """Test that StreamInput can't be created if extract_fps is invalid.""" + with pytest.raises(ValidationError): + StreamInput(url=url, extract_fps=extract_fps) + + +@pytest.mark.parametrize( + "url, channel_number", + [("http://example.com/stream.m3u8", -1), ("http://example.com/stream.m3u8", 0.3)], +) +def test_stream_input_invalid_channel(url, channel_number): + """Test that StreamInput can't be created if channel number is invalid.""" + with pytest.raises(ValidationError): + StreamInput(url=url, channel_number=channel_number)