Skip to content

Commit

Permalink
Merge branch 'main' into feature/add_example_usecase_for_demo
Browse files Browse the repository at this point in the history
  • Loading branch information
ravimeijerrig authored Oct 16, 2024
2 parents cf1df3a + d3ec543 commit 6f21bf4
Show file tree
Hide file tree
Showing 127 changed files with 3,881 additions and 569 deletions.
11 changes: 2 additions & 9 deletions amt/api/ai_act_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fastapi import Request

from amt.api.publication_category import PublicationCategories
from amt.core.internationalization import get_current_translation


Expand Down Expand Up @@ -72,15 +73,7 @@ def get_ai_act_profile_selector(request: Request) -> AiActProfileSelector:
"AI-model voor algemene doeleinden",
)
role_options = ("aanbieder", "gebruiksverantwoordelijke")
publication_category_options = (
"impactvol algoritme",
"niet-impactvol algoritme",
"hoog-risico AI",
"geen hoog-risico AI",
"verboden AI",
"uitzondering van toepassing",
"niet van toepassing",
)
publication_category_options = (*(p.value for p in PublicationCategories), "niet van toepassing")
systemic_risk_options = ("systeemrisico", "geen systeemrisico", "niet van toepassing")
transparency_obligations_options = (
"transparantieverplichtingen",
Expand Down
3 changes: 3 additions & 0 deletions amt/api/navigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,9 @@ class Navigation:
PROJECT_DETAILS = BaseNavigationItem(
display_text=DisplayText.DETAILS, url="/project/{project_id}/details/system_card"
)
PROJECT_MODEL = BaseNavigationItem(
display_text=DisplayText.MODEL, url="/project/{project_id}/details/model/inference"
)
PROJECT_NEW = BaseNavigationItem(display_text=DisplayText.NEW, url="/projects/new")
PROJECT_SYSTEM_INFO = BaseNavigationItem(display_text=DisplayText.INFO, url="/project/{project_id}/details")
PROJECT_SYSTEM_ALGORITHM_DETAILS = BaseNavigationItem(
Expand Down
47 changes: 47 additions & 0 deletions amt/api/publication_category.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from enum import Enum

from fastapi import Request

from amt.core.internationalization import get_current_translation
from amt.schema.publication_category import PublicationCategory

logger = logging.getLogger(__name__)


class PublicationCategories(Enum):
IMPACTVOL_ALGORITME = "impactvol algoritme"
NIET_IMPACTVOL_ALGORITME = "niet-impactvol algoritme"
HOOG_RISICO_AI = "hoog-risico AI"
GEEN_HOOG_RISICO_AI = "geen hoog-risico AI"
VERBODEN_AI = "verboden AI"
UITZONDERING_VAN_TOEPASSING = "uitzondering van toepassing"


def get_publication_category(key: PublicationCategories | None, request: Request) -> PublicationCategory | None:
"""
Given the key and translation, returns the translated text.
:param key: the key
:param request: request to get the current language
:return: a Publication Category model with the correct translation
"""

if key is None:
return None

translations = get_current_translation(request)
_ = translations.gettext
# translations are determined at runtime, which is why we use the dictionary below
keys = {
PublicationCategories.IMPACTVOL_ALGORITME: _("Impactful algorithm"),
PublicationCategories.NIET_IMPACTVOL_ALGORITME: _("Non-impactful algorithm"),
PublicationCategories.HOOG_RISICO_AI: _("High-risk AI"),
PublicationCategories.GEEN_HOOG_RISICO_AI: _("No high-risk AI"),
PublicationCategories.VERBODEN_AI: _("Forbidden AI"),
PublicationCategories.UITZONDERING_VAN_TOEPASSING: _("Exception of application"),
}
return PublicationCategory(id=key.name, name=keys[key])


def get_publication_categories(request: Request) -> list[PublicationCategory | None]:
return [get_publication_category(p, request) for p in PublicationCategories]
158 changes: 156 additions & 2 deletions amt/api/routes/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from fastapi import APIRouter, Depends, Request
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field

from amt.api.deps import templates
from amt.api.lifecycles import get_lifecycle
Expand All @@ -18,8 +19,11 @@
from amt.core.exceptions import AMTNotFound, AMTRepositoryError
from amt.enums.status import Status
from amt.models import Project
from amt.schema.measure import ExtendedMeasureTask, MeasureTask
from amt.schema.requirement import RequirementTask
from amt.schema.system_card import SystemCard
from amt.schema.task import MovedTask
from amt.services import task_registry
from amt.services.instruments_and_requirements_state import InstrumentStateService, RequirementsStateService
from amt.services.projects import ProjectsService
from amt.services.storage import StorageFactory
Expand Down Expand Up @@ -77,6 +81,7 @@ def get_project_details_tabs(request: Request) -> list[NavigationItem]:
[
Navigation.PROJECT_SYSTEM_INFO,
Navigation.PROJECT_SYSTEM_ALGORITHM_DETAILS,
Navigation.PROJECT_MODEL,
Navigation.PROJECT_REQUIREMENTS,
Navigation.PROJECT_DATA_CARD,
Navigation.PROJECT_TASKS,
Expand Down Expand Up @@ -242,6 +247,42 @@ async def get_system_card(
return templates.TemplateResponse(request, "pages/system_card.html.j2", context)


@router.get("/{project_id}/details/model/inference")
async def get_project_inference(
request: Request, project_id: int, projects_service: Annotated[ProjectsService, Depends(ProjectsService)]
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)

breadcrumbs = resolve_base_navigation_items(
[
Navigation.PROJECTS_ROOT,
BaseNavigationItem(custom_display_text=project.name, url="/project/{project_id}/details/model/inference"),
Navigation.PROJECT_MODEL,
],
request,
)

system_card_data = get_system_card_data()
instrument_state = get_instrument_state()
requirements_state = get_requirements_state(project.system_card)

tab_items = get_project_details_tabs(request)

context = {
"lifecycle": get_lifecycle(project.lifecycle, request),
"last_edited": project.last_edited,
"system_card": system_card_data,
"instrument_state": instrument_state,
"requirements_state": requirements_state,
"project": project,
"project_id": project.id,
"breadcrumbs": breadcrumbs,
"tab_items": tab_items,
}

return templates.TemplateResponse(request, "projects/details_inference.html.j2", context)


# !!!
# Implementation of this endpoint is for now independent of the project ID, meaning
# that the same system card is rendered for all project ID's. This is due to the fact
Expand Down Expand Up @@ -274,8 +315,27 @@ async def get_system_card_requirements(
system_card = project.system_card
requirements = fetch_requirements([requirement.urn for requirement in system_card.requirements])

# Get measures that correspond to the requirements.
requirements_and_measures = [(requirement, fetch_measures(requirement.links)) for requirement in requirements]
# Get measures that correspond to the requirements and merge them with the measuretasks
requirements_and_measures = []
for requirement in requirements:
completed_measures_count = 0
linked_measures = fetch_measures(requirement.links)
extended_linked_measures: list[ExtendedMeasureTask] = []
for measure in linked_measures:
measure_task = find_measure_task(system_card, measure.urn)
if measure_task:
ext_measure_task = ExtendedMeasureTask(
name=measure.name,
description=measure.description,
urn=measure.urn,
state=measure_task.state,
value=measure_task.value,
version=measure_task.version,
)
if ext_measure_task.state == "done":
completed_measures_count += 1
extended_linked_measures.append(ext_measure_task)
requirements_and_measures.append((requirement, completed_measures_count, extended_linked_measures)) # pyright: ignore [reportUnknownMemberType]

context = {
"instrument_state": instrument_state,
Expand All @@ -290,6 +350,100 @@ async def get_system_card_requirements(
return templates.TemplateResponse(request, "projects/details_requirements.html.j2", context)


def find_measure_task(system_card: SystemCard, urn: str) -> MeasureTask | None:
for measure in system_card.measures:
if measure.urn == urn:
return measure
return None


def find_requirement_task(system_card: SystemCard, requirement_urn: str) -> RequirementTask | None:
for requirement in system_card.requirements:
if requirement.urn == requirement_urn:
return requirement
return None


def find_requirement_tasks_by_measure_urn(system_card: SystemCard, measure_urn: str) -> list[RequirementTask]:
requirement_mapper: dict[str, RequirementTask] = {}
for requirement_task in system_card.requirements:
requirement_mapper[requirement_task.urn] = requirement_task

requirement_tasks: list[RequirementTask] = []
measure = fetch_measures([measure_urn])
for requirement_urn in measure[0].links:
# TODO: This is because measure are linked to too many requirement not applicable in our use case
if len(fetch_requirements([requirement_urn])) > 0:
requirement_tasks.append(requirement_mapper[requirement_urn])

return requirement_tasks


@router.get("/{project_id}/measure/{measure_urn}")
async def get_measure(
request: Request,
project_id: int,
measure_urn: str,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)
measure = task_registry.fetch_measures([measure_urn])
measure_task = find_measure_task(project.system_card, measure_urn)

context = {
"measure": measure[0],
"measure_state": measure_task.state, # pyright: ignore [reportOptionalMemberAccess]
"measure_value": measure_task.value, # pyright: ignore [reportOptionalMemberAccess]
"project_id": project_id,
}

return templates.TemplateResponse(request, "projects/details_measure_modal.html.j2", context)


class MeasureUpdate(BaseModel):
measure_state: str = Field(default=None)
measure_value: str = Field(default=None)


@router.post("/{project_id}/measure/{measure_urn}")
async def update_measure_value(
request: Request,
project_id: int,
measure_urn: str,
measure_update: MeasureUpdate,
projects_service: Annotated[ProjectsService, Depends(ProjectsService)],
) -> HTMLResponse:
project = get_project_or_error(project_id, projects_service, request)

measure_task = find_measure_task(project.system_card, measure_urn)
measure_task.state = measure_update.measure_state # pyright: ignore [reportOptionalMemberAccess]
measure_task.value = measure_update.measure_value # pyright: ignore [reportOptionalMemberAccess]

# update for the linked requirements the state based on all it's measures
requirement_tasks = find_requirement_tasks_by_measure_urn(project.system_card, measure_urn)
requirement_urns = [requirement_task.urn for requirement_task in requirement_tasks]
requirements = fetch_requirements(requirement_urns)

for requirement in requirements:
count_completed = 0
for link_measure_urn in requirement.links:
link_measure_task = find_measure_task(project.system_card, link_measure_urn)
if link_measure_task: # noqa: SIM102
if link_measure_task.state == "done":
count_completed += 1
requirement_task = find_requirement_task(project.system_card, requirement.urn)
if count_completed == len(requirement.links):
requirement_task.state = "done" # pyright: ignore [reportOptionalMemberAccess]
elif count_completed == 0 and len(requirement.links) > 0:
requirement_task.state = "to do" # pyright: ignore [reportOptionalMemberAccess]
else:
requirement_task.state = "in progress" # pyright: ignore [reportOptionalMemberAccess]

projects_service.update(project)
# TODO: FIX THIS!! The page now reloads at the top, which is annoying
return templates.Redirect(request, f"/project/{project_id}/details/system_card/requirements")


# !!!
# Implementation of this endpoint is for now independent of the project ID, meaning
# that the same system card is rendered for all project ID's. This is due to the fact
Expand Down
58 changes: 50 additions & 8 deletions amt/api/routes/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@

from amt.api.ai_act_profile import get_ai_act_profile_selector
from amt.api.deps import templates
from amt.api.lifecycles import get_lifecycles
from amt.api.lifecycles import Lifecycles, get_lifecycle, get_lifecycles
from amt.api.navigation import Navigation, resolve_base_navigation_items, resolve_navigation_items
from amt.api.publication_category import PublicationCategories, get_publication_categories, get_publication_category
from amt.schema.localized_value_item import LocalizedValueItem
from amt.schema.project import ProjectNew
from amt.services.instruments import InstrumentsService
from amt.services.projects import ProjectsService
Expand All @@ -16,6 +18,24 @@
logger = logging.getLogger(__name__)


def get_localized_value(key: str, value: str, request: Request) -> LocalizedValueItem:
match key:
case "lifecycle":
lifecycle = get_lifecycle(Lifecycles[value], request)
if lifecycle:
return LocalizedValueItem(value=value, display_value=lifecycle.name)
else:
return LocalizedValueItem(value=value, display_value="Unknown")
case "publication-category":
publication_category = get_publication_category(PublicationCategories[value], request)
if publication_category:
return LocalizedValueItem(value=value, display_value=publication_category.name)
else:
return LocalizedValueItem(value=value, display_value="Unknown")
case _:
return LocalizedValueItem(value=value, display_value="Unknown filter option")


@router.get("/")
async def get_root(
request: Request,
Expand All @@ -24,7 +44,25 @@ async def get_root(
limit: int = Query(100, ge=1),
search: str = Query(""),
) -> HTMLResponse:
projects = projects_service.paginate(skip=skip, limit=limit, search=search)
active_filters = {
k.removeprefix("active-filter-"): v
for k, v in request.query_params.items()
if k.startswith("active-filter") and v != ""
}
add_filters = {
k.removeprefix("add-filter-"): v
for k, v in request.query_params.items()
if k.startswith("add-filter") and v != ""
}
drop_filters = [v for k, v in request.query_params.items() if k.startswith("drop-filter") and v != ""]
filters = {k: v for k, v in (active_filters | add_filters).items() if k not in drop_filters}
localized_filters = {k: get_localized_value(k, v, request) for k, v in filters.items()}

projects = projects_service.paginate(skip=skip, limit=limit, search=search, filters=filters)
# todo: the lifecycle has to be 'localized', maybe for display 'Project' should become a different object
for project in projects:
project.lifecycle = get_lifecycle(project.lifecycle, request) # pyright: ignore [reportAttributeAccessIssue]

next = skip + limit

sub_menu_items = resolve_navigation_items([Navigation.PROJECTS_OVERVIEW], request) # pyright: ignore [reportUnusedVariable] # noqa
Expand All @@ -36,15 +74,19 @@ async def get_root(
"projects": projects,
"next": next,
"limit": limit,
"start": skip,
"search": search,
"lifecycles": get_lifecycles(request),
"publication_categories": get_publication_categories(request),
"filters": localized_filters,
}

if request.state.htmx:
return templates.TemplateResponse(
request, "projects/_list.html.j2", {"projects": projects, "next": next, "search": search, "limit": limit}
)

return templates.TemplateResponse(request, "projects/index.html.j2", context)
if request.state.htmx and drop_filters:
return templates.TemplateResponse(request, "parts/project_search.html.j2", context)
elif request.state.htmx:
return templates.TemplateResponse(request, "parts/filter_list.html.j2", context)
else:
return templates.TemplateResponse(request, "projects/index.html.j2", context)


@router.get("/new")
Expand Down
Loading

0 comments on commit 6f21bf4

Please sign in to comment.