From f5b48613682afed658057ec2ebab907dc867d957 Mon Sep 17 00:00:00 2001 From: ChristopherSpelt Date: Fri, 18 Oct 2024 10:50:52 +0200 Subject: [PATCH] Refactor InstrumentsService to use the Task Registry API --- amt/clients/clients.py | 91 +++++---------------- amt/schema/github.py | 1 + amt/server.py | 1 + amt/services/instruments.py | 53 +++++------- tests/clients/test_clients.py | 93 +++++++--------------- tests/constants.py | 76 +++++++++--------- tests/services/test_instruments_service.py | 93 +++++++++++++++------- tests/services/test_instruments_state.py | 17 ++-- 8 files changed, 186 insertions(+), 239 deletions(-) diff --git a/amt/clients/clients.py b/amt/clients/clients.py index e60242f4..0d498203 100644 --- a/amt/clients/clients.py +++ b/amt/clients/clients.py @@ -1,92 +1,43 @@ import logging -from abc import ABC, abstractmethod -from datetime import datetime, timezone import httpx -from amt.core.exceptions import AMTNotFound +from amt.core.exceptions import AMTInstrumentError, AMTNotFound from amt.schema.github import RepositoryContent +from amt.schema.instrument import Instrument logger = logging.getLogger(__name__) -class Client(ABC): +class TaskRegistryAPIClient: """ - Abstract class which is used to set up HTTP clients that retrieve instruments from the - task registry. + This class interacts with the Task Registry API. + + Currently it supports: + - Retrieving the list of instruments. + - Getting an instrument by URN. """ - @abstractmethod + base_url = "https://task-registry.apps.digilab.network" + def __init__(self, max_retries: int = 3, timeout: int = 5) -> None: transport = httpx.HTTPTransport(retries=max_retries) self.client = httpx.Client(timeout=timeout, transport=transport) - @abstractmethod - def get_content(self, url: str) -> bytes: - """ - This method should implement getting the content of an instrument from given URL. - """ - - @abstractmethod - def list_content(self, url: str = "") -> RepositoryContent: - """ - This method should implement getting list of instruments from given URL. - """ - - def _get(self, url: str) -> httpx.Response: - """ - Private function that performs a GET request to given URL. - """ - response = self.client.get(url) + def get_instrument_list(self) -> RepositoryContent: + response = self.client.get(f"{TaskRegistryAPIClient.base_url}/instruments/") if response.status_code != 200: raise AMTNotFound() - return response - - -def get_client(repo_type: str) -> Client: - match repo_type: - case "github_pages": - return GitHubPagesClient() - case "github": - return GitHubClient() - case _: - raise AMTNotFound() - - -class GitHubPagesClient(Client): - def __init__(self) -> None: - super().__init__() - - def get_content(self, url: str) -> bytes: - return super()._get(url).content - - def list_content(self, url: str = "https://minbzk.github.io/task-registry/index.json") -> RepositoryContent: - response = super()._get(url) return RepositoryContent.model_validate(response.json()["entries"]) + def get_instrument(self, urn: str, version: str = "latest") -> Instrument: + response = self.client.get(f"{TaskRegistryAPIClient.base_url}/urns/", params={"version": version, "urn": urn}) -class GitHubClient(Client): - def __init__(self) -> None: - super().__init__() - self.client.event_hooks["response"] = [self._check_rate_limit] - # TODO(Berry): add authentication headers with event_hooks - - def get_content(self, url: str) -> bytes: - return super()._get(url).content + if response.status_code != 200: + raise AMTNotFound() - def list_content( - self, - url: str = "https://api.github.com/repos/MinBZK/task-registry/contents/instruments?ref=main", - ) -> RepositoryContent: - response = super()._get(url) - return RepositoryContent.model_validate(response.json()) + data = response.json() + if "urn" not in data: + logger.exception("Invalid instrument fetched: key 'urn' must occur in instrument.") + raise AMTInstrumentError() - def _check_rate_limit(self, response: httpx.Response) -> None: - if "x-ratelimit-remaining" in response.headers: - remaining = int(response.headers["X-RateLimit-Remaining"]) - if remaining == 0: - reset_timestamp = int(response.headers["X-RateLimit-Reset"]) - reset_time = datetime.fromtimestamp(reset_timestamp, timezone.utc) # noqa: UP017 - wait_seconds = (reset_time - datetime.now(timezone.utc)).total_seconds() # noqa: UP017 - logger.warning( - f"Rate limit exceeded. We need to wait for {wait_seconds} seconds. (not implemented yet)" - ) + return Instrument(**data) diff --git a/amt/schema/github.py b/amt/schema/github.py index cb93b0ef..bbca1680 100644 --- a/amt/schema/github.py +++ b/amt/schema/github.py @@ -9,6 +9,7 @@ class Links(BaseModel): class ContentItem(BaseModel): name: str + urn: str path: str sha: str | None = Field(default=None) size: int diff --git a/amt/server.py b/amt/server.py index 8843cd1f..a758dc9e 100644 --- a/amt/server.py +++ b/amt/server.py @@ -37,6 +37,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: logger.info(f"Starting {PROJECT_NAME} version {VERSION}") logger.info(f"Settings: {mask.secrets(get_settings().model_dump())}") yield + logger.info(f"Stopping application {PROJECT_NAME} version {VERSION}") logging.shutdown() diff --git a/amt/services/instruments.py b/amt/services/instruments.py index 0517222e..daa989e8 100644 --- a/amt/services/instruments.py +++ b/amt/services/instruments.py @@ -1,47 +1,38 @@ import logging from collections.abc import Sequence -import yaml - -from amt.clients.clients import get_client -from amt.core.exceptions import AMTInstrumentError -from amt.schema.github import RepositoryContent +from amt.clients.clients import TaskRegistryAPIClient from amt.schema.instrument import Instrument logger = logging.getLogger(__name__) class InstrumentsService: - def __init__(self, repo_type: str = "github_pages") -> None: - self.client = get_client(repo_type) - - def fetch_github_content_list(self) -> RepositoryContent: - response = self.client.list_content() - return RepositoryContent.model_validate(response) + def __init__(self) -> None: + self.client = TaskRegistryAPIClient() - def fetch_github_content(self, url: str) -> Instrument: - bytes_data = self.client.get_content(url) + def fetch_instruments(self, urns: str | Sequence[str] | None = None) -> list[Instrument]: + """ + This functions returns instruments with given URN's. If urns contains an URN that is not a + valid URN of an instrument it is simply ignored. - # assume yaml - data = yaml.safe_load(bytes_data) + @param: URN's of instruments to fetch. If empty, function returns all instruments. + @return: List of instruments with given URN's in 'urns'. + """ - if "urn" not in data: - # todo: this is now an HTTP error, while a service can also be used from another context - logger.exception("Key 'urn' not found in instrument.") - raise AMTInstrumentError() + if isinstance(urns, str): + urns = [urns] - return Instrument(**data) + all_valid_urns = self.fetch_urns() - def fetch_instruments(self, urns: Sequence[str] | None = None) -> list[Instrument]: - content_list = self.fetch_github_content_list() + if urns is not None: + return [self.client.get_instrument(urn) for urn in urns if urn in all_valid_urns] - instruments: list[Instrument] = [] + return [self.client.get_instrument(urn) for urn in all_valid_urns] - for content in content_list.root: # TODO(Berry): fix root field - instrument = self.fetch_github_content(str(content.download_url)) - if urns is None: - instruments.append(instrument) - else: - if instrument.urn in set(urns): - instruments.append(instrument) - return instruments + def fetch_urns(self) -> list[str]: + """ + Fetches all valid instrument URN's. + """ + content_list = self.client.get_instrument_list() + return [content.urn for content in content_list.root] diff --git a/tests/clients/test_clients.py b/tests/clients/test_clients.py index 55e3f23d..6f158698 100644 --- a/tests/clients/test_clients.py +++ b/tests/clients/test_clients.py @@ -1,95 +1,58 @@ +import json + import pytest -from amt.clients.clients import get_client +from amt.clients.clients import TaskRegistryAPIClient from amt.core.exceptions import AMTNotFound from amt.schema.github import RepositoryContent +from amt.schema.instrument import Instrument from pytest_httpx import HTTPXMock +from tests.constants import TASK_REGISTRY_CONTENT_PAYLOAD, TASK_REGISTRY_LIST_PAYLOAD -def test_get_client_unknown_client(): - with pytest.raises(AMTNotFound, match="The requested page or resource could not be found."): - get_client("unknown_client") - - -def test_get_content_github(httpx_mock: HTTPXMock): - # given - httpx_mock.add_response( - url="https://api.github.com/stuff/123", - content=b"somecontent", - headers={"X-RateLimit-Remaining": "7", "X-RateLimit-Reset": "200000000"}, - ) - github_client = get_client("github") - - # when - result = github_client.get_content("https://api.github.com/stuff/123") - - # then - assert result == b"somecontent" - - -def test_list_content_github(httpx_mock: HTTPXMock): - # given - url = "https://api.github.com/repos/MinBZK/task-registry/contents/?ref=main" - github_client = get_client("github") - repository_content = RepositoryContent(root=[]) - +def test_task_registry_api_client_get_instrument_list(httpx_mock: HTTPXMock): + task_registry_api_client = TaskRegistryAPIClient() httpx_mock.add_response( - url=url, - json=repository_content.model_dump(), + url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() ) - # when - result = github_client.list_content(url) - - # then - assert result == repository_content + result = task_registry_api_client.get_instrument_list() + assert result == RepositoryContent.model_validate(json.loads(TASK_REGISTRY_LIST_PAYLOAD)["entries"]) -def test_github_ratelimit_exceeded(httpx_mock: HTTPXMock): - # given - httpx_mock.add_response( - url="https://api.github.com/stuff/123", - status_code=403, - headers={"X-RateLimit-Remaining": "0", "X-RateLimit-Reset": "200000000"}, - ) - github_client = get_client("github") - # when - with pytest.raises(AMTNotFound) as exc_info: - github_client.get_content("https://api.github.com/stuff/123") +def test_task_registry_api_client_get_instrument_list_not_succesfull(httpx_mock: HTTPXMock): + task_registry_api_client = TaskRegistryAPIClient() + httpx_mock.add_response(status_code=408, url="https://task-registry.apps.digilab.network/instruments/") # then - assert "The requested page or resource could not be found" in str(exc_info.value) + pytest.raises(AMTNotFound, task_registry_api_client.get_instrument_list) -def test_get_content_github_pages(httpx_mock: HTTPXMock): +def test_task_registry_api_client_get_instrument(httpx_mock: HTTPXMock): # given + task_registry_api_client = TaskRegistryAPIClient() httpx_mock.add_response( - url="https://minbzk.github.io/stuff/123", - content=b"somecontent", + url="https://task-registry.apps.digilab.network/urns/?version=latest&urn=urn%3Anl%3Aaivt%3Atr%3Aiama%3A1.0", + content=TASK_REGISTRY_CONTENT_PAYLOAD.encode(), ) - github_client = get_client("github_pages") # when - result = github_client.get_content("https://minbzk.github.io/stuff/123") + urn = "urn:nl:aivt:tr:iama:1.0" + result = task_registry_api_client.get_instrument(urn) # then - assert result == b"somecontent" + assert result == Instrument(**json.loads(TASK_REGISTRY_CONTENT_PAYLOAD)) -def test_list_content_github_pages(httpx_mock: HTTPXMock): - # given - url = "https://minbzk.github.io/task-registry/index.json" - github_client = get_client("github_pages") - repository_content = RepositoryContent(root=[]) - input = {"entries": repository_content.model_dump()} - +def test_task_registry_api_client_get_instrument_not_succesfull(httpx_mock: HTTPXMock): + task_registry_api_client = TaskRegistryAPIClient() httpx_mock.add_response( - url=url, - json=input, + status_code=408, + url="https://task-registry.apps.digilab.network/urns/?version=latest&urn=urn%3Anl%3Aaivt%3Atr%3Aiama%3A1.0", ) - # when - result = github_client.list_content(url) + urn = "urn:nl:aivt:tr:iama:1.0" # then - assert result == repository_content + with pytest.raises(AMTNotFound): + task_registry_api_client.get_instrument(urn) diff --git a/tests/constants.py b/tests/constants.py index e2c02ac8..3c9df309 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -62,41 +62,43 @@ def default_task( ) -GITHUB_LIST_PAYLOAD = """ -[ - { - "name": "iama.yaml", - "path": "instruments/iama.yaml", - "sha": "50cf187eaea995ba848d93f1799fd34d4af8036b", - "size": 30319, - "url": "https://api.github.com/repos/MinBZK/task-registry/contents/instruments/iama.yaml?ref=main", - "html_url": "https://github.com/MinBZK/task-registry/blob/main/instruments/iama.yaml", - "git_url": "https://api.github.com/repos/MinBZK/task-registry/git/blobs/50cf187eaea995ba848d93f1799fd34d4af8036b", - "download_url": "https://raw.githubusercontent.com/MinBZK/task-registry/main/instruments/iama.yaml", - "type": "file", - "_links": { - "self": "https://api.github.com/repos/MinBZK/task-registry/contents/instruments/iama.yaml?ref=main", - "git": "https://api.github.com/repos/MinBZK/task-registry/git/blobs/50cf187eaea995ba848d93f1799fd34d4af8036b", - "html": "https://github.com/MinBZK/task-registry/blob/main/instruments/iama.yaml" +TASK_REGISTRY_LIST_PAYLOAD = """ +{ +"entries": [ + { + "type": "file", + "size": 32897, + "name": "iama.yaml", + "path": "instruments/iama.yaml", + "urn": "urn:nl:aivt:tr:iama:1.0", + "download_url": "https://minbzk.github.io/task-registry/instruments/iama.yaml", + "_links": { + "self": "https://minbzk.github.io/task-registry/instruments/iama.yaml" + } } - } - ] - """ - -GITHUB_CONTENT_PAYLOAD = """ -systemcard_path: .assessments[] -schema_version: 1.1.0 - -name: "Impact Assessment Mensenrechten en Algoritmes (IAMA)" -description: "Het IAMA helpt om de risico's voor mensenrechten bij het gebruik van algoritmen in kaart te brengen en maatregelen te nemen om deze aan te pakken." -urn: "urn:nl:aivt:ir:iama:1.0" -language: "nl" -owners: -- organization: "" - name: "" - email: "" - role: "" -date: "" -url: "https://www.rijksoverheid.nl/documenten/rapporten/2021/02/25/impact-assessment-mensenrechten-en-algoritmes" -tasks: [] -""" # noqa: E501 +] +} +""" + +TASK_REGISTRY_CONTENT_PAYLOAD = """ +{ + "systemcard_path": ".assessments[]", + "schema_version": "1.1.0", + "name": "Impact Assessment Mensenrechten en Algoritmes (IAMA)", + "description": "Het IAMA helpt om de risico's voor mensenrechten bij het gebruik van algoritmen \ + in kaart te brengen en maatregelen te nemen om deze aan te pakken.", + "urn": "urn:nl:aivt:tr:iama:1.0", + "language": "nl", + "owners": [ + { + "organization": "", + "name": "", + "email": "", + "role": "" + } + ], + "date": "", + "url": "https://www.rijksoverheid.nl/documenten/rapporten/2021/02/25/impact-assessment-mensenrechten-en-algoritmes", + "tasks": [] +} +""" diff --git a/tests/services/test_instruments_service.py b/tests/services/test_instruments_service.py index 6dc8373a..5a039991 100644 --- a/tests/services/test_instruments_service.py +++ b/tests/services/test_instruments_service.py @@ -2,24 +2,39 @@ from amt.core.exceptions import AMTInstrumentError from amt.services.instruments import InstrumentsService from pytest_httpx import HTTPXMock -from tests.constants import GITHUB_CONTENT_PAYLOAD, GITHUB_LIST_PAYLOAD +from tests.constants import ( + TASK_REGISTRY_CONTENT_PAYLOAD, + TASK_REGISTRY_LIST_PAYLOAD, +) # TODO(berry): made payloads to a better location +def test_fetch_urns(httpx_mock: HTTPXMock): + # given + instruments_service = InstrumentsService() + httpx_mock.add_response( + url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() + ) + + # when + result = instruments_service.fetch_urns() + + # then + assert len(result) == 1 + assert result[0] == "urn:nl:aivt:tr:iama:1.0" + + def test_fetch_instruments(httpx_mock: HTTPXMock): # given - instruments_service = InstrumentsService("github") + instruments_service = InstrumentsService() httpx_mock.add_response( - url="https://api.github.com/repos/MinBZK/task-registry/contents/instruments?ref=main", - content=GITHUB_LIST_PAYLOAD.encode(), - headers={"X-RateLimit-Remaining": "7", "X-RateLimit-Reset": "200000000", "Content-Type": "application/json"}, + url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() ) httpx_mock.add_response( - url="https://raw.githubusercontent.com/MinBZK/task-registry/main/instruments/iama.yaml", - content=GITHUB_CONTENT_PAYLOAD.encode(), - headers={"X-RateLimit-Remaining": "7", "X-RateLimit-Reset": "200000000", "content-type": "text/plain"}, + url="https://task-registry.apps.digilab.network/urns/?version=latest&urn=urn%3Anl%3Aaivt%3Atr%3Aiama%3A1.0", + content=TASK_REGISTRY_CONTENT_PAYLOAD.encode(), ) # when @@ -29,45 +44,69 @@ def test_fetch_instruments(httpx_mock: HTTPXMock): assert len(result) == 1 -def test_fetch_instruments_with_urns(httpx_mock: HTTPXMock): +def test_fetch_instrument_with_urn(httpx_mock: HTTPXMock): # given - instruments_service = InstrumentsService("github") + instruments_service = InstrumentsService() httpx_mock.add_response( - url="https://api.github.com/repos/MinBZK/task-registry/contents/instruments?ref=main", - content=GITHUB_LIST_PAYLOAD.encode(), - headers={"X-RateLimit-Remaining": "7", "X-RateLimit-Reset": "200000000", "Content-Type": "application/json"}, + url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() ) - httpx_mock.add_response( - url="https://raw.githubusercontent.com/MinBZK/task-registry/main/instruments/iama.yaml", - content=GITHUB_CONTENT_PAYLOAD.encode(), - headers={"X-RateLimit-Remaining": "7", "X-RateLimit-Reset": "200000000", "content-type": "text/plain"}, + url="https://task-registry.apps.digilab.network/urns/?version=latest&urn=urn%3Anl%3Aaivt%3Atr%3Aiama%3A1.0", + content=TASK_REGISTRY_CONTENT_PAYLOAD.encode(), ) - urn = "urn:nl:aivt:ir:iama:1.0" + # when + urn = "urn:nl:aivt:tr:iama:1.0" + result = instruments_service.fetch_instruments(urn) + + # then + assert len(result) == 1 + + +def test_fetch_instruments_with_urns(httpx_mock: HTTPXMock): + # given + instruments_service = InstrumentsService() + httpx_mock.add_response( + url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() + ) + httpx_mock.add_response( + url="https://task-registry.apps.digilab.network/urns/?version=latest&urn=urn%3Anl%3Aaivt%3Atr%3Aiama%3A1.0", + content=TASK_REGISTRY_CONTENT_PAYLOAD.encode(), + ) # when + urn = "urn:nl:aivt:tr:iama:1.0" result = instruments_service.fetch_instruments([urn]) # then assert len(result) == 1 - assert result[0].urn == urn -def test_fetch_instruments_invalid(httpx_mock: HTTPXMock): +def test_fetch_instruments_with_invalid_urn(httpx_mock: HTTPXMock): # given - instruments_service = InstrumentsService("github") + instruments_service = InstrumentsService() httpx_mock.add_response( - url="https://api.github.com/repos/MinBZK/task-registry/contents/instruments?ref=main", - content=GITHUB_LIST_PAYLOAD.encode(), - headers={"X-RateLimit-Remaining": "7", "X-RateLimit-Reset": "200000000", "Content-Type": "application/json"}, + url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() ) # when + urn = "urn:nl:aivt:ir:iama:1.0" + result = instruments_service.fetch_instruments([urn]) + + # then + assert len(result) == 0 + + +def test_fetch_instruments_invalid(httpx_mock: HTTPXMock): + # given + instruments_service = InstrumentsService() + httpx_mock.add_response( + url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() + ) + httpx_mock.add_response( - url="https://raw.githubusercontent.com/MinBZK/task-registry/main/instruments/iama.yaml", - content=b"test: 1", - headers={"X-RateLimit-Remaining": "7", "X-RateLimit-Reset": "200000000", "content-type": "text/plain"}, + url="https://task-registry.apps.digilab.network/urns/?version=latest&urn=urn%3Anl%3Aaivt%3Atr%3Aiama%3A1.0", + content=b'{"test": 1}', ) # then diff --git a/tests/services/test_instruments_state.py b/tests/services/test_instruments_state.py index 9f8dc1ee..f1122a9e 100644 --- a/tests/services/test_instruments_state.py +++ b/tests/services/test_instruments_state.py @@ -174,15 +174,14 @@ def test_find_next_tasks_for_instrument_correct_lifecycle(system_card: SystemCar def test_get_state_per_instrument(system_card: SystemCard): instrument_state_service = InstrumentStateService(system_card) res = instrument_state_service.get_state_per_instrument() - assert res == [ - {"urn": "urn:nl:aivt:tr:aiia:1.0", "in_progress": 1, "name": "AI Impact Assessment (AIIA)"}, - { - "urn": "urn:nl:aivt:tr:iama:1.0", - "in_progress": 1, - "name": "Impact Assessment Mensenrechten en Algoritmes (IAMA)", - }, - {"in_progress": 0, "name": "URN not found in Task Registry.", "urn": "urn:instrument:assessment"}, - ] + assert {"urn": "urn:nl:aivt:tr:aiia:1.0", "in_progress": 1, "name": "AI Impact Assessment (AIIA)"} in res + assert { + "urn": "urn:nl:aivt:tr:iama:1.0", + "in_progress": 1, + "name": "Impact Assessment Mensenrechten en Algoritmes (IAMA)", + } in res + assert {"urn": "urn:nl:aivt:tr:aiia:1.0", "in_progress": 1, "name": "AI Impact Assessment (AIIA)"} in res + assert {"in_progress": 0, "name": "URN not found in Task Registry.", "urn": "urn:instrument:assessment"} in res def test_get_amount_completed_instruments(system_card: SystemCard):