Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add user id to user table in db #355

Merged
merged 1 commit into from
Nov 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions amt/api/routes/auth.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import hashlib
import logging
from typing import Annotated
from urllib.parse import quote_plus
from uuid import UUID

from authlib.integrations.starlette_client import OAuth, OAuthError # type: ignore
from fastapi import APIRouter, Request
from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse, RedirectResponse, Response

from amt.api.deps import templates
from amt.core.authorization import get_user
from amt.core.exceptions import AMTAuthorizationFlowError
from amt.models.user import User
from amt.services.users import UsersService

router = APIRouter()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -42,7 +46,10 @@ async def logout(request: Request) -> RedirectResponse: # pragma: no cover


@router.get("/callback", response_class=Response)
async def auth_callback(request: Request) -> Response: # pragma: no cover
async def auth_callback(
request: Request,
users_service: Annotated[UsersService, Depends(UsersService)],
) -> Response: # pragma: no cover
oauth: OAuth = request.app.state.oauth
try:
token = await oauth.keycloak.authorize_access_token(request) # type: ignore
Expand All @@ -58,6 +65,10 @@ async def auth_callback(request: Request) -> Response: # pragma: no cover
name: str = str(user["name"]).strip().lower() # type: ignore
user["name_encoded"] = quote_plus(name)

if "sub" in user and isinstance(user["sub"], str):
new_user = User(id=UUID(user["sub"]), name=user["name"]) # type: ignore
new_user = await users_service.create_or_update(new_user)

if user:
request.session["user"] = dict(user) # type: ignore
request.session["id_token"] = token["id_token"] # type: ignore
Expand Down
38 changes: 38 additions & 0 deletions amt/migrations/versions/22298f3aac77_drop_users_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""drop users table

Revision ID: 22298f3aac77
Revises: 7f20f8562007
Create Date: 2024-11-12 09:33:50.853310

"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "22298f3aac77"
down_revision: str | None = "6581a03aabec"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.drop_constraint("fk_task_user_id_user", type_="foreignkey")
op.drop_column("task", "user_id")
op.drop_table("user")


def downgrade() -> None:
op.create_table(
"user",
sa.Column("id", sa.INTEGER(), nullable=False),
sa.Column("name", sa.VARCHAR(length=255), nullable=False),
sa.Column("avatar", sa.VARCHAR(length=255), nullable=True),
sa.PrimaryKeyConstraint("id", name="pk_user"),
)
op.add_column("task", sa.Column("user_id", sa.INTEGER(), nullable=True))
with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.create_foreign_key("fk_task_user_id_user", "user", ["user_id"], ["id"])
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""create user table with uuid as pk

Revision ID: 69243fd24222
Revises: 22298f3aac77
Create Date: 2024-11-12 09:49:53.558089

"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "69243fd24222"
down_revision: str | None = "22298f3aac77"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
op.create_table(
"user",
sa.Column("id", sa.UUID(), nullable=False),
sa.Column("name", sa.String(), nullable=False),
sa.PrimaryKeyConstraint("id", name=op.f("pk_user")),
)
op.add_column("task", sa.Column("user_id", sa.UUID(), nullable=True))
with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.create_foreign_key(op.f("fk_task_user_id_user"), "user", ["user_id"], ["id"])


def downgrade() -> None:
with op.batch_alter_table("task", schema=None) as batch_op:
batch_op.drop_constraint(op.f("fk_task_user_id_user"), type_="foreignkey")
op.drop_column("task", "user_id")
op.drop_table("user")
10 changes: 6 additions & 4 deletions amt/models/task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from uuid import UUID

from sqlalchemy import ForeignKey, String
from sqlalchemy.orm import Mapped, mapped_column

