From c2479826062ddf7879df6499ee394aa9bc96cd91 Mon Sep 17 00:00:00 2001 From: ChristopherSpelt Date: Mon, 11 Nov 2024 11:57:26 +0100 Subject: [PATCH] Add user id to user table in db --- amt/api/routes/auth.py | 15 ++++- .../versions/22298f3aac77_drop_users_table.py | 38 +++++++++++++ ...24222_create_user_table_with_uuid_as_pk.py | 37 +++++++++++++ amt/models/task.py | 10 ++-- amt/models/user.py | 6 +- amt/repositories/users.py | 55 +++++++++++++++++++ amt/services/users.py | 25 +++++++++ compose.override.yml | 10 ---- tests/constants.py | 10 ++-- tests/models/test_model_user.py | 3 +- tests/repositories/test_users.py | 50 +++++++++++++++++ tests/services/test_tasks_service.py | 4 +- tests/services/test_users_service.py | 46 ++++++++++++++++ 13 files changed, 284 insertions(+), 25 deletions(-) create mode 100644 amt/migrations/versions/22298f3aac77_drop_users_table.py create mode 100644 amt/migrations/versions/69243fd24222_create_user_table_with_uuid_as_pk.py create mode 100644 amt/repositories/users.py create mode 100644 amt/services/users.py delete mode 100644 compose.override.yml create mode 100644 tests/repositories/test_users.py create mode 100644 tests/services/test_users_service.py diff --git a/amt/api/routes/auth.py b/amt/api/routes/auth.py index 897a989c..d8b71908 100644 --- a/amt/api/routes/auth.py +++ b/amt/api/routes/auth.py @@ -1,14 +1,18 @@ import hashlib import logging +from typing import Annotated from urllib.parse import quote_plus +from uuid import UUID from authlib.integrations.starlette_client import OAuth, OAuthError # type: ignore -from fastapi import APIRouter, Request +from fastapi import APIRouter, Depends, Request from fastapi.responses import HTMLResponse, RedirectResponse, Response from amt.api.deps import templates from amt.core.authorization import get_user from amt.core.exceptions import AMTAuthorizationFlowError +from amt.models.user import User +from amt.services.users import UsersService router = APIRouter() logger = logging.getLogger(__name__) @@ -42,7 +46,10 @@ async def logout(request: Request) -> RedirectResponse: # pragma: no cover @router.get("/callback", response_class=Response) -async def auth_callback(request: Request) -> Response: # pragma: no cover +async def auth_callback( + request: Request, + users_service: Annotated[UsersService, Depends(UsersService)], +) -> Response: # pragma: no cover oauth: OAuth = request.app.state.oauth try: token = await oauth.keycloak.authorize_access_token(request) # type: ignore @@ -58,6 +65,10 @@ async def auth_callback(request: Request) -> Response: # pragma: no cover name: str = str(user["name"]).strip().lower() # type: ignore user["name_encoded"] = quote_plus(name) + if "sub" in user and isinstance(user["sub"], str): + new_user = User(id=UUID(user["sub"]), name=user["name"]) # type: ignore + new_user = await users_service.create_or_update(new_user) + if user: request.session["user"] = dict(user) # type: ignore request.session["id_token"] = token["id_token"] # type: ignore diff --git a/amt/migrations/versions/22298f3aac77_drop_users_table.py b/amt/migrations/versions/22298f3aac77_drop_users_table.py new file mode 100644 index 00000000..a295a1ff --- /dev/null +++ b/amt/migrations/versions/22298f3aac77_drop_users_table.py @@ -0,0 +1,38 @@ +"""drop users table + +Revision ID: 22298f3aac77 +Revises: 7f20f8562007 +Create Date: 2024-11-12 09:33:50.853310 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "22298f3aac77" +down_revision: str | None = "6581a03aabec" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + with op.batch_alter_table("task", schema=None) as batch_op: + batch_op.drop_constraint("fk_task_user_id_user", type_="foreignkey") + op.drop_column("task", "user_id") + op.drop_table("user") + + +def downgrade() -> None: + op.create_table( + "user", + sa.Column("id", sa.INTEGER(), nullable=False), + sa.Column("name", sa.VARCHAR(length=255), nullable=False), + sa.Column("avatar", sa.VARCHAR(length=255), nullable=True), + sa.PrimaryKeyConstraint("id", name="pk_user"), + ) + op.add_column("task", sa.Column("user_id", sa.INTEGER(), nullable=True)) + with op.batch_alter_table("task", schema=None) as batch_op: + batch_op.create_foreign_key("fk_task_user_id_user", "user", ["user_id"], ["id"]) diff --git a/amt/migrations/versions/69243fd24222_create_user_table_with_uuid_as_pk.py b/amt/migrations/versions/69243fd24222_create_user_table_with_uuid_as_pk.py new file mode 100644 index 00000000..98ffa692 --- /dev/null +++ b/amt/migrations/versions/69243fd24222_create_user_table_with_uuid_as_pk.py @@ -0,0 +1,37 @@ +"""create user table with uuid as pk + +Revision ID: 69243fd24222 +Revises: 22298f3aac77 +Create Date: 2024-11-12 09:49:53.558089 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "69243fd24222" +down_revision: str | None = "22298f3aac77" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + op.create_table( + "user", + sa.Column("id", sa.UUID(), nullable=False), + sa.Column("name", sa.String(), nullable=False), + sa.PrimaryKeyConstraint("id", name=op.f("pk_user")), + ) + op.add_column("task", sa.Column("user_id", sa.UUID(), nullable=True)) + with op.batch_alter_table("task", schema=None) as batch_op: + batch_op.create_foreign_key(op.f("fk_task_user_id_user"), "user", ["user_id"], ["id"]) + + +def downgrade() -> None: + with op.batch_alter_table("task", schema=None) as batch_op: + batch_op.drop_constraint(op.f("fk_task_user_id_user"), type_="foreignkey") + op.drop_column("task", "user_id") + op.drop_table("user") diff --git a/amt/models/task.py b/amt/models/task.py index 4fbd0f32..3b84816d 100644 --- a/amt/models/task.py +++ b/amt/models/task.py @@ -1,3 +1,5 @@ +from uuid import UUID + from sqlalchemy import ForeignKey, String from sqlalchemy.orm import Mapped, mapped_column @@ -12,9 +14,9 @@ class Task(Base): description: Mapped[str] sort_order: Mapped[float] status_id: Mapped[int | None] = mapped_column(default=None) - user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id")) - # TODO: (Christopher) SQLModel does not allow to give the below restraint an name - # which is needed for alembic. This results in changing the migration file - # manually to give the restrain a name. + user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id")) + ## TODO: (Christopher) SQLModel does not allow to give the below restraint an name + ## which is needed for alembic. This results in changing the migration file + ## manually to give the restrain a name. project_id: Mapped[int | None] = mapped_column(ForeignKey("project.id")) # todo(robbert) Tasks probably are grouped (and sub-grouped), so we probably need a reference to a group_id diff --git a/amt/models/user.py b/amt/models/user.py index 0c8f4e98..5f31d27d 100644 --- a/amt/models/user.py +++ b/amt/models/user.py @@ -1,3 +1,6 @@ +from uuid import UUID + +from sqlalchemy import UUID as SQLAlchemyUUID from sqlalchemy.orm import Mapped, mapped_column from amt.models.base import Base @@ -6,6 +9,5 @@ class User(Base): __tablename__ = "user" - id: Mapped[int] = mapped_column(primary_key=True) + id: Mapped[UUID] = mapped_column(SQLAlchemyUUID(as_uuid=True), primary_key=True) name: Mapped[str] - avatar: Mapped[str | None] = mapped_column(default=None) diff --git a/amt/repositories/users.py b/amt/repositories/users.py new file mode 100644 index 00000000..f1a001f4 --- /dev/null +++ b/amt/repositories/users.py @@ -0,0 +1,55 @@ +import logging +from typing import Annotated +from uuid import UUID + +from fastapi import Depends +from sqlalchemy import select +from sqlalchemy.exc import NoResultFound, SQLAlchemyError +from sqlalchemy.ext.asyncio import AsyncSession + +from amt.core.exceptions import AMTRepositoryError +from amt.models.user import User +from amt.repositories.deps import get_session + +logger = logging.getLogger(__name__) + + +class UsersRepository: + """ + The UsersRepository provides access to the repository layer. + """ + + def __init__(self, session: Annotated[AsyncSession, Depends(get_session)]) -> None: + self.session = session + + async def find_by_id(self, id: UUID) -> User | None: + """ + Returns the user with the given id. + :param id: the id of the user to find + :return: the user with the given id or an exception if no user was found + """ + statement = select(User).where(User.id == id) + try: + return (await self.session.execute(statement)).scalars().one() + except NoResultFound: + return None + + async def upsert(self, user: User) -> User: + """ + Upserts (create or update) a user. + :param user: the user to upsert. + :return: the upserted user. + """ + try: + existing_user = await self.find_by_id(user.id) + if existing_user: + existing_user.name = user.name + else: + self.session.add(user) + await self.session.commit() + except SQLAlchemyError as e: # pragma: no cover + logger.exception("Error saving user") + await self.session.rollback() + raise AMTRepositoryError from e + + return user diff --git a/amt/services/users.py b/amt/services/users.py new file mode 100644 index 00000000..921d082e --- /dev/null +++ b/amt/services/users.py @@ -0,0 +1,25 @@ +import logging +from typing import Annotated +from uuid import UUID + +from fastapi import Depends + +from amt.models.user import User +from amt.repositories.users import UsersRepository + +logger = logging.getLogger(__name__) + + +class UsersService: + def __init__( + self, + repository: Annotated[UsersRepository, Depends(UsersRepository)], + ) -> None: + self.repository = repository + + async def get(self, id: str | UUID) -> User | None: + id = UUID(id) if isinstance(id, str) else id + return await self.repository.find_by_id(id) + + async def create_or_update(self, user: User) -> User: + return await self.repository.upsert(user) diff --git a/compose.override.yml b/compose.override.yml deleted file mode 100644 index e03e6d97..00000000 --- a/compose.override.yml +++ /dev/null @@ -1,10 +0,0 @@ -services: - amt-test: - build: - context: . - dockerfile: Dockerfile - target: test - image: ghcr.io/minbzk/amt-test:latest - env_file: - - path: prod.env - required: true diff --git a/tests/constants.py b/tests/constants.py index f2143ab5..1e81f8e9 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,4 +1,5 @@ import json +from uuid import UUID from amt.api.lifecycles import Lifecycles from amt.api.navigation import BaseNavigationItem, DisplayText @@ -18,14 +19,15 @@ def default_base_navigation_item( return BaseNavigationItem(display_text=display_text, url=url, custom_display_text=custom_display_text, icon=icon) -def default_user(name: str = "default user", avatar: str | None = None) -> User: - return User(name=name, avatar=avatar) - - def default_project(name: str = "default project") -> Project: return Project(name=name) +def default_user(id: str | UUID = "00494b4d-bcdf-425a-8140-bea0f3cbd3c2", name: str = "John Smith") -> User: + id = UUID(id) if isinstance(id, str) else id + return User(id=id, name=name) + + def default_project_with_system_card(name: str = "default project") -> Project: with open("resources/system_card_templates/AMT_Template_1.json") as f: system_card_from_template = json.load(f) diff --git a/tests/models/test_model_user.py b/tests/models/test_model_user.py index 1176e580..ce121685 100644 --- a/tests/models/test_model_user.py +++ b/tests/models/test_model_user.py @@ -4,7 +4,8 @@ def test_model_basic_user(): # given - user = User(name="Test user", avatar=None) + user = User(id="70419b8d-29ff-4c51-9822-05a46f6c916e", name="Test user") # then + assert user.id == "70419b8d-29ff-4c51-9822-05a46f6c916e" assert user.name == "Test user" diff --git a/tests/repositories/test_users.py b/tests/repositories/test_users.py new file mode 100644 index 00000000..d6400777 --- /dev/null +++ b/tests/repositories/test_users.py @@ -0,0 +1,50 @@ +from unittest.mock import AsyncMock + +import pytest +from amt.core.exceptions import AMTRepositoryError +from amt.repositories.users import UsersRepository +from sqlalchemy.exc import SQLAlchemyError +from tests.constants import default_user +from tests.database_test_utils import DatabaseTestUtils + + +@pytest.mark.asyncio +async def test_find_by_id(db: DatabaseTestUtils): + await db.given([default_user()]) + users_repository = UsersRepository(db.get_session()) + result = await users_repository.find_by_id(default_user().id) + assert result is not None + assert result.id == default_user().id + assert result.name == default_user().name + + +@pytest.mark.asyncio +async def test_upsert_new(db: DatabaseTestUtils): + new_user = default_user() + users_repository = UsersRepository(db.get_session()) + await users_repository.upsert(new_user) + result = await users_repository.find_by_id(new_user.id) + assert result is not None + assert result.id == new_user.id + assert result.name == new_user.name + + +@pytest.mark.asyncio +async def test_upsert_existing(db: DatabaseTestUtils): + await db.given([default_user()]) + new_user = default_user(name="John Smith New") + users_repository = UsersRepository(db.get_session()) + await users_repository.upsert(new_user) + result = await users_repository.find_by_id(new_user.id) + assert result is not None + assert result.id == new_user.id + assert result.name == new_user.name + + +@pytest.mark.asyncio +async def test_upsert_error(db: DatabaseTestUtils): + new_user = default_user(name="John Smith New") + users_repository = UsersRepository(db.get_session()) + users_repository.find_by_id = AsyncMock(side_effect=SQLAlchemyError("Database error")) + with pytest.raises(AMTRepositoryError): + await users_repository.upsert(new_user) diff --git a/tests/services/test_tasks_service.py b/tests/services/test_tasks_service.py index 9df971c7..19305a79 100644 --- a/tests/services/test_tasks_service.py +++ b/tests/services/test_tasks_service.py @@ -84,9 +84,9 @@ async def test_get_tasks_for_project(tasks_service_with_mock: TasksService, mock @pytest.mark.asyncio async def test_assign_task(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository): task1: Task = await mock_tasks_repository.find_by_id(1) - user1: User = User(id=1, name="User 1", avatar="none.jpg") + user1: User = User(id="1", name="User 1") await tasks_service_with_mock.assign_task(task1, user1) - assert task1.user_id == 1 + assert task1.user_id == "1" @pytest.mark.asyncio diff --git a/tests/services/test_users_service.py b/tests/services/test_users_service.py new file mode 100644 index 00000000..be09795d --- /dev/null +++ b/tests/services/test_users_service.py @@ -0,0 +1,46 @@ +from uuid import UUID + +import pytest +from amt.repositories.users import UsersRepository +from amt.services.users import UsersService +from pytest_mock import MockFixture +from tests.constants import default_user + + +@pytest.mark.asyncio +async def test_get_user(mocker: MockFixture): + # Given + id = UUID("3d284d80-fc47-41ab-9696-fab562bacbd5") + name = "John Smith" + users_service = UsersService( + repository=mocker.AsyncMock(spec=UsersRepository), + ) + users_service.repository.find_by_id.return_value = default_user(id=id, name=name) # type: ignore + + # When + user = await users_service.get(id) + + # Then + assert user is not None + assert user.id == id + assert user.name == name + users_service.repository.find_by_id.assert_awaited_once_with(id) # type: ignore + + +@pytest.mark.asyncio +async def test_create_or_update(mocker: MockFixture): + # Given + user = default_user() + users_service = UsersService( + repository=mocker.AsyncMock(spec=UsersRepository), + ) + users_service.repository.upsert.return_value = user # type: ignore + + # When + retreived_user = await users_service.create_or_update(user) + + # Then + assert retreived_user is not None + assert retreived_user.id == user.id + assert retreived_user.name == user.name + users_service.repository.upsert.assert_awaited_once_with(user) # type: ignore