generated from MinBZK/python-project-template
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
186 additions
and
239 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,92 +1,43 @@ | ||
import logging | ||
from abc import ABC, abstractmethod | ||
from datetime import datetime, timezone | ||
|
||
import httpx | ||
from amt.core.exceptions import AMTNotFound | ||
from amt.core.exceptions import AMTInstrumentError, AMTNotFound | ||
from amt.schema.github import RepositoryContent | ||
from amt.schema.instrument import Instrument | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Client(ABC): | ||
class TaskRegistryAPIClient: | ||
""" | ||
Abstract class which is used to set up HTTP clients that retrieve instruments from the | ||
task registry. | ||
This class interacts with the Task Registry API. | ||
Currently it supports: | ||
- Retrieving the list of instruments. | ||
- Getting an instrument by URN. | ||
""" | ||
|
||
@abstractmethod | ||
base_url = "https://task-registry.apps.digilab.network" | ||
|
||
def __init__(self, max_retries: int = 3, timeout: int = 5) -> None: | ||
transport = httpx.HTTPTransport(retries=max_retries) | ||
self.client = httpx.Client(timeout=timeout, transport=transport) | ||
|
||
@abstractmethod | ||
def get_content(self, url: str) -> bytes: | ||
""" | ||
This method should implement getting the content of an instrument from given URL. | ||
""" | ||
|
||
@abstractmethod | ||
def list_content(self, url: str = "") -> RepositoryContent: | ||
""" | ||
This method should implement getting list of instruments from given URL. | ||
""" | ||
|
||
def _get(self, url: str) -> httpx.Response: | ||
""" | ||
Private function that performs a GET request to given URL. | ||
""" | ||
response = self.client.get(url) | ||
def get_instrument_list(self) -> RepositoryContent: | ||
response = self.client.get(f"{TaskRegistryAPIClient.base_url}/instruments/") | ||
if response.status_code != 200: | ||
raise AMTNotFound() | ||
return response | ||
|
||
|
||
def get_client(repo_type: str) -> Client: | ||
match repo_type: | ||
case "github_pages": | ||
return GitHubPagesClient() | ||
case "github": | ||
return GitHubClient() | ||
case _: | ||
raise AMTNotFound() | ||
|
||
|
||
class GitHubPagesClient(Client): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
|
||
def get_content(self, url: str) -> bytes: | ||
return super()._get(url).content | ||
|
||
def list_content(self, url: str = "https://minbzk.github.io/task-registry/index.json") -> RepositoryContent: | ||
response = super()._get(url) | ||
return RepositoryContent.model_validate(response.json()["entries"]) | ||
|
||
def get_instrument(self, urn: str, version: str = "latest") -> Instrument: | ||
response = self.client.get(f"{TaskRegistryAPIClient.base_url}/urns/", params={"version": version, "urn": urn}) | ||
|
||
class GitHubClient(Client): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.client.event_hooks["response"] = [self._check_rate_limit] | ||
# TODO(Berry): add authentication headers with event_hooks | ||
|
||
def get_content(self, url: str) -> bytes: | ||
return super()._get(url).content | ||
if response.status_code != 200: | ||
raise AMTNotFound() | ||
|
||
def list_content( | ||
self, | ||
url: str = "https://api.github.com/repos/MinBZK/task-registry/contents/instruments?ref=main", | ||
) -> RepositoryContent: | ||
response = super()._get(url) | ||
return RepositoryContent.model_validate(response.json()) | ||
data = response.json() | ||
if "urn" not in data: | ||
logger.exception("Invalid instrument fetched: key 'urn' must occur in instrument.") | ||
raise AMTInstrumentError() | ||
|
||
def _check_rate_limit(self, response: httpx.Response) -> None: | ||
if "x-ratelimit-remaining" in response.headers: | ||
remaining = int(response.headers["X-RateLimit-Remaining"]) | ||
if remaining == 0: | ||
reset_timestamp = int(response.headers["X-RateLimit-Reset"]) | ||
reset_time = datetime.fromtimestamp(reset_timestamp, timezone.utc) # noqa: UP017 | ||
wait_seconds = (reset_time - datetime.now(timezone.utc)).total_seconds() # noqa: UP017 | ||
logger.warning( | ||
f"Rate limit exceeded. We need to wait for {wait_seconds} seconds. (not implemented yet)" | ||
) | ||
return Instrument(**data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,47 +1,38 @@ | ||
import logging | ||
from collections.abc import Sequence | ||
|
||
import yaml | ||
|
||
from amt.clients.clients import get_client | ||
from amt.core.exceptions import AMTInstrumentError | ||
from amt.schema.github import RepositoryContent | ||
from amt.clients.clients import TaskRegistryAPIClient | ||
from amt.schema.instrument import Instrument | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class InstrumentsService: | ||
def __init__(self, repo_type: str = "github_pages") -> None: | ||
self.client = get_client(repo_type) | ||
|
||
def fetch_github_content_list(self) -> RepositoryContent: | ||
response = self.client.list_content() | ||
return RepositoryContent.model_validate(response) | ||
def __init__(self) -> None: | ||
self.client = TaskRegistryAPIClient() | ||
|
||
def fetch_github_content(self, url: str) -> Instrument: | ||
bytes_data = self.client.get_content(url) | ||
def fetch_instruments(self, urns: str | Sequence[str] | None = None) -> list[Instrument]: | ||
""" | ||
This functions returns instruments with given URN's. If urns contains an URN that is not a | ||
valid URN of an instrument it is simply ignored. | ||
# assume yaml | ||
data = yaml.safe_load(bytes_data) | ||
@param: URN's of instruments to fetch. If empty, function returns all instruments. | ||
@return: List of instruments with given URN's in 'urns'. | ||
""" | ||
|
||
if "urn" not in data: | ||
# todo: this is now an HTTP error, while a service can also be used from another context | ||
logger.exception("Key 'urn' not found in instrument.") | ||
raise AMTInstrumentError() | ||
if isinstance(urns, str): | ||
urns = [urns] | ||
|
||
return Instrument(**data) | ||
all_valid_urns = self.fetch_urns() | ||
|
||
def fetch_instruments(self, urns: Sequence[str] | None = None) -> list[Instrument]: | ||
content_list = self.fetch_github_content_list() | ||
if urns is not None: | ||
return [self.client.get_instrument(urn) for urn in urns if urn in all_valid_urns] | ||
|
||
instruments: list[Instrument] = [] | ||
return [self.client.get_instrument(urn) for urn in all_valid_urns] | ||
|
||
for content in content_list.root: # TODO(Berry): fix root field | ||
instrument = self.fetch_github_content(str(content.download_url)) | ||
if urns is None: | ||
instruments.append(instrument) | ||
else: | ||
if instrument.urn in set(urns): | ||
instruments.append(instrument) | ||
return instruments | ||
def fetch_urns(self) -> list[str]: | ||
""" | ||
Fetches all valid instrument URN's. | ||
""" | ||
content_list = self.client.get_instrument_list() | ||
return [content.urn for content in content_list.root] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,95 +1,58 @@ | ||
import json | ||
|
||
import pytest | ||
from amt.clients.clients import get_client | ||
from amt.clients.clients import TaskRegistryAPIClient | ||
from amt.core.exceptions import AMTNotFound | ||
from amt.schema.github import RepositoryContent | ||
from amt.schema.instrument import Instrument | ||
from pytest_httpx import HTTPXMock | ||
from tests.constants import TASK_REGISTRY_CONTENT_PAYLOAD, TASK_REGISTRY_LIST_PAYLOAD | ||
|
||
|
||
def test_get_client_unknown_client(): | ||
with pytest.raises(AMTNotFound, match="The requested page or resource could not be found."): | ||
get_client("unknown_client") | ||
|
||
|
||
def test_get_content_github(httpx_mock: HTTPXMock): | ||
# given | ||
httpx_mock.add_response( | ||
url="https://api.github.com/stuff/123", | ||
content=b"somecontent", | ||
headers={"X-RateLimit-Remaining": "7", "X-RateLimit-Reset": "200000000"}, | ||
) | ||
github_client = get_client("github") | ||
|
||
# when | ||
result = github_client.get_content("https://api.github.com/stuff/123") | ||
|
||
# then | ||
assert result == b"somecontent" | ||
|
||
|
||
def test_list_content_github(httpx_mock: HTTPXMock): | ||
# given | ||
url = "https://api.github.com/repos/MinBZK/task-registry/contents/?ref=main" | ||
github_client = get_client("github") | ||
repository_content = RepositoryContent(root=[]) | ||
|
||
def test_task_registry_api_client_get_instrument_list(httpx_mock: HTTPXMock): | ||
task_registry_api_client = TaskRegistryAPIClient() | ||
httpx_mock.add_response( | ||
url=url, | ||
json=repository_content.model_dump(), | ||
url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode() | ||
) | ||
|
||
# when | ||
result = github_client.list_content(url) | ||
|
||
# then | ||
assert result == repository_content | ||
result = task_registry_api_client.get_instrument_list() | ||
|
||
assert result == RepositoryContent.model_validate(json.loads(TASK_REGISTRY_LIST_PAYLOAD)["entries"]) | ||
|
||
def test_github_ratelimit_exceeded(httpx_mock: HTTPXMock): | ||
# given | ||
httpx_mock.add_response( | ||
url="https://api.github.com/stuff/123", | ||
status_code=403, | ||
headers={"X-RateLimit-Remaining": "0", "X-RateLimit-Reset": "200000000"}, | ||
) | ||
github_client = get_client("github") | ||
|
||
# when | ||
with pytest.raises(AMTNotFound) as exc_info: | ||
github_client.get_content("https://api.github.com/stuff/123") | ||
def test_task_registry_api_client_get_instrument_list_not_succesfull(httpx_mock: HTTPXMock): | ||
task_registry_api_client = TaskRegistryAPIClient() | ||
httpx_mock.add_response(status_code=408, url="https://task-registry.apps.digilab.network/instruments/") | ||
|
||
# then | ||
assert "The requested page or resource could not be found" in str(exc_info.value) | ||
pytest.raises(AMTNotFound, task_registry_api_client.get_instrument_list) | ||
|
||
|
||
def test_get_content_github_pages(httpx_mock: HTTPXMock): | ||
def test_task_registry_api_client_get_instrument(httpx_mock: HTTPXMock): | ||
# given | ||
task_registry_api_client = TaskRegistryAPIClient() | ||
httpx_mock.add_response( | ||
url="https://minbzk.github.io/stuff/123", | ||
content=b"somecontent", | ||
url="https://task-registry.apps.digilab.network/urns/?version=latest&urn=urn%3Anl%3Aaivt%3Atr%3Aiama%3A1.0", | ||
content=TASK_REGISTRY_CONTENT_PAYLOAD.encode(), | ||
) | ||
github_client = get_client("github_pages") | ||
|
||
# when | ||
result = github_client.get_content("https://minbzk.github.io/stuff/123") | ||
urn = "urn:nl:aivt:tr:iama:1.0" | ||
result = task_registry_api_client.get_instrument(urn) | ||
|
||
# then | ||
assert result == b"somecontent" | ||
assert result == Instrument(**json.loads(TASK_REGISTRY_CONTENT_PAYLOAD)) | ||
|
||
|
||
def test_list_content_github_pages(httpx_mock: HTTPXMock): | ||
# given | ||
url = "https://minbzk.github.io/task-registry/index.json" | ||
github_client = get_client("github_pages") | ||
repository_content = RepositoryContent(root=[]) | ||
input = {"entries": repository_content.model_dump()} | ||
|
||
def test_task_registry_api_client_get_instrument_not_succesfull(httpx_mock: HTTPXMock): | ||
task_registry_api_client = TaskRegistryAPIClient() | ||
httpx_mock.add_response( | ||
url=url, | ||
json=input, | ||
status_code=408, | ||
url="https://task-registry.apps.digilab.network/urns/?version=latest&urn=urn%3Anl%3Aaivt%3Atr%3Aiama%3A1.0", | ||
) | ||
|
||
# when | ||
result = github_client.list_content(url) | ||
urn = "urn:nl:aivt:tr:iama:1.0" | ||
|
||
# then | ||
assert result == repository_content | ||
with pytest.raises(AMTNotFound): | ||
task_registry_api_client.get_instrument(urn) |
Oops, something went wrong.