Skip to content

Commit

Permalink
merge main branch
Browse files Browse the repository at this point in the history
  • Loading branch information
HRashidi committed Jun 11, 2024
1 parent 530ae92 commit bc165e2
Show file tree
Hide file tree
Showing 14 changed files with 375 additions and 73 deletions.
1 change: 1 addition & 0 deletions .devcontainer/devcontainer.json
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
{
"name": "Ubuntu",
"runArgs": ["--name", "${localEnv:USER}_dev_container"],
"build": {
"dockerfile": "Dockerfile"
},
Expand Down
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ If you are using Visual Studio Code, you can run this repository in a
run everything you need for the repo in an isolated environment via docker on a host system.
Running it somewhere other than a Mobius dev server may cause issues due to the mounts of `/nas` and
`/nas2` inside the container, but you can specify the environment variables for VS Code `PATH_NAS` and
`PATH_NAS2` which will override the default locations used for these mount points (otherise they default
`PATH_NAS2` which will override the default locations used for these mount points (otherwise they default
to look for `/nas` and `/nas2`). You can read more about environment variables for dev containers
[here](https://containers.dev/implementors/json_reference/).

Expand Down Expand Up @@ -195,7 +195,7 @@ This will run the tests normally using GPU and save the deployment cache after r


## Databases
The project uses two databases: a vector database as well as a tradtional SQL database,
The project uses two databases: a vector database as well as a traditional SQL database,
referred to internally as vectorstore and datastore, respectively.

### Vectorstore
Expand All @@ -219,10 +219,10 @@ Higher level code for interacting with the ORM is available in `aana.repository.

## Settings

Here are the environment variables that can be used to configure the Aaana SDK:
Here are the environment variables that can be used to configure the Aana SDK:
- TMP_DATA_DIR: The directory to store temporary data. Default: `/tmp/aana`.
- NUM_WORKERS: The number of request workers. Default: `2`.
- DB_CONFIG: The database configuration in the format `{"datastore_type": "sqlite", "datastore_config": {"path": "/path/to/sqlite.db"}}`. Currently only SQLite and PostgreSQL are supported. Default: `{"datastore_type": "sqlite", "datastore_config": {"path": "/var/lib/aana_data"}}`.
- DB_CONFIG: The database configuration in the format `{"datastore_type": "sqlite", "datastore_config": {"path": "/path/to/sqlite.db"}}`. Currently, only SQLite and PostgreSQL are supported. Default: `{"datastore_type": "sqlite", "datastore_config": {"path": "/var/lib/aana_data"}}`.
- USE_DEPLOYMENT_CACHE (testing only): If set to `true`, the tests will use the deployment cache to avoid downloading the models and running the deployments. Default: `false`.
- SAVE_DEPLOYMENT_CACHE (testing only): If set to `true`, the tests will save the deployment cache after running the deployments. Default: `false`.
- HF_HUB_ENABLE_HF_TRANSFER: If set to `1`, the HuggingFace Transformers will use the HF Transfer library to download the models from HuggingFace Hub to speed up the process. Recommended to always set to it `1`. Default: `0`.
32 changes: 9 additions & 23 deletions aana/core/models/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,17 @@
import uuid
from dataclasses import dataclass
from pathlib import Path
from typing import Annotated

import numpy as np
import PIL.Image
from pydantic import (
AfterValidator,
AnyUrl,
BaseModel,
ConfigDict,
Field,
ValidationError,
field_validator,
model_validator,
)
from pydantic_core import InitErrorDetails
Expand Down Expand Up @@ -255,15 +257,17 @@ class ImageInput(BaseModel):
Attributes:
path (str): the file path of the image
url (str): the URL of the image
url (AnyUrl): the URL of the image
content (bytes): the content of the image in bytes
numpy (bytes): the image as a numpy array
"""

path: str | None = Field(None, description="The file path of the image.")
url: str | None = Field(
None, description="The URL of the image."
) # TODO: validate url
url: Annotated[
AnyUrl | None,
AfterValidator(lambda x: str(x) if x else None),
Field(None, description="The URL of the image."),
]
content: bytes | None = Field(
None,
description=(
Expand All @@ -283,24 +287,6 @@ class ImageInput(BaseModel):
description="The ID of the image. If not provided, it will be generated automatically.",
)

@field_validator("media_id")
@classmethod
def media_id_must_not_be_empty(cls, media_id):
"""Validates that the media_id is not an empty string.
Args:
media_id (MediaId): The value of the media_id field.
Raises:
ValueError: If the media_id is an empty string.
Returns:
str: The non-empty media_id value.
"""
if media_id == "":
raise ValueError("media_id cannot be an empty string") # noqa: TRY003
return media_id

def set_file(self, file: bytes):
"""Sets the instance internal file data.
Expand Down
58 changes: 58 additions & 0 deletions aana/core/models/stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import uuid
from typing import Annotated

from pydantic import (
AfterValidator,
AnyUrl,
BaseModel,
ConfigDict,
Field,
)

from aana.core.models.media import MediaId


class StreamInput(BaseModel):
"""A video stream input.
The 'url' must be provided.
Attributes:
media_id (MediaId): the ID of the video stream. If not provided, it will be generated automatically.
url (AnyUrl): the URL of the video stream
channel_number (int): the desired channel of stream to be processed
extract_fps (float): the number of frames to extract per second
"""

url: Annotated[
AnyUrl,
Field(description="The URL of the video stream."),
AfterValidator(lambda x: str(x)),
]
channel_number: int = Field(
default=0,
ge=0,
description=("the desired channel of stream"),
)

extract_fps: float = Field(
default=3.0,
gt=0.0,
description=(
"The number of frames to extract per second. "
"Can be smaller than 1. For example, 0.5 means 1 frame every 2 seconds."
),
)

media_id: MediaId = Field(
default_factory=lambda: str(uuid.uuid4()),
description="The ID of the video. If not provided, it will be generated automatically.",
)

model_config = ConfigDict(
json_schema_extra={
"description": ("A video Stream. \n" "The 'url' must be provided. \n")
},
validate_assignment=True,
file_upload=False,
)
54 changes: 10 additions & 44 deletions aana/core/models/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@
from pathlib import Path
import torch, decord # noqa: F401 # See https://github.com/dmlc/decord/issues/263
from decord import DECORDError

from typing import Annotated
from aana.configs.settings import settings
from aana.exceptions.io import VideoReadingException
from aana.core.models.media import Media
import uuid

from pydantic import (
AfterValidator,
AnyUrl,
BaseModel,
ConfigDict,
Field,
ValidationError,
field_validator,
model_validator,
)
from pydantic_core import InitErrorDetails
Expand Down Expand Up @@ -144,7 +145,7 @@ class VideoParams(BaseModel):
"""A pydantic model for video parameters.
Attributes:
extract_fps (int): the number of frames to extract per second
extract_fps (float): the number of frames to extract per second
fast_mode_enabled (bool): whether to use fast mode (keyframes only)
"""

Expand Down Expand Up @@ -179,14 +180,16 @@ class VideoInput(BaseModel):
Attributes:
media_id (MediaId): the ID of the video. If not provided, it will be generated automatically.
path (str): the file path of the video
url (str): the URL of the video (supports YouTube videos)
url (AnyUrl): the URL of the video (supports YouTube videos)
content (bytes): the content of the video in bytes
"""

path: str | None = Field(None, description="The file path of the video.")
url: str | None = Field(
None, description="The URL of the video (supports YouTube videos)."
)
url: Annotated[
AnyUrl | None,
AfterValidator(lambda x: str(x) if x else None),
Field(None, description="The URL of the video (supports YouTube videos)."),
]
content: bytes | None = Field(
None,
description=(
Expand All @@ -199,43 +202,6 @@ class VideoInput(BaseModel):
description="The ID of the video. If not provided, it will be generated automatically.",
)

@field_validator("url")
@classmethod
def check_url(cls, url: str) -> str:
"""Check that the URL is valid and supported.
Right now, we support normal URLs and youtube URLs.
Args:
url (str): the URL
Returns:
str: the valid URL
Raises:
ValueError: if the URL is invalid or unsupported
"""
# TODO: implement the youtube URL validation
return url

@field_validator("media_id")
@classmethod
def media_id_must_not_be_empty(cls, media_id):
"""Validates that the media_id is not an empty string.
Args:
media_id (MediaId): The value of the media_id field.
Raises:
ValueError: If the media_id is an empty string.
Returns:
str: The non-empty media_id value.
"""
if media_id == "":
raise ValueError("media_id cannot be an empty string") # noqa: TRY003
return media_id

@model_validator(mode="after")
def check_only_one_field(self) -> Self:
"""Check that exactly one of 'path', 'url', or 'content' is provided.
Expand Down
22 changes: 22 additions & 0 deletions aana/exceptions/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,3 +102,25 @@ class VideoReadingException(VideoException):
"""

pass

class StreamReadingException(BaseException):
"""Exception raised when there is an error reading an stream.
Attributes:
stream (Stream): the stream that caused the exception
"""

def __init__(self, url: str, msg: str = ""):
"""Initialize the exception.
Args:
url (str): the URL of the stream that caused the exception
msg (str): the error message
"""
super().__init__(url=url)
self.url = url
self.msg = msg

def __reduce__(self):
"""Used for pickling."""
return (self.__class__, (self.url, self.msg))
74 changes: 74 additions & 0 deletions aana/integrations/external/av.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,22 @@
import wave
from collections.abc import Generator
from pathlib import Path
from typing import TypedDict

import av
import numpy as np

from aana.core.libraries.audio import AbstractAudioLibrary
from aana.core.models.image import Image
from aana.core.models.stream import StreamInput
from aana.exceptions.io import StreamReadingException


class FramesDict(TypedDict):
"""Represents a set of frames with ids, timestamps."""
frames: list[Image]
timestamps: list[float]
frame_ids: list[int]


def load_audio(file: Path | None, sample_rate: int = 16000) -> bytes:
Expand Down Expand Up @@ -120,6 +131,69 @@ def resample_frames(frames: Generator, resampler) -> Generator:
yield from resampler.resample(frame)


def fetch_stream_frames(
stream_input: StreamInput, batch_size: int = 2
) -> Generator[FramesDict, None, None]:
"""Generate frames from a video using decord.
Args:
stream_input (StreamInput): the video stream to fetch frames from
batch_size (int): the number of frames to yield at each iteration
Yields:
FramesDict: a dictionary containing the extracted frames, frame ids, timestamps, and duration for each batch
"""
stream_url = stream_input.url
channel = stream_input.channel_number
extraction_fps = stream_input.extract_fps

try:
stream_container = av.open(stream_url)
except Exception as e:
raise StreamReadingException(stream_url) from e

available_streams = [s for s in stream_container.streams if s.type == "video"]

# Check the stream channel be valid
if len(available_streams) == 0 or channel >= len(available_streams):
raise StreamReadingException(
stream_url,
msg=f"selected channel does not exist: {channel + 1} from {len(available_streams)}",
)
video_stream = available_streams[channel]

avg_rate = float(video_stream.average_rate)

if extraction_fps > avg_rate:
extraction_fps = avg_rate

frame_rate = int(avg_rate / extraction_fps)

# read frames from the stream
frame_number = 0
batch_frames = []
batch_timestamps = []
num_batches = 0

for packet in stream_container.demux(video_stream):
for frame in packet.decode():
if frame_number % frame_rate == 0:
img = Image(numpy=frame.to_rgb().to_ndarray())
packet_timestamp = float(frame.pts * frame.time_base) # in seconds
batch_frames.append(img)
batch_timestamps.append(packet_timestamp)
frame_number += 1
if len(batch_frames) == batch_size:
num_batches += 1
yield FramesDict(
frames=batch_frames,
frame_ids=list(
range(num_batches * batch_size, (num_batches + 1) * batch_size)
),
timestamps=batch_timestamps,
)
batch_frames = []
batch_timestamps = []

class pyAVWrapper(AbstractAudioLibrary):
"""Class for audio handling using PyAV library."""

Expand Down
2 changes: 1 addition & 1 deletion aana/projects/chat_with_video/app.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from aana.configs.deployments import (
meta_llama3_8b_instruct_deployment,
hf_blip2_opt_2_7b_deployment,
meta_llama3_8b_instruct_deployment,
vad_deployment,
whisper_medium_deployment,
)
Expand Down
Empty file.
Loading

0 comments on commit bc165e2

Please sign in to comment.