Skip to content

Commit

Permalink
Fix all failing tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Nov 14, 2024
1 parent 4d68d9a commit 376c8f7
Show file tree
Hide file tree
Showing 26 changed files with 14,567 additions and 206 deletions.
1 change: 0 additions & 1 deletion amt/api/routes/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def get_instrument_state(system_card: SystemCard) -> dict[str, Any]:

def get_requirements_state(system_card: SystemCard) -> dict[str, Any]:
requirements = fetch_requirements([requirement.urn for requirement in system_card.requirements])
logging.info(f"{requirements=}")
requirements_state_service = RequirementsStateService(system_card)
requirements_state = requirements_state_service.get_requirements_state(requirements)

Expand Down
21 changes: 3 additions & 18 deletions amt/clients/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import httpx
from amt.core.exceptions import AMTInstrumentError, AMTNotFound
from amt.schema.github import RepositoryContent
from amt.schema.instrument import Instrument
from amt.schema.requirement import Requirement

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -48,22 +46,9 @@ def get_list_of_task(self, task: TaskType = TaskType.INSTRUMENTS) -> RepositoryC
response_data = self._make_request(f"{task.value}/")
return RepositoryContent.model_validate(response_data["entries"])

def get_task_by_urn(self, task_type: TaskType, urn: str, version: str = "latest") -> Instrument | Requirement:
def get_task_by_urn(self, task_type: TaskType, urn: str, version: str = "latest") -> dict[str, Any]:
response_data = self._make_request(f"{task_type.value}/urn/{urn}", params={"version": version})
if "urn" not in response_data:
logger.exception(f"Invalid task {task_type.value} fetched: key 'urn' must occur in instrument.")
logger.exception(f"Invalid task {task_type.value} fetched: key 'urn' must occur in task {task_type.value}.")
raise AMTInstrumentError()

match task_type:
case TaskType.INSTRUMENTS:
return Instrument(**response_data)
case TaskType.REQUIREMENTS:
return Requirement(**response_data)
case _:
return NotImplemented

def get_instrument(self, urn: str, version: str = "latest") -> Instrument:
return self.get_task_by_urn(TaskType.INSTRUMENTS, urn, version) # type: ignore

def get_requirement(self, urn: str, version: str = "latest") -> Requirement:
return self.get_task_by_urn(TaskType.REQUIREMENTS, urn, version) # type: ignore
return response_data
16 changes: 12 additions & 4 deletions amt/services/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from collections.abc import Sequence

from amt.clients.clients import TaskRegistryAPIClient, TaskType
from amt.core.exceptions import AMTNotFound
from amt.schema.instrument import Instrument

