Skip to content

Commit

Permalink
Add user id to user table in db
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Nov 13, 2024
1 parent 1393a5e commit c247982
Show file tree
Hide file tree
Showing 13 changed files with 284 additions and 25 deletions.
15 changes: 13 additions & 2 deletions amt/api/routes/auth.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import hashlib
import logging
from typing import Annotated
from urllib.parse import quote_plus
from uuid import UUID

from authlib.integrations.starlette_client import OAuth, OAuthError # type: ignore
from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse, RedirectResponse, Response

from amt.api.deps import templates
from amt.core.authorization import get_user
from amt.core.exceptions import AMTAuthorizationFlowError
from amt.models.user import User
from amt.services.users import UsersService

router = APIRouter()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,7 +46,10 @@ async def logout(request: Request) -> RedirectResponse: # pragma: no cover


@router.get("/callback", response_class=Response)
async def auth_callback(request: Request) -> Response: # pragma: no cover
async def auth_callback(
request: Request,
users_service: Annotated[UsersService, Depends(UsersService)],
) -> Response: # pragma: no cover
oauth: OAuth = request.app.state.oauth
try:
token = await oauth.keycloak.authorize_access_token(request) # type: ignore
Expand All @@ -58,6 +65,10 @@ async def auth_callback(request: Request) -> Response: # pragma: no cover
name: str = str(user["name"]).strip().lower() # type: ignore
user["name_encoded"] = quote_plus(name)

if "sub" in user and isinstance(user["sub"], str):
new_user = User(id=UUID(user["sub"]), name=user["name"]) # type: ignore
new_user = await users_service.create_or_update(new_user)

if user:
request.session["user"] = dict(user) # type: ignore
request.session["id_token"] = token["id_token"] # type: ignore
Expand Down
38 changes: 38 additions & 0 deletions amt/migrations/versions/22298f3aac77_drop_users_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""drop users table
Revision ID: 22298f3aac77
Revises: 7f20f8562007
Create Date: 2024-11-12 09:33:50.853310
"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "22298f3aac77"
down_revision: str | None = "6581a03aabec"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.drop_constraint("fk_task_user_id_user", type_="foreignkey")
op.drop_column("task", "user_id")
op.drop_table("user")


def downgrade() -> None:
op.create_table(
"user",
sa.Column("id", sa.INTEGER(), nullable=False),
sa.Column("name", sa.VARCHAR(length=255), nullable=False),
sa.Column("avatar", sa.VARCHAR(length=255), nullable=True),
sa.PrimaryKeyConstraint("id", name="pk_user"),
)
op.add_column("task", sa.Column("user_id", sa.INTEGER(), nullable=True))
with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.create_foreign_key("fk_task_user_id_user", "user", ["user_id"], ["id"])
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""create user table with uuid as pk
Revision ID: 69243fd24222
Revises: 22298f3aac77
Create Date: 2024-11-12 09:49:53.558089
"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "69243fd24222"
down_revision: str | None = "22298f3aac77"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
op.create_table(
"user",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_user")),
)
op.add_column("task", sa.Column("user_id", sa.UUID(), nullable=True))
with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.create_foreign_key(op.f("fk_task_user_id_user"), "user", ["user_id"], ["id"])


def downgrade() -> None:
with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.drop_constraint(op.f("fk_task_user_id_user"), type_="foreignkey")
op.drop_column("task", "user_id")
op.drop_table("user")
10 changes: 6 additions & 4 deletions amt/models/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import UUID

from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column

Expand All @@ -12,9 +14,9 @@ class Task(Base):
description: Mapped[str]
sort_order: Mapped[float]
status_id: Mapped[int | None] = mapped_column(default=None)
user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"))
# TODO: (Christopher) SQLModel does not allow to give the below restraint an name
# which is needed for alembic. This results in changing the migration file
# manually to give the restrain a name.
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"))
## TODO: (Christopher) SQLModel does not allow to give the below restraint an name
## which is needed for alembic. This results in changing the migration file
## manually to give the restrain a name.
project_id: Mapped[int | None] = mapped_column(ForeignKey("project.id"))
# todo(robbert) Tasks probably are grouped (and sub-grouped), so we probably need a reference to a group_id
6 changes: 4 additions & 2 deletions amt/models/user.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from uuid import UUID

from sqlalchemy import UUID as SQLAlchemyUUID
from sqlalchemy.orm import Mapped, mapped_column

from amt.models.base import Base
Expand All @@ -6,6 +9,5 @@
class User(Base):
__tablename__ = "user"

id: Mapped[int] = mapped_column(primary_key=True)
id: Mapped[UUID] = mapped_column(SQLAlchemyUUID(as_uuid=True), primary_key=True)
name: Mapped[str]
avatar: Mapped[str | None] = mapped_column(default=None)
55 changes: 55 additions & 0 deletions amt/repositories/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession

from amt.core.exceptions import AMTRepositoryError
from amt.models.user import User
from amt.repositories.deps import get_session

logger = logging.getLogger(__name__)


class UsersRepository:
"""
The UsersRepository provides access to the repository layer.
"""

