diff --git a/amt/api/routes/project.py b/amt/api/routes/project.py
index 9e241c25..58a5d5b4 100644
--- a/amt/api/routes/project.py
+++ b/amt/api/routes/project.py
@@ -1,5 +1,7 @@
+import asyncio
import functools
import logging
+from collections.abc import Sequence
from pathlib import Path
from typing import Annotated, Any, cast
@@ -18,6 +20,7 @@
from amt.core.exceptions import AMTNotFound, AMTRepositoryError
from amt.enums.status import Status
from amt.models import Project
+from amt.models.task import Task
from amt.schema.measure import ExtendedMeasureTask, MeasureTask
from amt.schema.requirement import RequirementTask
from amt.schema.system_card import SystemCard
@@ -65,10 +68,10 @@ def get_requirements_state(system_card: SystemCard) -> dict[str, Any]:
}
-def get_project_or_error(project_id: int, projects_service: ProjectsService, request: Request) -> Project:
+async def get_project_or_error(project_id: int, projects_service: ProjectsService, request: Request) -> Project:
try:
logger.debug(f"getting project with id {project_id}")
- project = projects_service.get(project_id)
+ project = await projects_service.get(project_id)
request.state.path_variables = {"project_id": project_id}
except AMTRepositoryError as e:
raise AMTNotFound from e
@@ -98,6 +101,14 @@ def get_projects_submenu_items() -> list[BaseNavigationItem]:
]
+async def gather_project_tasks(project_id: int, task_service: TasksService) -> dict[Status, Sequence[Task]]:
+ fetch_tasks = [task_service.get_tasks_for_project(project_id, status + 0) for status in Status]
+
+ results = await asyncio.gather(*fetch_tasks)
+
+ return dict(zip(Status, results, strict=True))
+
+
@router.get("/{project_id}/details/tasks")
async def get_tasks(
request: Request,
@@ -105,10 +116,11 @@ async def get_tasks(
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
tasks_service: Annotated[TasksService, Depends(TasksService)],
) -> HTMLResponse:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
tab_items = get_project_details_tabs(request)
+ tasks_by_status = await gather_project_tasks(project_id, task_service=tasks_service)
breadcrumbs = resolve_base_navigation_items(
[
@@ -124,7 +136,7 @@ async def get_tasks(
context = {
"instrument_state": instrument_state,
"requirements_state": requirements_state,
- "tasks_service": tasks_service,
+ "tasks_by_status": tasks_by_status,
"statuses": Status,
"project": project,
"project_id": project.id,
@@ -153,7 +165,7 @@ async def move_task(
moved_task.next_sibling_id = None
if moved_task.previous_sibling_id == -1:
moved_task.previous_sibling_id = None
- task = tasks_service.move_task(
+ task = await tasks_service.move_task(
moved_task.id,
moved_task.status_id,
moved_task.previous_sibling_id,
@@ -168,7 +180,7 @@ async def move_task(
async def get_project_context(
project_id: int, projects_service: ProjectsService, request: Request
) -> tuple[Project, dict[str, Any]]:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
system_card_data = get_system_card_data()
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
@@ -268,7 +280,7 @@ async def get_project_update(
) -> HTMLResponse:
project, context = await get_project_context(project_id, projects_service, request)
set_path(project, path, update_data.value)
- projects_service.update(project)
+ await projects_service.update(project)
context["path"] = path.replace("/", ".")
return templates.TemplateResponse(request, "parts/view_cell.html.j2", context)
@@ -284,7 +296,7 @@ async def get_system_card(
project_id: int,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
@@ -326,7 +338,7 @@ async def get_system_card(
async def get_project_inference(
request: Request, project_id: int, projects_service: Annotated[ProjectsService, Depends(ProjectsService)]
) -> HTMLResponse:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
breadcrumbs = resolve_base_navigation_items(
[
@@ -370,7 +382,7 @@ async def get_system_card_requirements(
project_id: int,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
# TODO: This tab is fairly slow, fix in later releases
@@ -464,7 +476,7 @@ async def get_measure(
measure_urn: str,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
measure = task_registry.fetch_measures([measure_urn])
measure_task = find_measure_task(project.system_card, measure_urn)
@@ -491,7 +503,7 @@ async def update_measure_value(
measure_update: MeasureUpdate,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
measure_task = find_measure_task(project.system_card, measure_urn)
measure_task.state = measure_update.measure_state # pyright: ignore [reportOptionalMemberAccess]
@@ -517,7 +529,7 @@ async def update_measure_value(
else:
requirement_task.state = "in progress" # pyright: ignore [reportOptionalMemberAccess]
- projects_service.update(project)
+ await projects_service.update(project)
# TODO: FIX THIS!! The page now reloads at the top, which is annoying
return templates.Redirect(request, f"/algorithm-system/{project_id}/details/system_card/requirements")
@@ -533,7 +545,7 @@ async def get_system_card_data_page(
project_id: int,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
@@ -573,7 +585,7 @@ async def get_system_card_instruments(
project_id: int,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
@@ -614,7 +626,7 @@ async def get_assessment_card(
assessment_card: str,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
@@ -665,7 +677,7 @@ async def get_model_card(
model_card: str,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
- project = get_project_or_error(project_id, projects_service, request)
+ project = await get_project_or_error(project_id, projects_service, request)
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)
diff --git a/amt/api/routes/projects.py b/amt/api/routes/projects.py
index f0b0dfc4..a7028b1a 100644
--- a/amt/api/routes/projects.py
+++ b/amt/api/routes/projects.py
@@ -62,7 +62,7 @@ async def get_root(
k.removeprefix("sort-by-"): v for k, v in request.query_params.items() if k.startswith("sort-by-") and v != ""
}
- projects = projects_service.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort_by)
+ projects = await projects_service.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort_by)
# todo: the lifecycle has to be 'localized', maybe for display 'Project' should become a different object
for project in projects:
project.lifecycle = get_localized_lifecycle(project.lifecycle, request) # pyright: ignore [reportAttributeAccessIssue]
@@ -131,6 +131,6 @@ async def post_new(
project_new.systemic_risk = "geen systeemrisico"
project_new.open_source = "open-source"
- project = projects_service.create(project_new)
+ project = await projects_service.create(project_new)
response = templates.Redirect(request, f"/algorithm-system/{project.id}/details/tasks")
return response
diff --git a/amt/core/config.py b/amt/core/config.py
index 014dba9b..6a8413bc 100644
--- a/amt/core/config.py
+++ b/amt/core/config.py
@@ -44,7 +44,7 @@ class Settings(BaseSettings):
# todo(berry): create submodel for database settings
APP_DATABASE_SCHEME: DatabaseSchemaType = "sqlite"
- APP_DATABASE_DRIVER: str | None = None
+ APP_DATABASE_DRIVER: str | None = "aiosqlite"
APP_DATABASE_SERVER: str = "db"
APP_DATABASE_PORT: int = 5432
diff --git a/amt/core/db.py b/amt/core/db.py
index 5bb05d08..26933508 100644
--- a/amt/core/db.py
+++ b/amt/core/db.py
@@ -1,8 +1,7 @@
import logging
-from sqlalchemy import create_engine, select
-from sqlalchemy.engine import Engine
-from sqlalchemy.orm import Session
+from sqlalchemy import select
+from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine
from amt.core.config import get_settings
from amt.models.base import Base
@@ -10,30 +9,32 @@
logger = logging.getLogger(__name__)
-def get_engine() -> Engine:
+def get_engine() -> AsyncEngine:
settings = get_settings()
connect_args = {"check_same_thread": False} if settings.APP_DATABASE_SCHEME == "sqlite" else {}
- return create_engine(
+ return create_async_engine(
settings.SQLALCHEMY_DATABASE_URI, # pyright: ignore [reportArgumentType]
connect_args=connect_args,
echo=settings.SQLALCHEMY_ECHO,
)
-def check_db() -> None:
+async def check_db() -> None:
logger.info("Checking database connection")
- with Session(get_engine()) as session:
- session.execute(select(1))
+ async with AsyncSession(get_engine()) as session:
+ await session.execute(select(1))
logger.info("Finish Checking database connection")
-def init_db() -> None:
+async def init_db() -> None:
logger.info("Initializing database")
if get_settings().AUTO_CREATE_SCHEMA: # pragma: no cover
logger.info("Creating database schema")
- Base.metadata.create_all(get_engine())
+ engine = get_engine()
+ async with engine.begin() as connection:
+ await connection.run_sync(Base.metadata.create_all)
logger.info("Finished initializing database")
diff --git a/amt/migrations/env.py b/amt/migrations/env.py
index 5058d8f2..9f4b92b0 100644
--- a/amt/migrations/env.py
+++ b/amt/migrations/env.py
@@ -1,9 +1,11 @@
+import asyncio
import os
from logging.config import fileConfig
from alembic import context
from amt.models import * # noqa
-from sqlalchemy import engine_from_config, pool
+from sqlalchemy import Connection, pool
+from sqlalchemy.ext.asyncio import async_engine_from_config
from sqlalchemy.schema import MetaData
from amt.models.base import Base
@@ -18,17 +20,18 @@
def get_url() -> str:
scheme = os.getenv("APP_DATABASE_SCHEME", "sqlite")
+ driver = os.getenv("APP_DATABASE_DRIVER", "aiosqlite")
if scheme == "sqlite":
file = os.getenv("APP_DATABASE_FILE", "database.sqlite3")
- return f"{scheme}:///{file}"
+ return f"{scheme}+{driver}:///{file}"
user = os.getenv("APP_DATABASE_USER", "amt")
password = os.getenv("APP_DATABASE_PASSWORD", "")
server = os.getenv("APP_DATABASE_SERVER", "db")
port = os.getenv("APP_DATABASE_PORT", "5432")
db = os.getenv("APP_DATABASE_DB", "amt")
- return f"{scheme}://{user}:{password}@{server}:{port}/{db}"
+ return f"{scheme}+{driver}://{user}:{password}@{server}:{port}/{db}"
def run_migrations_offline() -> None:
@@ -52,7 +55,14 @@ def run_migrations_offline() -> None:
context.run_migrations()
-def run_migrations_online() -> None:
+def do_run_migrations(connection: Connection) -> None:
+ context.configure(connection=connection, target_metadata=target_metadata)
+
+ with context.begin_transaction():
+ context.run_migrations()
+
+
+async def run_migrations_online() -> None:
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
@@ -63,18 +73,13 @@ def run_migrations_online() -> None:
if configuration is None:
raise Exception("Failed to get configuration section") # noqa: TRY002
configuration["sqlalchemy.url"] = get_url()
- connectable = engine_from_config(configuration, prefix="sqlalchemy.", poolclass=pool.NullPool)
-
- with connectable.connect() as connection:
- context.configure(
- connection=connection, target_metadata=target_metadata, compare_type=True, render_as_batch=True
- )
+ connectable = async_engine_from_config(configuration, prefix="sqlalchemy.", poolclass=pool.NullPool)
- with context.begin_transaction():
- context.run_migrations()
+ async with connectable.connect() as connection:
+ await connection.run_sync(do_run_migrations)
if context.is_offline_mode():
run_migrations_offline()
else:
- run_migrations_online()
+ asyncio.run(run_migrations_online())
diff --git a/amt/models/project.py b/amt/models/project.py
index 9fe88845..b56f0b4d 100644
--- a/amt/models/project.py
+++ b/amt/models/project.py
@@ -56,7 +56,7 @@ class Project(Base):
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(255))
- lifecycle: Mapped[Lifecycles | None] = mapped_column(ENUM(Lifecycles), nullable=True)
+ lifecycle: Mapped[Lifecycles | None] = mapped_column(ENUM(Lifecycles, name="lifecycle"), nullable=True)
system_card_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
last_edited: Mapped[datetime] = mapped_column(server_default=func.now(), onupdate=func.now(), nullable=False)
diff --git a/amt/repositories/deps.py b/amt/repositories/deps.py
index bf68aeac..f7bf0db2 100644
--- a/amt/repositories/deps.py
+++ b/amt/repositories/deps.py
@@ -1,10 +1,15 @@
-from collections.abc import Generator
+from collections.abc import AsyncGenerator
-from sqlalchemy.orm import Session
+from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from amt.core.db import get_engine
-def get_session() -> Generator[Session, None, None]:
- with Session(get_engine()) as session:
- yield session
+async def get_session() -> AsyncGenerator[AsyncSession, None]:
+ async_session_factory = async_sessionmaker(
+ get_engine(),
+ expire_on_commit=False,
+ class_=AsyncSession,
+ )
+ async with async_session_factory() as async_session:
+ yield async_session
diff --git a/amt/repositories/projects.py b/amt/repositories/projects.py
index 0df74d43..cb433ed6 100644
--- a/amt/repositories/projects.py
+++ b/amt/repositories/projects.py
@@ -5,7 +5,7 @@
from fastapi import Depends
from sqlalchemy import func, select
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
-from sqlalchemy.orm import Session
+from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy_utils import escape_like # pyright: ignore[reportMissingTypeStubs, reportUnknownVariableType]
from amt.api.publication_category import PublicationCategories
@@ -31,47 +31,49 @@ def sort_by_lifecycle_reversed(project: Project) -> int:
class ProjectsRepository:
- def __init__(self, session: Annotated[Session, Depends(get_session)]) -> None:
+ def __init__(self, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
self.session = session
- def find_all(self) -> Sequence[Project]:
- return self.session.execute(select(Project)).scalars().all()
+ async def find_all(self) -> Sequence[Project]:
+ result = await self.session.execute(select(Project))
+ return result.scalars().all()
- def delete(self, project: Project) -> None:
+ async def delete(self, project: Project) -> None:
"""
Deletes the given status in the repository.
:param status: the status to store
:return: the updated status after storing
"""
try:
- self.session.delete(project)
- self.session.commit()
+ await self.session.delete(project)
+ await self.session.commit()
except Exception as e:
logger.exception("Error deleting project")
- self.session.rollback()
+ await self.session.rollback()
raise AMTRepositoryError from e
return None
- def save(self, project: Project) -> Project:
+ async def save(self, project: Project) -> Project:
try:
self.session.add(project)
- self.session.commit()
- self.session.refresh(project)
+ await self.session.commit()
+ await self.session.refresh(project)
except SQLAlchemyError as e:
logger.exception("Error saving project")
- self.session.rollback()
+ await self.session.rollback()
raise AMTRepositoryError from e
return project
- def find_by_id(self, project_id: int) -> Project:
+ async def find_by_id(self, project_id: int) -> Project:
try:
statement = select(Project).where(Project.id == project_id)
- return self.session.execute(statement).scalars().one()
+ result = await self.session.execute(statement)
+ return result.scalars().one()
except NoResultFound as e:
logger.exception("Project not found")
raise AMTRepositoryError from e
- def paginate( # noqa
+ async def paginate( # noqa
self, skip: int, limit: int, search: str, filters: dict[str, str], sort: dict[str, str]
) -> list[Project]:
try:
@@ -102,7 +104,8 @@ def paginate( # noqa
else:
statement = statement.order_by(func.lower(Project.name))
statement = statement.offset(skip).limit(limit)
- result = list(self.session.execute(statement).scalars())
+ db_result = await self.session.execute(statement)
+ result = list(db_result.scalars())
# todo: the good way to do sorting is to use an enum field (or any int field)
# in the database so we can sort on that
if result and sort and "lifecycle" in sort:
diff --git a/amt/repositories/tasks.py b/amt/repositories/tasks.py
index 735094fb..1e3a649b 100644
--- a/amt/repositories/tasks.py
+++ b/amt/repositories/tasks.py
@@ -5,7 +5,7 @@
from fastapi import Depends
from sqlalchemy import and_, select
from sqlalchemy.exc import NoResultFound
-from sqlalchemy.orm import Session
+from sqlalchemy.ext.asyncio import AsyncSession
from amt.core.exceptions import AMTRepositoryError
from amt.models import Task
@@ -19,26 +19,26 @@ class TasksRepository:
The TasksRepository provides access to the repository layer.
"""
- def __init__(self, session: Annotated[Session, Depends(get_session)]) -> None:
+ def __init__(self, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
self.session = session
- def find_all(self) -> Sequence[Task]:
+ async def find_all(self) -> Sequence[Task]:
"""
Returns all tasks in the repository.
:return: all tasks in the repository
"""
- return self.session.execute(select(Task)).scalars().all()
+ return (await self.session.execute(select(Task))).scalars().all()
- def find_by_status_id(self, status_id: int) -> Sequence[Task]:
+ async def find_by_status_id(self, status_id: int) -> Sequence[Task]:
"""
Returns all tasks in the repository for the given status_id.
:param status_id: the status_id to filter on
:return: a list of tasks in the repository for the given status_id
"""
statement = select(Task).where(Task.status_id == status_id).order_by(Task.sort_order)
- return self.session.execute(statement).scalars().all()
+ return (await self.session.execute(statement)).scalars().all()
- def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> Sequence[Task]:
+ async def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> Sequence[Task]:
"""
Returns all tasks in the repository for the given project_id.
:param project_id: the project_id to filter on
@@ -49,9 +49,9 @@ def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> S
.where(and_(Task.status_id == status_id, Task.project_id == project_id))
.order_by(Task.sort_order)
)
- return self.session.execute(statement).scalars().all()
+ return (await self.session.execute(statement)).scalars().all()
- def save(self, task: Task) -> Task:
+ async def save(self, task: Task) -> Task:
"""
Stores the given task in the repository or throws a RepositoryException
:param task: the task to store
@@ -59,15 +59,15 @@ def save(self, task: Task) -> Task:
"""
try:
self.session.add(task)
- self.session.commit()
- self.session.refresh(task)
+ await self.session.commit()
+ await self.session.refresh(task)
except Exception as e:
logger.exception("Could not store task")
- self.session.rollback()
+ await self.session.rollback()
raise AMTRepositoryError from e
return task
- def save_all(self, tasks: Sequence[Task]) -> None:
+ async def save_all(self, tasks: Sequence[Task]) -> None:
"""
Stores the given tasks in the repository or throws a RepositoryException
:param tasks: the tasks to store
@@ -75,28 +75,28 @@ def save_all(self, tasks: Sequence[Task]) -> None:
"""
try:
self.session.add_all(tasks)
- self.session.commit()
+ await self.session.commit()
except Exception as e:
logger.exception("Could not store all tasks")
- self.session.rollback()
+ await self.session.rollback()
raise AMTRepositoryError from e
- def delete(self, task: Task) -> None:
+ async def delete(self, task: Task) -> None:
"""
Deletes the given task in the repository or throws a RepositoryException
:param task: the task to store
:return: the updated task after storing
"""
try:
- self.session.delete(task)
- self.session.commit()
+ await self.session.delete(task)
+ await self.session.commit()
except Exception as e:
logger.exception("Could not delete task")
- self.session.rollback()
+ await self.session.rollback()
raise AMTRepositoryError from e
return None
- def find_by_id(self, task_id: int) -> Task:
+ async def find_by_id(self, task_id: int) -> Task:
"""
Returns the task with the given id.
:param task_id: the id of the task to find
@@ -104,7 +104,7 @@ def find_by_id(self, task_id: int) -> Task:
"""
statement = select(Task).where(Task.id == task_id)
try:
- return self.session.execute(statement).scalars().one()
+ return (await self.session.execute(statement)).scalars().one()
except NoResultFound as e:
logger.exception("Task not found")
raise AMTRepositoryError from e
diff --git a/amt/server.py b/amt/server.py
index a758dc9e..78d6e480 100644
--- a/amt/server.py
+++ b/amt/server.py
@@ -32,8 +32,8 @@
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
mask = Mask(mask_keywords=["database_uri"])
- check_db()
- init_db()
+ await check_db()
+ await init_db()
logger.info(f"Starting {PROJECT_NAME} version {VERSION}")
logger.info(f"Settings: {mask.secrets(get_settings().model_dump())}")
yield
diff --git a/amt/services/projects.py b/amt/services/projects.py
index a8ec837f..42b1be95 100644
--- a/amt/services/projects.py
+++ b/amt/services/projects.py
@@ -26,10 +26,11 @@ def __init__(
self.instrument_service = instrument_service
self.task_service = task_service
- def get(self, project_id: int) -> Project:
- return self.repository.find_by_id(project_id)
+ async def get(self, project_id: int) -> Project:
+ project = await self.repository.find_by_id(project_id)
+ return project
- def create(self, project_new: ProjectNew) -> Project:
+ async def create(self, project_new: ProjectNew) -> Project:
instruments: list[InstrumentBase] = [
InstrumentBase(urn=instrument_urn) for instrument_urn in project_new.instruments
]
@@ -54,20 +55,22 @@ def create(self, project_new: ProjectNew) -> Project:
)
project = Project(name=project_new.name, lifecycle=project_new.lifecycle, system_card=system_card)
- project = self.update(project)
+ project = await self.update(project)
selected_instruments = self.instrument_service.fetch_instruments(project_new.instruments) # type: ignore
for instrument in selected_instruments:
- self.task_service.create_instrument_tasks(instrument.tasks, project)
+ await self.task_service.create_instrument_tasks(instrument.tasks, project)
return project
- def paginate(
+ async def paginate(
self, skip: int, limit: int, search: str, filters: dict[str, str], sort: dict[str, str]
) -> list[Project]:
- return self.repository.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort)
+ projects = await self.repository.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort)
+ return projects
- def update(self, project: Project) -> Project:
+ async def update(self, project: Project) -> Project:
# TODO: Is this the right place to sync system cards: system_card and system_card_json?
project.sync_system_card()
- return self.repository.save(project)
+ project = await self.repository.save(project)
+ return project
diff --git a/amt/services/tasks.py b/amt/services/tasks.py
index aafae4d7..1b369026 100644
--- a/amt/services/tasks.py
+++ b/amt/services/tasks.py
@@ -25,21 +25,24 @@ def __init__(
self.storage_writer = StorageFactory.init(storage_type="file", location="./output", filename="system_card.yaml")
self.system_card = SystemCard()
- def get_tasks(self, status_id: int) -> Sequence[Task]:
- return self.repository.find_by_status_id(status_id)
+ async def get_tasks(self, status_id: int) -> Sequence[Task]:
+ task = await self.repository.find_by_status_id(status_id)
+ return task
- def get_tasks_for_project(self, project_id: int, status_id: int) -> Sequence[Task]:
- return self.repository.find_by_project_id_and_status_id(project_id, status_id)
+ async def get_tasks_for_project(self, project_id: int, status_id: int) -> Sequence[Task]:
+ tasks = await self.repository.find_by_project_id_and_status_id(project_id, status_id)
+ return tasks
- def assign_task(self, task: Task, user: User) -> Task:
+ async def assign_task(self, task: Task, user: User) -> Task:
task.user_id = user.id
- return self.repository.save(task)
+ task = await self.repository.save(task)
+ return task
- def create_instrument_tasks(self, tasks: Sequence[InstrumentTask], project: Project) -> None:
+ async def create_instrument_tasks(self, tasks: Sequence[InstrumentTask], project: Project) -> None:
# TODO: (Christopher) At this moment a status has to be retrieved from the DB. In the future
# we will have static statuses, so this will need to change.
status = Status.TODO
- self.repository.save_all(
+ await self.repository.save_all(
[
# TODO: (Christopher) The ticket does not specify what to do when question type is not an
# open questions, hence for now all titles will be set to task.question.
@@ -50,7 +53,7 @@ def create_instrument_tasks(self, tasks: Sequence[InstrumentTask], project: Proj
]
)
- def move_task(
+ async def move_task(
self, task_id: int, status_id: int, previous_sibling_id: int | None = None, next_sibling_id: int | None = None
) -> Task:
"""
@@ -61,7 +64,7 @@ def move_task(
:param next_sibling_id: the id of the next sibling of the task or None
:return: the updated task
"""
- task = self.repository.find_by_id(task_id)
+ task = await self.repository.find_by_id(task_id)
if status_id == Status.DONE:
# TODO: This seems off, tasks should be written to the correct location in the system card.
@@ -73,17 +76,18 @@ def move_task(
# update order position of the card
if previous_sibling_id and next_sibling_id:
- previous_task = self.repository.find_by_id(previous_sibling_id)
- next_task = self.repository.find_by_id(next_sibling_id)
+ previous_task = await self.repository.find_by_id(previous_sibling_id)
+ next_task = await self.repository.find_by_id(next_sibling_id)
new_sort_order = previous_task.sort_order + ((next_task.sort_order - previous_task.sort_order) / 2)
task.sort_order = new_sort_order
elif previous_sibling_id and not next_sibling_id:
- previous_task = self.repository.find_by_id(previous_sibling_id)
+ previous_task = await self.repository.find_by_id(previous_sibling_id)
task.sort_order = previous_task.sort_order + 10
elif not previous_sibling_id and next_sibling_id:
- next_task = self.repository.find_by_id(next_sibling_id)
+ next_task = await self.repository.find_by_id(next_sibling_id)
task.sort_order = next_task.sort_order / 2
else:
task.sort_order = 10
- return self.repository.save(task)
+ task = await self.repository.save(task)
+ return task
diff --git a/amt/site/templates/macros/tasks.html.j2 b/amt/site/templates/macros/tasks.html.j2
index 327d34f9..077c56cd 100644
--- a/amt/site/templates/macros/tasks.html.j2
+++ b/amt/site/templates/macros/tasks.html.j2
@@ -39,9 +39,9 @@
data-id="{{ status.value }}"
id="column-{{ status.value }}">
{% if project is defined %}
- {% for task in tasks_service.get_tasks_for_project(project.id, status) %}{{ render_task_card_full(task) }}{% endfor %}
+ {% for task in tasks_by_status[status] %}{{ render_task_card_full(task) }}{% endfor %}
{% else %}
- {% for task in tasks_service.get_tasks(status) %}{{ render_task_card_full(task) }}{% endfor %}
+ {% for task in tasks_by_status[status] %}{{ render_task_card_full(task) }}{% endfor %}
{% endif %}
diff --git a/amt/site/templates/projects/tasks.html.j2 b/amt/site/templates/projects/tasks.html.j2
index e738f52a..12f95f0d 100644
--- a/amt/site/templates/projects/tasks.html.j2
+++ b/amt/site/templates/projects/tasks.html.j2
@@ -21,7 +21,7 @@
- {% for status in statuses %}{{ render.column(project, status, translations, tasks_service) }}{% endfor %}
+ {% for status in statuses %}{{ render.column(project, status, translations, tasks_by_status) }}{% endfor %}
diff --git a/compose.override.yml b/compose.override.yml
new file mode 100644
index 00000000..e03e6d97
--- /dev/null
+++ b/compose.override.yml
@@ -0,0 +1,10 @@
+services:
+ amt-test:
+ build:
+ context: .
+ dockerfile: Dockerfile
+ target: test
+ image: ghcr.io/minbzk/amt-test:latest
+ env_file:
+ - path: prod.env
+ required: true
diff --git a/poetry.lock b/poetry.lock
index b05f5110..6b4f4c0e 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,22 @@
-# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand.
+
+[[package]]
+name = "aiosqlite"
+version = "0.20.0"
+description = "asyncio bridge to the standard sqlite3 module"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "aiosqlite-0.20.0-py3-none-any.whl", hash = "sha256:36a1deaca0cac40ebe32aac9977a6e2bbc7f5189f23f4a54d5908986729e5bd6"},
+ {file = "aiosqlite-0.20.0.tar.gz", hash = "sha256:6d35c8c256637f4672f843c31021464090805bf925385ac39473fb16eaaca3d7"},
+]
+
+[package.dependencies]
+typing_extensions = ">=4.0"
+
+[package.extras]
+dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"]
+docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"]
[[package]]
name = "alembic"
@@ -50,6 +68,69 @@ doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)",
test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"]
trio = ["trio (>=0.26.1)"]
+[[package]]
+name = "asyncpg"
+version = "0.30.0"
+description = "An asyncio PostgreSQL driver"
+optional = false
+python-versions = ">=3.8.0"
+files = [
+ {file = "asyncpg-0.30.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfb4dd5ae0699bad2b233672c8fc5ccbd9ad24b89afded02341786887e37927e"},
+ {file = "asyncpg-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dc1f62c792752a49f88b7e6f774c26077091b44caceb1983509edc18a2222ec0"},
+ {file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3152fef2e265c9c24eec4ee3d22b4f4d2703d30614b0b6753e9ed4115c8a146f"},
+ {file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7255812ac85099a0e1ffb81b10dc477b9973345793776b128a23e60148dd1af"},
+ {file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:578445f09f45d1ad7abddbff2a3c7f7c291738fdae0abffbeb737d3fc3ab8b75"},
+ {file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c42f6bb65a277ce4d93f3fba46b91a265631c8df7250592dd4f11f8b0152150f"},
+ {file = "asyncpg-0.30.0-cp310-cp310-win32.whl", hash = "sha256:aa403147d3e07a267ada2ae34dfc9324e67ccc4cdca35261c8c22792ba2b10cf"},
+ {file = "asyncpg-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb622c94db4e13137c4c7f98834185049cc50ee01d8f657ef898b6407c7b9c50"},
+ {file = "asyncpg-0.30.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e0511ad3dec5f6b4f7a9e063591d407eee66b88c14e2ea636f187da1dcfff6a"},
+ {file = "asyncpg-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:915aeb9f79316b43c3207363af12d0e6fd10776641a7de8a01212afd95bdf0ed"},
+ {file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c198a00cce9506fcd0bf219a799f38ac7a237745e1d27f0e1f66d3707c84a5a"},
+ {file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3326e6d7381799e9735ca2ec9fd7be4d5fef5dcbc3cb555d8a463d8460607956"},
+ {file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51da377487e249e35bd0859661f6ee2b81db11ad1f4fc036194bc9cb2ead5056"},
+ {file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc6d84136f9c4d24d358f3b02be4b6ba358abd09f80737d1ac7c444f36108454"},
+ {file = "asyncpg-0.30.0-cp311-cp311-win32.whl", hash = "sha256:574156480df14f64c2d76450a3f3aaaf26105869cad3865041156b38459e935d"},
+ {file = "asyncpg-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:3356637f0bd830407b5597317b3cb3571387ae52ddc3bca6233682be88bbbc1f"},
+ {file = "asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e"},
+ {file = "asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a"},
+ {file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3"},
+ {file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737"},
+ {file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a"},
+ {file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af"},
+ {file = "asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e"},
+ {file = "asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305"},
+ {file = "asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70"},
+ {file = "asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3"},
+ {file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33"},
+ {file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4"},
+ {file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4"},
+ {file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba"},
+ {file = "asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590"},
+ {file = "asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e"},
+ {file = "asyncpg-0.30.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:29ff1fc8b5bf724273782ff8b4f57b0f8220a1b2324184846b39d1ab4122031d"},
+ {file = "asyncpg-0.30.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64e899bce0600871b55368b8483e5e3e7f1860c9482e7f12e0a771e747988168"},
+ {file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b290f4726a887f75dcd1b3006f484252db37602313f806e9ffc4e5996cfe5cb"},
+ {file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f86b0e2cd3f1249d6fe6fd6cfe0cd4538ba994e2d8249c0491925629b9104d0f"},
+ {file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:393af4e3214c8fa4c7b86da6364384c0d1b3298d45803375572f415b6f673f38"},
+ {file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:fd4406d09208d5b4a14db9a9dbb311b6d7aeeab57bded7ed2f8ea41aeef39b34"},
+ {file = "asyncpg-0.30.0-cp38-cp38-win32.whl", hash = "sha256:0b448f0150e1c3b96cb0438a0d0aa4871f1472e58de14a3ec320dbb2798fb0d4"},
+ {file = "asyncpg-0.30.0-cp38-cp38-win_amd64.whl", hash = "sha256:f23b836dd90bea21104f69547923a02b167d999ce053f3d502081acea2fba15b"},
+ {file = "asyncpg-0.30.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f4e83f067b35ab5e6371f8a4c93296e0439857b4569850b178a01385e82e9ad"},
+ {file = "asyncpg-0.30.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5df69d55add4efcd25ea2a3b02025b669a285b767bfbf06e356d68dbce4234ff"},
+ {file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3479a0d9a852c7c84e822c073622baca862d1217b10a02dd57ee4a7a081f708"},
+ {file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26683d3b9a62836fad771a18ecf4659a30f348a561279d6227dab96182f46144"},
+ {file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1b982daf2441a0ed314bd10817f1606f1c28b1136abd9e4f11335358c2c631cb"},
+ {file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1c06a3a50d014b303e5f6fc1e5f95eb28d2cee89cf58384b700da621e5d5e547"},
+ {file = "asyncpg-0.30.0-cp39-cp39-win32.whl", hash = "sha256:1b11a555a198b08f5c4baa8f8231c74a366d190755aa4f99aacec5970afe929a"},
+ {file = "asyncpg-0.30.0-cp39-cp39-win_amd64.whl", hash = "sha256:8b684a3c858a83cd876f05958823b68e8d14ec01bb0c0d14a6704c5bf9711773"},
+ {file = "asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851"},
+]
+
+[package.extras]
+docs = ["Sphinx (>=8.1.3,<8.2.0)", "sphinx-rtd-theme (>=1.2.2)"]
+gssauth = ["gssapi", "sspilib"]
+test = ["distro (>=1.9.0,<1.10.0)", "flake8 (>=6.1,<7.0)", "flake8-pyi (>=24.1.0,<24.2.0)", "gssapi", "k5test", "mypy (>=1.8.0,<1.9.0)", "sspilib", "uvloop (>=0.15.3)"]
+
[[package]]
name = "authlib"
version = "1.3.2"
@@ -1027,6 +1108,17 @@ files = [
{file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"},
]
+[[package]]
+name = "nest-asyncio"
+version = "1.6.0"
+description = "Patch asyncio to allow nested event loops"
+optional = false
+python-versions = ">=3.5"
+files = [
+ {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"},
+ {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"},
+]
+
[[package]]
name = "nodeenv"
version = "1.9.1"
@@ -1428,6 +1520,24 @@ pluggy = ">=1.5,<2"
[package.extras]
dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"]
+[[package]]
+name = "pytest-asyncio"
+version = "0.24.0"
+description = "Pytest support for asyncio"
+optional = false
+python-versions = ">=3.8"
+files = [
+ {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"},
+ {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"},
+]
+
+[package.dependencies]
+pytest = ">=8.2,<9"
+
+[package.extras]
+docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"]
+testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"]
+
[[package]]
name = "pytest-base-url"
version = "2.1.0"
@@ -1909,7 +2019,7 @@ files = [
]
[package.dependencies]
-greenlet = {version = "!=0.4.17", markers = "python_version < \"3.13\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"}
+greenlet = {version = "!=0.4.17", optional = true, markers = "python_version < \"3.13\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\") or extra == \"asyncio\""}
typing-extensions = ">=4.6.0"
[package.extras]
@@ -1967,13 +2077,13 @@ url = ["furl (>=0.4.1)"]
[[package]]
name = "starlette"
-version = "0.40.0"
+version = "0.41.2"
description = "The little ASGI library that shines."
optional = false
python-versions = ">=3.8"
files = [
- {file = "starlette-0.40.0-py3-none-any.whl", hash = "sha256:c494a22fae73805376ea6bf88439783ecfba9aac88a43911b48c653437e784c4"},
- {file = "starlette-0.40.0.tar.gz", hash = "sha256:1a3139688fb298ce5e2d661d37046a66ad996ce94be4d4983be019a23a04ea35"},
+ {file = "starlette-0.41.2-py3-none-any.whl", hash = "sha256:fbc189474b4731cf30fcef52f18a8d070e3f3b46c6a04c97579e85e6ffca942d"},
+ {file = "starlette-0.41.2.tar.gz", hash = "sha256:9834fd799d1a87fd346deb76158668cfa0b0d56f85caefe8268e2d97c3468b62"},
]
[package.dependencies]
@@ -2352,4 +2462,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = "^3.12"
-content-hash = "06db522ccdbb19c0c0061b5533ca30e03ec51755ba3a2888b80f77a5c28c7e45"
+content-hash = "a342f51a83d9cd19dd54d039ad978433c49edb17b98b03f34891f27933228b4e"
diff --git a/prod.env b/prod.env
index 001d4690..47f619a3 100644
--- a/prod.env
+++ b/prod.env
@@ -5,6 +5,7 @@ ENVIRONMENT=production
BACKEND_CORS_ORIGINS="http://localhost,https://localhost,http://127.0.0.1,https://127.0.0.1"
SECRET_KEY=changethis
APP_DATABASE_SCHEME="postgresql"
+APP_DATABASE_DRIVER="asyncpg"
APP_DATABASE_USER=amt
APP_DATABASE_DB=amt
APP_DATABASE_PASSWORD=changethis
diff --git a/pyproject.toml b/pyproject.toml
index bd617d90..b52952de 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,15 +36,19 @@ pyyaml-include = "^2.1"
click = "^8.1.7"
python-ulid = "^3.0.0"
fastapi-csrf-protect = "^0.3.4"
-sqlalchemy = "^2.0.36"
+sqlalchemy = {extras = ["asyncio"], version = "^2.0.36"}
sqlalchemy-utils = "^0.41.2"
liccheck = "^0.9.2"
authlib = "^1.3.2"
-pytest-mock = "^3.14.0"
+aiosqlite = "^0.20.0"
+asyncpg = "^0.30.0"
[tool.poetry.group.test.dependencies]
pytest = "^8.3.3"
+pytest-asyncio = "^0.24.0"
+nest-asyncio = "^1.6.0"
+pytest-mock = "^3.14.0"
coverage = "^7.6.4"
playwright = "^1.47.0"
pytest-playwright = "^0.5.2"
@@ -137,6 +141,7 @@ markers = [
"enable_auth: marks tests that require authentication"
]
+
[tool.liccheck]
level = "PARANOID"
dependencies = true
diff --git a/tests/api/routes/test_health.py b/tests/api/routes/test_health.py
index 49edab92..626c1866 100644
--- a/tests/api/routes/test_health.py
+++ b/tests/api/routes/test_health.py
@@ -1,8 +1,10 @@
-from fastapi.testclient import TestClient
+import pytest
+from httpx import AsyncClient
-def test_health_ready(client: TestClient) -> None:
- response = client.get(
+@pytest.mark.asyncio
+async def test_health_ready(client: AsyncClient) -> None:
+ response = await client.get(
"/health/ready",
)
assert response.status_code == 200
@@ -10,8 +12,9 @@ def test_health_ready(client: TestClient) -> None:
assert response.json() == {"status": "ok"}
-def test_health_live(client: TestClient) -> None:
- response = client.get(
+@pytest.mark.asyncio
+async def test_health_live(client: AsyncClient) -> None:
+ response = await client.get(
"/health/live",
)
assert response.status_code == 200
diff --git a/tests/api/routes/test_pages.py b/tests/api/routes/test_pages.py
index 59a0f9d7..cb5acfc2 100644
--- a/tests/api/routes/test_pages.py
+++ b/tests/api/routes/test_pages.py
@@ -1,11 +1,13 @@
-from fastapi.testclient import TestClient
+import pytest
+from httpx import AsyncClient
from tests.database_test_utils import DatabaseTestUtils
-def test_get_main_page(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_main_page(client: AsyncClient, db: DatabaseTestUtils) -> None:
# when
- response = client.get("/pages/")
+ response = await client.get("/pages/")
# then
assert response.status_code == 200
diff --git a/tests/api/routes/test_project.py b/tests/api/routes/test_project.py
index 463ba775..8e662279 100644
--- a/tests/api/routes/test_project.py
+++ b/tests/api/routes/test_project.py
@@ -3,16 +3,17 @@
import pytest
from amt.api.routes.project import set_path
from amt.models import Project
-from fastapi.testclient import TestClient
+from httpx import AsyncClient
from pytest_mock import MockFixture
from tests.constants import default_project, default_task
from tests.database_test_utils import DatabaseTestUtils
-def test_get_unknown_project(client: TestClient) -> None:
+@pytest.mark.asyncio
+async def test_get_unknown_project(client: AsyncClient) -> None:
# when
- response = client.get("/algorithm-system/1")
+ response = await client.get("/algorithm-system/1")
# then
assert response.status_code == 404
@@ -20,12 +21,13 @@ def test_get_unknown_project(client: TestClient) -> None:
assert b"The requested page or resource could not be found." in response.content
-def test_get_project_tasks(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_project_tasks(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)])
+ await db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)])
# when
- response = client.get("/algorithm-system/1/details/tasks")
+ response = await client.get("/algorithm-system/1/details/tasks")
# then
assert response.status_code == 200
@@ -36,12 +38,13 @@ def test_get_project_tasks(client: TestClient, db: DatabaseTestUtils) -> None:
# TODO: Test are now have hard coded URL paths because the system card
# is fixed for now. Tests need to be refactored and made proper once
# the actual stored system card in a project is being rendered.
-def test_get_system_card(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_system_card(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1")])
+ await db.given([default_project("testproject1")])
# when
- response = client.get("/algorithm-system/1/details/system_card")
+ response = await client.get("/algorithm-system/1/details/system_card")
# then
assert response.status_code == 200
@@ -52,9 +55,10 @@ def test_get_system_card(client: TestClient, db: DatabaseTestUtils) -> None:
# TODO: Test are now have hard coded URL paths because the system card
# is fixed for now. Tests need to be refactored and made proper once
# the actual stored system card in a project is being rendered.
-def test_get_system_card_unknown_project(client: TestClient) -> None:
+@pytest.mark.asyncio
+async def test_get_system_card_unknown_project(client: AsyncClient) -> None:
# when
- response = client.get("/algorithm-system/1/details/system_card")
+ response = await client.get("/algorithm-system/1/details/system_card")
# then
assert response.status_code == 404
@@ -65,12 +69,13 @@ def test_get_system_card_unknown_project(client: TestClient) -> None:
# TODO: Test are now have hard coded URL paths because the system card
# is fixed for now. Tests need to be refactored and made proper once
# the actual stored system card in a project is being rendered.
-def test_get_assessment_card(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_assessment_card(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1")])
+ await db.given([default_project("testproject1")])
# when
- response = client.get("/algorithm-system/1/details/system_card/assessments/iama")
+ response = await client.get("/algorithm-system/1/details/system_card/assessments/iama")
# then
assert response.status_code == 200
@@ -81,9 +86,10 @@ def test_get_assessment_card(client: TestClient, db: DatabaseTestUtils) -> None:
# TODO: Test are now have hard coded URL paths because the system card
# is fixed for now. Tests need to be refactored and made proper once
# the actual stored system card in a project is being rendered.
-def test_get_assessment_card_unknown_project(client: TestClient) -> None:
+@pytest.mark.asyncio
+async def test_get_assessment_card_unknown_project(client: AsyncClient) -> None:
# when
- response = client.get("/algorithm-system/1/details/system_card/assessments/iama")
+ response = await client.get("/algorithm-system/1/details/system_card/assessments/iama")
# then
assert response.status_code == 404
@@ -94,12 +100,13 @@ def test_get_assessment_card_unknown_project(client: TestClient) -> None:
# TODO: Test are now have hard coded URL paths because the system card
# is fixed for now. Tests need to be refactored and made proper once
# the actual stored system card in a project is being rendered.
-def test_get_assessment_card_unknown_assessment(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_assessment_card_unknown_assessment(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1")])
+ await db.given([default_project("testproject1")])
# when
- response = client.get("/algorithm-system/1/details/system_card/assessments/nonexistent")
+ response = await client.get("/algorithm-system/1/details/system_card/assessments/nonexistent")
# then
assert response.status_code == 404
@@ -110,12 +117,13 @@ def test_get_assessment_card_unknown_assessment(client: TestClient, db: Database
# TODO: Test are now have hard coded URL paths because the system card
# is fixed for now. Tests need to be refactored and made proper once
# the actual stored system card in a project is being rendered.
-def test_get_model_card(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_model_card(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1")])
+ await db.given([default_project("testproject1")])
# when
- response = client.get("/algorithm-system/1/details/system_card/models/logres_iris")
+ response = await client.get("/algorithm-system/1/details/system_card/models/logres_iris")
# then
assert response.status_code == 200
@@ -126,9 +134,10 @@ def test_get_model_card(client: TestClient, db: DatabaseTestUtils) -> None:
# TODO: Test are now have hard coded URL paths because the system card
# is fixed for now. Tests need to be refactored and made proper once
# the actual stored system card in a project is being rendered.
-def test_get_model_card_unknown_project(client: TestClient) -> None:
+@pytest.mark.asyncio
+async def test_get_model_card_unknown_project(client: AsyncClient) -> None:
# when
- response = client.get("/algorithm-system/1/details/system_card/models/logres_iris")
+ response = await client.get("/algorithm-system/1/details/system_card/models/logres_iris")
# then
assert response.status_code == 404
@@ -139,12 +148,13 @@ def test_get_model_card_unknown_project(client: TestClient) -> None:
# TODO: Test are now have hard coded URL paths because the system card
# is fixed for now. Tests need to be refactored and made proper once
# the actual stored system card in a project is being rendered.
-def test_get_assessment_card_unknown_model_card(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_assessment_card_unknown_model_card(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1")])
+ await db.given([default_project("testproject1")])
# when
- response = client.get("/algorithm-system/1/details/system_card/models/nonexistent")
+ response = await client.get("/algorithm-system/1/details/system_card/models/nonexistent")
# then
assert response.status_code == 404
@@ -152,12 +162,13 @@ def test_get_assessment_card_unknown_model_card(client: TestClient, db: Database
assert b"The requested page or resource could not be found." in response.content
-def test_get_project_details(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_project_details(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)])
+ await db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)])
# when
- response = client.get("/algorithm-system/1/details")
+ response = await client.get("/algorithm-system/1/details")
# then
assert response.status_code == 200
@@ -165,12 +176,13 @@ def test_get_project_details(client: TestClient, db: DatabaseTestUtils) -> None:
assert b"Details" in response.content
-def test_get_system_card_requirements(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_system_card_requirements(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)])
+ await db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)])
# when
- response = client.get("/algorithm-system/1/details/system_card/requirements")
+ response = await client.get("/algorithm-system/1/details/system_card/requirements")
# then
assert response.status_code == 200
@@ -178,12 +190,13 @@ def test_get_system_card_requirements(client: TestClient, db: DatabaseTestUtils)
assert b"0" in response.content
-def test_get_system_card_data_page(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_system_card_data_page(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)])
+ await db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)])
# when
- response = client.get("/algorithm-system/1/details/system_card/data")
+ response = await client.get("/algorithm-system/1/details/system_card/data")
# then
assert response.status_code == 200
@@ -191,12 +204,13 @@ def test_get_system_card_data_page(client: TestClient, db: DatabaseTestUtils) ->
assert b"To be implemented" in response.content
-def test_get_system_card_instruments(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_system_card_instruments(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)])
+ await db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)])
# when
- response = client.get("/algorithm-system/1/details/system_card/instruments")
+ response = await client.get("/algorithm-system/1/details/system_card/instruments")
# then
assert response.status_code == 200
@@ -204,12 +218,13 @@ def test_get_system_card_instruments(client: TestClient, db: DatabaseTestUtils)
assert b"To be implemented" in response.content
-def test_get_project_edit(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_project_edit(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1")])
+ await db.given([default_project("testproject1")])
# when
- response = client.get("/algorithm-system/1/edit/system_card/lifecycle")
+ response = await client.get("/algorithm-system/1/edit/system_card/lifecycle")
# then
assert response.status_code == 200
@@ -218,12 +233,13 @@ def test_get_project_edit(client: TestClient, db: DatabaseTestUtils) -> None:
assert b"lifecycle" in response.content
-def test_get_project_cancel(client: TestClient, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_project_cancel(client: AsyncClient, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1")])
+ await db.given([default_project("testproject1")])
# when
- response = client.get("/algorithm-system/1/cancel/system_card/lifecycle")
+ response = await client.get("/algorithm-system/1/cancel/system_card/lifecycle")
# then
assert response.status_code == 200
@@ -232,14 +248,17 @@ def test_get_project_cancel(client: TestClient, db: DatabaseTestUtils) -> None:
assert b"lifecycle" in response.content
-def test_get_project_update(client: TestClient, mocker: MockFixture, db: DatabaseTestUtils) -> None:
+@pytest.mark.asyncio
+async def test_get_project_update(client: AsyncClient, mocker: MockFixture, db: DatabaseTestUtils) -> None:
# given
- db.given([default_project("testproject1")])
+ await db.given([default_project("testproject1")])
client.cookies["fastapi-csrf-token"] = "1"
mocker.patch("fastapi_csrf_protect.CsrfProtect.validate_csrf", new_callable=mocker.AsyncMock)
# when
- response = client.put("/algorithm-system/1/update/name", json={"value": "Test Name"}, headers={"X-CSRF-Token": "1"})
+ response = await client.put(
+ "/algorithm-system/1/update/name", json={"value": "Test Name"}, headers={"X-CSRF-Token": "1"}
+ )
# then
assert response.status_code == 200
@@ -248,7 +267,7 @@ def test_get_project_update(client: TestClient, mocker: MockFixture, db: Databas
assert b"Test Name" in response.content
# Verify that the project was updated in the database
- updated_projects = db.get(Project, "id", 1)
+ updated_projects = await db.get(Project, "id", 1)
assert len(updated_projects) == 1
updated_project = updated_projects[0]
assert updated_project.name == "Test Name" # type: ignore
diff --git a/tests/api/routes/test_projects.py b/tests/api/routes/test_projects.py
index cdd09045..82e884e4 100644
--- a/tests/api/routes/test_projects.py
+++ b/tests/api/routes/test_projects.py
@@ -1,5 +1,6 @@
from typing import cast
+import pytest
from amt.api.routes.projects import get_localized_value
from amt.models import Project
from amt.models.base import Base
@@ -8,35 +9,39 @@
from amt.schema.system_card import SystemCard
from amt.services.task_registry import get_requirements_and_measures
from fastapi.requests import Request
-from fastapi.testclient import TestClient
+from httpx import AsyncClient
from pytest_mock import MockFixture
from tests.constants import default_instrument
from tests.database_test_utils import DatabaseTestUtils
-def test_projects_get_root(client: TestClient) -> None:
- response = client.get("/algorithm-systems/")
+@pytest.mark.asyncio
+async def test_projects_get_root(client: AsyncClient) -> None:
+ response = await client.get("/algorithm-systems/")
assert response.status_code == 200
assert b'' in response.content
-def test_projects_get_root_missing_slash(client: TestClient) -> None:
- response = client.get("/algorithm-systems")
+@pytest.mark.asyncio
+async def test_projects_get_root_missing_slash(client: AsyncClient) -> None:
+ response = await client.get("/algorithm-systems", follow_redirects=True)
assert response.status_code == 200
assert b'
' in response.content
-def test_projects_get_root_htmx(client: TestClient) -> None:
- response = client.get("/algorithm-systems/", headers={"HX-Request": "true"})
+@pytest.mark.asyncio
+async def test_projects_get_root_htmx(client: AsyncClient) -> None:
+ response = await client.get("/algorithm-systems/", headers={"HX-Request": "true"})
assert response.status_code == 200
assert b'
' not in response.content
-def test_get_new_projects(client: TestClient, mocker: MockFixture) -> None:
+@pytest.mark.asyncio
+async def test_get_new_projects(client: AsyncClient, mocker: MockFixture) -> None:
# given
mocker.patch(
"amt.services.instruments.InstrumentsService.fetch_instruments",
@@ -44,7 +49,7 @@ def test_get_new_projects(client: TestClient, mocker: MockFixture) -> None:
)
# when
- response = client.get("/algorithm-systems/new")
+ response = await client.get("/algorithm-systems/new")
assert response.status_code == 200
assert response.headers["content-type"] == "text/html; charset=utf-8"
content = " ".join(response.content.decode().split())
@@ -59,13 +64,14 @@ def test_get_new_projects(client: TestClient, mocker: MockFixture) -> None:
)
-def test_post_new_projects_bad_request(client: TestClient, mocker: MockFixture) -> None:
+@pytest.mark.asyncio
+async def test_post_new_projects_bad_request(client: AsyncClient, mocker: MockFixture) -> None:
# given
mocker.patch("fastapi_csrf_protect.CsrfProtect.validate_csrf", new_callable=mocker.AsyncMock)
# when
client.cookies["fastapi-csrf-token"] = "1"
- response = client.post("/algorithm-systems/new", json={}, headers={"X-CSRF-Token": "1"})
+ response = await client.post("/algorithm-systems/new", json={}, headers={"X-CSRF-Token": "1"})
# then
assert response.status_code == 400
@@ -73,7 +79,8 @@ def test_post_new_projects_bad_request(client: TestClient, mocker: MockFixture)
assert b"Field required" in response.content
-def test_post_new_projects(client: TestClient, mocker: MockFixture) -> None:
+@pytest.mark.asyncio
+async def test_post_new_projects(client: AsyncClient, mocker: MockFixture) -> None:
client.cookies["fastapi-csrf-token"] = "1"
new_project = ProjectNew(
name="default project",
@@ -93,7 +100,7 @@ def test_post_new_projects(client: TestClient, mocker: MockFixture) -> None:
)
# when
- response = client.post("/algorithm-systems/new", json=new_project.model_dump(), headers={"X-CSRF-Token": "1"})
+ response = await client.post("/algorithm-systems/new", json=new_project.model_dump(), headers={"X-CSRF-Token": "1"})
# then
assert response.status_code == 200
@@ -101,8 +108,9 @@ def test_post_new_projects(client: TestClient, mocker: MockFixture) -> None:
assert response.headers["HX-Redirect"] == "/algorithm-system/1/details/tasks"
-def test_post_new_projects_write_system_card(
- client: TestClient,
+@pytest.mark.asyncio
+async def test_post_new_projects_write_system_card(
+ client: AsyncClient,
mocker: MockFixture,
db: DatabaseTestUtils,
) -> None:
@@ -148,10 +156,10 @@ def test_post_new_projects_write_system_card(
)
# when
- client.post("/algorithm-systems/new", json=project_new.model_dump(), headers={"X-CSRF-Token": "1"})
+ await client.post("/algorithm-systems/new", json=project_new.model_dump(), headers={"X-CSRF-Token": "1"})
# then
- base_projects: list[Base] = db.get(Project, "name", name)
+ base_projects: list[Base] = await db.get(Project, "name", name)
projects: list[Project] = cast(list[Project], base_projects)
assert any(project.system_card == system_card for project in projects if project.system_card is not None)
diff --git a/tests/api/routes/test_root.py b/tests/api/routes/test_root.py
index 99d7a9c2..bd41694a 100644
--- a/tests/api/routes/test_root.py
+++ b/tests/api/routes/test_root.py
@@ -1,8 +1,10 @@
-from fastapi.testclient import TestClient
+import pytest
+from httpx import AsyncClient
-def test_get_root(client: TestClient) -> None:
- response = client.get(
+@pytest.mark.asyncio
+async def test_get_root(client: AsyncClient) -> None:
+ response = await client.get(
"/",
follow_redirects=False,
)
diff --git a/tests/api/routes/test_static.py b/tests/api/routes/test_static.py
index a73c255c..9f6d392f 100644
--- a/tests/api/routes/test_static.py
+++ b/tests/api/routes/test_static.py
@@ -1,10 +1,12 @@
from pathlib import Path
-from fastapi.testclient import TestClient
+import pytest
+from httpx import AsyncClient
-def test_static_css(client: TestClient) -> None:
+@pytest.mark.asyncio
+async def test_static_css(client: AsyncClient) -> None:
files = Path("amt/site/static/dist").glob("*.js")
for filename in files:
- response = client.get(f"/static/dist/{filename.name}")
+ response = await client.get(f"/static/dist/{filename.name}")
assert response.status_code == 200
diff --git a/tests/api/test_http_headers.py b/tests/api/test_http_headers.py
index b18d4668..21b797e0 100644
--- a/tests/api/test_http_headers.py
+++ b/tests/api/test_http_headers.py
@@ -1,8 +1,10 @@
-from fastapi.testclient import TestClient
+import pytest
+from httpx import AsyncClient
-def test_sts_header(client: TestClient) -> None:
- response = client.get(
+@pytest.mark.asyncio
+async def test_sts_header(client: AsyncClient) -> None:
+ response = await client.get(
"/",
)
assert response.status_code == 200
diff --git a/tests/conftest.py b/tests/conftest.py
index 25b57bde..ae1c1631 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -1,26 +1,33 @@
+import asyncio
import logging
import os
import re
-from collections.abc import Callable, Generator
+from collections.abc import AsyncIterator, Callable, Generator
from multiprocessing import Process
from pathlib import Path
from typing import Any
import httpx
+import nest_asyncio # type: ignore [(reportMissingTypeStubs)]
import pytest
+import pytest_asyncio
import uvicorn
from amt.models.base import Base
from amt.server import create_app
-from fastapi.testclient import TestClient
+from httpx import ASGITransport, AsyncClient
from playwright.sync_api import Browser
-from sqlalchemy import create_engine, text
-from sqlalchemy.orm import Session
+from sqlalchemy import text
+from sqlalchemy.ext.asyncio import create_async_engine
+from sqlalchemy.ext.asyncio.session import async_sessionmaker
from tests.database_e2e_setup import setup_database_e2e
from tests.database_test_utils import DatabaseTestUtils
logger = logging.getLogger(__name__)
+# Dubious choice here: allow nested event loops.
+nest_asyncio.apply() # type: ignore [(reportUnknownMemberType)]
+
def run_server_uvicorn(database_file: Path, host: str = "127.0.0.1", port: int = 3462) -> None:
os.environ["APP_DATABASE_FILE"] = "/" + str(database_file)
@@ -32,32 +39,40 @@ def run_server_uvicorn(database_file: Path, host: str = "127.0.0.1", port: int =
@pytest.fixture(scope="session")
-def setup_db_and_server(
+async def setup_db_and_server(
tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest
-) -> Generator[Any, None, None]:
+) -> AsyncIterator[str]:
test_dir = tmp_path_factory.mktemp("e2e_database")
database_file = test_dir / "test.sqlite3"
if request.config.getoption("--db") == "postgresql":
- engine = create_engine(get_db_uri())
+ engine = create_async_engine(get_db_uri())
else:
- engine = create_engine(f"sqlite:///{database_file}", connect_args={"check_same_thread": False})
+ engine = create_async_engine(f"sqlite+aiosqlite:///{database_file}", connect_args={"check_same_thread": False})
+
metadata = Base.metadata
- metadata.create_all(engine)
- with Session(engine, expire_on_commit=False) as session:
- setup_database_e2e(session)
+ async with engine.begin() as connection:
+ await connection.run_sync(metadata.create_all)
+
+ async_session = async_sessionmaker(engine, expire_on_commit=False)
+
+ async with async_session() as session:
+ await setup_database_e2e(session)
process = Process(target=run_server_uvicorn, args=(database_file,))
process.start()
yield "http://127.0.0.1:3462"
process.terminate()
- metadata.drop_all(engine)
+
+ async with engine.begin() as conn:
+ await conn.run_sync(metadata.drop_all)
@pytest.fixture(autouse=True)
def disable_auth(request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch) -> None: # noqa: PT004
- marker = request.node.get_closest_marker("enable_auth") # type: ignore
+ marker = request.node.get_closest_marker("enable_auth") # type: ignore [(reportUnknownMemberType)]
+
if not marker:
monkeypatch.setenv("DISABLE_AUTH", "true")
return
@@ -94,20 +109,20 @@ def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config
items[:] = tests + e2e_tests
-@pytest.fixture
-def client(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch) -> Generator[TestClient, None, None]:
+@pytest_asyncio.fixture # type: ignore [(reportUnknownMemberType)]
+async def client(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch) -> AsyncIterator[AsyncClient]:
# overwrite db url
monkeypatch.setenv("APP_DATABASE_FILE", "/" + str(db.get_database_file()))
from amt.repositories.deps import get_session
app = create_app()
- with TestClient(app, raise_server_exceptions=True) as c:
+ async with AsyncClient(
+ transport=ASGITransport(app=app, raise_app_exceptions=True), base_url="http://testserver/"
+ ) as ac:
app.dependency_overrides[get_session] = db.get_session
-
- c.timeout = 5
-
- yield c
+ ac.timeout = 5
+ yield ac
@pytest.fixture(scope="session")
@@ -117,43 +132,54 @@ def browser_context_args(browser_context_args: dict[str, Any]) -> dict[str, Any]
@pytest.fixture(scope="session")
def browser(
- launch_browser: Callable[[], Browser], setup_db_and_server: Generator[str, None, None]
+ launch_browser: Callable[[], Browser],
+ setup_db_and_server: AsyncIterator[str],
) -> Generator[Browser, None, None]:
transport = httpx.HTTPTransport(retries=5)
+
+ loop = asyncio.get_event_loop()
+ url = loop.run_until_complete(setup_db_and_server.__anext__())
+
with httpx.Client(transport=transport, verify=False, timeout=0.7) as client: # noqa: S501
- client.get(f"{setup_db_and_server}/")
+ client.get(f"{url}/")
browser = launch_browser()
yield browser
browser.close()
+ # Clean up by consuming the rest of the generator.
+ try: # noqa
+ loop.run_until_complete(setup_db_and_server.__anext__())
+ except StopAsyncIteration:
+ pass
def get_db_uri() -> str:
user = os.getenv("APP_DATABASE_USER", "amt")
password = os.getenv("APP_DATABASE_PASSWORD", "changethis")
server = os.getenv("APP_DATABASE_SERVER", "db")
+ driver = os.getenv("APP_DATABASE_DRIVER", "asyncpg")
port = os.getenv("APP_DATABASE_PORT", "5432")
db = os.getenv("APP_DATABASE_DB", "amt")
- return f"postgresql://{user}:{password}@{server}:{port}/{db}"
+ return f"postgresql+{driver}://{user}:{password}@{server}:{port}/{db}"
-def create_db(new_db: str) -> str:
+async def create_db(new_db: str) -> str:
url = get_db_uri()
- engine = create_engine(url, isolation_level="AUTOCOMMIT")
+ engine = create_async_engine(url, isolation_level="AUTOCOMMIT")
if new_db == os.getenv("APP_DATABASE_USER", "amt"):
return url
logger.info(f"Creating database {new_db}")
user = os.getenv("APP_DATABASE_USER", "amt")
- with Session(engine) as session:
- session.execute(text(f"DROP DATABASE IF EXISTS {new_db};")) # type: ignore
- session.execute(text(f"CREATE DATABASE {new_db} OWNER {user};")) # type: ignore
- session.commit()
- path = Path(url)
+ async with engine.connect() as conn:
+ await conn.execute(text(f"DROP DATABASE IF EXISTS {new_db};")) # type: ignore
+ await conn.execute(text(f"CREATE DATABASE {new_db} OWNER {user};")) # type: ignore
+ await conn.commit()
- return str(path.parent / new_db).replace("postgresql:/", "postgresql://")
+ path = Path(url)
+ return str(path.parent / new_db).replace("postgresql+asyncpg:/", "postgresql+asyncpg://")
def generate_db_name(request: pytest.FixtureRequest) -> str:
@@ -172,20 +198,23 @@ def generate_db_name(request: pytest.FixtureRequest) -> str:
return sanitized_name
-@pytest.fixture
-def db(
+@pytest_asyncio.fixture # type: ignore [(reportUnknownMemberType)]
+async def db(
tmp_path: Path, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch
-) -> Generator[DatabaseTestUtils, None, None]:
+) -> AsyncIterator[DatabaseTestUtils]:
database_file = tmp_path / "test.sqlite3"
if request.config.getoption("--db") == "postgresql":
db_name: str = generate_db_name(request)
- url = create_db(db_name)
+ url = await create_db(db_name)
monkeypatch.setenv("APP_DATABASE_DB", db_name)
- engine = create_engine(url)
+ engine = create_async_engine(url)
else:
- engine = create_engine(f"sqlite:///{database_file}", connect_args={"check_same_thread": False})
- Base.metadata.create_all(engine)
+ engine = create_async_engine(f"sqlite+aiosqlite:///{database_file}", connect_args={"check_same_thread": False})
+
+ async with engine.begin() as conn:
+ await conn.run_sync(Base.metadata.create_all)
- with Session(engine, expire_on_commit=False) as session:
+ async_session = async_sessionmaker(engine, expire_on_commit=False)
+ async with async_session() as session:
yield DatabaseTestUtils(session, database_file)
diff --git a/tests/core/test_config.py b/tests/core/test_config.py
index 84246983..1040b10b 100644
--- a/tests/core/test_config.py
+++ b/tests/core/test_config.py
@@ -7,6 +7,7 @@ def test_environment_settings(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setenv("ENVIRONMENT", "production")
monkeypatch.setenv("SECRET_KEY", "mysecret")
monkeypatch.setenv("APP_DATABASE_SCHEME", "postgresql")
+ monkeypatch.setenv("APP_DATABASE_DRIVER", "asyncpg")
monkeypatch.setenv("APP_DATABASE_USER", "amt2")
monkeypatch.setenv("APP_DATABASE_DB", "amt2")
monkeypatch.setenv("APP_DATABASE_PASSWORD", "mypassword")
@@ -18,11 +19,12 @@ def test_environment_settings(monkeypatch: pytest.MonkeyPatch):
assert settings.ENVIRONMENT == "production"
assert settings.LOGGING_LEVEL == "INFO"
assert settings.APP_DATABASE_SCHEME == "postgresql"
+ assert settings.APP_DATABASE_DRIVER == "asyncpg"
assert settings.APP_DATABASE_SERVER == "db"
assert settings.APP_DATABASE_PORT == 5432
assert settings.APP_DATABASE_USER == "amt2"
assert settings.APP_DATABASE_DB == "amt2"
- assert settings.SQLALCHEMY_DATABASE_URI == "postgresql://amt2:mypassword@db:5432/amt2"
+ assert settings.SQLALCHEMY_DATABASE_URI == "postgresql+asyncpg://amt2:mypassword@db:5432/amt2"
def test_environment_settings_production_sqlite_error(monkeypatch: pytest.MonkeyPatch):
diff --git a/tests/core/test_db.py b/tests/core/test_db.py
index 83036149..43442e81 100644
--- a/tests/core/test_db.py
+++ b/tests/core/test_db.py
@@ -7,18 +7,19 @@
)
from pytest_mock import MockFixture
from sqlalchemy import select
-from sqlalchemy.orm import Session
+from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
-def test_check_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, mocker: MockFixture):
+@pytest.mark.asyncio
+async def test_check_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, mocker: MockFixture):
database_file = tmp_path / "database.sqlite3"
monkeypatch.setenv("APP_DATABASE_FILE", str(database_file))
- org_exec = Session.execute
- Session.execute = mocker.MagicMock()
- check_db()
+ org_exec = AsyncSession.execute
+ AsyncSession.execute = mocker.AsyncMock()
+ await check_db()
- assert Session.execute.call_args is not None
- assert str(select(1)) == str(Session.execute.call_args.args[0])
- Session.execute = org_exec
+ assert AsyncSession.execute.call_args is not None
+ assert str(select(1)) == str(AsyncSession.execute.call_args.args[0])
+ AsyncSession.execute = org_exec
diff --git a/tests/core/test_exception_handlers.py b/tests/core/test_exception_handlers.py
index 2dec95ba..c75fb73e 100644
--- a/tests/core/test_exception_handlers.py
+++ b/tests/core/test_exception_handlers.py
@@ -2,25 +2,25 @@
from amt.core.exceptions import AMTCSRFProtectError
from amt.schema.project import ProjectNew
from fastapi import status
-from fastapi.testclient import TestClient
+from httpx import AsyncClient
-def test_http_exception_handler(client: TestClient):
- response = client.get("/raise-http-exception")
+async def test_http_exception_handler(client: AsyncClient):
+ response = await client.get("/raise-http-exception")
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.headers["content-type"] == "text/html; charset=utf-8"
-def test_request_validation_exception_handler(client: TestClient):
- response = client.get("/algorithm-systems/?skip=a")
+async def test_request_validation_exception_handler(client: AsyncClient):
+ response = await client.get("/algorithm-systems/?skip=a")
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.headers["content-type"] == "text/html; charset=utf-8"
-def test_request_csrf_protect_exception_handler_invalid_token_in_header(client: TestClient):
- data = client.get("/algorithm-systems/new")
+async def test_request_csrf_protect_exception_handler_invalid_token_in_header(client: AsyncClient):
+ data = await client.get("/algorithm-systems/new")
new_project = ProjectNew(name="default project", lifecycle="DATA_EXPLORATION_AND_PREPARATION")
with pytest.raises(AMTCSRFProtectError):
_response = client.post(
@@ -28,22 +28,22 @@ def test_request_csrf_protect_exception_handler_invalid_token_in_header(client:
)
-def test_http_exception_handler_htmx(client: TestClient):
- response = client.get("/raise-http-exception", headers={"HX-Request": "true"})
+async def test_http_exception_handler_htmx(client: AsyncClient):
+ response = await client.get("/raise-http-exception", headers={"HX-Request": "true"})
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.headers["content-type"] == "text/html; charset=utf-8"
-def test_request_validation_exception_handler_htmx(client: TestClient):
- response = client.get("/algorithm-systems/?skip=a", headers={"HX-Request": "true"})
+async def test_request_validation_exception_handler_htmx(client: AsyncClient):
+ response = await client.get("/algorithm-systems/?skip=a", headers={"HX-Request": "true"})
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.headers["content-type"] == "text/html; charset=utf-8"
-def test_request_csrf_protect_exception_handler_invalid_token(client: TestClient):
- data = client.get("/algorithm-systems/new")
+async def test_request_csrf_protect_exception_handler_invalid_token(client: AsyncClient):
+ data = await client.get("/algorithm-systems/new")
new_project = ProjectNew(name="default project", lifecycle="DATA_EXPLORATION_AND_PREPARATION")
with pytest.raises(AMTCSRFProtectError):
_response = client.post(
@@ -54,8 +54,8 @@ def test_request_csrf_protect_exception_handler_invalid_token(client: TestClient
)
-def test_(client: TestClient):
- response = client.get("/algorithm-systems/?skip=a", headers={"HX-Request": "true"})
+async def test_(client: AsyncClient):
+ response = await client.get("/algorithm-systems/?skip=a", headers={"HX-Request": "true"})
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.headers["content-type"] == "text/html; charset=utf-8"
diff --git a/tests/database_e2e_setup.py b/tests/database_e2e_setup.py
index 7c89d9c4..59355614 100644
--- a/tests/database_e2e_setup.py
+++ b/tests/database_e2e_setup.py
@@ -1,24 +1,24 @@
from amt.enums.status import Status
from amt.models import Project
-from sqlalchemy.orm import Session
+from sqlalchemy.ext.asyncio.session import AsyncSession
from tests.constants import default_project, default_task, default_user
from tests.database_test_utils import DatabaseTestUtils
-def setup_database_e2e(session: Session) -> None:
+async def setup_database_e2e(session: AsyncSession) -> None:
db_e2e = DatabaseTestUtils(session)
- db_e2e.given([default_user()])
+ await db_e2e.given([default_user()])
projects: list[Project] = []
for idx in range(120):
projects.append(default_project(name=f"Project {idx}"))
- db_e2e.given([*projects])
+ await db_e2e.given([*projects])
task1 = default_task(title="task1", status_id=Status.TODO, sort_order=-3, project_id=projects[0].id)
task2 = default_task(title="task2", status_id=Status.TODO, sort_order=-2, project_id=projects[0].id)
task3 = default_task(title="task3", status_id=Status.TODO, sort_order=-1, project_id=projects[0].id)
- db_e2e.given([task1, task2, task3])
+ await db_e2e.given([task1, task2, task3])
diff --git a/tests/database_test_utils.py b/tests/database_test_utils.py
index a20e99e1..e36577a6 100644
--- a/tests/database_test_utils.py
+++ b/tests/database_test_utils.py
@@ -3,38 +3,42 @@
from amt.models.base import Base
from sqlalchemy import select
-from sqlalchemy.orm import Session
+from sqlalchemy.ext.asyncio import AsyncSession
logger = logging.getLogger(__name__)
class DatabaseTestUtils:
- def __init__(self, session: Session, database_file: Path | None = None) -> None:
- self.session: Session = session
+ def __init__(self, session: AsyncSession, database_file: Path | None = None) -> None:
+ self.session: AsyncSession = session
self.database_file: Path | None = database_file
- def given(self, models: list[Base]) -> None:
+ async def given(self, models: list[Base]) -> None:
session = self.get_session()
session.add_all(models)
- session.commit()
+ await session.commit()
for model in models:
- session.refresh(model) # inefficient, but needed to create correlations between models
+ await session.refresh(model) # inefficient, but needed to create correlations between models
- def get_session(self) -> Session:
+ def get_session(self) -> AsyncSession:
return self.session
def get_database_file(self) -> Path | None:
return self.database_file
- def exists(self, model: type[Base], model_field: str, field_value: str | int) -> Base | None:
- return self.get_session().execute(select(model).where(model_field == field_value)).one() is not None # type: ignore
+ async def exists(self, model: type[Base], model_field: str, field_value: str | int) -> bool:
+ try:
+ result = await self.get_session().execute(select(model).where(getattr(model, model_field) == field_value))
+ return result.scalar_one_or_none() is not None
+ except AttributeError as err:
+ raise ValueError(f"Field '{model_field}' does not exist in model {model.__name__}") from err
- def get(self, model: type[Base], model_field: str, field_value: str | int) -> list[Base]:
+ async def get(self, model: type[Base], model_field: str, field_value: str | int) -> list[Base]:
try:
query = select(model).where(getattr(model, model_field) == field_value)
- result = self.get_session().execute(query)
+ result = await self.get_session().execute(query)
return list(result.scalars().all())
except AttributeError as err:
raise ValueError(f"Field '{model_field}' does not exist in model {model.__name__}") from err
diff --git a/tests/middleware/test_authorization.py b/tests/middleware/test_authorization.py
index 043802ca..61480836 100644
--- a/tests/middleware/test_authorization.py
+++ b/tests/middleware/test_authorization.py
@@ -1,10 +1,11 @@
import pytest
-from fastapi.testclient import TestClient
+from httpx import AsyncClient
+@pytest.mark.asyncio
@pytest.mark.enable_auth
-def test_auth_not_project(client: TestClient) -> None:
- response = client.get("/projects/")
+async def test_auth_not_project(client: AsyncClient) -> None:
+ response = await client.get("/projects/", follow_redirects=True)
assert response.status_code == 200
assert response.url == "http://testserver/"
diff --git a/tests/repositories/test_deps.py b/tests/repositories/test_deps.py
index 228fb794..f4ff4c85 100644
--- a/tests/repositories/test_deps.py
+++ b/tests/repositories/test_deps.py
@@ -1,10 +1,16 @@
+import pytest
from amt.repositories.deps import get_session
-from sqlalchemy.orm import Session
+from sqlalchemy.ext.asyncio import AsyncSession
-def test_get_session():
+@pytest.mark.asyncio
+async def test_get_session():
session_generator = get_session()
+ session = await anext(session_generator)
+ assert isinstance(session, AsyncSession)
- session = next(session_generator)
-
- assert isinstance(session, Session)
+ # Clean up by consuming the rest of the generator
+ try: # noqa
+ await session_generator.aclose()
+ except StopAsyncIteration:
+ pass
diff --git a/tests/repositories/test_projects.py b/tests/repositories/test_projects.py
index b0abefc9..5d02aebb 100644
--- a/tests/repositories/test_projects.py
+++ b/tests/repositories/test_projects.py
@@ -6,102 +6,113 @@
from tests.database_test_utils import DatabaseTestUtils
-def test_find_all(db: DatabaseTestUtils):
- db.given([default_project(), default_project()])
+@pytest.mark.asyncio
+async def test_find_all(db: DatabaseTestUtils):
+ await db.given([default_project(), default_project()])
project_repository = ProjectsRepository(db.get_session())
- results = project_repository.find_all()
+ results = await project_repository.find_all()
assert results[0].id == 1
assert results[1].id == 2
assert len(results) == 2
-def test_find_all_no_results(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_find_all_no_results(db: DatabaseTestUtils):
project_repository = ProjectsRepository(db.get_session())
- results = project_repository.find_all()
+ results = await project_repository.find_all()
assert len(results) == 0
-def test_save(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_save(db: DatabaseTestUtils):
project_repository = ProjectsRepository(db.get_session())
project = default_project()
- project_repository.save(project)
+ await project_repository.save(project)
- result = project_repository.find_by_id(1)
+ result = await project_repository.find_by_id(1)
- project_repository.delete(project) # cleanup
+ await project_repository.delete(project) # cleanup
assert result.id == 1
assert result.name == default_project().name
-def test_delete(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_delete(db: DatabaseTestUtils):
project_repository = ProjectsRepository(db.get_session())
project = default_project()
- project_repository.save(project)
- project_repository.delete(project)
+ await project_repository.save(project)
+ await project_repository.delete(project)
- results = project_repository.find_all()
+ results = await project_repository.find_all()
assert len(results) == 0
-def test_save_failed(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_save_failed(db: DatabaseTestUtils):
project_repository = ProjectsRepository(db.get_session())
project = default_project()
project.id = 1
project_duplicate = default_project()
project_duplicate.id = 1
- project_repository.save(project)
+ await project_repository.save(project)
with pytest.raises(AMTRepositoryError):
- project_repository.save(project_duplicate)
+ await project_repository.save(project_duplicate)
- project_repository.delete(project) # cleanup
+ await project_repository.delete(project) # cleanup
-def test_delete_failed(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_delete_failed(db: DatabaseTestUtils):
project_repository = ProjectsRepository(db.get_session())
project = default_project()
with pytest.raises(AMTRepositoryError):
- project_repository.delete(project)
+ await project_repository.delete(project)
-def test_find_by_id(db: DatabaseTestUtils):
- db.given([default_project()])
+@pytest.mark.asyncio
+async def test_find_by_id(db: DatabaseTestUtils):
+ await db.given([default_project()])
project_repository = ProjectsRepository(db.get_session())
- result = project_repository.find_by_id(1)
+ result = await project_repository.find_by_id(1)
assert result.id == 1
assert result.name == default_project().name
-def test_find_by_id_failed(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_find_by_id_failed(db: DatabaseTestUtils):
project_repository = ProjectsRepository(db.get_session())
with pytest.raises(AMTRepositoryError):
- project_repository.find_by_id(1)
+ await project_repository.find_by_id(1)
-def test_paginate(db: DatabaseTestUtils):
- db.given([default_project()])
+@pytest.mark.asyncio
+async def test_paginate(db: DatabaseTestUtils):
+ await db.given([default_project()])
project_repository = ProjectsRepository(db.get_session())
- result: list[Project] = project_repository.paginate(skip=0, limit=3, search="", filters={}, sort={})
+ result: list[Project] = await project_repository.paginate(skip=0, limit=3, search="", filters={}, sort={})
assert len(result) == 1
-def test_paginate_more(db: DatabaseTestUtils):
- db.given([default_project(), default_project(), default_project(), default_project()])
+@pytest.mark.asyncio
+async def test_paginate_more(db: DatabaseTestUtils):
+ await db.given([default_project(), default_project(), default_project(), default_project()])
project_repository = ProjectsRepository(db.get_session())
- result: list[Project] = project_repository.paginate(skip=0, limit=3, search="", filters={}, sort={})
+ result: list[Project] = await project_repository.paginate(skip=0, limit=3, search="", filters={}, sort={})
assert len(result) == 3
-def test_paginate_capitalize(db: DatabaseTestUtils):
- db.given(
+@pytest.mark.asyncio
+async def test_paginate_capitalize(db: DatabaseTestUtils):
+ await db.given(
[
default_project(name="Project1"),
default_project(name="bbb"),
@@ -111,7 +122,7 @@ def test_paginate_capitalize(db: DatabaseTestUtils):
)
project_repository = ProjectsRepository(db.get_session())
- result: list[Project] = project_repository.paginate(skip=0, limit=4, search="", filters={}, sort={})
+ result: list[Project] = await project_repository.paginate(skip=0, limit=4, search="", filters={}, sort={})
assert len(result) == 4
assert result[0].name == "Aaa"
@@ -120,8 +131,9 @@ def test_paginate_capitalize(db: DatabaseTestUtils):
assert result[3].name == "Project1"
-def test_search(db: DatabaseTestUtils):
- db.given(
+@pytest.mark.asyncio
+async def test_search(db: DatabaseTestUtils):
+ await db.given(
[
default_project(name="Project1"),
default_project(name="bbb"),
@@ -131,14 +143,15 @@ def test_search(db: DatabaseTestUtils):
)
project_repository = ProjectsRepository(db.get_session())
- result: list[Project] = project_repository.paginate(skip=0, limit=4, search="bbb", filters={}, sort={})
+ result: list[Project] = await project_repository.paginate(skip=0, limit=4, search="bbb", filters={}, sort={})
assert len(result) == 1
assert result[0].name == "bbb"
-def test_search_multiple(db: DatabaseTestUtils):
- db.given(
+@pytest.mark.asyncio
+async def test_search_multiple(db: DatabaseTestUtils):
+ await db.given(
[
default_project(name="Project1"),
default_project(name="bbb"),
@@ -148,22 +161,24 @@ def test_search_multiple(db: DatabaseTestUtils):
)
project_repository = ProjectsRepository(db.get_session())
- result: list[Project] = project_repository.paginate(skip=0, limit=4, search="A", filters={}, sort={})
+ result: list[Project] = await project_repository.paginate(skip=0, limit=4, search="A", filters={}, sort={})
assert len(result) == 2
assert result[0].name == "Aaa"
assert result[1].name == "aba"
-def test_search_no_results(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_search_no_results(db: DatabaseTestUtils):
project_repository = ProjectsRepository(db.get_session())
- result: list[Project] = project_repository.paginate(skip=0, limit=4, search="A", filters={}, sort={})
+ result: list[Project] = await project_repository.paginate(skip=0, limit=4, search="A", filters={}, sort={})
assert len(result) == 0
-def test_raises_exception(db: DatabaseTestUtils):
- db.given([default_project()])
+@pytest.mark.asyncio
+async def test_raises_exception(db: DatabaseTestUtils):
+ await db.given([default_project()])
project_repository = ProjectsRepository(db.get_session())
with pytest.raises(AMTRepositoryError):
- project_repository.paginate(skip="a", limit=3, search="", filters={}, sort={}) # type: ignore
+ await project_repository.paginate(skip="a", limit=3, search="", filters={}, sort={}) # type: ignore
diff --git a/tests/repositories/test_tasks.py b/tests/repositories/test_tasks.py
index 6ecad641..418d1f6c 100644
--- a/tests/repositories/test_tasks.py
+++ b/tests/repositories/test_tasks.py
@@ -7,43 +7,47 @@
from tests.database_test_utils import DatabaseTestUtils
-def test_find_all(db: DatabaseTestUtils):
- db.given([default_task(), default_task()])
+@pytest.mark.asyncio
+async def test_find_all(db: DatabaseTestUtils):
+ await db.given([default_task(), default_task()])
tasks_repository: TasksRepository = TasksRepository(db.get_session())
- results = tasks_repository.find_all()
+ results = await tasks_repository.find_all()
assert results[0].id == 1
assert results[1].id == 2
assert len(results) == 2
-def test_find_all_no_results(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_find_all_no_results(db: DatabaseTestUtils):
tasks_repository: TasksRepository = TasksRepository(db.get_session())
- results = tasks_repository.find_all()
+ results = await tasks_repository.find_all()
assert len(results) == 0
-def test_save(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_save(db: DatabaseTestUtils):
tasks_repository: TasksRepository = TasksRepository(db.get_session())
task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10)
- tasks_repository.save(task)
- result = tasks_repository.find_by_id(1)
+ await tasks_repository.save(task)
+ result = await tasks_repository.find_by_id(1)
assert result.id == 1
assert result.title == "Test title"
assert result.description == "Test description"
assert result.sort_order == 10
- tasks_repository.delete(task) # cleanup
+ await tasks_repository.delete(task) # cleanup
-def test_save_all(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_save_all(db: DatabaseTestUtils):
tasks_repository: TasksRepository = TasksRepository(db.get_session())
task_1: Task = Task(id=1, title="Test title 1", description="Test description 1", sort_order=10)
task_2: Task = Task(id=2, title="Test title 2", description="Test description 2", sort_order=11)
- tasks_repository.save_all([task_1, task_2])
- result_1 = tasks_repository.find_by_id(1)
- result_2 = tasks_repository.find_by_id(2)
+ await tasks_repository.save_all([task_1, task_2])
+ result_1 = await tasks_repository.find_by_id(1)
+ result_2 = await tasks_repository.find_by_id(2)
assert result_1.id == 1
assert result_1.title == "Test title 1"
@@ -55,84 +59,92 @@ def test_save_all(db: DatabaseTestUtils):
assert result_2.description == "Test description 2"
assert result_2.sort_order == 11
- tasks_repository.delete(task_1) # cleanup
- tasks_repository.delete(task_2) # cleanup
+ await tasks_repository.delete(task_1) # cleanup
+ await tasks_repository.delete(task_2) # cleanup
-def test_save_all_failed(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_save_all_failed(db: DatabaseTestUtils):
tasks_repository: TasksRepository = TasksRepository(db.get_session())
task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10)
- tasks_repository.save_all([task])
+ await tasks_repository.save_all([task])
task_duplicate: Task = Task(id=1, title="Test title duplicate", description="Test description", sort_order=10)
with pytest.raises(AMTRepositoryError):
- tasks_repository.save_all([task_duplicate])
+ await tasks_repository.save_all([task_duplicate])
- tasks_repository.delete(task) # cleanup
+ await tasks_repository.delete(task) # cleanup
-def test_delete_task(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_delete_task(db: DatabaseTestUtils):
tasks_repository: TasksRepository = TasksRepository(db.get_session())
task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10)
- tasks_repository.save(task)
- tasks_repository.delete(task) # cleanup
+ await tasks_repository.save(task)
+ await tasks_repository.delete(task) # cleanup
- results = tasks_repository.find_all()
+ results = await tasks_repository.find_all()
assert len(results) == 0
-def test_save_failed(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_save_failed(db: DatabaseTestUtils):
tasks_repository: TasksRepository = TasksRepository(db.get_session())
task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10)
- tasks_repository.save(task)
+ await tasks_repository.save(task)
task_duplicate: Task = Task(id=1, title="Test title duplicate", description="Test description", sort_order=10)
with pytest.raises(AMTRepositoryError):
- tasks_repository.save(task_duplicate)
+ await tasks_repository.save(task_duplicate)
- tasks_repository.delete(task) # cleanup
+ await tasks_repository.delete(task) # cleanup
-def test_delete_failed(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_delete_failed(db: DatabaseTestUtils):
tasks_repository: TasksRepository = TasksRepository(db.get_session())
task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10)
with pytest.raises(AMTRepositoryError):
- tasks_repository.delete(task)
+ await tasks_repository.delete(task)
-def test_find_by_id(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_find_by_id(db: DatabaseTestUtils):
task = default_task()
- db.given([task])
+ await db.given([task])
tasks_repository: TasksRepository = TasksRepository(db.get_session())
- result: Task = tasks_repository.find_by_id(1)
+ result: Task = await tasks_repository.find_by_id(1)
assert result.id == 1
assert result.title == "Default Task"
-def test_find_by_id_failed(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_find_by_id_failed(db: DatabaseTestUtils):
tasks_repository: TasksRepository = TasksRepository(db.get_session())
with pytest.raises(AMTRepositoryError):
- tasks_repository.find_by_id(1)
+ await tasks_repository.find_by_id(1)
-def test_find_by_status_id(db: DatabaseTestUtils):
+@pytest.mark.asyncio
+async def test_find_by_status_id(db: DatabaseTestUtils):
task = default_task(status_id=Status.TODO)
- db.given([task, default_task()])
+ await db.given([task, default_task()])
tasks_repository: TasksRepository = TasksRepository(db.get_session())
- results = tasks_repository.find_by_status_id(1)
+ results = await tasks_repository.find_by_status_id(1)
assert len(results) == 1
assert results[0].id == 1
-def test_find_by_project_id_and_status_id(db: DatabaseTestUtils):
- db.given([default_project()])
+@pytest.mark.asyncio
+async def test_find_by_project_id_and_status_id(db: DatabaseTestUtils):
+ await db.given([default_project()])
task = default_task(status_id=Status.TODO, project_id=1)
- db.given([task, default_task()])
+ await db.given([task, default_task()])
tasks_repository: TasksRepository = TasksRepository(db.get_session())
- results = tasks_repository.find_by_project_id_and_status_id(1, 1)
+ results = await tasks_repository.find_by_project_id_and_status_id(1, 1)
assert len(results) == 1
assert results[0].id == 1
assert results[0].project_id == 1
diff --git a/tests/services/test_projects_service.py b/tests/services/test_projects_service.py
index 8efc584e..02bd592c 100644
--- a/tests/services/test_projects_service.py
+++ b/tests/services/test_projects_service.py
@@ -1,3 +1,4 @@
+import pytest
from amt.models.project import Project
from amt.repositories.projects import ProjectsRepository
from amt.schema.project import ProjectNew
@@ -9,40 +10,42 @@
from tests.constants import default_instrument
-def test_get_project(mocker: MockFixture):
+@pytest.mark.asyncio
+async def test_get_project(mocker: MockFixture):
# Given
project_id = 1
project_name = "Project 1"
project_lifecycle = "development"
projects_service = ProjectsService(
- repository=mocker.Mock(spec=ProjectsRepository),
- task_service=mocker.Mock(spec=TasksService),
- instrument_service=mocker.Mock(spec=InstrumentsService),
+ repository=mocker.AsyncMock(spec=ProjectsRepository),
+ task_service=mocker.AsyncMock(spec=TasksService),
+ instrument_service=mocker.AsyncMock(spec=InstrumentsService),
)
projects_service.repository.find_by_id.return_value = Project( # type: ignore
id=project_id, name=project_name, lifecycle=project_lifecycle
)
# When
- project = projects_service.get(project_id)
+ project = await projects_service.get(project_id)
# Then
assert project.id == project_id
assert project.name == project_name
assert project.lifecycle == project_lifecycle
- projects_service.repository.find_by_id.assert_called_once_with(project_id) # type: ignore
+ projects_service.repository.find_by_id.assert_awaited_once_with(project_id) # type: ignore
-def test_create_project(mocker: MockFixture):
+@pytest.mark.asyncio
+async def test_create_project(mocker: MockFixture):
project_id = 1
project_name = "Project 1"
project_lifecycle = "development"
system_card = SystemCard(name=project_name)
projects_service = ProjectsService(
- repository=mocker.Mock(spec=ProjectsRepository),
- task_service=mocker.Mock(spec=TasksService),
- instrument_service=mocker.Mock(spec=InstrumentsService),
+ repository=mocker.AsyncMock(spec=ProjectsRepository),
+ task_service=mocker.AsyncMock(spec=TasksService),
+ instrument_service=mocker.AsyncMock(spec=InstrumentsService),
)
projects_service.repository.save.return_value = Project( # type: ignore
id=project_id, name=project_name, lifecycle=project_lifecycle, system_card=system_card
@@ -61,10 +64,10 @@ def test_create_project(mocker: MockFixture):
transparency_obligations="project_transparency_obligations",
role="project_role",
)
- project = projects_service.create(project_new)
+ project = await projects_service.create(project_new)
# Then
assert project.id == project_id
assert project.name == project_name
assert project.lifecycle == project_lifecycle
- projects_service.repository.save.assert_called() # type: ignore
+ projects_service.repository.save.assert_awaited() # type: ignore
diff --git a/tests/services/test_tasks_service.py b/tests/services/test_tasks_service.py
index 90149ded..9df971c7 100644
--- a/tests/services/test_tasks_service.py
+++ b/tests/services/test_tasks_service.py
@@ -38,19 +38,19 @@ def reset(self):
def find_all(self):
return self._tasks
- def find_by_status_id(self, status_id: int) -> Sequence[Task]:
+ async def find_by_status_id(self, status_id: int) -> Sequence[Task]:
return list(filter(lambda x: x.status_id == status_id, self._tasks))
- def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> Sequence[Task]:
+ async def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> Sequence[Task]:
return list(filter(lambda x: x.status_id == status_id and x.project_id == project_id, self._tasks))
- def find_by_id(self, task_id: int) -> Task:
+ async def find_by_id(self, task_id: int) -> Task:
return next(filter(lambda x: x.id == task_id, self._tasks))
- def save(self, task: Task) -> Task:
+ async def save(self, task: Task) -> Task:
return task
- def save_all(self, tasks: Sequence[Task]) -> None:
+ async def save_all(self, tasks: Sequence[Task]) -> None:
for task in tasks:
self._tasks.append(task)
return None
@@ -69,22 +69,30 @@ def tasks_service_with_mock(mock_tasks_repository: TasksRepository):
return tasks_service
-def test_get_tasks(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository):
- assert len(tasks_service_with_mock.get_tasks(1)) == 3
+@pytest.mark.asyncio
+async def test_get_tasks(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository):
+ tasks = await tasks_service_with_mock.get_tasks(1)
+ assert len(tasks) == 3
-def test_get_tasks_for_project(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository):
- assert len(tasks_service_with_mock.get_tasks_for_project(1, 1)) == 2
+@pytest.mark.asyncio
+async def test_get_tasks_for_project(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository):
+ tasks = await tasks_service_with_mock.get_tasks_for_project(1, 1)
+ assert len(tasks) == 2
-def test_assign_task(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository):
- task1: Task = mock_tasks_repository.find_by_id(1)
+@pytest.mark.asyncio
+async def test_assign_task(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository):
+ task1: Task = await mock_tasks_repository.find_by_id(1)
user1: User = User(id=1, name="User 1", avatar="none.jpg")
- tasks_service_with_mock.assign_task(task1, user1)
+ await tasks_service_with_mock.assign_task(task1, user1)
assert task1.user_id == 1
-def test_create_instrument_tasks(tasks_service_with_mock: TasksService, mock_tasks_repository: MockTasksRepository):
+@pytest.mark.asyncio
+async def test_create_instrument_tasks(
+ tasks_service_with_mock: TasksService, mock_tasks_repository: MockTasksRepository
+):
# Given
task_1 = InstrumentTask(question="question_1", urn="instrument_1_task_1", lifecycle=[])
task_2 = InstrumentTask(question="question_1", urn="instrument_1_task_1", lifecycle=[])
@@ -95,48 +103,61 @@ def test_create_instrument_tasks(tasks_service_with_mock: TasksService, mock_tas
# When
mock_tasks_repository.clear()
- tasks_service_with_mock.create_instrument_tasks([task_1, task_2], project)
+ await tasks_service_with_mock.create_instrument_tasks([task_1, task_2], project)
# Then
- assert len(mock_tasks_repository.find_all()) == 2
- assert mock_tasks_repository.find_all()[0].project_id == 1
- assert mock_tasks_repository.find_all()[0].title == task_1.question
- assert mock_tasks_repository.find_all()[1].project_id == 1
- assert mock_tasks_repository.find_all()[1].title == task_2.question
+ tasks = mock_tasks_repository.find_all()
+ assert len(tasks) == 2
+ assert tasks[0].project_id == 1
+ assert tasks[0].title == task_1.question
+ assert tasks[1].project_id == 1
+ assert tasks[1].title == task_2.question
-def test_move_task(tasks_service_with_mock: TasksService, mock_tasks_repository: MockTasksRepository):
+@pytest.mark.asyncio
+async def test_move_task(tasks_service_with_mock: TasksService, mock_tasks_repository: MockTasksRepository):
# test changing order between 2 cards
mock_tasks_repository.reset()
- assert mock_tasks_repository.find_by_id(1).sort_order == 10
- tasks_service_with_mock.move_task(1, 1, 2, 3)
- assert mock_tasks_repository.find_by_id(1).sort_order == 25
+ task = await mock_tasks_repository.find_by_id(1)
+ assert task.sort_order == 10
+ await tasks_service_with_mock.move_task(1, 1, 2, 3)
+ task = await mock_tasks_repository.find_by_id(1)
+ assert task.sort_order == 25
# test changing order, after the last card
mock_tasks_repository.reset()
- tasks_service_with_mock.move_task(1, 1, 3, None)
- assert mock_tasks_repository.find_by_id(1).sort_order == 40
+ await tasks_service_with_mock.move_task(1, 1, 3, None)
+ task = await mock_tasks_repository.find_by_id(1)
+ assert task.sort_order == 40
# test changing order, before the first card
mock_tasks_repository.reset()
- assert mock_tasks_repository.find_by_id(3).sort_order == 30
- tasks_service_with_mock.move_task(3, 1, None, 1)
- assert mock_tasks_repository.find_by_id(3).sort_order == 5
+ task = await mock_tasks_repository.find_by_id(3)
+ assert task.sort_order == 30
+ await tasks_service_with_mock.move_task(3, 1, None, 1)
+ task = await mock_tasks_repository.find_by_id(3)
+ assert task.sort_order == 5
# test moving to in progress
mock_tasks_repository.reset()
- mock_tasks_repository.find_by_id(1).sort_order = 0
- tasks_service_with_mock.move_task(1, 2)
- assert mock_tasks_repository.find_by_id(1).sort_order == 10
+ task = await mock_tasks_repository.find_by_id(1)
+ task.sort_order = 0
+ await tasks_service_with_mock.move_task(1, 2)
+ task = await mock_tasks_repository.find_by_id(1)
+ assert task.sort_order == 10
# test moving to todo
mock_tasks_repository.reset()
- mock_tasks_repository.find_by_id(1).sort_order = 0
- tasks_service_with_mock.move_task(1, 4)
- assert mock_tasks_repository.find_by_id(1).sort_order == 10
+ task = await mock_tasks_repository.find_by_id(1)
+ task.sort_order = 0
+ await tasks_service_with_mock.move_task(1, 4)
+ task = await mock_tasks_repository.find_by_id(1)
+ assert task.sort_order == 10
# test moving move under other card
mock_tasks_repository.reset()
- mock_tasks_repository.find_by_id(2).sort_order = 10
- tasks_service_with_mock.move_task(2, 1, None, 1)
- assert mock_tasks_repository.find_by_id(2).sort_order == 5
+ task = await mock_tasks_repository.find_by_id(2)
+ task.sort_order = 10
+ await tasks_service_with_mock.move_task(2, 1, None, 1)
+ task = await mock_tasks_repository.find_by_id(2)
+ assert task.sort_order == 5
diff --git a/tests/test_main.py b/tests/test_main.py
index 226f9e5f..18e92866 100644
--- a/tests/test_main.py
+++ b/tests/test_main.py
@@ -1,8 +1,10 @@
-from fastapi.testclient import TestClient
+import pytest
+from httpx import AsyncClient
-def test_get_non_exisiting(client: TestClient) -> None:
- response = client.get(
+@pytest.mark.asyncio
+async def test_get_non_exisiting(client: AsyncClient) -> None:
+ response = await client.get(
"/pathdoesnotexist/",
)
assert response.status_code == 404