diff --git a/aana/api/api_generation.py b/aana/api/api_generation.py index 5fd1f0ec..d18b7bf6 100644 --- a/aana/api/api_generation.py +++ b/aana/api/api_generation.py @@ -19,8 +19,8 @@ from aana.exceptions.runtime import ( MultipleFileUploadNotAllowed, ) -from aana.storage.engine import engine from aana.storage.services.task import create_task +from aana.storage.session import get_session def get_default_values(func): @@ -77,7 +77,7 @@ async def initialize(self): self.asr_handle = await AanaDeploymentHandle.create("whisper_deployment") ``` """ - self.session = Session(engine) + self.session = get_session() self.initialized = True async def run(self, *args, **kwargs): diff --git a/aana/api/request_handler.py b/aana/api/request_handler.py index f79df0a4..94566085 100644 --- a/aana/api/request_handler.py +++ b/aana/api/request_handler.py @@ -19,13 +19,13 @@ from aana.core.models.sampling import SamplingParams from aana.core.models.task import TaskId, TaskInfo from aana.deployments.aana_deployment_handle import AanaDeploymentHandle -from aana.storage.engine import engine from aana.storage.repository.task import TaskRepository +from aana.storage.session import get_session def get_db(): """Get a database session.""" - db = Session(engine) + db = get_session() try: yield db finally: diff --git a/aana/core/models/chat.py b/aana/core/models/chat.py index 41ce38af..4c009f37 100644 --- a/aana/core/models/chat.py +++ b/aana/core/models/chat.py @@ -1,7 +1,6 @@ from typing import Annotated, Literal from pydantic import BaseModel, ConfigDict, Field -from typing_extensions import Doc __all__ = [ "Role", diff --git a/aana/core/models/whisper.py b/aana/core/models/whisper.py index 2e0629d0..54c1a615 100644 --- a/aana/core/models/whisper.py +++ b/aana/core/models/whisper.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator -class MyConfig(ConfigDict, total=False): # noqa: D101 +class MyConfig(ConfigDict, total=False): json_schema_extra: dict diff --git a/aana/deployments/task_queue_deployment.py b/aana/deployments/task_queue_deployment.py index d872b928..3e63ea9d 100644 --- a/aana/deployments/task_queue_deployment.py +++ b/aana/deployments/task_queue_deployment.py @@ -6,13 +6,13 @@ import ray from pydantic import BaseModel, Field from ray import serve -from sqlalchemy.orm import Session from aana.api.exception_handler import custom_exception_handler from aana.configs.settings import settings as aana_settings from aana.deployments.base_deployment import BaseDeployment from aana.storage.models.task import Status as TaskStatus from aana.storage.repository.task import TaskRepository +from aana.storage.session import get_session from aana.utils.asyncio import run_async @@ -35,10 +35,7 @@ def __init__(self): self.loop_task.add_done_callback( lambda fut: fut.result() if not fut.cancelled() else None ) - - from aana.storage.engine import engine - - self.session = Session(engine) + self.session = get_session() self.task_repo = TaskRepository(self.session) def check_health(self): diff --git a/aana/storage/__init__.py b/aana/storage/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/aana/storage/models/__init__.py b/aana/storage/models/__init__.py index fc77ecec..11d60631 100644 --- a/aana/storage/models/__init__.py +++ b/aana/storage/models/__init__.py @@ -18,3 +18,11 @@ from aana.storage.models.task import TaskEntity from aana.storage.models.transcript import TranscriptEntity from aana.storage.models.video import VideoEntity + +__all__ = [ + "BaseEntity", + "MediaEntity", + "VideoEntity", + "CaptionEntity", + "TranscriptEntity", +] diff --git a/aana/storage/models/caption.py b/aana/storage/models/caption.py index d93685e0..a53b860e 100644 --- a/aana/storage/models/caption.py +++ b/aana/storage/models/caption.py @@ -12,7 +12,16 @@ class CaptionEntity(BaseEntity, TimeStampEntity): - """ORM model for video captions.""" + """ORM model for video captions. + + Attributes: + id (int): Unique identifier for the caption. + model (str): Name of the model used to generate the caption. + frame_id (int): The 0-based frame id of video for caption. + caption (str): Frame caption. + timestamp (float): Frame timestamp in seconds. + caption_type (str): The type of caption (populated automatically by ORM based on `polymorphic_identity` of subclass). + """ __tablename__ = "caption" @@ -44,7 +53,17 @@ def from_caption_output( frame_id: int, timestamp: float, ) -> CaptionEntity: - """Converts a Caption pydantic model to a CaptionEntity.""" + """Converts a Caption pydantic model to a CaptionEntity. + + Args: + model_name (str): Name of the model used to generate the caption. + caption (Caption): Caption pydantic model. + frame_id (int): The 0-based frame id of video for caption. + timestamp (float): Frame timestamp in seconds. + + Returns: + CaptionEntity: ORM model for video captions. + """ return CaptionEntity( model=model_name, frame_id=frame_id, diff --git a/aana/storage/models/media.py b/aana/storage/models/media.py index 1f2123b9..e5fe8b20 100644 --- a/aana/storage/models/media.py +++ b/aana/storage/models/media.py @@ -7,7 +7,14 @@ class MediaEntity(BaseEntity, TimeStampEntity): - """Table for media items.""" + """Base ORM class for media (e.g. videos, images, etc.). + + This class is meant to be subclassed by other media types. + + Attributes: + id (MediaId): Unique identifier for the media. + media_type (str): The type of media (populated automatically by ORM based on `polymorphic_identity` of subclass). + """ __tablename__ = "media" id: Mapped[MediaId] = mapped_column( diff --git a/aana/storage/models/transcript.py b/aana/storage/models/transcript.py index 30570cee..d2e34c61 100644 --- a/aana/storage/models/transcript.py +++ b/aana/storage/models/transcript.py @@ -16,7 +16,17 @@ class TranscriptEntity(BaseEntity, TimeStampEntity): - """ORM class for media transcripts generated by a model.""" + """ORM class for media transcripts generated by a model. + + Attributes: + id (int): Unique identifier for the transcript. + model (str): Name of the model used to generate the transcript. + transcript (str): Full text transcript of the media. + segments (dict): Segments of the transcript. + language (str): Language of the transcript as predicted by the model. + language_confidence (float): Confidence score of language prediction. + transcript_type (str): The type of transcript (populated automatically by ORM based on `polymorphic_identity` of subclass). + """ __tablename__ = "transcript" @@ -50,7 +60,17 @@ def from_asr_output( transcription: AsrTranscription, segments: AsrSegments, ) -> TranscriptEntity: - """Converts an AsrTranscriptionInfo and AsrTranscription to a single Transcript entity.""" + """Converts an AsrTranscriptionInfo and AsrTranscription to a single Transcript entity. + + Args: + model_name (str): Name of the model used to generate the transcript. + info (AsrTranscriptionInfo): Information about the transcription. + transcription (AsrTranscription): The full transcription. + segments (AsrSegments): Segments of the transcription. + + Returns: + TranscriptEntity: A new instance of the TranscriptEntity class. + """ return TranscriptEntity( model=model_name, language=info.language, diff --git a/aana/storage/models/video.py b/aana/storage/models/video.py index 2f4eedfa..c06d1ba7 100644 --- a/aana/storage/models/video.py +++ b/aana/storage/models/video.py @@ -1,4 +1,3 @@ - from sqlalchemy import ForeignKey from sqlalchemy.orm import Mapped, mapped_column @@ -7,7 +6,15 @@ class VideoEntity(MediaEntity): - """Base ORM class for videos.""" + """Base ORM class for videos. + + Attributes: + id (MediaId): Unique identifier for the video. + path (str): Path to the video file. + url (str): URL to the video file. + title (str): Title of the video. + description (str): Description of the video. + """ __tablename__ = "video" diff --git a/aana/storage/repository/__init__.py b/aana/storage/repository/__init__.py index e69de29b..01579195 100644 --- a/aana/storage/repository/__init__.py +++ b/aana/storage/repository/__init__.py @@ -0,0 +1,13 @@ +from aana.storage.repository.base import BaseRepository +from aana.storage.repository.caption import CaptionRepository +from aana.storage.repository.media import MediaRepository +from aana.storage.repository.transcript import TranscriptRepository +from aana.storage.repository.video import VideoRepository + +__all__ = [ + "BaseRepository", + "MediaRepository", + "VideoRepository", + "CaptionRepository", + "TranscriptRepository", +] diff --git a/aana/storage/session.py b/aana/storage/session.py index 8a12e369..4b51b87c 100644 --- a/aana/storage/session.py +++ b/aana/storage/session.py @@ -2,7 +2,13 @@ from aana.storage.engine import engine +__all__ = ["get_session"] + def get_session() -> Session: - """Provides a SQLAlchemy Session object.""" + """Get a new SQLAlchemy Session object. + + Returns: + Session: SQLAlchemy Session object. + """ return Session(engine) diff --git a/aana/tests/conftest.py b/aana/tests/conftest.py index 5f4651bc..f16a573f 100644 --- a/aana/tests/conftest.py +++ b/aana/tests/conftest.py @@ -150,7 +150,7 @@ def db_session(): # Configure the database to use the temporary file settings.db_config.datastore_config = SQLiteConfig(path=tmp.name) # Reset the engine - settings.db_config.engine = None + settings.db_config._engine = None # Run migrations to set up the schema run_alembic_migrations(settings) diff --git a/docs/reference/storage/index.md b/docs/reference/storage/index.md new file mode 100644 index 00000000..8b6c43ff --- /dev/null +++ b/docs/reference/storage/index.md @@ -0,0 +1,84 @@ +# Storage + +Aana SDK provides an integration with an SQL database to store and retrieve data. + +Currently, Aana SDK supports SQLite (default) and PostgreSQL databases. See [Database Configuration](/reference/settings/#aana.configs.DbSettings) for more information. + +The database integration is based on the [SQLAlchemy](https://www.sqlalchemy.org/) library and consists of two main components: + +- [Models](/reference/storage/models/) - Database models (entities) that represent tables in the database. +- [Repositories](/reference/storage/repositories/) - Classes that provide an interface to interact with the database models. + +To use the database integration, you can either: + +- Use the provided models and repositories. +- Create your own models and repositories by extending the provided ones (for example, extending the VideoEntity model to add custom fields). +- Create your own models and repositories from scratch/base classes (for example, creating a new model for a new entity). + +## How to Use Provided Models and Repositories + +If you want to use the provided models and repositories, you can use the following steps: + +### Get session object + +You can use `get_session` method from the `aana.storage.session` module: + +```python +from aana.storage.session import get_session + +session = get_session() +``` + + + +If you are using Endpoint, you can use the `session` attribute that is available after the endpoint is initialized: + +```python +from aana.api import Endpoint + +class TranscribeVideoEndpoint(Endpoint): + async def initialize(self): + await super().initialize() + # self.session is available here after the endpoint is initialized + + async def run(self, video: VideoInput) -> WhisperOutput: + # self.session is available here as well +``` + + +### Create a repository object and use it to interact with the database. + +You can use the provided repositories from the `aana.storage.repository` module. See [Repositories](/reference/storage/repositories/) for the list of available repositories. + +For example, to work with the `VideoEntity` model, you can create a `VideoRepository` object: + +```python +from aana.storage.repository import VideoRepository + +video_repository = VideoRepository(session) +``` + +And then use the repository object to interact with the database (for example, save a video): + +```python +from aana.core.models import Video + +video = Video(title="My Video", url="https://example.com/video.mp4") +video_repository.save(video) +``` + +Or, if you are using Endpoint, you can create a repository object in the `initialize` method: + +```python +from aana.api import Endpoint + +class TranscribeVideoEndpoint(Endpoint): + async def initialize(self): + await super().initialize() + self.video_repository = VideoRepository(self.session) + + async def run(self, video: VideoInput) -> WhisperOutput: + video_obj: Video = await run_remote(download_video)(video_input=video) + self.video_repository.save(video_obj) + # ... +``` \ No newline at end of file diff --git a/docs/reference/storage/models.md b/docs/reference/storage/models.md new file mode 100644 index 00000000..92a32543 --- /dev/null +++ b/docs/reference/storage/models.md @@ -0,0 +1,3 @@ +# Models + +::: aana.storage.models \ No newline at end of file diff --git a/docs/reference/storage/repositories.md b/docs/reference/storage/repositories.md new file mode 100644 index 00000000..1bba69fd --- /dev/null +++ b/docs/reference/storage/repositories.md @@ -0,0 +1,3 @@ +# Repositories + +::: aana.storage.repository \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index 11c60ab1..7993c11e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -131,6 +131,10 @@ nav: - reference/models/vad.md - reference/models/video.md - reference/models/whisper.md + - Storage: + - reference/storage/index.md + - reference/storage/models.md + - reference/storage/repositories.md - reference/integrations.md - reference/processors.md - reference/exceptions.md