logger = logging.getLogger(__name__)
Expand All @@ -25,14 +26,21 @@ def fetch_instruments(self, urns: str | Sequence[str] | None = None) -> list[Ins

all_valid_urns = self.fetch_urns()

if urns is not None:
return [self.client.get_instrument(urn) for urn in urns if urn in all_valid_urns]
if urns is None:
return [Instrument(**self.client.get_task_by_urn(TaskType.INSTRUMENTS, urn)) for urn in all_valid_urns]

return [self.client.get_instrument(urn) for urn in all_valid_urns]
instruments: list[Instrument] = []
for urn in urns:
try:
instruments.append(Instrument(**self.client.get_task_by_urn(TaskType.INSTRUMENTS, urn)))
except AMTNotFound:
logger.warning(f"cannot find instrument with URN {urn}")

return instruments

def fetch_urns(self) -> list[str]:
"""
Fetches all valid instrument URN's.
"""
content_list = self.client.get_list_of_task(TaskType.INSTRUMENTS)
content_list = self.client.get_list_of_task(task=TaskType.INSTRUMENTS)
return [content.urn for content in content_list.root]
3 changes: 2 additions & 1 deletion amt/services/instruments_and_requirements_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ def get_state_per_instrument(self) -> list[dict[str, int]]:
# Otherwise the instrument is completed as there are not any tasks left.

urns = [instrument.urn for instrument in self.system_card.instruments]
instruments = InstrumentsService().fetch_instruments(urns)
instruments_service = InstrumentsService()
instruments = instruments_service.fetch_instruments(urns)
# TODO: refactor this data structure in 3 lines below (also change in get_all_next_tasks + check_state.py)
instruments_dict = {}
instrument_states: dict[str, Any] = {}
Expand Down
47 changes: 47 additions & 0 deletions amt/services/measures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from collections.abc import Sequence

from amt.clients.clients import TaskRegistryAPIClient, TaskType
from amt.core.exceptions import AMTNotFound
from amt.schema.measure import Measure

logger = logging.getLogger(__name__)


class MeasuresService:
def __init__(self) -> None:
self.client = TaskRegistryAPIClient()

def fetch_measures(self, urns: str | Sequence[str] | None = None) -> list[Measure]:
"""
This functions returns measures with given URN's. If urns contains an URN that is not a
valid URN of an measure it is simply ignored.
@param: URN's of measures to fetch. If empty, function returns all measures.
@return: List of measures with given URN's in 'urns'.
"""

if isinstance(urns, str):
urns = [urns]

all_valid_urns = self.fetch_urns()

if urns is None:
return [Measure(**self.client.get_task_by_urn(TaskType.MEASURES, urn)) for urn in all_valid_urns]

measures: list[Measure] = []
for urn in urns:
try:
logger.info(f"adding measure with URN {urn}")
measures.append(Measure(**self.client.get_task_by_urn(TaskType.MEASURES, urn)))
except AMTNotFound:
logger.warning(f"cannot find measure with URN {urn}")

return measures

def fetch_urns(self) -> list[str]:
"""
Fetches all valid measure URN's.
"""
content_list = self.client.get_list_of_task(TaskType.MEASURES)
return [content.urn for content in content_list.root]
47 changes: 47 additions & 0 deletions amt/services/requirements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import logging
from collections.abc import Sequence

from amt.clients.clients import TaskRegistryAPIClient, TaskType
from amt.core.exceptions import AMTNotFound
from amt.schema.requirement import Requirement

logger = logging.getLogger(__name__)


class RequirementsService:
def __init__(self) -> None:
self.client = TaskRegistryAPIClient()

def fetch_requirements(self, urns: str | Sequence[str] | None = None) -> list[Requirement]:
"""
This functions returns requirement with given URN's. If urns contains an URN that is not a
valid URN of an requirement it is simply ignored.
@param: URN's of requirements to fetch. If empty, function returns all requirements.
@return: List of requirements with given URN's in 'urns'.
"""

if isinstance(urns, str):
urns = [urns]

all_valid_urns = self.fetch_urns()

if urns is None:
return [Requirement(**self.client.get_task_by_urn(TaskType.REQUIREMENTS, urn)) for urn in all_valid_urns]

requirements: list[Requirement] = []
for urn in urns:
try:
requirements.append(Requirement(**self.client.get_task_by_urn(TaskType.REQUIREMENTS, urn)))
except AMTNotFound:
logger.warning(f"cannot find instrument with URN {urn}")

return requirements


def fetch_urns(self) -> list[str]:
"""
Fetches all valid requirement URN's.
"""
content_list = self.client.get_list_of_task(TaskType.REQUIREMENTS)
return [content.urn for content in content_list.root]
64 changes: 11 additions & 53 deletions amt/services/task_registry.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,19 @@
import logging
from collections.abc import Sequence
from functools import lru_cache
from pathlib import Path
from typing import Any

from amt.clients.clients import TaskRegistryAPIClient, TaskType
from amt.schema.measure import Measure, MeasureTask
from amt.schema.requirement import Requirement, RequirementTask
from amt.schema.system_card import AiActProfile
from amt.services.storage import StorageFactory
from amt.services.measures import MeasuresService
from amt.services.requirements import RequirementsService

logger = logging.getLogger(__name__)


class RequirementsService:
def __init__(self) -> None:
self.client = TaskRegistryAPIClient()

def fetch_requirements(self, urns: str | Sequence[str] | None = None) -> list[Requirement]:
"""
This functions returns requirements with given URN's. If urns contains an URN that is not a
valid URN of an requirement it is simply ignored.
@param: URN's of requirements to fetch. If empty, function returns all requiements.
@return: List of requiements with given URN's in 'urns'.
"""

if isinstance(urns, str):
urns = [urns]

all_valid_urns = self.fetch_urns()

if urns is not None:
return [self.client.get_requirement(urn) for urn in urns if urn in all_valid_urns]

return [self.client.get_requirement(urn) for urn in all_valid_urns]

def fetch_urns(self) -> list[str]:
"""
Fetches all valid instrument URN's.
"""
content_list = self.client.get_list_of_task(task=TaskType.REQUIREMENTS)
return [content.urn for content in content_list.root]
def get_requirements(ai_act_profile: AiActProfile) -> list[RequirementTask]:
requirements_card: list[RequirementTask] = []
return requirements_card


def get_requirements_and_measures(
Expand All @@ -62,25 +34,14 @@ def fetch_all_requirements() -> dict[str, Requirement]:
"""
Fetch requirements with URN in urns.
"""
all_requirements = RequirementsService().fetch_requirements()
logging.info(all_requirements)
requirement_service = RequirementsService()
all_requirements = requirement_service.fetch_requirements()
requirements: dict[str, Requirement] = {}

for requirement in all_requirements:
requirements[requirement.urn] = requirement

return requirements
#mock_registry_path = Path("example_registry/requirements")
#requirements: dict[str, Requirement] = {}

#for requirement_path in mock_registry_path.glob("*.yaml"):
# requirement: Any = StorageFactory.init(
# storage_type="file", location=requirement_path.parent, filename=requirement_path.name
# ).read()
# requirements[requirement["urn"]] = Requirement(**requirement)

#return requirements



def fetch_requirements(urns: Sequence[str]) -> list[Requirement]:
Expand All @@ -91,19 +52,16 @@ def fetch_requirements(urns: Sequence[str]) -> list[Requirement]:
return [all_requirements[urn] for urn in urns if urn in all_requirements]


@lru_cache
def fetch_all_measures() -> dict[str, Measure]:
"""
Fetch measures with URN in urns.
"""
mock_registry_path = Path("example_registry/measures")
measure_service = MeasuresService()
all_measures = measure_service.fetch_measures()
measures: dict[str, Measure] = {}

for measure_path in mock_registry_path.glob("*.yaml"):
measure: Any = StorageFactory.init(
storage_type="file", location=measure_path.parent, filename=measure_path.name
).read()
measures[measure["urn"]] = Measure(**measure)
for measure in all_measures:
measures[measure.urn] = measure

return measures

Expand Down
Loading

0 comments on commit 376c8f7

Please sign in to comment.