Skip to content

Commit

Permalink
91 make tad database connection async (#306)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt authored Oct 28, 2024
2 parents c885873 + 6920fcb commit e4c2387
Show file tree
Hide file tree
Showing 38 changed files with 723 additions and 430 deletions.
46 changes: 29 additions & 17 deletions amt/api/routes/project.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import functools
import logging
from collections.abc import Sequence
from pathlib import Path
from typing import Annotated, Any, cast

Expand All @@ -18,6 +20,7 @@
from amt.core.exceptions import AMTNotFound, AMTRepositoryError
from amt.enums.status import Status
from amt.models import Project
from amt.models.task import Task
from amt.schema.measure import ExtendedMeasureTask, MeasureTask
from amt.schema.requirement import RequirementTask
from amt.schema.system_card import SystemCard
Expand Down Expand Up @@ -65,10 +68,10 @@ def get_requirements_state(system_card: SystemCard) -> dict[str, Any]:
}


def get_project_or_error(project_id: int, projects_service: ProjectsService, request: Request) -> Project:
async def get_project_or_error(project_id: int, projects_service: ProjectsService, request: Request) -> Project:
try:
logger.debug(f"getting project with id {project_id}")
project = projects_service.get(project_id)
project = await projects_service.get(project_id)
request.state.path_variables = {"project_id": project_id}
except AMTRepositoryError as e:
raise AMTNotFound from e
Expand Down Expand Up @@ -98,17 +101,26 @@ def get_projects_submenu_items() -> list[BaseNavigationItem]:
]


async def gather_project_tasks(project_id: int, task_service: TasksService) -> dict[Status, Sequence[Task]]:
fetch_tasks = [task_service.get_tasks_for_project(project_id, status + 0) for status in Status]

results = await asyncio.gather(*fetch_tasks)

return dict(zip(Status, results, strict=True))


