diff --git a/amt/api/routes/project.py b/amt/api/routes/project.py index 9e241c25..58a5d5b4 100644 --- a/amt/api/routes/project.py +++ b/amt/api/routes/project.py @@ -1,5 +1,7 @@ +import asyncio import functools import logging +from collections.abc import Sequence from pathlib import Path from typing import Annotated, Any, cast @@ -18,6 +20,7 @@ from amt.core.exceptions import AMTNotFound, AMTRepositoryError from amt.enums.status import Status from amt.models import Project +from amt.models.task import Task from amt.schema.measure import ExtendedMeasureTask, MeasureTask from amt.schema.requirement import RequirementTask from amt.schema.system_card import SystemCard @@ -65,10 +68,10 @@ def get_requirements_state(system_card: SystemCard) -> dict[str, Any]: } -def get_project_or_error(project_id: int, projects_service: ProjectsService, request: Request) -> Project: +async def get_project_or_error(project_id: int, projects_service: ProjectsService, request: Request) -> Project: try: logger.debug(f"getting project with id {project_id}") - project = projects_service.get(project_id) + project = await projects_service.get(project_id) request.state.path_variables = {"project_id": project_id} except AMTRepositoryError as e: raise AMTNotFound from e @@ -98,6 +101,14 @@ def get_projects_submenu_items() -> list[BaseNavigationItem]: ] +async def gather_project_tasks(project_id: int, task_service: TasksService) -> dict[Status, Sequence[Task]]: + fetch_tasks = [task_service.get_tasks_for_project(project_id, status + 0) for status in Status] + + results = await asyncio.gather(*fetch_tasks) + + return dict(zip(Status, results, strict=True)) + + @router.get("/{project_id}/details/tasks") async def get_tasks( request: Request, @@ -105,10 +116,11 @@ async def get_tasks( projects_service: Annotated[ProjectsService, Depends(ProjectsService)], tasks_service: Annotated[TasksService, Depends(TasksService)], ) -> HTMLResponse: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) instrument_state = get_instrument_state() requirements_state = get_requirements_state(project.system_card) tab_items = get_project_details_tabs(request) + tasks_by_status = await gather_project_tasks(project_id, task_service=tasks_service) breadcrumbs = resolve_base_navigation_items( [ @@ -124,7 +136,7 @@ async def get_tasks( context = { "instrument_state": instrument_state, "requirements_state": requirements_state, - "tasks_service": tasks_service, + "tasks_by_status": tasks_by_status, "statuses": Status, "project": project, "project_id": project.id, @@ -153,7 +165,7 @@ async def move_task( moved_task.next_sibling_id = None if moved_task.previous_sibling_id == -1: moved_task.previous_sibling_id = None - task = tasks_service.move_task( + task = await tasks_service.move_task( moved_task.id, moved_task.status_id, moved_task.previous_sibling_id, @@ -168,7 +180,7 @@ async def move_task( async def get_project_context( project_id: int, projects_service: ProjectsService, request: Request ) -> tuple[Project, dict[str, Any]]: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) system_card_data = get_system_card_data() instrument_state = get_instrument_state() requirements_state = get_requirements_state(project.system_card) @@ -268,7 +280,7 @@ async def get_project_update( ) -> HTMLResponse: project, context = await get_project_context(project_id, projects_service, request) set_path(project, path, update_data.value) - projects_service.update(project) + await projects_service.update(project) context["path"] = path.replace("/", ".") return templates.TemplateResponse(request, "parts/view_cell.html.j2", context) @@ -284,7 +296,7 @@ async def get_system_card( project_id: int, projects_service: Annotated[ProjectsService, Depends(ProjectsService)], ) -> HTMLResponse: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) instrument_state = get_instrument_state() requirements_state = get_requirements_state(project.system_card) @@ -326,7 +338,7 @@ async def get_system_card( async def get_project_inference( request: Request, project_id: int, projects_service: Annotated[ProjectsService, Depends(ProjectsService)] ) -> HTMLResponse: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) breadcrumbs = resolve_base_navigation_items( [ @@ -370,7 +382,7 @@ async def get_system_card_requirements( project_id: int, projects_service: Annotated[ProjectsService, Depends(ProjectsService)], ) -> HTMLResponse: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) instrument_state = get_instrument_state() requirements_state = get_requirements_state(project.system_card) # TODO: This tab is fairly slow, fix in later releases @@ -464,7 +476,7 @@ async def get_measure( measure_urn: str, projects_service: Annotated[ProjectsService, Depends(ProjectsService)], ) -> HTMLResponse: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) measure = task_registry.fetch_measures([measure_urn]) measure_task = find_measure_task(project.system_card, measure_urn) @@ -491,7 +503,7 @@ async def update_measure_value( measure_update: MeasureUpdate, projects_service: Annotated[ProjectsService, Depends(ProjectsService)], ) -> HTMLResponse: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) measure_task = find_measure_task(project.system_card, measure_urn) measure_task.state = measure_update.measure_state # pyright: ignore [reportOptionalMemberAccess] @@ -517,7 +529,7 @@ async def update_measure_value( else: requirement_task.state = "in progress" # pyright: ignore [reportOptionalMemberAccess] - projects_service.update(project) + await projects_service.update(project) # TODO: FIX THIS!! The page now reloads at the top, which is annoying return templates.Redirect(request, f"/algorithm-system/{project_id}/details/system_card/requirements") @@ -533,7 +545,7 @@ async def get_system_card_data_page( project_id: int, projects_service: Annotated[ProjectsService, Depends(ProjectsService)], ) -> HTMLResponse: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) instrument_state = get_instrument_state() requirements_state = get_requirements_state(project.system_card) @@ -573,7 +585,7 @@ async def get_system_card_instruments( project_id: int, projects_service: Annotated[ProjectsService, Depends(ProjectsService)], ) -> HTMLResponse: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) instrument_state = get_instrument_state() requirements_state = get_requirements_state(project.system_card) @@ -614,7 +626,7 @@ async def get_assessment_card( assessment_card: str, projects_service: Annotated[ProjectsService, Depends(ProjectsService)], ) -> HTMLResponse: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) instrument_state = get_instrument_state() requirements_state = get_requirements_state(project.system_card) @@ -665,7 +677,7 @@ async def get_model_card( model_card: str, projects_service: Annotated[ProjectsService, Depends(ProjectsService)], ) -> HTMLResponse: - project = get_project_or_error(project_id, projects_service, request) + project = await get_project_or_error(project_id, projects_service, request) instrument_state = get_instrument_state() requirements_state = get_requirements_state(project.system_card) diff --git a/amt/api/routes/projects.py b/amt/api/routes/projects.py index f0b0dfc4..a7028b1a 100644 --- a/amt/api/routes/projects.py +++ b/amt/api/routes/projects.py @@ -62,7 +62,7 @@ async def get_root( k.removeprefix("sort-by-"): v for k, v in request.query_params.items() if k.startswith("sort-by-") and v != "" } - projects = projects_service.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort_by) + projects = await projects_service.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort_by) # todo: the lifecycle has to be 'localized', maybe for display 'Project' should become a different object for project in projects: project.lifecycle = get_localized_lifecycle(project.lifecycle, request) # pyright: ignore [reportAttributeAccessIssue] @@ -131,6 +131,6 @@ async def post_new( project_new.systemic_risk = "geen systeemrisico" project_new.open_source = "open-source" - project = projects_service.create(project_new) + project = await projects_service.create(project_new) response = templates.Redirect(request, f"/algorithm-system/{project.id}/details/tasks") return response diff --git a/amt/core/config.py b/amt/core/config.py index 014dba9b..6a8413bc 100644 --- a/amt/core/config.py +++ b/amt/core/config.py @@ -44,7 +44,7 @@ class Settings(BaseSettings): # todo(berry): create submodel for database settings APP_DATABASE_SCHEME: DatabaseSchemaType = "sqlite" - APP_DATABASE_DRIVER: str | None = None + APP_DATABASE_DRIVER: str | None = "aiosqlite" APP_DATABASE_SERVER: str = "db" APP_DATABASE_PORT: int = 5432 diff --git a/amt/core/db.py b/amt/core/db.py index 5bb05d08..26933508 100644 --- a/amt/core/db.py +++ b/amt/core/db.py @@ -1,8 +1,7 @@ import logging -from sqlalchemy import create_engine, select -from sqlalchemy.engine import Engine -from sqlalchemy.orm import Session +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, create_async_engine from amt.core.config import get_settings from amt.models.base import Base @@ -10,30 +9,32 @@ logger = logging.getLogger(__name__) -def get_engine() -> Engine: +def get_engine() -> AsyncEngine: settings = get_settings() connect_args = {"check_same_thread": False} if settings.APP_DATABASE_SCHEME == "sqlite" else {} - return create_engine( + return create_async_engine( settings.SQLALCHEMY_DATABASE_URI, # pyright: ignore [reportArgumentType] connect_args=connect_args, echo=settings.SQLALCHEMY_ECHO, ) -def check_db() -> None: +async def check_db() -> None: logger.info("Checking database connection") - with Session(get_engine()) as session: - session.execute(select(1)) + async with AsyncSession(get_engine()) as session: + await session.execute(select(1)) logger.info("Finish Checking database connection") -def init_db() -> None: +async def init_db() -> None: logger.info("Initializing database") if get_settings().AUTO_CREATE_SCHEMA: # pragma: no cover logger.info("Creating database schema") - Base.metadata.create_all(get_engine()) + engine = get_engine() + async with engine.begin() as connection: + await connection.run_sync(Base.metadata.create_all) logger.info("Finished initializing database") diff --git a/amt/migrations/env.py b/amt/migrations/env.py index 5058d8f2..9f4b92b0 100644 --- a/amt/migrations/env.py +++ b/amt/migrations/env.py @@ -1,9 +1,11 @@ +import asyncio import os from logging.config import fileConfig from alembic import context from amt.models import * # noqa -from sqlalchemy import engine_from_config, pool +from sqlalchemy import Connection, pool +from sqlalchemy.ext.asyncio import async_engine_from_config from sqlalchemy.schema import MetaData from amt.models.base import Base @@ -18,17 +20,18 @@ def get_url() -> str: scheme = os.getenv("APP_DATABASE_SCHEME", "sqlite") + driver = os.getenv("APP_DATABASE_DRIVER", "aiosqlite") if scheme == "sqlite": file = os.getenv("APP_DATABASE_FILE", "database.sqlite3") - return f"{scheme}:///{file}" + return f"{scheme}+{driver}:///{file}" user = os.getenv("APP_DATABASE_USER", "amt") password = os.getenv("APP_DATABASE_PASSWORD", "") server = os.getenv("APP_DATABASE_SERVER", "db") port = os.getenv("APP_DATABASE_PORT", "5432") db = os.getenv("APP_DATABASE_DB", "amt") - return f"{scheme}://{user}:{password}@{server}:{port}/{db}" + return f"{scheme}+{driver}://{user}:{password}@{server}:{port}/{db}" def run_migrations_offline() -> None: @@ -52,7 +55,14 @@ def run_migrations_offline() -> None: context.run_migrations() -def run_migrations_online() -> None: +def do_run_migrations(connection: Connection) -> None: + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_migrations_online() -> None: """Run migrations in 'online' mode. In this scenario we need to create an Engine @@ -63,18 +73,13 @@ def run_migrations_online() -> None: if configuration is None: raise Exception("Failed to get configuration section") # noqa: TRY002 configuration["sqlalchemy.url"] = get_url() - connectable = engine_from_config(configuration, prefix="sqlalchemy.", poolclass=pool.NullPool) - - with connectable.connect() as connection: - context.configure( - connection=connection, target_metadata=target_metadata, compare_type=True, render_as_batch=True - ) + connectable = async_engine_from_config(configuration, prefix="sqlalchemy.", poolclass=pool.NullPool) - with context.begin_transaction(): - context.run_migrations() + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) if context.is_offline_mode(): run_migrations_offline() else: - run_migrations_online() + asyncio.run(run_migrations_online()) diff --git a/amt/models/project.py b/amt/models/project.py index 9fe88845..b56f0b4d 100644 --- a/amt/models/project.py +++ b/amt/models/project.py @@ -56,7 +56,7 @@ class Project(Base): id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String(255)) - lifecycle: Mapped[Lifecycles | None] = mapped_column(ENUM(Lifecycles), nullable=True) + lifecycle: Mapped[Lifecycles | None] = mapped_column(ENUM(Lifecycles, name="lifecycle"), nullable=True) system_card_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict) last_edited: Mapped[datetime] = mapped_column(server_default=func.now(), onupdate=func.now(), nullable=False) diff --git a/amt/repositories/deps.py b/amt/repositories/deps.py index bf68aeac..f7bf0db2 100644 --- a/amt/repositories/deps.py +++ b/amt/repositories/deps.py @@ -1,10 +1,15 @@ -from collections.abc import Generator +from collections.abc import AsyncGenerator -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from amt.core.db import get_engine -def get_session() -> Generator[Session, None, None]: - with Session(get_engine()) as session: - yield session +async def get_session() -> AsyncGenerator[AsyncSession, None]: + async_session_factory = async_sessionmaker( + get_engine(), + expire_on_commit=False, + class_=AsyncSession, + ) + async with async_session_factory() as async_session: + yield async_session diff --git a/amt/repositories/projects.py b/amt/repositories/projects.py index 0df74d43..cb433ed6 100644 --- a/amt/repositories/projects.py +++ b/amt/repositories/projects.py @@ -5,7 +5,7 @@ from fastapi import Depends from sqlalchemy import func, select from sqlalchemy.exc import NoResultFound, SQLAlchemyError -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy_utils import escape_like # pyright: ignore[reportMissingTypeStubs, reportUnknownVariableType] from amt.api.publication_category import PublicationCategories @@ -31,47 +31,49 @@ def sort_by_lifecycle_reversed(project: Project) -> int: class ProjectsRepository: - def __init__(self, session: Annotated[Session, Depends(get_session)]) -> None: + def __init__(self, session: Annotated[AsyncSession, Depends(get_session)]) -> None: self.session = session - def find_all(self) -> Sequence[Project]: - return self.session.execute(select(Project)).scalars().all() + async def find_all(self) -> Sequence[Project]: + result = await self.session.execute(select(Project)) + return result.scalars().all() - def delete(self, project: Project) -> None: + async def delete(self, project: Project) -> None: """ Deletes the given status in the repository. :param status: the status to store :return: the updated status after storing """ try: - self.session.delete(project) - self.session.commit() + await self.session.delete(project) + await self.session.commit() except Exception as e: logger.exception("Error deleting project") - self.session.rollback() + await self.session.rollback() raise AMTRepositoryError from e return None - def save(self, project: Project) -> Project: + async def save(self, project: Project) -> Project: try: self.session.add(project) - self.session.commit() - self.session.refresh(project) + await self.session.commit() + await self.session.refresh(project) except SQLAlchemyError as e: logger.exception("Error saving project") - self.session.rollback() + await self.session.rollback() raise AMTRepositoryError from e return project - def find_by_id(self, project_id: int) -> Project: + async def find_by_id(self, project_id: int) -> Project: try: statement = select(Project).where(Project.id == project_id) - return self.session.execute(statement).scalars().one() + result = await self.session.execute(statement) + return result.scalars().one() except NoResultFound as e: logger.exception("Project not found") raise AMTRepositoryError from e - def paginate( # noqa + async def paginate( # noqa self, skip: int, limit: int, search: str, filters: dict[str, str], sort: dict[str, str] ) -> list[Project]: try: @@ -102,7 +104,8 @@ def paginate( # noqa else: statement = statement.order_by(func.lower(Project.name)) statement = statement.offset(skip).limit(limit) - result = list(self.session.execute(statement).scalars()) + db_result = await self.session.execute(statement) + result = list(db_result.scalars()) # todo: the good way to do sorting is to use an enum field (or any int field) # in the database so we can sort on that if result and sort and "lifecycle" in sort: diff --git a/amt/repositories/tasks.py b/amt/repositories/tasks.py index 735094fb..1e3a649b 100644 --- a/amt/repositories/tasks.py +++ b/amt/repositories/tasks.py @@ -5,7 +5,7 @@ from fastapi import Depends from sqlalchemy import and_, select from sqlalchemy.exc import NoResultFound -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession from amt.core.exceptions import AMTRepositoryError from amt.models import Task @@ -19,26 +19,26 @@ class TasksRepository: The TasksRepository provides access to the repository layer. """ - def __init__(self, session: Annotated[Session, Depends(get_session)]) -> None: + def __init__(self, session: Annotated[AsyncSession, Depends(get_session)]) -> None: self.session = session - def find_all(self) -> Sequence[Task]: + async def find_all(self) -> Sequence[Task]: """ Returns all tasks in the repository. :return: all tasks in the repository """ - return self.session.execute(select(Task)).scalars().all() + return (await self.session.execute(select(Task))).scalars().all() - def find_by_status_id(self, status_id: int) -> Sequence[Task]: + async def find_by_status_id(self, status_id: int) -> Sequence[Task]: """ Returns all tasks in the repository for the given status_id. :param status_id: the status_id to filter on :return: a list of tasks in the repository for the given status_id """ statement = select(Task).where(Task.status_id == status_id).order_by(Task.sort_order) - return self.session.execute(statement).scalars().all() + return (await self.session.execute(statement)).scalars().all() - def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> Sequence[Task]: + async def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> Sequence[Task]: """ Returns all tasks in the repository for the given project_id. :param project_id: the project_id to filter on @@ -49,9 +49,9 @@ def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> S .where(and_(Task.status_id == status_id, Task.project_id == project_id)) .order_by(Task.sort_order) ) - return self.session.execute(statement).scalars().all() + return (await self.session.execute(statement)).scalars().all() - def save(self, task: Task) -> Task: + async def save(self, task: Task) -> Task: """ Stores the given task in the repository or throws a RepositoryException :param task: the task to store @@ -59,15 +59,15 @@ def save(self, task: Task) -> Task: """ try: self.session.add(task) - self.session.commit() - self.session.refresh(task) + await self.session.commit() + await self.session.refresh(task) except Exception as e: logger.exception("Could not store task") - self.session.rollback() + await self.session.rollback() raise AMTRepositoryError from e return task - def save_all(self, tasks: Sequence[Task]) -> None: + async def save_all(self, tasks: Sequence[Task]) -> None: """ Stores the given tasks in the repository or throws a RepositoryException :param tasks: the tasks to store @@ -75,28 +75,28 @@ def save_all(self, tasks: Sequence[Task]) -> None: """ try: self.session.add_all(tasks) - self.session.commit() + await self.session.commit() except Exception as e: logger.exception("Could not store all tasks") - self.session.rollback() + await self.session.rollback() raise AMTRepositoryError from e - def delete(self, task: Task) -> None: + async def delete(self, task: Task) -> None: """ Deletes the given task in the repository or throws a RepositoryException :param task: the task to store :return: the updated task after storing """ try: - self.session.delete(task) - self.session.commit() + await self.session.delete(task) + await self.session.commit() except Exception as e: logger.exception("Could not delete task") - self.session.rollback() + await self.session.rollback() raise AMTRepositoryError from e return None - def find_by_id(self, task_id: int) -> Task: + async def find_by_id(self, task_id: int) -> Task: """ Returns the task with the given id. :param task_id: the id of the task to find @@ -104,7 +104,7 @@ def find_by_id(self, task_id: int) -> Task: """ statement = select(Task).where(Task.id == task_id) try: - return self.session.execute(statement).scalars().one() + return (await self.session.execute(statement)).scalars().one() except NoResultFound as e: logger.exception("Task not found") raise AMTRepositoryError from e diff --git a/amt/server.py b/amt/server.py index a758dc9e..78d6e480 100644 --- a/amt/server.py +++ b/amt/server.py @@ -32,8 +32,8 @@ @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: mask = Mask(mask_keywords=["database_uri"]) - check_db() - init_db() + await check_db() + await init_db() logger.info(f"Starting {PROJECT_NAME} version {VERSION}") logger.info(f"Settings: {mask.secrets(get_settings().model_dump())}") yield diff --git a/amt/services/projects.py b/amt/services/projects.py index a8ec837f..42b1be95 100644 --- a/amt/services/projects.py +++ b/amt/services/projects.py @@ -26,10 +26,11 @@ def __init__( self.instrument_service = instrument_service self.task_service = task_service - def get(self, project_id: int) -> Project: - return self.repository.find_by_id(project_id) + async def get(self, project_id: int) -> Project: + project = await self.repository.find_by_id(project_id) + return project - def create(self, project_new: ProjectNew) -> Project: + async def create(self, project_new: ProjectNew) -> Project: instruments: list[InstrumentBase] = [ InstrumentBase(urn=instrument_urn) for instrument_urn in project_new.instruments ] @@ -54,20 +55,22 @@ def create(self, project_new: ProjectNew) -> Project: ) project = Project(name=project_new.name, lifecycle=project_new.lifecycle, system_card=system_card) - project = self.update(project) + project = await self.update(project) selected_instruments = self.instrument_service.fetch_instruments(project_new.instruments) # type: ignore for instrument in selected_instruments: - self.task_service.create_instrument_tasks(instrument.tasks, project) + await self.task_service.create_instrument_tasks(instrument.tasks, project) return project - def paginate( + async def paginate( self, skip: int, limit: int, search: str, filters: dict[str, str], sort: dict[str, str] ) -> list[Project]: - return self.repository.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort) + projects = await self.repository.paginate(skip=skip, limit=limit, search=search, filters=filters, sort=sort) + return projects - def update(self, project: Project) -> Project: + async def update(self, project: Project) -> Project: # TODO: Is this the right place to sync system cards: system_card and system_card_json? project.sync_system_card() - return self.repository.save(project) + project = await self.repository.save(project) + return project diff --git a/amt/services/tasks.py b/amt/services/tasks.py index aafae4d7..1b369026 100644 --- a/amt/services/tasks.py +++ b/amt/services/tasks.py @@ -25,21 +25,24 @@ def __init__( self.storage_writer = StorageFactory.init(storage_type="file", location="./output", filename="system_card.yaml") self.system_card = SystemCard() - def get_tasks(self, status_id: int) -> Sequence[Task]: - return self.repository.find_by_status_id(status_id) + async def get_tasks(self, status_id: int) -> Sequence[Task]: + task = await self.repository.find_by_status_id(status_id) + return task - def get_tasks_for_project(self, project_id: int, status_id: int) -> Sequence[Task]: - return self.repository.find_by_project_id_and_status_id(project_id, status_id) + async def get_tasks_for_project(self, project_id: int, status_id: int) -> Sequence[Task]: + tasks = await self.repository.find_by_project_id_and_status_id(project_id, status_id) + return tasks - def assign_task(self, task: Task, user: User) -> Task: + async def assign_task(self, task: Task, user: User) -> Task: task.user_id = user.id - return self.repository.save(task) + task = await self.repository.save(task) + return task - def create_instrument_tasks(self, tasks: Sequence[InstrumentTask], project: Project) -> None: + async def create_instrument_tasks(self, tasks: Sequence[InstrumentTask], project: Project) -> None: # TODO: (Christopher) At this moment a status has to be retrieved from the DB. In the future # we will have static statuses, so this will need to change. status = Status.TODO - self.repository.save_all( + await self.repository.save_all( [ # TODO: (Christopher) The ticket does not specify what to do when question type is not an # open questions, hence for now all titles will be set to task.question. @@ -50,7 +53,7 @@ def create_instrument_tasks(self, tasks: Sequence[InstrumentTask], project: Proj ] ) - def move_task( + async def move_task( self, task_id: int, status_id: int, previous_sibling_id: int | None = None, next_sibling_id: int | None = None ) -> Task: """ @@ -61,7 +64,7 @@ def move_task( :param next_sibling_id: the id of the next sibling of the task or None :return: the updated task """ - task = self.repository.find_by_id(task_id) + task = await self.repository.find_by_id(task_id) if status_id == Status.DONE: # TODO: This seems off, tasks should be written to the correct location in the system card. @@ -73,17 +76,18 @@ def move_task( # update order position of the card if previous_sibling_id and next_sibling_id: - previous_task = self.repository.find_by_id(previous_sibling_id) - next_task = self.repository.find_by_id(next_sibling_id) + previous_task = await self.repository.find_by_id(previous_sibling_id) + next_task = await self.repository.find_by_id(next_sibling_id) new_sort_order = previous_task.sort_order + ((next_task.sort_order - previous_task.sort_order) / 2) task.sort_order = new_sort_order elif previous_sibling_id and not next_sibling_id: - previous_task = self.repository.find_by_id(previous_sibling_id) + previous_task = await self.repository.find_by_id(previous_sibling_id) task.sort_order = previous_task.sort_order + 10 elif not previous_sibling_id and next_sibling_id: - next_task = self.repository.find_by_id(next_sibling_id) + next_task = await self.repository.find_by_id(next_sibling_id) task.sort_order = next_task.sort_order / 2 else: task.sort_order = 10 - return self.repository.save(task) + task = await self.repository.save(task) + return task diff --git a/amt/site/templates/macros/tasks.html.j2 b/amt/site/templates/macros/tasks.html.j2 index 327d34f9..077c56cd 100644 --- a/amt/site/templates/macros/tasks.html.j2 +++ b/amt/site/templates/macros/tasks.html.j2 @@ -39,9 +39,9 @@ data-id="{{ status.value }}" id="column-{{ status.value }}"> {% if project is defined %} - {% for task in tasks_service.get_tasks_for_project(project.id, status) %}{{ render_task_card_full(task) }}{% endfor %} + {% for task in tasks_by_status[status] %}{{ render_task_card_full(task) }}{% endfor %} {% else %} - {% for task in tasks_service.get_tasks(status) %}{{ render_task_card_full(task) }}{% endfor %} + {% for task in tasks_by_status[status] %}{{ render_task_card_full(task) }}{% endfor %} {% endif %} diff --git a/amt/site/templates/projects/tasks.html.j2 b/amt/site/templates/projects/tasks.html.j2 index e738f52a..12f95f0d 100644 --- a/amt/site/templates/projects/tasks.html.j2 +++ b/amt/site/templates/projects/tasks.html.j2 @@ -21,7 +21,7 @@
- {% for status in statuses %}{{ render.column(project, status, translations, tasks_service) }}{% endfor %} + {% for status in statuses %}{{ render.column(project, status, translations, tasks_by_status) }}{% endfor %}
diff --git a/compose.override.yml b/compose.override.yml new file mode 100644 index 00000000..e03e6d97 --- /dev/null +++ b/compose.override.yml @@ -0,0 +1,10 @@ +services: + amt-test: + build: + context: . + dockerfile: Dockerfile + target: test + image: ghcr.io/minbzk/amt-test:latest + env_file: + - path: prod.env + required: true diff --git a/poetry.lock b/poetry.lock index b05f5110..6b4f4c0e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,22 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.2 and should not be changed by hand. + +[[package]] +name = "aiosqlite" +version = "0.20.0" +description = "asyncio bridge to the standard sqlite3 module" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiosqlite-0.20.0-py3-none-any.whl", hash = "sha256:36a1deaca0cac40ebe32aac9977a6e2bbc7f5189f23f4a54d5908986729e5bd6"}, + {file = "aiosqlite-0.20.0.tar.gz", hash = "sha256:6d35c8c256637f4672f843c31021464090805bf925385ac39473fb16eaaca3d7"}, +] + +[package.dependencies] +typing_extensions = ">=4.0" + +[package.extras] +dev = ["attribution (==1.7.0)", "black (==24.2.0)", "coverage[toml] (==7.4.1)", "flake8 (==7.0.0)", "flake8-bugbear (==24.2.6)", "flit (==3.9.0)", "mypy (==1.8.0)", "ufmt (==2.3.0)", "usort (==1.0.8.post1)"] +docs = ["sphinx (==7.2.6)", "sphinx-mdinclude (==0.5.3)"] [[package]] name = "alembic" @@ -50,6 +68,69 @@ doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21.0b1)"] trio = ["trio (>=0.26.1)"] +[[package]] +name = "asyncpg" +version = "0.30.0" +description = "An asyncio PostgreSQL driver" +optional = false +python-versions = ">=3.8.0" +files = [ + {file = "asyncpg-0.30.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:bfb4dd5ae0699bad2b233672c8fc5ccbd9ad24b89afded02341786887e37927e"}, + {file = "asyncpg-0.30.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:dc1f62c792752a49f88b7e6f774c26077091b44caceb1983509edc18a2222ec0"}, + {file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3152fef2e265c9c24eec4ee3d22b4f4d2703d30614b0b6753e9ed4115c8a146f"}, + {file = "asyncpg-0.30.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c7255812ac85099a0e1ffb81b10dc477b9973345793776b128a23e60148dd1af"}, + {file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:578445f09f45d1ad7abddbff2a3c7f7c291738fdae0abffbeb737d3fc3ab8b75"}, + {file = "asyncpg-0.30.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:c42f6bb65a277ce4d93f3fba46b91a265631c8df7250592dd4f11f8b0152150f"}, + {file = "asyncpg-0.30.0-cp310-cp310-win32.whl", hash = "sha256:aa403147d3e07a267ada2ae34dfc9324e67ccc4cdca35261c8c22792ba2b10cf"}, + {file = "asyncpg-0.30.0-cp310-cp310-win_amd64.whl", hash = "sha256:fb622c94db4e13137c4c7f98834185049cc50ee01d8f657ef898b6407c7b9c50"}, + {file = "asyncpg-0.30.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5e0511ad3dec5f6b4f7a9e063591d407eee66b88c14e2ea636f187da1dcfff6a"}, + {file = "asyncpg-0.30.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:915aeb9f79316b43c3207363af12d0e6fd10776641a7de8a01212afd95bdf0ed"}, + {file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1c198a00cce9506fcd0bf219a799f38ac7a237745e1d27f0e1f66d3707c84a5a"}, + {file = "asyncpg-0.30.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3326e6d7381799e9735ca2ec9fd7be4d5fef5dcbc3cb555d8a463d8460607956"}, + {file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:51da377487e249e35bd0859661f6ee2b81db11ad1f4fc036194bc9cb2ead5056"}, + {file = "asyncpg-0.30.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:bc6d84136f9c4d24d358f3b02be4b6ba358abd09f80737d1ac7c444f36108454"}, + {file = "asyncpg-0.30.0-cp311-cp311-win32.whl", hash = "sha256:574156480df14f64c2d76450a3f3aaaf26105869cad3865041156b38459e935d"}, + {file = "asyncpg-0.30.0-cp311-cp311-win_amd64.whl", hash = "sha256:3356637f0bd830407b5597317b3cb3571387ae52ddc3bca6233682be88bbbc1f"}, + {file = "asyncpg-0.30.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c902a60b52e506d38d7e80e0dd5399f657220f24635fee368117b8b5fce1142e"}, + {file = "asyncpg-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aca1548e43bbb9f0f627a04666fedaca23db0a31a84136ad1f868cb15deb6e3a"}, + {file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c2a2ef565400234a633da0eafdce27e843836256d40705d83ab7ec42074efb3"}, + {file = "asyncpg-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1292b84ee06ac8a2ad8e51c7475aa309245874b61333d97411aab835c4a2f737"}, + {file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0f5712350388d0cd0615caec629ad53c81e506b1abaaf8d14c93f54b35e3595a"}, + {file = "asyncpg-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:db9891e2d76e6f425746c5d2da01921e9a16b5a71a1c905b13f30e12a257c4af"}, + {file = "asyncpg-0.30.0-cp312-cp312-win32.whl", hash = "sha256:68d71a1be3d83d0570049cd1654a9bdfe506e794ecc98ad0873304a9f35e411e"}, + {file = "asyncpg-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:9a0292c6af5c500523949155ec17b7fe01a00ace33b68a476d6b5059f9630305"}, + {file = "asyncpg-0.30.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:05b185ebb8083c8568ea8a40e896d5f7af4b8554b64d7719c0eaa1eb5a5c3a70"}, + {file = "asyncpg-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c47806b1a8cbb0a0db896f4cd34d89942effe353a5035c62734ab13b9f938da3"}, + {file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9b6fde867a74e8c76c71e2f64f80c64c0f3163e687f1763cfaf21633ec24ec33"}, + {file = "asyncpg-0.30.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46973045b567972128a27d40001124fbc821c87a6cade040cfcd4fa8a30bcdc4"}, + {file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:9110df111cabc2ed81aad2f35394a00cadf4f2e0635603db6ebbd0fc896f46a4"}, + {file = "asyncpg-0.30.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04ff0785ae7eed6cc138e73fc67b8e51d54ee7a3ce9b63666ce55a0bf095f7ba"}, + {file = "asyncpg-0.30.0-cp313-cp313-win32.whl", hash = "sha256:ae374585f51c2b444510cdf3595b97ece4f233fde739aa14b50e0d64e8a7a590"}, + {file = "asyncpg-0.30.0-cp313-cp313-win_amd64.whl", hash = "sha256:f59b430b8e27557c3fb9869222559f7417ced18688375825f8f12302c34e915e"}, + {file = "asyncpg-0.30.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:29ff1fc8b5bf724273782ff8b4f57b0f8220a1b2324184846b39d1ab4122031d"}, + {file = "asyncpg-0.30.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64e899bce0600871b55368b8483e5e3e7f1860c9482e7f12e0a771e747988168"}, + {file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b290f4726a887f75dcd1b3006f484252db37602313f806e9ffc4e5996cfe5cb"}, + {file = "asyncpg-0.30.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f86b0e2cd3f1249d6fe6fd6cfe0cd4538ba994e2d8249c0491925629b9104d0f"}, + {file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:393af4e3214c8fa4c7b86da6364384c0d1b3298d45803375572f415b6f673f38"}, + {file = "asyncpg-0.30.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:fd4406d09208d5b4a14db9a9dbb311b6d7aeeab57bded7ed2f8ea41aeef39b34"}, + {file = "asyncpg-0.30.0-cp38-cp38-win32.whl", hash = "sha256:0b448f0150e1c3b96cb0438a0d0aa4871f1472e58de14a3ec320dbb2798fb0d4"}, + {file = "asyncpg-0.30.0-cp38-cp38-win_amd64.whl", hash = "sha256:f23b836dd90bea21104f69547923a02b167d999ce053f3d502081acea2fba15b"}, + {file = "asyncpg-0.30.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:6f4e83f067b35ab5e6371f8a4c93296e0439857b4569850b178a01385e82e9ad"}, + {file = "asyncpg-0.30.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5df69d55add4efcd25ea2a3b02025b669a285b767bfbf06e356d68dbce4234ff"}, + {file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a3479a0d9a852c7c84e822c073622baca862d1217b10a02dd57ee4a7a081f708"}, + {file = "asyncpg-0.30.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26683d3b9a62836fad771a18ecf4659a30f348a561279d6227dab96182f46144"}, + {file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:1b982daf2441a0ed314bd10817f1606f1c28b1136abd9e4f11335358c2c631cb"}, + {file = "asyncpg-0.30.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:1c06a3a50d014b303e5f6fc1e5f95eb28d2cee89cf58384b700da621e5d5e547"}, + {file = "asyncpg-0.30.0-cp39-cp39-win32.whl", hash = "sha256:1b11a555a198b08f5c4baa8f8231c74a366d190755aa4f99aacec5970afe929a"}, + {file = "asyncpg-0.30.0-cp39-cp39-win_amd64.whl", hash = "sha256:8b684a3c858a83cd876f05958823b68e8d14ec01bb0c0d14a6704c5bf9711773"}, + {file = "asyncpg-0.30.0.tar.gz", hash = "sha256:c551e9928ab6707602f44811817f82ba3c446e018bfe1d3abecc8ba5f3eac851"}, +] + +[package.extras] +docs = ["Sphinx (>=8.1.3,<8.2.0)", "sphinx-rtd-theme (>=1.2.2)"] +gssauth = ["gssapi", "sspilib"] +test = ["distro (>=1.9.0,<1.10.0)", "flake8 (>=6.1,<7.0)", "flake8-pyi (>=24.1.0,<24.2.0)", "gssapi", "k5test", "mypy (>=1.8.0,<1.9.0)", "sspilib", "uvloop (>=0.15.3)"] + [[package]] name = "authlib" version = "1.3.2" @@ -1027,6 +1108,17 @@ files = [ {file = "markupsafe-3.0.2.tar.gz", hash = "sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0"}, ] +[[package]] +name = "nest-asyncio" +version = "1.6.0" +description = "Patch asyncio to allow nested event loops" +optional = false +python-versions = ">=3.5" +files = [ + {file = "nest_asyncio-1.6.0-py3-none-any.whl", hash = "sha256:87af6efd6b5e897c81050477ef65c62e2b2f35d51703cae01aff2905b1852e1c"}, + {file = "nest_asyncio-1.6.0.tar.gz", hash = "sha256:6f172d5449aca15afd6c646851f4e31e02c598d553a667e38cafa997cfec55fe"}, +] + [[package]] name = "nodeenv" version = "1.9.1" @@ -1428,6 +1520,24 @@ pluggy = ">=1.5,<2" [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-asyncio" +version = "0.24.0" +description = "Pytest support for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, + {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, +] + +[package.dependencies] +pytest = ">=8.2,<9" + +[package.extras] +docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] +testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] + [[package]] name = "pytest-base-url" version = "2.1.0" @@ -1909,7 +2019,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version < \"3.13\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} +greenlet = {version = "!=0.4.17", optional = true, markers = "python_version < \"3.13\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\") or extra == \"asyncio\""} typing-extensions = ">=4.6.0" [package.extras] @@ -1967,13 +2077,13 @@ url = ["furl (>=0.4.1)"] [[package]] name = "starlette" -version = "0.40.0" +version = "0.41.2" description = "The little ASGI library that shines." optional = false python-versions = ">=3.8" files = [ - {file = "starlette-0.40.0-py3-none-any.whl", hash = "sha256:c494a22fae73805376ea6bf88439783ecfba9aac88a43911b48c653437e784c4"}, - {file = "starlette-0.40.0.tar.gz", hash = "sha256:1a3139688fb298ce5e2d661d37046a66ad996ce94be4d4983be019a23a04ea35"}, + {file = "starlette-0.41.2-py3-none-any.whl", hash = "sha256:fbc189474b4731cf30fcef52f18a8d070e3f3b46c6a04c97579e85e6ffca942d"}, + {file = "starlette-0.41.2.tar.gz", hash = "sha256:9834fd799d1a87fd346deb76158668cfa0b0d56f85caefe8268e2d97c3468b62"}, ] [package.dependencies] @@ -2352,4 +2462,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.12" -content-hash = "06db522ccdbb19c0c0061b5533ca30e03ec51755ba3a2888b80f77a5c28c7e45" +content-hash = "a342f51a83d9cd19dd54d039ad978433c49edb17b98b03f34891f27933228b4e" diff --git a/prod.env b/prod.env index 001d4690..47f619a3 100644 --- a/prod.env +++ b/prod.env @@ -5,6 +5,7 @@ ENVIRONMENT=production BACKEND_CORS_ORIGINS="http://localhost,https://localhost,http://127.0.0.1,https://127.0.0.1" SECRET_KEY=changethis APP_DATABASE_SCHEME="postgresql" +APP_DATABASE_DRIVER="asyncpg" APP_DATABASE_USER=amt APP_DATABASE_DB=amt APP_DATABASE_PASSWORD=changethis diff --git a/pyproject.toml b/pyproject.toml index bd617d90..b52952de 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,15 +36,19 @@ pyyaml-include = "^2.1" click = "^8.1.7" python-ulid = "^3.0.0" fastapi-csrf-protect = "^0.3.4" -sqlalchemy = "^2.0.36" +sqlalchemy = {extras = ["asyncio"], version = "^2.0.36"} sqlalchemy-utils = "^0.41.2" liccheck = "^0.9.2" authlib = "^1.3.2" -pytest-mock = "^3.14.0" +aiosqlite = "^0.20.0" +asyncpg = "^0.30.0" [tool.poetry.group.test.dependencies] pytest = "^8.3.3" +pytest-asyncio = "^0.24.0" +nest-asyncio = "^1.6.0" +pytest-mock = "^3.14.0" coverage = "^7.6.4" playwright = "^1.47.0" pytest-playwright = "^0.5.2" @@ -137,6 +141,7 @@ markers = [ "enable_auth: marks tests that require authentication" ] + [tool.liccheck] level = "PARANOID" dependencies = true diff --git a/tests/api/routes/test_health.py b/tests/api/routes/test_health.py index 49edab92..626c1866 100644 --- a/tests/api/routes/test_health.py +++ b/tests/api/routes/test_health.py @@ -1,8 +1,10 @@ -from fastapi.testclient import TestClient +import pytest +from httpx import AsyncClient -def test_health_ready(client: TestClient) -> None: - response = client.get( +@pytest.mark.asyncio +async def test_health_ready(client: AsyncClient) -> None: + response = await client.get( "/health/ready", ) assert response.status_code == 200 @@ -10,8 +12,9 @@ def test_health_ready(client: TestClient) -> None: assert response.json() == {"status": "ok"} -def test_health_live(client: TestClient) -> None: - response = client.get( +@pytest.mark.asyncio +async def test_health_live(client: AsyncClient) -> None: + response = await client.get( "/health/live", ) assert response.status_code == 200 diff --git a/tests/api/routes/test_pages.py b/tests/api/routes/test_pages.py index 59a0f9d7..cb5acfc2 100644 --- a/tests/api/routes/test_pages.py +++ b/tests/api/routes/test_pages.py @@ -1,11 +1,13 @@ -from fastapi.testclient import TestClient +import pytest +from httpx import AsyncClient from tests.database_test_utils import DatabaseTestUtils -def test_get_main_page(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_main_page(client: AsyncClient, db: DatabaseTestUtils) -> None: # when - response = client.get("/pages/") + response = await client.get("/pages/") # then assert response.status_code == 200 diff --git a/tests/api/routes/test_project.py b/tests/api/routes/test_project.py index 463ba775..8e662279 100644 --- a/tests/api/routes/test_project.py +++ b/tests/api/routes/test_project.py @@ -3,16 +3,17 @@ import pytest from amt.api.routes.project import set_path from amt.models import Project -from fastapi.testclient import TestClient +from httpx import AsyncClient from pytest_mock import MockFixture from tests.constants import default_project, default_task from tests.database_test_utils import DatabaseTestUtils -def test_get_unknown_project(client: TestClient) -> None: +@pytest.mark.asyncio +async def test_get_unknown_project(client: AsyncClient) -> None: # when - response = client.get("/algorithm-system/1") + response = await client.get("/algorithm-system/1") # then assert response.status_code == 404 @@ -20,12 +21,13 @@ def test_get_unknown_project(client: TestClient) -> None: assert b"The requested page or resource could not be found." in response.content -def test_get_project_tasks(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_project_tasks(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)]) + await db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)]) # when - response = client.get("/algorithm-system/1/details/tasks") + response = await client.get("/algorithm-system/1/details/tasks") # then assert response.status_code == 200 @@ -36,12 +38,13 @@ def test_get_project_tasks(client: TestClient, db: DatabaseTestUtils) -> None: # TODO: Test are now have hard coded URL paths because the system card # is fixed for now. Tests need to be refactored and made proper once # the actual stored system card in a project is being rendered. -def test_get_system_card(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_system_card(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1")]) + await db.given([default_project("testproject1")]) # when - response = client.get("/algorithm-system/1/details/system_card") + response = await client.get("/algorithm-system/1/details/system_card") # then assert response.status_code == 200 @@ -52,9 +55,10 @@ def test_get_system_card(client: TestClient, db: DatabaseTestUtils) -> None: # TODO: Test are now have hard coded URL paths because the system card # is fixed for now. Tests need to be refactored and made proper once # the actual stored system card in a project is being rendered. -def test_get_system_card_unknown_project(client: TestClient) -> None: +@pytest.mark.asyncio +async def test_get_system_card_unknown_project(client: AsyncClient) -> None: # when - response = client.get("/algorithm-system/1/details/system_card") + response = await client.get("/algorithm-system/1/details/system_card") # then assert response.status_code == 404 @@ -65,12 +69,13 @@ def test_get_system_card_unknown_project(client: TestClient) -> None: # TODO: Test are now have hard coded URL paths because the system card # is fixed for now. Tests need to be refactored and made proper once # the actual stored system card in a project is being rendered. -def test_get_assessment_card(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_assessment_card(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1")]) + await db.given([default_project("testproject1")]) # when - response = client.get("/algorithm-system/1/details/system_card/assessments/iama") + response = await client.get("/algorithm-system/1/details/system_card/assessments/iama") # then assert response.status_code == 200 @@ -81,9 +86,10 @@ def test_get_assessment_card(client: TestClient, db: DatabaseTestUtils) -> None: # TODO: Test are now have hard coded URL paths because the system card # is fixed for now. Tests need to be refactored and made proper once # the actual stored system card in a project is being rendered. -def test_get_assessment_card_unknown_project(client: TestClient) -> None: +@pytest.mark.asyncio +async def test_get_assessment_card_unknown_project(client: AsyncClient) -> None: # when - response = client.get("/algorithm-system/1/details/system_card/assessments/iama") + response = await client.get("/algorithm-system/1/details/system_card/assessments/iama") # then assert response.status_code == 404 @@ -94,12 +100,13 @@ def test_get_assessment_card_unknown_project(client: TestClient) -> None: # TODO: Test are now have hard coded URL paths because the system card # is fixed for now. Tests need to be refactored and made proper once # the actual stored system card in a project is being rendered. -def test_get_assessment_card_unknown_assessment(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_assessment_card_unknown_assessment(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1")]) + await db.given([default_project("testproject1")]) # when - response = client.get("/algorithm-system/1/details/system_card/assessments/nonexistent") + response = await client.get("/algorithm-system/1/details/system_card/assessments/nonexistent") # then assert response.status_code == 404 @@ -110,12 +117,13 @@ def test_get_assessment_card_unknown_assessment(client: TestClient, db: Database # TODO: Test are now have hard coded URL paths because the system card # is fixed for now. Tests need to be refactored and made proper once # the actual stored system card in a project is being rendered. -def test_get_model_card(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_model_card(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1")]) + await db.given([default_project("testproject1")]) # when - response = client.get("/algorithm-system/1/details/system_card/models/logres_iris") + response = await client.get("/algorithm-system/1/details/system_card/models/logres_iris") # then assert response.status_code == 200 @@ -126,9 +134,10 @@ def test_get_model_card(client: TestClient, db: DatabaseTestUtils) -> None: # TODO: Test are now have hard coded URL paths because the system card # is fixed for now. Tests need to be refactored and made proper once # the actual stored system card in a project is being rendered. -def test_get_model_card_unknown_project(client: TestClient) -> None: +@pytest.mark.asyncio +async def test_get_model_card_unknown_project(client: AsyncClient) -> None: # when - response = client.get("/algorithm-system/1/details/system_card/models/logres_iris") + response = await client.get("/algorithm-system/1/details/system_card/models/logres_iris") # then assert response.status_code == 404 @@ -139,12 +148,13 @@ def test_get_model_card_unknown_project(client: TestClient) -> None: # TODO: Test are now have hard coded URL paths because the system card # is fixed for now. Tests need to be refactored and made proper once # the actual stored system card in a project is being rendered. -def test_get_assessment_card_unknown_model_card(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_assessment_card_unknown_model_card(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1")]) + await db.given([default_project("testproject1")]) # when - response = client.get("/algorithm-system/1/details/system_card/models/nonexistent") + response = await client.get("/algorithm-system/1/details/system_card/models/nonexistent") # then assert response.status_code == 404 @@ -152,12 +162,13 @@ def test_get_assessment_card_unknown_model_card(client: TestClient, db: Database assert b"The requested page or resource could not be found." in response.content -def test_get_project_details(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_project_details(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)]) + await db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)]) # when - response = client.get("/algorithm-system/1/details") + response = await client.get("/algorithm-system/1/details") # then assert response.status_code == 200 @@ -165,12 +176,13 @@ def test_get_project_details(client: TestClient, db: DatabaseTestUtils) -> None: assert b"Details" in response.content -def test_get_system_card_requirements(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_system_card_requirements(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)]) + await db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)]) # when - response = client.get("/algorithm-system/1/details/system_card/requirements") + response = await client.get("/algorithm-system/1/details/system_card/requirements") # then assert response.status_code == 200 @@ -178,12 +190,13 @@ def test_get_system_card_requirements(client: TestClient, db: DatabaseTestUtils) assert b"0" in response.content -def test_get_system_card_data_page(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_system_card_data_page(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)]) + await db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)]) # when - response = client.get("/algorithm-system/1/details/system_card/data") + response = await client.get("/algorithm-system/1/details/system_card/data") # then assert response.status_code == 200 @@ -191,12 +204,13 @@ def test_get_system_card_data_page(client: TestClient, db: DatabaseTestUtils) -> assert b"To be implemented" in response.content -def test_get_system_card_instruments(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_system_card_instruments(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)]) + await db.given([default_project("testproject1"), default_task(project_id=1, status_id=1)]) # when - response = client.get("/algorithm-system/1/details/system_card/instruments") + response = await client.get("/algorithm-system/1/details/system_card/instruments") # then assert response.status_code == 200 @@ -204,12 +218,13 @@ def test_get_system_card_instruments(client: TestClient, db: DatabaseTestUtils) assert b"To be implemented" in response.content -def test_get_project_edit(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_project_edit(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1")]) + await db.given([default_project("testproject1")]) # when - response = client.get("/algorithm-system/1/edit/system_card/lifecycle") + response = await client.get("/algorithm-system/1/edit/system_card/lifecycle") # then assert response.status_code == 200 @@ -218,12 +233,13 @@ def test_get_project_edit(client: TestClient, db: DatabaseTestUtils) -> None: assert b"lifecycle" in response.content -def test_get_project_cancel(client: TestClient, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_project_cancel(client: AsyncClient, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1")]) + await db.given([default_project("testproject1")]) # when - response = client.get("/algorithm-system/1/cancel/system_card/lifecycle") + response = await client.get("/algorithm-system/1/cancel/system_card/lifecycle") # then assert response.status_code == 200 @@ -232,14 +248,17 @@ def test_get_project_cancel(client: TestClient, db: DatabaseTestUtils) -> None: assert b"lifecycle" in response.content -def test_get_project_update(client: TestClient, mocker: MockFixture, db: DatabaseTestUtils) -> None: +@pytest.mark.asyncio +async def test_get_project_update(client: AsyncClient, mocker: MockFixture, db: DatabaseTestUtils) -> None: # given - db.given([default_project("testproject1")]) + await db.given([default_project("testproject1")]) client.cookies["fastapi-csrf-token"] = "1" mocker.patch("fastapi_csrf_protect.CsrfProtect.validate_csrf", new_callable=mocker.AsyncMock) # when - response = client.put("/algorithm-system/1/update/name", json={"value": "Test Name"}, headers={"X-CSRF-Token": "1"}) + response = await client.put( + "/algorithm-system/1/update/name", json={"value": "Test Name"}, headers={"X-CSRF-Token": "1"} + ) # then assert response.status_code == 200 @@ -248,7 +267,7 @@ def test_get_project_update(client: TestClient, mocker: MockFixture, db: Databas assert b"Test Name" in response.content # Verify that the project was updated in the database - updated_projects = db.get(Project, "id", 1) + updated_projects = await db.get(Project, "id", 1) assert len(updated_projects) == 1 updated_project = updated_projects[0] assert updated_project.name == "Test Name" # type: ignore diff --git a/tests/api/routes/test_projects.py b/tests/api/routes/test_projects.py index cdd09045..82e884e4 100644 --- a/tests/api/routes/test_projects.py +++ b/tests/api/routes/test_projects.py @@ -1,5 +1,6 @@ from typing import cast +import pytest from amt.api.routes.projects import get_localized_value from amt.models import Project from amt.models.base import Base @@ -8,35 +9,39 @@ from amt.schema.system_card import SystemCard from amt.services.task_registry import get_requirements_and_measures from fastapi.requests import Request -from fastapi.testclient import TestClient +from httpx import AsyncClient from pytest_mock import MockFixture from tests.constants import default_instrument from tests.database_test_utils import DatabaseTestUtils -def test_projects_get_root(client: TestClient) -> None: - response = client.get("/algorithm-systems/") +@pytest.mark.asyncio +async def test_projects_get_root(client: AsyncClient) -> None: + response = await client.get("/algorithm-systems/") assert response.status_code == 200 assert b'
' in response.content -def test_projects_get_root_missing_slash(client: TestClient) -> None: - response = client.get("/algorithm-systems") +@pytest.mark.asyncio +async def test_projects_get_root_missing_slash(client: AsyncClient) -> None: + response = await client.get("/algorithm-systems", follow_redirects=True) assert response.status_code == 200 assert b'
' in response.content -def test_projects_get_root_htmx(client: TestClient) -> None: - response = client.get("/algorithm-systems/", headers={"HX-Request": "true"}) +@pytest.mark.asyncio +async def test_projects_get_root_htmx(client: AsyncClient) -> None: + response = await client.get("/algorithm-systems/", headers={"HX-Request": "true"}) assert response.status_code == 200 assert b'' not in response.content -def test_get_new_projects(client: TestClient, mocker: MockFixture) -> None: +@pytest.mark.asyncio +async def test_get_new_projects(client: AsyncClient, mocker: MockFixture) -> None: # given mocker.patch( "amt.services.instruments.InstrumentsService.fetch_instruments", @@ -44,7 +49,7 @@ def test_get_new_projects(client: TestClient, mocker: MockFixture) -> None: ) # when - response = client.get("/algorithm-systems/new") + response = await client.get("/algorithm-systems/new") assert response.status_code == 200 assert response.headers["content-type"] == "text/html; charset=utf-8" content = " ".join(response.content.decode().split()) @@ -59,13 +64,14 @@ def test_get_new_projects(client: TestClient, mocker: MockFixture) -> None: ) -def test_post_new_projects_bad_request(client: TestClient, mocker: MockFixture) -> None: +@pytest.mark.asyncio +async def test_post_new_projects_bad_request(client: AsyncClient, mocker: MockFixture) -> None: # given mocker.patch("fastapi_csrf_protect.CsrfProtect.validate_csrf", new_callable=mocker.AsyncMock) # when client.cookies["fastapi-csrf-token"] = "1" - response = client.post("/algorithm-systems/new", json={}, headers={"X-CSRF-Token": "1"}) + response = await client.post("/algorithm-systems/new", json={}, headers={"X-CSRF-Token": "1"}) # then assert response.status_code == 400 @@ -73,7 +79,8 @@ def test_post_new_projects_bad_request(client: TestClient, mocker: MockFixture) assert b"Field required" in response.content -def test_post_new_projects(client: TestClient, mocker: MockFixture) -> None: +@pytest.mark.asyncio +async def test_post_new_projects(client: AsyncClient, mocker: MockFixture) -> None: client.cookies["fastapi-csrf-token"] = "1" new_project = ProjectNew( name="default project", @@ -93,7 +100,7 @@ def test_post_new_projects(client: TestClient, mocker: MockFixture) -> None: ) # when - response = client.post("/algorithm-systems/new", json=new_project.model_dump(), headers={"X-CSRF-Token": "1"}) + response = await client.post("/algorithm-systems/new", json=new_project.model_dump(), headers={"X-CSRF-Token": "1"}) # then assert response.status_code == 200 @@ -101,8 +108,9 @@ def test_post_new_projects(client: TestClient, mocker: MockFixture) -> None: assert response.headers["HX-Redirect"] == "/algorithm-system/1/details/tasks" -def test_post_new_projects_write_system_card( - client: TestClient, +@pytest.mark.asyncio +async def test_post_new_projects_write_system_card( + client: AsyncClient, mocker: MockFixture, db: DatabaseTestUtils, ) -> None: @@ -148,10 +156,10 @@ def test_post_new_projects_write_system_card( ) # when - client.post("/algorithm-systems/new", json=project_new.model_dump(), headers={"X-CSRF-Token": "1"}) + await client.post("/algorithm-systems/new", json=project_new.model_dump(), headers={"X-CSRF-Token": "1"}) # then - base_projects: list[Base] = db.get(Project, "name", name) + base_projects: list[Base] = await db.get(Project, "name", name) projects: list[Project] = cast(list[Project], base_projects) assert any(project.system_card == system_card for project in projects if project.system_card is not None) diff --git a/tests/api/routes/test_root.py b/tests/api/routes/test_root.py index 99d7a9c2..bd41694a 100644 --- a/tests/api/routes/test_root.py +++ b/tests/api/routes/test_root.py @@ -1,8 +1,10 @@ -from fastapi.testclient import TestClient +import pytest +from httpx import AsyncClient -def test_get_root(client: TestClient) -> None: - response = client.get( +@pytest.mark.asyncio +async def test_get_root(client: AsyncClient) -> None: + response = await client.get( "/", follow_redirects=False, ) diff --git a/tests/api/routes/test_static.py b/tests/api/routes/test_static.py index a73c255c..9f6d392f 100644 --- a/tests/api/routes/test_static.py +++ b/tests/api/routes/test_static.py @@ -1,10 +1,12 @@ from pathlib import Path -from fastapi.testclient import TestClient +import pytest +from httpx import AsyncClient -def test_static_css(client: TestClient) -> None: +@pytest.mark.asyncio +async def test_static_css(client: AsyncClient) -> None: files = Path("amt/site/static/dist").glob("*.js") for filename in files: - response = client.get(f"/static/dist/{filename.name}") + response = await client.get(f"/static/dist/{filename.name}") assert response.status_code == 200 diff --git a/tests/api/test_http_headers.py b/tests/api/test_http_headers.py index b18d4668..21b797e0 100644 --- a/tests/api/test_http_headers.py +++ b/tests/api/test_http_headers.py @@ -1,8 +1,10 @@ -from fastapi.testclient import TestClient +import pytest +from httpx import AsyncClient -def test_sts_header(client: TestClient) -> None: - response = client.get( +@pytest.mark.asyncio +async def test_sts_header(client: AsyncClient) -> None: + response = await client.get( "/", ) assert response.status_code == 200 diff --git a/tests/conftest.py b/tests/conftest.py index 25b57bde..ae1c1631 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,26 +1,33 @@ +import asyncio import logging import os import re -from collections.abc import Callable, Generator +from collections.abc import AsyncIterator, Callable, Generator from multiprocessing import Process from pathlib import Path from typing import Any import httpx +import nest_asyncio # type: ignore [(reportMissingTypeStubs)] import pytest +import pytest_asyncio import uvicorn from amt.models.base import Base from amt.server import create_app -from fastapi.testclient import TestClient +from httpx import ASGITransport, AsyncClient from playwright.sync_api import Browser -from sqlalchemy import create_engine, text -from sqlalchemy.orm import Session +from sqlalchemy import text +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.ext.asyncio.session import async_sessionmaker from tests.database_e2e_setup import setup_database_e2e from tests.database_test_utils import DatabaseTestUtils logger = logging.getLogger(__name__) +# Dubious choice here: allow nested event loops. +nest_asyncio.apply() # type: ignore [(reportUnknownMemberType)] + def run_server_uvicorn(database_file: Path, host: str = "127.0.0.1", port: int = 3462) -> None: os.environ["APP_DATABASE_FILE"] = "/" + str(database_file) @@ -32,32 +39,40 @@ def run_server_uvicorn(database_file: Path, host: str = "127.0.0.1", port: int = @pytest.fixture(scope="session") -def setup_db_and_server( +async def setup_db_and_server( tmp_path_factory: pytest.TempPathFactory, request: pytest.FixtureRequest -) -> Generator[Any, None, None]: +) -> AsyncIterator[str]: test_dir = tmp_path_factory.mktemp("e2e_database") database_file = test_dir / "test.sqlite3" if request.config.getoption("--db") == "postgresql": - engine = create_engine(get_db_uri()) + engine = create_async_engine(get_db_uri()) else: - engine = create_engine(f"sqlite:///{database_file}", connect_args={"check_same_thread": False}) + engine = create_async_engine(f"sqlite+aiosqlite:///{database_file}", connect_args={"check_same_thread": False}) + metadata = Base.metadata - metadata.create_all(engine) - with Session(engine, expire_on_commit=False) as session: - setup_database_e2e(session) + async with engine.begin() as connection: + await connection.run_sync(metadata.create_all) + + async_session = async_sessionmaker(engine, expire_on_commit=False) + + async with async_session() as session: + await setup_database_e2e(session) process = Process(target=run_server_uvicorn, args=(database_file,)) process.start() yield "http://127.0.0.1:3462" process.terminate() - metadata.drop_all(engine) + + async with engine.begin() as conn: + await conn.run_sync(metadata.drop_all) @pytest.fixture(autouse=True) def disable_auth(request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch) -> None: # noqa: PT004 - marker = request.node.get_closest_marker("enable_auth") # type: ignore + marker = request.node.get_closest_marker("enable_auth") # type: ignore [(reportUnknownMemberType)] + if not marker: monkeypatch.setenv("DISABLE_AUTH", "true") return @@ -94,20 +109,20 @@ def pytest_collection_modifyitems(session: pytest.Session, config: pytest.Config items[:] = tests + e2e_tests -@pytest.fixture -def client(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch) -> Generator[TestClient, None, None]: +@pytest_asyncio.fixture # type: ignore [(reportUnknownMemberType)] +async def client(db: DatabaseTestUtils, monkeypatch: pytest.MonkeyPatch) -> AsyncIterator[AsyncClient]: # overwrite db url monkeypatch.setenv("APP_DATABASE_FILE", "/" + str(db.get_database_file())) from amt.repositories.deps import get_session app = create_app() - with TestClient(app, raise_server_exceptions=True) as c: + async with AsyncClient( + transport=ASGITransport(app=app, raise_app_exceptions=True), base_url="http://testserver/" + ) as ac: app.dependency_overrides[get_session] = db.get_session - - c.timeout = 5 - - yield c + ac.timeout = 5 + yield ac @pytest.fixture(scope="session") @@ -117,43 +132,54 @@ def browser_context_args(browser_context_args: dict[str, Any]) -> dict[str, Any] @pytest.fixture(scope="session") def browser( - launch_browser: Callable[[], Browser], setup_db_and_server: Generator[str, None, None] + launch_browser: Callable[[], Browser], + setup_db_and_server: AsyncIterator[str], ) -> Generator[Browser, None, None]: transport = httpx.HTTPTransport(retries=5) + + loop = asyncio.get_event_loop() + url = loop.run_until_complete(setup_db_and_server.__anext__()) + with httpx.Client(transport=transport, verify=False, timeout=0.7) as client: # noqa: S501 - client.get(f"{setup_db_and_server}/") + client.get(f"{url}/") browser = launch_browser() yield browser browser.close() + # Clean up by consuming the rest of the generator. + try: # noqa + loop.run_until_complete(setup_db_and_server.__anext__()) + except StopAsyncIteration: + pass def get_db_uri() -> str: user = os.getenv("APP_DATABASE_USER", "amt") password = os.getenv("APP_DATABASE_PASSWORD", "changethis") server = os.getenv("APP_DATABASE_SERVER", "db") + driver = os.getenv("APP_DATABASE_DRIVER", "asyncpg") port = os.getenv("APP_DATABASE_PORT", "5432") db = os.getenv("APP_DATABASE_DB", "amt") - return f"postgresql://{user}:{password}@{server}:{port}/{db}" + return f"postgresql+{driver}://{user}:{password}@{server}:{port}/{db}" -def create_db(new_db: str) -> str: +async def create_db(new_db: str) -> str: url = get_db_uri() - engine = create_engine(url, isolation_level="AUTOCOMMIT") + engine = create_async_engine(url, isolation_level="AUTOCOMMIT") if new_db == os.getenv("APP_DATABASE_USER", "amt"): return url logger.info(f"Creating database {new_db}") user = os.getenv("APP_DATABASE_USER", "amt") - with Session(engine) as session: - session.execute(text(f"DROP DATABASE IF EXISTS {new_db};")) # type: ignore - session.execute(text(f"CREATE DATABASE {new_db} OWNER {user};")) # type: ignore - session.commit() - path = Path(url) + async with engine.connect() as conn: + await conn.execute(text(f"DROP DATABASE IF EXISTS {new_db};")) # type: ignore + await conn.execute(text(f"CREATE DATABASE {new_db} OWNER {user};")) # type: ignore + await conn.commit() - return str(path.parent / new_db).replace("postgresql:/", "postgresql://") + path = Path(url) + return str(path.parent / new_db).replace("postgresql+asyncpg:/", "postgresql+asyncpg://") def generate_db_name(request: pytest.FixtureRequest) -> str: @@ -172,20 +198,23 @@ def generate_db_name(request: pytest.FixtureRequest) -> str: return sanitized_name -@pytest.fixture -def db( +@pytest_asyncio.fixture # type: ignore [(reportUnknownMemberType)] +async def db( tmp_path: Path, request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch -) -> Generator[DatabaseTestUtils, None, None]: +) -> AsyncIterator[DatabaseTestUtils]: database_file = tmp_path / "test.sqlite3" if request.config.getoption("--db") == "postgresql": db_name: str = generate_db_name(request) - url = create_db(db_name) + url = await create_db(db_name) monkeypatch.setenv("APP_DATABASE_DB", db_name) - engine = create_engine(url) + engine = create_async_engine(url) else: - engine = create_engine(f"sqlite:///{database_file}", connect_args={"check_same_thread": False}) - Base.metadata.create_all(engine) + engine = create_async_engine(f"sqlite+aiosqlite:///{database_file}", connect_args={"check_same_thread": False}) + + async with engine.begin() as conn: + await conn.run_sync(Base.metadata.create_all) - with Session(engine, expire_on_commit=False) as session: + async_session = async_sessionmaker(engine, expire_on_commit=False) + async with async_session() as session: yield DatabaseTestUtils(session, database_file) diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 84246983..1040b10b 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -7,6 +7,7 @@ def test_environment_settings(monkeypatch: pytest.MonkeyPatch): monkeypatch.setenv("ENVIRONMENT", "production") monkeypatch.setenv("SECRET_KEY", "mysecret") monkeypatch.setenv("APP_DATABASE_SCHEME", "postgresql") + monkeypatch.setenv("APP_DATABASE_DRIVER", "asyncpg") monkeypatch.setenv("APP_DATABASE_USER", "amt2") monkeypatch.setenv("APP_DATABASE_DB", "amt2") monkeypatch.setenv("APP_DATABASE_PASSWORD", "mypassword") @@ -18,11 +19,12 @@ def test_environment_settings(monkeypatch: pytest.MonkeyPatch): assert settings.ENVIRONMENT == "production" assert settings.LOGGING_LEVEL == "INFO" assert settings.APP_DATABASE_SCHEME == "postgresql" + assert settings.APP_DATABASE_DRIVER == "asyncpg" assert settings.APP_DATABASE_SERVER == "db" assert settings.APP_DATABASE_PORT == 5432 assert settings.APP_DATABASE_USER == "amt2" assert settings.APP_DATABASE_DB == "amt2" - assert settings.SQLALCHEMY_DATABASE_URI == "postgresql://amt2:mypassword@db:5432/amt2" + assert settings.SQLALCHEMY_DATABASE_URI == "postgresql+asyncpg://amt2:mypassword@db:5432/amt2" def test_environment_settings_production_sqlite_error(monkeypatch: pytest.MonkeyPatch): diff --git a/tests/core/test_db.py b/tests/core/test_db.py index 83036149..43442e81 100644 --- a/tests/core/test_db.py +++ b/tests/core/test_db.py @@ -7,18 +7,19 @@ ) from pytest_mock import MockFixture from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession logger = logging.getLogger(__name__) -def test_check_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, mocker: MockFixture): +@pytest.mark.asyncio +async def test_check_database(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, mocker: MockFixture): database_file = tmp_path / "database.sqlite3" monkeypatch.setenv("APP_DATABASE_FILE", str(database_file)) - org_exec = Session.execute - Session.execute = mocker.MagicMock() - check_db() + org_exec = AsyncSession.execute + AsyncSession.execute = mocker.AsyncMock() + await check_db() - assert Session.execute.call_args is not None - assert str(select(1)) == str(Session.execute.call_args.args[0]) - Session.execute = org_exec + assert AsyncSession.execute.call_args is not None + assert str(select(1)) == str(AsyncSession.execute.call_args.args[0]) + AsyncSession.execute = org_exec diff --git a/tests/core/test_exception_handlers.py b/tests/core/test_exception_handlers.py index 2dec95ba..c75fb73e 100644 --- a/tests/core/test_exception_handlers.py +++ b/tests/core/test_exception_handlers.py @@ -2,25 +2,25 @@ from amt.core.exceptions import AMTCSRFProtectError from amt.schema.project import ProjectNew from fastapi import status -from fastapi.testclient import TestClient +from httpx import AsyncClient -def test_http_exception_handler(client: TestClient): - response = client.get("/raise-http-exception") +async def test_http_exception_handler(client: AsyncClient): + response = await client.get("/raise-http-exception") assert response.status_code == status.HTTP_404_NOT_FOUND assert response.headers["content-type"] == "text/html; charset=utf-8" -def test_request_validation_exception_handler(client: TestClient): - response = client.get("/algorithm-systems/?skip=a") +async def test_request_validation_exception_handler(client: AsyncClient): + response = await client.get("/algorithm-systems/?skip=a") assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.headers["content-type"] == "text/html; charset=utf-8" -def test_request_csrf_protect_exception_handler_invalid_token_in_header(client: TestClient): - data = client.get("/algorithm-systems/new") +async def test_request_csrf_protect_exception_handler_invalid_token_in_header(client: AsyncClient): + data = await client.get("/algorithm-systems/new") new_project = ProjectNew(name="default project", lifecycle="DATA_EXPLORATION_AND_PREPARATION") with pytest.raises(AMTCSRFProtectError): _response = client.post( @@ -28,22 +28,22 @@ def test_request_csrf_protect_exception_handler_invalid_token_in_header(client: ) -def test_http_exception_handler_htmx(client: TestClient): - response = client.get("/raise-http-exception", headers={"HX-Request": "true"}) +async def test_http_exception_handler_htmx(client: AsyncClient): + response = await client.get("/raise-http-exception", headers={"HX-Request": "true"}) assert response.status_code == status.HTTP_404_NOT_FOUND assert response.headers["content-type"] == "text/html; charset=utf-8" -def test_request_validation_exception_handler_htmx(client: TestClient): - response = client.get("/algorithm-systems/?skip=a", headers={"HX-Request": "true"}) +async def test_request_validation_exception_handler_htmx(client: AsyncClient): + response = await client.get("/algorithm-systems/?skip=a", headers={"HX-Request": "true"}) assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.headers["content-type"] == "text/html; charset=utf-8" -def test_request_csrf_protect_exception_handler_invalid_token(client: TestClient): - data = client.get("/algorithm-systems/new") +async def test_request_csrf_protect_exception_handler_invalid_token(client: AsyncClient): + data = await client.get("/algorithm-systems/new") new_project = ProjectNew(name="default project", lifecycle="DATA_EXPLORATION_AND_PREPARATION") with pytest.raises(AMTCSRFProtectError): _response = client.post( @@ -54,8 +54,8 @@ def test_request_csrf_protect_exception_handler_invalid_token(client: TestClient ) -def test_(client: TestClient): - response = client.get("/algorithm-systems/?skip=a", headers={"HX-Request": "true"}) +async def test_(client: AsyncClient): + response = await client.get("/algorithm-systems/?skip=a", headers={"HX-Request": "true"}) assert response.status_code == status.HTTP_400_BAD_REQUEST assert response.headers["content-type"] == "text/html; charset=utf-8" diff --git a/tests/database_e2e_setup.py b/tests/database_e2e_setup.py index 7c89d9c4..59355614 100644 --- a/tests/database_e2e_setup.py +++ b/tests/database_e2e_setup.py @@ -1,24 +1,24 @@ from amt.enums.status import Status from amt.models import Project -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio.session import AsyncSession from tests.constants import default_project, default_task, default_user from tests.database_test_utils import DatabaseTestUtils -def setup_database_e2e(session: Session) -> None: +async def setup_database_e2e(session: AsyncSession) -> None: db_e2e = DatabaseTestUtils(session) - db_e2e.given([default_user()]) + await db_e2e.given([default_user()]) projects: list[Project] = [] for idx in range(120): projects.append(default_project(name=f"Project {idx}")) - db_e2e.given([*projects]) + await db_e2e.given([*projects]) task1 = default_task(title="task1", status_id=Status.TODO, sort_order=-3, project_id=projects[0].id) task2 = default_task(title="task2", status_id=Status.TODO, sort_order=-2, project_id=projects[0].id) task3 = default_task(title="task3", status_id=Status.TODO, sort_order=-1, project_id=projects[0].id) - db_e2e.given([task1, task2, task3]) + await db_e2e.given([task1, task2, task3]) diff --git a/tests/database_test_utils.py b/tests/database_test_utils.py index a20e99e1..e36577a6 100644 --- a/tests/database_test_utils.py +++ b/tests/database_test_utils.py @@ -3,38 +3,42 @@ from amt.models.base import Base from sqlalchemy import select -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession logger = logging.getLogger(__name__) class DatabaseTestUtils: - def __init__(self, session: Session, database_file: Path | None = None) -> None: - self.session: Session = session + def __init__(self, session: AsyncSession, database_file: Path | None = None) -> None: + self.session: AsyncSession = session self.database_file: Path | None = database_file - def given(self, models: list[Base]) -> None: + async def given(self, models: list[Base]) -> None: session = self.get_session() session.add_all(models) - session.commit() + await session.commit() for model in models: - session.refresh(model) # inefficient, but needed to create correlations between models + await session.refresh(model) # inefficient, but needed to create correlations between models - def get_session(self) -> Session: + def get_session(self) -> AsyncSession: return self.session def get_database_file(self) -> Path | None: return self.database_file - def exists(self, model: type[Base], model_field: str, field_value: str | int) -> Base | None: - return self.get_session().execute(select(model).where(model_field == field_value)).one() is not None # type: ignore + async def exists(self, model: type[Base], model_field: str, field_value: str | int) -> bool: + try: + result = await self.get_session().execute(select(model).where(getattr(model, model_field) == field_value)) + return result.scalar_one_or_none() is not None + except AttributeError as err: + raise ValueError(f"Field '{model_field}' does not exist in model {model.__name__}") from err - def get(self, model: type[Base], model_field: str, field_value: str | int) -> list[Base]: + async def get(self, model: type[Base], model_field: str, field_value: str | int) -> list[Base]: try: query = select(model).where(getattr(model, model_field) == field_value) - result = self.get_session().execute(query) + result = await self.get_session().execute(query) return list(result.scalars().all()) except AttributeError as err: raise ValueError(f"Field '{model_field}' does not exist in model {model.__name__}") from err diff --git a/tests/middleware/test_authorization.py b/tests/middleware/test_authorization.py index 043802ca..61480836 100644 --- a/tests/middleware/test_authorization.py +++ b/tests/middleware/test_authorization.py @@ -1,10 +1,11 @@ import pytest -from fastapi.testclient import TestClient +from httpx import AsyncClient +@pytest.mark.asyncio @pytest.mark.enable_auth -def test_auth_not_project(client: TestClient) -> None: - response = client.get("/projects/") +async def test_auth_not_project(client: AsyncClient) -> None: + response = await client.get("/projects/", follow_redirects=True) assert response.status_code == 200 assert response.url == "http://testserver/" diff --git a/tests/repositories/test_deps.py b/tests/repositories/test_deps.py index 228fb794..f4ff4c85 100644 --- a/tests/repositories/test_deps.py +++ b/tests/repositories/test_deps.py @@ -1,10 +1,16 @@ +import pytest from amt.repositories.deps import get_session -from sqlalchemy.orm import Session +from sqlalchemy.ext.asyncio import AsyncSession -def test_get_session(): +@pytest.mark.asyncio +async def test_get_session(): session_generator = get_session() + session = await anext(session_generator) + assert isinstance(session, AsyncSession) - session = next(session_generator) - - assert isinstance(session, Session) + # Clean up by consuming the rest of the generator + try: # noqa + await session_generator.aclose() + except StopAsyncIteration: + pass diff --git a/tests/repositories/test_projects.py b/tests/repositories/test_projects.py index b0abefc9..5d02aebb 100644 --- a/tests/repositories/test_projects.py +++ b/tests/repositories/test_projects.py @@ -6,102 +6,113 @@ from tests.database_test_utils import DatabaseTestUtils -def test_find_all(db: DatabaseTestUtils): - db.given([default_project(), default_project()]) +@pytest.mark.asyncio +async def test_find_all(db: DatabaseTestUtils): + await db.given([default_project(), default_project()]) project_repository = ProjectsRepository(db.get_session()) - results = project_repository.find_all() + results = await project_repository.find_all() assert results[0].id == 1 assert results[1].id == 2 assert len(results) == 2 -def test_find_all_no_results(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_find_all_no_results(db: DatabaseTestUtils): project_repository = ProjectsRepository(db.get_session()) - results = project_repository.find_all() + results = await project_repository.find_all() assert len(results) == 0 -def test_save(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_save(db: DatabaseTestUtils): project_repository = ProjectsRepository(db.get_session()) project = default_project() - project_repository.save(project) + await project_repository.save(project) - result = project_repository.find_by_id(1) + result = await project_repository.find_by_id(1) - project_repository.delete(project) # cleanup + await project_repository.delete(project) # cleanup assert result.id == 1 assert result.name == default_project().name -def test_delete(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_delete(db: DatabaseTestUtils): project_repository = ProjectsRepository(db.get_session()) project = default_project() - project_repository.save(project) - project_repository.delete(project) + await project_repository.save(project) + await project_repository.delete(project) - results = project_repository.find_all() + results = await project_repository.find_all() assert len(results) == 0 -def test_save_failed(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_save_failed(db: DatabaseTestUtils): project_repository = ProjectsRepository(db.get_session()) project = default_project() project.id = 1 project_duplicate = default_project() project_duplicate.id = 1 - project_repository.save(project) + await project_repository.save(project) with pytest.raises(AMTRepositoryError): - project_repository.save(project_duplicate) + await project_repository.save(project_duplicate) - project_repository.delete(project) # cleanup + await project_repository.delete(project) # cleanup -def test_delete_failed(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_delete_failed(db: DatabaseTestUtils): project_repository = ProjectsRepository(db.get_session()) project = default_project() with pytest.raises(AMTRepositoryError): - project_repository.delete(project) + await project_repository.delete(project) -def test_find_by_id(db: DatabaseTestUtils): - db.given([default_project()]) +@pytest.mark.asyncio +async def test_find_by_id(db: DatabaseTestUtils): + await db.given([default_project()]) project_repository = ProjectsRepository(db.get_session()) - result = project_repository.find_by_id(1) + result = await project_repository.find_by_id(1) assert result.id == 1 assert result.name == default_project().name -def test_find_by_id_failed(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_find_by_id_failed(db: DatabaseTestUtils): project_repository = ProjectsRepository(db.get_session()) with pytest.raises(AMTRepositoryError): - project_repository.find_by_id(1) + await project_repository.find_by_id(1) -def test_paginate(db: DatabaseTestUtils): - db.given([default_project()]) +@pytest.mark.asyncio +async def test_paginate(db: DatabaseTestUtils): + await db.given([default_project()]) project_repository = ProjectsRepository(db.get_session()) - result: list[Project] = project_repository.paginate(skip=0, limit=3, search="", filters={}, sort={}) + result: list[Project] = await project_repository.paginate(skip=0, limit=3, search="", filters={}, sort={}) assert len(result) == 1 -def test_paginate_more(db: DatabaseTestUtils): - db.given([default_project(), default_project(), default_project(), default_project()]) +@pytest.mark.asyncio +async def test_paginate_more(db: DatabaseTestUtils): + await db.given([default_project(), default_project(), default_project(), default_project()]) project_repository = ProjectsRepository(db.get_session()) - result: list[Project] = project_repository.paginate(skip=0, limit=3, search="", filters={}, sort={}) + result: list[Project] = await project_repository.paginate(skip=0, limit=3, search="", filters={}, sort={}) assert len(result) == 3 -def test_paginate_capitalize(db: DatabaseTestUtils): - db.given( +@pytest.mark.asyncio +async def test_paginate_capitalize(db: DatabaseTestUtils): + await db.given( [ default_project(name="Project1"), default_project(name="bbb"), @@ -111,7 +122,7 @@ def test_paginate_capitalize(db: DatabaseTestUtils): ) project_repository = ProjectsRepository(db.get_session()) - result: list[Project] = project_repository.paginate(skip=0, limit=4, search="", filters={}, sort={}) + result: list[Project] = await project_repository.paginate(skip=0, limit=4, search="", filters={}, sort={}) assert len(result) == 4 assert result[0].name == "Aaa" @@ -120,8 +131,9 @@ def test_paginate_capitalize(db: DatabaseTestUtils): assert result[3].name == "Project1" -def test_search(db: DatabaseTestUtils): - db.given( +@pytest.mark.asyncio +async def test_search(db: DatabaseTestUtils): + await db.given( [ default_project(name="Project1"), default_project(name="bbb"), @@ -131,14 +143,15 @@ def test_search(db: DatabaseTestUtils): ) project_repository = ProjectsRepository(db.get_session()) - result: list[Project] = project_repository.paginate(skip=0, limit=4, search="bbb", filters={}, sort={}) + result: list[Project] = await project_repository.paginate(skip=0, limit=4, search="bbb", filters={}, sort={}) assert len(result) == 1 assert result[0].name == "bbb" -def test_search_multiple(db: DatabaseTestUtils): - db.given( +@pytest.mark.asyncio +async def test_search_multiple(db: DatabaseTestUtils): + await db.given( [ default_project(name="Project1"), default_project(name="bbb"), @@ -148,22 +161,24 @@ def test_search_multiple(db: DatabaseTestUtils): ) project_repository = ProjectsRepository(db.get_session()) - result: list[Project] = project_repository.paginate(skip=0, limit=4, search="A", filters={}, sort={}) + result: list[Project] = await project_repository.paginate(skip=0, limit=4, search="A", filters={}, sort={}) assert len(result) == 2 assert result[0].name == "Aaa" assert result[1].name == "aba" -def test_search_no_results(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_search_no_results(db: DatabaseTestUtils): project_repository = ProjectsRepository(db.get_session()) - result: list[Project] = project_repository.paginate(skip=0, limit=4, search="A", filters={}, sort={}) + result: list[Project] = await project_repository.paginate(skip=0, limit=4, search="A", filters={}, sort={}) assert len(result) == 0 -def test_raises_exception(db: DatabaseTestUtils): - db.given([default_project()]) +@pytest.mark.asyncio +async def test_raises_exception(db: DatabaseTestUtils): + await db.given([default_project()]) project_repository = ProjectsRepository(db.get_session()) with pytest.raises(AMTRepositoryError): - project_repository.paginate(skip="a", limit=3, search="", filters={}, sort={}) # type: ignore + await project_repository.paginate(skip="a", limit=3, search="", filters={}, sort={}) # type: ignore diff --git a/tests/repositories/test_tasks.py b/tests/repositories/test_tasks.py index 6ecad641..418d1f6c 100644 --- a/tests/repositories/test_tasks.py +++ b/tests/repositories/test_tasks.py @@ -7,43 +7,47 @@ from tests.database_test_utils import DatabaseTestUtils -def test_find_all(db: DatabaseTestUtils): - db.given([default_task(), default_task()]) +@pytest.mark.asyncio +async def test_find_all(db: DatabaseTestUtils): + await db.given([default_task(), default_task()]) tasks_repository: TasksRepository = TasksRepository(db.get_session()) - results = tasks_repository.find_all() + results = await tasks_repository.find_all() assert results[0].id == 1 assert results[1].id == 2 assert len(results) == 2 -def test_find_all_no_results(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_find_all_no_results(db: DatabaseTestUtils): tasks_repository: TasksRepository = TasksRepository(db.get_session()) - results = tasks_repository.find_all() + results = await tasks_repository.find_all() assert len(results) == 0 -def test_save(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_save(db: DatabaseTestUtils): tasks_repository: TasksRepository = TasksRepository(db.get_session()) task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10) - tasks_repository.save(task) - result = tasks_repository.find_by_id(1) + await tasks_repository.save(task) + result = await tasks_repository.find_by_id(1) assert result.id == 1 assert result.title == "Test title" assert result.description == "Test description" assert result.sort_order == 10 - tasks_repository.delete(task) # cleanup + await tasks_repository.delete(task) # cleanup -def test_save_all(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_save_all(db: DatabaseTestUtils): tasks_repository: TasksRepository = TasksRepository(db.get_session()) task_1: Task = Task(id=1, title="Test title 1", description="Test description 1", sort_order=10) task_2: Task = Task(id=2, title="Test title 2", description="Test description 2", sort_order=11) - tasks_repository.save_all([task_1, task_2]) - result_1 = tasks_repository.find_by_id(1) - result_2 = tasks_repository.find_by_id(2) + await tasks_repository.save_all([task_1, task_2]) + result_1 = await tasks_repository.find_by_id(1) + result_2 = await tasks_repository.find_by_id(2) assert result_1.id == 1 assert result_1.title == "Test title 1" @@ -55,84 +59,92 @@ def test_save_all(db: DatabaseTestUtils): assert result_2.description == "Test description 2" assert result_2.sort_order == 11 - tasks_repository.delete(task_1) # cleanup - tasks_repository.delete(task_2) # cleanup + await tasks_repository.delete(task_1) # cleanup + await tasks_repository.delete(task_2) # cleanup -def test_save_all_failed(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_save_all_failed(db: DatabaseTestUtils): tasks_repository: TasksRepository = TasksRepository(db.get_session()) task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10) - tasks_repository.save_all([task]) + await tasks_repository.save_all([task]) task_duplicate: Task = Task(id=1, title="Test title duplicate", description="Test description", sort_order=10) with pytest.raises(AMTRepositoryError): - tasks_repository.save_all([task_duplicate]) + await tasks_repository.save_all([task_duplicate]) - tasks_repository.delete(task) # cleanup + await tasks_repository.delete(task) # cleanup -def test_delete_task(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_delete_task(db: DatabaseTestUtils): tasks_repository: TasksRepository = TasksRepository(db.get_session()) task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10) - tasks_repository.save(task) - tasks_repository.delete(task) # cleanup + await tasks_repository.save(task) + await tasks_repository.delete(task) # cleanup - results = tasks_repository.find_all() + results = await tasks_repository.find_all() assert len(results) == 0 -def test_save_failed(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_save_failed(db: DatabaseTestUtils): tasks_repository: TasksRepository = TasksRepository(db.get_session()) task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10) - tasks_repository.save(task) + await tasks_repository.save(task) task_duplicate: Task = Task(id=1, title="Test title duplicate", description="Test description", sort_order=10) with pytest.raises(AMTRepositoryError): - tasks_repository.save(task_duplicate) + await tasks_repository.save(task_duplicate) - tasks_repository.delete(task) # cleanup + await tasks_repository.delete(task) # cleanup -def test_delete_failed(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_delete_failed(db: DatabaseTestUtils): tasks_repository: TasksRepository = TasksRepository(db.get_session()) task: Task = Task(id=1, title="Test title", description="Test description", sort_order=10) with pytest.raises(AMTRepositoryError): - tasks_repository.delete(task) + await tasks_repository.delete(task) -def test_find_by_id(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_find_by_id(db: DatabaseTestUtils): task = default_task() - db.given([task]) + await db.given([task]) tasks_repository: TasksRepository = TasksRepository(db.get_session()) - result: Task = tasks_repository.find_by_id(1) + result: Task = await tasks_repository.find_by_id(1) assert result.id == 1 assert result.title == "Default Task" -def test_find_by_id_failed(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_find_by_id_failed(db: DatabaseTestUtils): tasks_repository: TasksRepository = TasksRepository(db.get_session()) with pytest.raises(AMTRepositoryError): - tasks_repository.find_by_id(1) + await tasks_repository.find_by_id(1) -def test_find_by_status_id(db: DatabaseTestUtils): +@pytest.mark.asyncio +async def test_find_by_status_id(db: DatabaseTestUtils): task = default_task(status_id=Status.TODO) - db.given([task, default_task()]) + await db.given([task, default_task()]) tasks_repository: TasksRepository = TasksRepository(db.get_session()) - results = tasks_repository.find_by_status_id(1) + results = await tasks_repository.find_by_status_id(1) assert len(results) == 1 assert results[0].id == 1 -def test_find_by_project_id_and_status_id(db: DatabaseTestUtils): - db.given([default_project()]) +@pytest.mark.asyncio +async def test_find_by_project_id_and_status_id(db: DatabaseTestUtils): + await db.given([default_project()]) task = default_task(status_id=Status.TODO, project_id=1) - db.given([task, default_task()]) + await db.given([task, default_task()]) tasks_repository: TasksRepository = TasksRepository(db.get_session()) - results = tasks_repository.find_by_project_id_and_status_id(1, 1) + results = await tasks_repository.find_by_project_id_and_status_id(1, 1) assert len(results) == 1 assert results[0].id == 1 assert results[0].project_id == 1 diff --git a/tests/services/test_projects_service.py b/tests/services/test_projects_service.py index 8efc584e..02bd592c 100644 --- a/tests/services/test_projects_service.py +++ b/tests/services/test_projects_service.py @@ -1,3 +1,4 @@ +import pytest from amt.models.project import Project from amt.repositories.projects import ProjectsRepository from amt.schema.project import ProjectNew @@ -9,40 +10,42 @@ from tests.constants import default_instrument -def test_get_project(mocker: MockFixture): +@pytest.mark.asyncio +async def test_get_project(mocker: MockFixture): # Given project_id = 1 project_name = "Project 1" project_lifecycle = "development" projects_service = ProjectsService( - repository=mocker.Mock(spec=ProjectsRepository), - task_service=mocker.Mock(spec=TasksService), - instrument_service=mocker.Mock(spec=InstrumentsService), + repository=mocker.AsyncMock(spec=ProjectsRepository), + task_service=mocker.AsyncMock(spec=TasksService), + instrument_service=mocker.AsyncMock(spec=InstrumentsService), ) projects_service.repository.find_by_id.return_value = Project( # type: ignore id=project_id, name=project_name, lifecycle=project_lifecycle ) # When - project = projects_service.get(project_id) + project = await projects_service.get(project_id) # Then assert project.id == project_id assert project.name == project_name assert project.lifecycle == project_lifecycle - projects_service.repository.find_by_id.assert_called_once_with(project_id) # type: ignore + projects_service.repository.find_by_id.assert_awaited_once_with(project_id) # type: ignore -def test_create_project(mocker: MockFixture): +@pytest.mark.asyncio +async def test_create_project(mocker: MockFixture): project_id = 1 project_name = "Project 1" project_lifecycle = "development" system_card = SystemCard(name=project_name) projects_service = ProjectsService( - repository=mocker.Mock(spec=ProjectsRepository), - task_service=mocker.Mock(spec=TasksService), - instrument_service=mocker.Mock(spec=InstrumentsService), + repository=mocker.AsyncMock(spec=ProjectsRepository), + task_service=mocker.AsyncMock(spec=TasksService), + instrument_service=mocker.AsyncMock(spec=InstrumentsService), ) projects_service.repository.save.return_value = Project( # type: ignore id=project_id, name=project_name, lifecycle=project_lifecycle, system_card=system_card @@ -61,10 +64,10 @@ def test_create_project(mocker: MockFixture): transparency_obligations="project_transparency_obligations", role="project_role", ) - project = projects_service.create(project_new) + project = await projects_service.create(project_new) # Then assert project.id == project_id assert project.name == project_name assert project.lifecycle == project_lifecycle - projects_service.repository.save.assert_called() # type: ignore + projects_service.repository.save.assert_awaited() # type: ignore diff --git a/tests/services/test_tasks_service.py b/tests/services/test_tasks_service.py index 90149ded..9df971c7 100644 --- a/tests/services/test_tasks_service.py +++ b/tests/services/test_tasks_service.py @@ -38,19 +38,19 @@ def reset(self): def find_all(self): return self._tasks - def find_by_status_id(self, status_id: int) -> Sequence[Task]: + async def find_by_status_id(self, status_id: int) -> Sequence[Task]: return list(filter(lambda x: x.status_id == status_id, self._tasks)) - def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> Sequence[Task]: + async def find_by_project_id_and_status_id(self, project_id: int, status_id: int) -> Sequence[Task]: return list(filter(lambda x: x.status_id == status_id and x.project_id == project_id, self._tasks)) - def find_by_id(self, task_id: int) -> Task: + async def find_by_id(self, task_id: int) -> Task: return next(filter(lambda x: x.id == task_id, self._tasks)) - def save(self, task: Task) -> Task: + async def save(self, task: Task) -> Task: return task - def save_all(self, tasks: Sequence[Task]) -> None: + async def save_all(self, tasks: Sequence[Task]) -> None: for task in tasks: self._tasks.append(task) return None @@ -69,22 +69,30 @@ def tasks_service_with_mock(mock_tasks_repository: TasksRepository): return tasks_service -def test_get_tasks(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository): - assert len(tasks_service_with_mock.get_tasks(1)) == 3 +@pytest.mark.asyncio +async def test_get_tasks(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository): + tasks = await tasks_service_with_mock.get_tasks(1) + assert len(tasks) == 3 -def test_get_tasks_for_project(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository): - assert len(tasks_service_with_mock.get_tasks_for_project(1, 1)) == 2 +@pytest.mark.asyncio +async def test_get_tasks_for_project(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository): + tasks = await tasks_service_with_mock.get_tasks_for_project(1, 1) + assert len(tasks) == 2 -def test_assign_task(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository): - task1: Task = mock_tasks_repository.find_by_id(1) +@pytest.mark.asyncio +async def test_assign_task(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository): + task1: Task = await mock_tasks_repository.find_by_id(1) user1: User = User(id=1, name="User 1", avatar="none.jpg") - tasks_service_with_mock.assign_task(task1, user1) + await tasks_service_with_mock.assign_task(task1, user1) assert task1.user_id == 1 -def test_create_instrument_tasks(tasks_service_with_mock: TasksService, mock_tasks_repository: MockTasksRepository): +@pytest.mark.asyncio +async def test_create_instrument_tasks( + tasks_service_with_mock: TasksService, mock_tasks_repository: MockTasksRepository +): # Given task_1 = InstrumentTask(question="question_1", urn="instrument_1_task_1", lifecycle=[]) task_2 = InstrumentTask(question="question_1", urn="instrument_1_task_1", lifecycle=[]) @@ -95,48 +103,61 @@ def test_create_instrument_tasks(tasks_service_with_mock: TasksService, mock_tas # When mock_tasks_repository.clear() - tasks_service_with_mock.create_instrument_tasks([task_1, task_2], project) + await tasks_service_with_mock.create_instrument_tasks([task_1, task_2], project) # Then - assert len(mock_tasks_repository.find_all()) == 2 - assert mock_tasks_repository.find_all()[0].project_id == 1 - assert mock_tasks_repository.find_all()[0].title == task_1.question - assert mock_tasks_repository.find_all()[1].project_id == 1 - assert mock_tasks_repository.find_all()[1].title == task_2.question + tasks = mock_tasks_repository.find_all() + assert len(tasks) == 2 + assert tasks[0].project_id == 1 + assert tasks[0].title == task_1.question + assert tasks[1].project_id == 1 + assert tasks[1].title == task_2.question -def test_move_task(tasks_service_with_mock: TasksService, mock_tasks_repository: MockTasksRepository): +@pytest.mark.asyncio +async def test_move_task(tasks_service_with_mock: TasksService, mock_tasks_repository: MockTasksRepository): # test changing order between 2 cards mock_tasks_repository.reset() - assert mock_tasks_repository.find_by_id(1).sort_order == 10 - tasks_service_with_mock.move_task(1, 1, 2, 3) - assert mock_tasks_repository.find_by_id(1).sort_order == 25 + task = await mock_tasks_repository.find_by_id(1) + assert task.sort_order == 10 + await tasks_service_with_mock.move_task(1, 1, 2, 3) + task = await mock_tasks_repository.find_by_id(1) + assert task.sort_order == 25 # test changing order, after the last card mock_tasks_repository.reset() - tasks_service_with_mock.move_task(1, 1, 3, None) - assert mock_tasks_repository.find_by_id(1).sort_order == 40 + await tasks_service_with_mock.move_task(1, 1, 3, None) + task = await mock_tasks_repository.find_by_id(1) + assert task.sort_order == 40 # test changing order, before the first card mock_tasks_repository.reset() - assert mock_tasks_repository.find_by_id(3).sort_order == 30 - tasks_service_with_mock.move_task(3, 1, None, 1) - assert mock_tasks_repository.find_by_id(3).sort_order == 5 + task = await mock_tasks_repository.find_by_id(3) + assert task.sort_order == 30 + await tasks_service_with_mock.move_task(3, 1, None, 1) + task = await mock_tasks_repository.find_by_id(3) + assert task.sort_order == 5 # test moving to in progress mock_tasks_repository.reset() - mock_tasks_repository.find_by_id(1).sort_order = 0 - tasks_service_with_mock.move_task(1, 2) - assert mock_tasks_repository.find_by_id(1).sort_order == 10 + task = await mock_tasks_repository.find_by_id(1) + task.sort_order = 0 + await tasks_service_with_mock.move_task(1, 2) + task = await mock_tasks_repository.find_by_id(1) + assert task.sort_order == 10 # test moving to todo mock_tasks_repository.reset() - mock_tasks_repository.find_by_id(1).sort_order = 0 - tasks_service_with_mock.move_task(1, 4) - assert mock_tasks_repository.find_by_id(1).sort_order == 10 + task = await mock_tasks_repository.find_by_id(1) + task.sort_order = 0 + await tasks_service_with_mock.move_task(1, 4) + task = await mock_tasks_repository.find_by_id(1) + assert task.sort_order == 10 # test moving move under other card mock_tasks_repository.reset() - mock_tasks_repository.find_by_id(2).sort_order = 10 - tasks_service_with_mock.move_task(2, 1, None, 1) - assert mock_tasks_repository.find_by_id(2).sort_order == 5 + task = await mock_tasks_repository.find_by_id(2) + task.sort_order = 10 + await tasks_service_with_mock.move_task(2, 1, None, 1) + task = await mock_tasks_repository.find_by_id(2) + assert task.sort_order == 5 diff --git a/tests/test_main.py b/tests/test_main.py index 226f9e5f..18e92866 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -1,8 +1,10 @@ -from fastapi.testclient import TestClient +import pytest +from httpx import AsyncClient -def test_get_non_exisiting(client: TestClient) -> None: - response = client.get( +@pytest.mark.asyncio +async def test_get_non_exisiting(client: AsyncClient) -> None: + response = await client.get( "/pathdoesnotexist/", ) assert response.status_code == 404