Expand All @@ -12,9 +14,9 @@ class Task(Base):
description: Mapped[str]
sort_order: Mapped[float]
status_id: Mapped[int | None] = mapped_column(default=None)
user_id: Mapped[int | None] = mapped_column(ForeignKey("user.id"))
# TODO: (Christopher) SQLModel does not allow to give the below restraint an name
# which is needed for alembic. This results in changing the migration file
# manually to give the restrain a name.
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"))
## TODO: (Christopher) SQLModel does not allow to give the below restraint an name
## which is needed for alembic. This results in changing the migration file
## manually to give the restrain a name.
project_id: Mapped[int | None] = mapped_column(ForeignKey("project.id"))
# todo(robbert) Tasks probably are grouped (and sub-grouped), so we probably need a reference to a group_id
6 changes: 4 additions & 2 deletions amt/models/user.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from uuid import UUID

from sqlalchemy import UUID as SQLAlchemyUUID
from sqlalchemy.orm import Mapped, mapped_column

from amt.models.base import Base
Expand All @@ -6,6 +9,5 @@
class User(Base):
__tablename__ = "user"

id: Mapped[int] = mapped_column(primary_key=True)
id: Mapped[UUID] = mapped_column(SQLAlchemyUUID(as_uuid=True), primary_key=True)
name: Mapped[str]
avatar: Mapped[str | None] = mapped_column(default=None)
55 changes: 55 additions & 0 deletions amt/repositories/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import logging
from typing import Annotated
from uuid import UUID

from fastapi import Depends
from sqlalchemy import select
from sqlalchemy.exc import NoResultFound, SQLAlchemyError
from sqlalchemy.ext.asyncio import AsyncSession

from amt.core.exceptions import AMTRepositoryError
from amt.models.user import User
from amt.repositories.deps import get_session

logger = logging.getLogger(__name__)


class UsersRepository:
"""
The UsersRepository provides access to the repository layer.
"""

def __init__(self, session: Annotated[AsyncSession, Depends(get_session)]) -> None:
self.session = session

async def find_by_id(self, id: UUID) -> User | None:
"""
Returns the user with the given id.
:param id: the id of the user to find
:return: the user with the given id or an exception if no user was found
"""
statement = select(User).where(User.id == id)
try:
return (await self.session.execute(statement)).scalars().one()
except NoResultFound:
return None

async def upsert(self, user: User) -> User:
"""
Upserts (create or update) a user.
:param user: the user to upsert.
:return: the upserted user.
"""
try:
existing_user = await self.find_by_id(user.id)
if existing_user:
existing_user.name = user.name
else:
self.session.add(user)
await self.session.commit()
except SQLAlchemyError as e: # pragma: no cover
logger.exception("Error saving user")
await self.session.rollback()
raise AMTRepositoryError from e

return user
25 changes: 25 additions & 0 deletions amt/services/users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging
from typing import Annotated
from uuid import UUID

from fastapi import Depends

from amt.models.user import User
from amt.repositories.users import UsersRepository

logger = logging.getLogger(__name__)


class UsersService:
def __init__(
self,
repository: Annotated[UsersRepository, Depends(UsersRepository)],
) -> None:
self.repository = repository

async def get(self, id: str | UUID) -> User | None:
id = UUID(id) if isinstance(id, str) else id
return await self.repository.find_by_id(id)

async def create_or_update(self, user: User) -> User:
return await self.repository.upsert(user)
10 changes: 0 additions & 10 deletions compose.override.yml

This file was deleted.

10 changes: 6 additions & 4 deletions tests/constants.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from uuid import UUID