@router.get("/{project_id}/details/tasks")
async def get_tasks(
request: Request,
project_id: int,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
tasks_service: Annotated[TasksService, Depends(TasksService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
tab_items = get_project_details_tabs(request)
tasks_by_status = await gather_project_tasks(project_id, task_service=tasks_service)

breadcrumbs = resolve_base_navigation_items(
[
Expand All @@ -124,7 +136,7 @@ async def get_tasks(
context = {
"instrument_state": instrument_state,
"requirements_state": requirements_state,
"tasks_service": tasks_service,
"tasks_by_status": tasks_by_status,
"statuses": Status,
"project": project,
"project_id": project.id,
Expand Down Expand Up @@ -153,7 +165,7 @@ async def move_task(
moved_task.next_sibling_id = None
if moved_task.previous_sibling_id == -1:
moved_task.previous_sibling_id = None
task = tasks_service.move_task(
task = await tasks_service.move_task(
moved_task.id,
moved_task.status_id,
moved_task.previous_sibling_id,
Expand All @@ -168,7 +180,7 @@ async def move_task(
async def get_project_context(
project_id: int, projects_service: ProjectsService, request: Request
) -> tuple[Project, dict[str, Any]]:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)
system_card_data = get_system_card_data()
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
Expand Down Expand Up @@ -268,7 +280,7 @@ async def get_project_update(
) -> HTMLResponse:
project, context = await get_project_context(project_id, projects_service, request)
set_path(project, path, update_data.value)
projects_service.update(project)
await projects_service.update(project)
context["path"] = path.replace("/", ".")
return templates.TemplateResponse(request, "parts/view_cell.html.j2", context)

Expand All @@ -284,7 +296,7 @@ async def get_system_card(
project_id: int,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)

Expand Down Expand Up @@ -326,7 +338,7 @@ async def get_system_card(
async def get_project_inference(
request: Request, project_id: int, projects_service: Annotated[ProjectsService, Depends(ProjectsService)]
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)

breadcrumbs = resolve_base_navigation_items(
[
Expand Down Expand Up @@ -370,7 +382,7 @@ async def get_system_card_requirements(
project_id: int,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
# TODO: This tab is fairly slow, fix in later releases
Expand Down Expand Up @@ -464,7 +476,7 @@ async def get_measure(
measure_urn: str,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)
measure = task_registry.fetch_measures([measure_urn])
measure_task = find_measure_task(project.system_card, measure_urn)

Expand All @@ -491,7 +503,7 @@ async def update_measure_value(
measure_update: MeasureUpdate,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)

measure_task = find_measure_task(project.system_card, measure_urn)
measure_task.state = measure_update.measure_state # pyright: ignore [reportOptionalMemberAccess]
Expand All @@ -517,7 +529,7 @@ async def update_measure_value(
else:
requirement_task.state = "in progress" # pyright: ignore [reportOptionalMemberAccess]

projects_service.update(project)
await projects_service.update(project)
# TODO: FIX THIS!! The page now reloads at the top, which is annoying
return templates.Redirect(request, f"/algorithm-system/{project_id}/details/system_card/requirements")

Expand All @@ -533,7 +545,7 @@ async def get_system_card_data_page(
project_id: int,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)

Expand Down Expand Up @@ -573,7 +585,7 @@ async def get_system_card_instruments(
project_id: int,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)

Expand Down Expand Up @@ -614,7 +626,7 @@ async def get_assessment_card(
assessment_card: str,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)

Expand Down Expand Up @@ -665,7 +677,7 @@ async def get_model_card(
model_card: str,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)

Expand Down
4 changes: 2 additions & 2 deletions amt/api/routes/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def get_root(
k.removeprefix("sort-by-"): v for k, v in request.query_params.items() if k.startswith("sort-by-") and v != ""
}

projects = projects_service.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort_by)
projects = await projects_service.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort_by)
# todo: the lifecycle has to be 'localized', maybe for display 'Project' should become a different object
for project in projects:
project.lifecycle = get_localized_lifecycle(project.lifecycle, request) # pyright: ignore [reportAttributeAccessIssue]
Expand Down Expand Up @@ -131,6 +131,6 @@ async def post_new(
project_new.systemic_risk = "geen systeemrisico"
project_new.open_source = "open-source"

project = projects_service.create(project_new)
project = await projects_service.create(project_new)
response = templates.Redirect(request, f"/algorithm-system/{project.id}/details/tasks")
return response
2 changes: 1 addition & 1 deletion amt/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class Settings(BaseSettings):

# todo(berry): create submodel for database settings
APP_DATABASE_SCHEME: DatabaseSchemaType = "sqlite"
APP_DATABASE_DRIVER: str | None = None
APP_DATABASE_DRIVER: str | None = "aiosqlite"

APP_DATABASE_SERVER: str = "db"
APP_DATABASE_PORT: int = 5432
Expand Down
21 changes: 11 additions & 10 deletions amt/core/db.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,40 @@
import logging

from sqlalchemy import create_engine, select
from sqlalchemy.engine import Engine
from sqlalchemy.orm import Session
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine

from amt.core.config import get_settings
from amt.models.base import Base

logger = logging.getLogger(__name__)


def get_engine() -> Engine:
def get_engine() -> AsyncEngine:
settings = get_settings()
connect_args = {"check_same_thread": False} if settings.APP_DATABASE_SCHEME == "sqlite" else {}

return create_engine(
return create_async_engine(
settings.SQLALCHEMY_DATABASE_URI, # pyright: ignore [reportArgumentType]
connect_args=connect_args,
echo=settings.SQLALCHEMY_ECHO,
)


def check_db() -> None:
async def check_db() -> None:
logger.info("Checking database connection")
with Session(get_engine()) as session:
session.execute(select(1))
async with AsyncSession(get_engine()) as session:
await session.execute(select(1))

logger.info("Finish Checking database connection")


def init_db() -> None:
async def init_db() -> None:
logger.info("Initializing database")

if get_settings().AUTO_CREATE_SCHEMA: # pragma: no cover
logger.info("Creating database schema")
Base.metadata.create_all(get_engine())
engine = get_engine()
async with engine.begin() as connection:
await connection.run_sync(Base.metadata.create_all)

logger.info("Finished initializing database")
31 changes: 18 additions & 13 deletions amt/migrations/env.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import os
from logging.config import fileConfig

from alembic import context
from amt.models import * # noqa
from sqlalchemy import engine_from_config, pool
from sqlalchemy import Connection, pool
from sqlalchemy.ext.asyncio import async_engine_from_config
from sqlalchemy.schema import MetaData

from amt.models.base import Base
Expand All @@ -18,17 +20,18 @@

def get_url() -> str:
scheme = os.getenv("APP_DATABASE_SCHEME", "sqlite")
driver = os.getenv("APP_DATABASE_DRIVER", "aiosqlite")

if scheme == "sqlite":
file = os.getenv("APP_DATABASE_FILE", "database.sqlite3")
return f"{scheme}:///{file}"
return f"{scheme}+{driver}:///{file}"

user = os.getenv("APP_DATABASE_USER", "amt")
password = os.getenv("APP_DATABASE_PASSWORD", "")
server = os.getenv("APP_DATABASE_SERVER", "db")
port = os.getenv("APP_DATABASE_PORT", "5432")
db = os.getenv("APP_DATABASE_DB", "amt")
return f"{scheme}://{user}:{password}@{server}:{port}/{db}"
return f"{scheme}+{driver}://{user}:{password}@{server}:{port}/{db}"


def run_migrations_offline() -> None:
Expand All @@ -52,7 +55,14 @@ def run_migrations_offline() -> None:
context.run_migrations()


def run_migrations_online() -> None:
def do_run_migrations(connection: Connection) -> None:
context.configure(connection=connection, target_metadata=target_metadata)

with context.begin_transaction():
context.run_migrations()


async def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
Expand All @@ -63,18 +73,13 @@ def run_migrations_online() -> None:
if configuration is None:
raise Exception("Failed to get configuration section") # noqa: TRY002
configuration["sqlalchemy.url"] = get_url()
connectable = engine_from_config(configuration, prefix="sqlalchemy.", poolclass=pool.NullPool)

with connectable.connect() as connection:
context.configure(
connection=connection, target_metadata=target_metadata, compare_type=True, render_as_batch=True
)
connectable = async_engine_from_config(configuration, prefix="sqlalchemy.", poolclass=pool.NullPool)

with context.begin_transaction():
context.run_migrations()
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)


if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()
asyncio.run(run_migrations_online())
2 changes: 1 addition & 1 deletion amt/models/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class Project(Base):

id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(255))
lifecycle: Mapped[Lifecycles | None] = mapped_column(ENUM(Lifecycles), nullable=True)
lifecycle: Mapped[Lifecycles | None] = mapped_column(ENUM(Lifecycles, name="lifecycle"), nullable=True)
system_card_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
last_edited: Mapped[datetime] = mapped_column(server_default=func.now(), onupdate=func.now(), nullable=False)

Expand Down
15 changes: 10 additions & 5 deletions amt/repositories/deps.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from collections.abc import Generator
from collections.abc import AsyncGenerator

from sqlalchemy.orm import Session
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker

from amt.core.db import get_engine


def get_session() -> Generator[Session, None, None]:
with Session(get_engine()) as session:
yield session
async def get_session() -> AsyncGenerator[AsyncSession, None]:
async_session_factory = async_sessionmaker(
get_engine(),
expire_on_commit=False,
class_=AsyncSession,
)
async with async_session_factory() as async_session:
yield async_session
Loading

0 comments on commit e4c2387

Please sign in to comment.