Skip to content

Commit

Permalink
Move systemcard to database
Browse files Browse the repository at this point in the history
  • Loading branch information
anneschuth committed Oct 12, 2024
1 parent ee1719b commit 5b525d0
Show file tree
Hide file tree
Showing 10 changed files with 120 additions and 38 deletions.
30 changes: 30 additions & 0 deletions amt/migrations/versions/7f20f8562007_add_system_card_column.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
"""Add system_card_column
Revision ID: 7f20f8562007
Revises: 1019b72fe63a
Create Date: 2024-10-12 14:28:32.153283
"""

from collections.abc import Sequence

import sqlalchemy as sa
from alembic import op

# revision identifiers, used by Alembic.
revision: str = "7f20f8562007"
down_revision: str | None = "1019b72fe63a"
branch_labels: str | Sequence[str] | None = None
depends_on: str | Sequence[str] | None = None


def upgrade() -> None:
with op.batch_alter_table("project", schema=None) as batch_op:
batch_op.add_column(sa.Column("system_card_json", sa.JSON(), nullable=True))
batch_op.drop_column("model_card")


def downgrade() -> None:
with op.batch_alter_table("project", schema=None) as batch_op:
batch_op.add_column(sa.Column("model_card", sa.String(), nullable=True))
batch_op.drop_column("system_card_json")
25 changes: 23 additions & 2 deletions amt/models/project.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from datetime import datetime
from typing import TypeVar
from typing import Any, TypeVar

from sqlalchemy import String, func
from sqlalchemy.dialects.postgresql import ENUM
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.types import JSON

from amt.api.lifecycles import Lifecycles
from amt.models.base import Base
from amt.schema.system_card import SystemCard

T = TypeVar("T", bound="Project")

Expand All @@ -17,5 +19,24 @@ class Project(Base):
id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(String(255))
lifecycle: Mapped[Lifecycles | None] = mapped_column(ENUM(Lifecycles), nullable=True)
model_card: Mapped[str | None] = mapped_column(default=None)
system_card_json: Mapped[dict[str, Any]] = mapped_column(JSON, default=dict)
last_edited: Mapped[datetime] = mapped_column(server_default=func.now(), onupdate=func.now(), nullable=False)

def __init__(self, *args: Any, **kwargs: Any) -> None: # noqa: ANN401
system_card: SystemCard | None = kwargs.pop("system_card", None)
super().__init__(*args, **kwargs)
if system_card is not None:
self.system_card = system_card

@property
def system_card(self) -> SystemCard | None:
if self.system_card_json:
return SystemCard.model_validate(self.system_card_json)
return None

@system_card.setter
def system_card(self, value: SystemCard | None) -> None:
if value is None:
self.system_card_json = {}
else:
self.system_card_json = value.model_dump(exclude_unset=True)
18 changes: 2 additions & 16 deletions amt/services/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@

from fastapi import Depends

from amt.core.config import get_settings
from amt.models import Project
from amt.repositories.projects import ProjectsRepository
from amt.schema.instrument import InstrumentBase
from amt.schema.project import ProjectNew
from amt.schema.system_card import AiActProfile, SystemCard
from amt.services.instruments import InstrumentsService
from amt.services.storage import Storage, StorageFactory
from amt.services.tasks import TasksService

logger = logging.getLogger(__name__)
Expand All @@ -31,16 +29,6 @@ def get(self, project_id: int) -> Project:
return self.repository.find_by_id(project_id)

def create(self, project_new: ProjectNew) -> Project:
project = Project(name=project_new.name, lifecycle=project_new.lifecycle)

self.repository.save(project)

system_card_file = get_settings().CARD_DIR / f"{project.id}_system.yaml"

project.model_card = str(system_card_file)

project = self.update(project)

instruments: list[InstrumentBase] = [
InstrumentBase(urn=instrument_urn) for instrument_urn in project_new.instruments
]
Expand All @@ -56,10 +44,8 @@ def create(self, project_new: ProjectNew) -> Project:

system_card = SystemCard(name=project_new.name, ai_act_profile=ai_act_profile, instruments=instruments)

storage_writer: Storage = StorageFactory.init(
storage_type="file", location=system_card_file.parent, filename=system_card_file.name
)
storage_writer.write(system_card.model_dump())
project = Project(name=project_new.name, lifecycle=project_new.lifecycle, system_card=system_card)
project = self.update(project)