from amt.api.lifecycles import Lifecycles
from amt.api.navigation import BaseNavigationItem, DisplayText
Expand All @@ -18,14 +19,15 @@ def default_base_navigation_item(
return BaseNavigationItem(display_text=display_text, url=url, custom_display_text=custom_display_text, icon=icon)


def default_user(name: str = "default user", avatar: str | None = None) -> User:
return User(name=name, avatar=avatar)


def default_project(name: str = "default project") -> Project:
return Project(name=name)


def default_user(id: str | UUID = "00494b4d-bcdf-425a-8140-bea0f3cbd3c2", name: str = "John Smith") -> User:
id = UUID(id) if isinstance(id, str) else id
return User(id=id, name=name)


def default_project_with_system_card(name: str = "default project") -> Project:
with open("resources/system_card_templates/AMT_Template_1.json") as f:
system_card_from_template = json.load(f)
Expand Down
3 changes: 2 additions & 1 deletion tests/models/test_model_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
def test_model_basic_user():
# given

user = User(name="Test user", avatar=None)
user = User(id="70419b8d-29ff-4c51-9822-05a46f6c916e", name="Test user")

# then
assert user.id == "70419b8d-29ff-4c51-9822-05a46f6c916e"
assert user.name == "Test user"
50 changes: 50 additions & 0 deletions tests/repositories/test_users.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from unittest.mock import AsyncMock

import pytest
from amt.core.exceptions import AMTRepositoryError
from amt.repositories.users import UsersRepository
from sqlalchemy.exc import SQLAlchemyError
from tests.constants import default_user
from tests.database_test_utils import DatabaseTestUtils


@pytest.mark.asyncio
async def test_find_by_id(db: DatabaseTestUtils):
await db.given([default_user()])
users_repository = UsersRepository(db.get_session())
result = await users_repository.find_by_id(default_user().id)
assert result is not None
assert result.id == default_user().id
assert result.name == default_user().name


@pytest.mark.asyncio
async def test_upsert_new(db: DatabaseTestUtils):
new_user = default_user()
users_repository = UsersRepository(db.get_session())
await users_repository.upsert(new_user)
result = await users_repository.find_by_id(new_user.id)
assert result is not None
assert result.id == new_user.id
assert result.name == new_user.name


@pytest.mark.asyncio
async def test_upsert_existing(db: DatabaseTestUtils):
await db.given([default_user()])
new_user = default_user(name="John Smith New")
users_repository = UsersRepository(db.get_session())
await users_repository.upsert(new_user)
result = await users_repository.find_by_id(new_user.id)
assert result is not None
assert result.id == new_user.id
assert result.name == new_user.name


@pytest.mark.asyncio
async def test_upsert_error(db: DatabaseTestUtils):
new_user = default_user(name="John Smith New")
users_repository = UsersRepository(db.get_session())
users_repository.find_by_id = AsyncMock(side_effect=SQLAlchemyError("Database error"))
with pytest.raises(AMTRepositoryError):
await users_repository.upsert(new_user)
4 changes: 2 additions & 2 deletions tests/services/test_tasks_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,9 @@ async def test_get_tasks_for_project(tasks_service_with_mock: TasksService, mock
@pytest.mark.asyncio
async def test_assign_task(tasks_service_with_mock: TasksService, mock_tasks_repository: TasksRepository):
task1: Task = await mock_tasks_repository.find_by_id(1)
user1: User = User(id=1, name="User 1", avatar="none.jpg")
user1: User = User(id="1", name="User 1")
await tasks_service_with_mock.assign_task(task1, user1)
assert task1.user_id == 1
assert task1.user_id == "1"


@pytest.mark.asyncio
Expand Down
46 changes: 46 additions & 0 deletions tests/services/test_users_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from uuid import UUID

import pytest
from amt.repositories.users import UsersRepository
from amt.services.users import UsersService
from pytest_mock import MockFixture
from tests.constants import default_user


@pytest.mark.asyncio
async def test_get_user(mocker: MockFixture):
# Given
id = UUID("3d284d80-fc47-41ab-9696-fab562bacbd5")
name = "John Smith"
users_service = UsersService(
repository=mocker.AsyncMock(spec=UsersRepository),
)
users_service.repository.find_by_id.return_value = default_user(id=id, name=name) # type: ignore

# When
user = await users_service.get(id)

# Then
assert user is not None
assert user.id == id
assert user.name == name
users_service.repository.find_by_id.assert_awaited_once_with(id) # type: ignore


@pytest.mark.asyncio
async def test_create_or_update(mocker: MockFixture):
# Given
user = default_user()
users_service = UsersService(
repository=mocker.AsyncMock(spec=UsersRepository),
)
users_service.repository.upsert.return_value = user # type: ignore

# When
retreived_user = await users_service.create_or_update(user)

# Then
assert retreived_user is not None
assert retreived_user.id == user.id
assert retreived_user.name == user.name
users_service.repository.upsert.assert_awaited_once_with(user) # type: ignore
Loading