def __init__(self, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
self.session = session

async def find_by_id(self, id: UUID) -> User | None:
"""
Returns the user with the given id.
:param id: the id of the user to find
:return: the user with the given id or an exception if no user was found
"""
statement = select(User).where(User.id == id)
try:
return (await self.session.execute(statement)).scalars().one()
except NoResultFound:
return None

async def upsert(self, user: User) -> User:
"""
Upserts (create or update) a user.
:param user: the user to upsert.
:return: the upserted user.
"""
try:
existing_user = await self.find_by_id(user.id)
if existing_user:
existing_user.name = user.name
else:
self.session.add(user)
await self.session.commit()
except SQLAlchemyError as e: # pragma: no cover
logger.exception("Error saving user")
await self.session.rollback()
raise AMTRepositoryError from e

return user
25 changes: 25 additions & 0 deletions amt/services/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging
from typing import Annotated
from uuid import UUID

from fastapi import Depends

from amt.models.user import User
from amt.repositories.users import UsersRepository

logger = logging.getLogger(__name__)


class UsersService:
def __init__(
self,
repository: Annotated[UsersRepository, Depends(UsersRepository)],
) -> None:
self.repository = repository

async def get(self, id: str | UUID) -> User | None:
id = UUID(id) if isinstance(id, str) else id
return await self.repository.find_by_id(id)

async def create_or_update(self, user: User) -> User:
return await self.repository.upsert(user)
10 changes: 0 additions & 10 deletions compose.override.yml

This file was deleted.

10 changes: 6 additions & 4 deletions tests/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from uuid import UUID

from amt.api.lifecycles import Lifecycles
from amt.api.navigation import BaseNavigationItem, DisplayText
Expand All @@ -18,14 +19,15 @@ def default_base_navigation_item(
return BaseNavigationItem(display_text=display_text, url=url, custom_display_text=custom_display_text, icon=icon)


def default_user(name: str = "default user", avatar: str | None = None) -> User:
return User(name=name, avatar=avatar)


def default_project(name: str = "default project") -> Project:
return Project(name=name)


def default_user(id: str | UUID = "00494b4d-bcdf-425a-8140-bea0f3cbd3c2", name: str = "John Smith") -> User:
id = UUID(id) if isinstance(id, str) else id
return User(id=id, name=name)


def default_project_with_system_card(name: str = "default project") -> Project:
with open("resources/system_card_templates/AMT_Template_1.json") as f:
system_card_from_template = json.load(f)
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_model_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
def test_model_basic_user():
# given

user = User(name="Test user", avatar=None)
user = User(id="70419b8d-29ff-4c51-9822-05a46f6c916e", name="Test user")

# then
assert user.id == "70419b8d-29ff-4c51-9822-05a46f6c916e"
assert user.name == "Test user"
50 changes: 50 additions & 0 deletions tests/repositories/test_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest.mock import AsyncMock

import pytest
from amt.core.exceptions import AMTRepositoryError
from amt.repositories.users import UsersRepository
from sqlalchemy.exc import SQLAlchemyError
from tests.constants import default_user
from tests.database_test_utils import DatabaseTestUtils


@pytest.mark.asyncio
async def test_find_by_id(db: DatabaseTestUtils):
await db.given([default_user()])
users_repository = UsersRepository(db.get_session())
result = await users_repository.find_by_id(default_user().id)
assert result is not None
assert result.id == default_user().id
assert result.name == default_user().name


@pytest.mark.asyncio
async def test_upsert_new(db: DatabaseTestUtils):
new_user = default_user()
users_repository = UsersRepository(db.get_session())
await users_repository.upsert(new_user)
result = await users_repository.find_by_id(new_user.id)
assert result is not None
assert result.id == new_user.id
assert result.name == new_user.name


@pytest.mark.asyncio
async def test_upsert_existing(db: DatabaseTestUtils):
await db.given([default_user()])
new_user = default_user(name="John Smith New")
users_repository = UsersRepository(db.get_session())
await users_repository.upsert(new_user)
result = await users_repository.find_by_id(new_user.id)
assert result is not None
assert result.id == new_user.id
assert result.name == new_user.name


@pytest.mark.asyncio
async def test_upsert_error(db: DatabaseTestUtils):
new_user = default_user(name="John Smith New")
users_repository = UsersRepository(db.get_session())
users_repository.find_by_id = AsyncMock(side_effect=SQLAlchemyError("Database error"))
with pytest.raises(AMTRepositoryError):
await users_repository.upsert(new_user)
4 changes: 2 additions & 2 deletions tests/services/test_tasks_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ async def test_get_tasks_for_project(tasks_service_with_mock: TasksService, mock
@pytest.mark.asyncio
async def test_assign_task(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository):
task1: Task = await mock_tasks_repository.find_by_id(1)
user1: User = User(id=1, name="User 1", avatar="none.jpg")
user1: User = User(id="1", name="User 1")
await tasks_service_with_mock.assign_task(task1, user1)
assert task1.user_id == 1
assert task1.user_id == "1"


@pytest.mark.asyncio
Expand Down
46 changes: 46 additions & 0 deletions tests/services/test_users_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from uuid import UUID

import pytest
from amt.repositories.users import UsersRepository
from amt.services.users import UsersService
from pytest_mock import MockFixture
from tests.constants import default_user


@pytest.mark.asyncio
async def test_get_user(mocker: MockFixture):
# Given
id = UUID("3d284d80-fc47-41ab-9696-fab562bacbd5")
name = "John Smith"
users_service = UsersService(
repository=mocker.AsyncMock(spec=UsersRepository),
)
users_service.repository.find_by_id.return_value = default_user(id=id, name=name) # type: ignore

# When
user = await users_service.get(id)

# Then
assert user is not None
assert user.id == id
assert user.name == name
users_service.repository.find_by_id.assert_awaited_once_with(id) # type: ignore


@pytest.mark.asyncio
async def test_create_or_update(mocker: MockFixture):
# Given
user = default_user()
users_service = UsersService(
repository=mocker.AsyncMock(spec=UsersRepository),
)
users_service.repository.upsert.return_value = user # type: ignore

# When
retreived_user = await users_service.create_or_update(user)

# Then
assert retreived_user is not None
assert retreived_user.id == user.id
assert retreived_user.name == user.name
users_service.repository.upsert.assert_awaited_once_with(user) # type: ignore

0 comments on commit c247982

Please sign in to comment.