selected_instruments = self.instrument_service.fetch_instruments(project_new.instruments) # type: ignore
for instrument in selected_instruments:
Expand Down
1 change: 1 addition & 0 deletions amt/services/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def move_task(
task = self.repository.find_by_id(task_id)

if status_id == Status.DONE:
# TODO: This seems off, tasks should be written to the correct location in the system card.
self.system_card.name = task.title
self.storage_writer.write(self.system_card.model_dump())

Expand Down
23 changes: 15 additions & 8 deletions tests/api/routes/test_projects.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
from collections.abc import Generator
from unittest.mock import MagicMock, Mock
from typing import cast
from unittest.mock import Mock

import pytest
from amt.models import Project
from amt.models.base import Base
from amt.schema.ai_act_profile import AiActProfile
from amt.schema.project import ProjectNew
from amt.schema.system_card import SystemCard
from amt.services.instruments import InstrumentsService
from amt.services.storage import FileSystemStorageService
from fastapi.testclient import TestClient
from fastapi_csrf_protect import CsrfProtect # type: ignore # noqa

from tests.constants import default_instrument
from tests.database_test_utils import DatabaseTestUtils


@pytest.fixture
Expand Down Expand Up @@ -97,14 +100,17 @@ def test_post_new_projects(


def test_post_new_projects_write_system_card(
client: TestClient, mock_csrf: Generator[None, None, None], init_instruments: Generator[None, None, None]
client: TestClient,
mock_csrf: Generator[None, None, None],
init_instruments: Generator[None, None, None],
db: DatabaseTestUtils,
) -> None:
# Given
client.cookies["fastapi-csrf-token"] = "1"
origin = FileSystemStorageService.write
FileSystemStorageService.write = MagicMock()

name = "name1"
project_new = ProjectNew(
name="name1",
name=name,
lifecycle="DESIGN",
type="AI-systeem",
open_source="open-source",
Expand All @@ -129,5 +135,6 @@ def test_post_new_projects_write_system_card(
client.post("/projects/new", json=project_new.model_dump(), headers={"X-CSRF-Token": "1"})

# then
FileSystemStorageService.write.assert_called_with(system_card.model_dump())
FileSystemStorageService.write = origin
base_projects: list[Base] = db.get(Project, "name", name)
projects: list[Project] = cast(list[Project], base_projects)
assert any(project.system_card == system_card for project in projects if project.system_card is not None)
4 changes: 2 additions & 2 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def default_user(name: str = "default user", avatar: str | None = None) -> User:
return User(name=name, avatar=avatar)


def default_project(name: str = "default project", model_card: str = "/tmp/1.yaml") -> Project: # noqa: S108
return Project(name=name, model_card=model_card)
def default_project(name: str = "default project") -> Project:
return Project(name=name)


def default_fastapi_request(url: str = "/") -> Request:
Expand Down
8 changes: 8 additions & 0 deletions tests/database_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,11 @@ def get_database_file(self) -> Path | None:

def exists(self, model: type[Base], model_field: str, field_value: str | int) -> Base | None:
return self.get_session().execute(select(model).where(model_field == field_value)).one() is not None # type: ignore

def get(self, model: type[Base], model_field: str, field_value: str | int) -> list[Base]:
try:
query = select(model).where(getattr(model, model_field) == field_value)
result = self.get_session().execute(query)
return list(result.scalars().all())
except AttributeError as err:
raise ValueError(f"Field '{model_field}' does not exist in model {model.__name__}") from err
34 changes: 33 additions & 1 deletion tests/models/test_model_project.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,41 @@
from amt.models.project import Project
from amt.schema.system_card import SystemCard


def test_model_basic_project():
# given
project = Project(name="Test Project", model_card="Test Card")
project = Project(name="Test Project")

# then
assert project.name == "Test Project"


def test_model_systemcard():
# given
system_card = SystemCard(name="Test System Card")

project = Project(name="Test Project", system_card=system_card)

# then
assert project.system_card is not None
assert project.system_card.name == "Test System Card"

# when
project.system_card = None

# then
assert project.system_card is None


def test_model_systemcard_full():
# given
system_card = SystemCard(name="Test System Card", description="Test description", status="active")

project = Project(name="Test Project")
project.system_card = system_card

# then
assert project.system_card is not None
assert project.system_card.name == "Test System Card"
assert project.system_card.description == "Test description"
assert project.system_card.status == "active"
11 changes: 5 additions & 6 deletions tests/services/test_projects_service.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from pathlib import Path
from unittest.mock import Mock

from amt.models.project import Project
from amt.repositories.projects import ProjectsRepository
from amt.schema.project import ProjectNew
from amt.schema.system_card import SystemCard
from amt.services.instruments import InstrumentsService
from amt.services.projects import ProjectsService
from amt.services.tasks import TasksService
Expand All @@ -15,14 +15,13 @@ def test_get_project():
project_id = 1
project_name = "Project 1"
project_lifecycle = "development"
project_model_card = "model_card_path"
projects_service = ProjectsService(
repository=Mock(spec=ProjectsRepository),
task_service=Mock(spec=TasksService),
instrument_service=Mock(spec=InstrumentsService),
)
projects_service.repository.find_by_id.return_value = Project( # type: ignore
id=project_id, name=project_name, lifecycle=project_lifecycle, model_card=project_model_card
id=project_id, name=project_name, lifecycle=project_lifecycle
)

# When
Expand All @@ -32,22 +31,22 @@ def test_get_project():
assert project.id == project_id
assert project.name == project_name
assert project.lifecycle == project_lifecycle
assert project.model_card == project_model_card
projects_service.repository.find_by_id.assert_called_once_with(project_id) # type: ignore


def test_create_project():
project_id = 1
project_name = "Project 1"
project_lifecycle = "development"
project_model_card = Path("model_card_path")
system_card = SystemCard(name=project_name)

projects_service = ProjectsService(
repository=Mock(spec=ProjectsRepository),
task_service=Mock(spec=TasksService),
instrument_service=Mock(spec=InstrumentsService),
)
projects_service.repository.save.return_value = Project( # type: ignore
id=project_id, name=project_name, lifecycle=project_lifecycle, model_card=str(project_model_card)
id=project_id, name=project_name, lifecycle=project_lifecycle, system_card=system_card
)
projects_service.instrument_service.fetch_instruments.return_value = [default_instrument()] # type: ignore

Expand Down
4 changes: 1 addition & 3 deletions tests/services/test_tasks_service.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Sequence
from pathlib import Path
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -92,8 +91,7 @@ def test_create_instrument_tasks(tasks_service_with_mock: TasksService, mock_tas

project_id = 1
project_name = "Project 1"
project_model_card = Path("model_card_path")
project = Project(id=project_id, name=project_name, model_card=str(project_model_card))
project = Project(id=project_id, name=project_name)

# When
mock_tasks_repository.clear()
Expand Down

0 comments on commit 5b525d0

Please sign in to comment.