Skip to content

Commit

Permalink
Refactor registry (#289)
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt authored Oct 21, 2024
2 parents 48f1ffb + f5b4861 commit 2e14ca2
Show file tree
Hide file tree
Showing 8 changed files with 186 additions and 239 deletions.
91 changes: 21 additions & 70 deletions amt/clients/clients.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,43 @@
import logging
from abc import ABC, abstractmethod
from datetime import datetime, timezone

import httpx
from amt.core.exceptions import AMTNotFound
from amt.core.exceptions import AMTInstrumentError, AMTNotFound
from amt.schema.github import RepositoryContent
from amt.schema.instrument import Instrument

logger = logging.getLogger(__name__)


class Client(ABC):
class TaskRegistryAPIClient:
"""
Abstract class which is used to set up HTTP clients that retrieve instruments from the
task registry.
This class interacts with the Task Registry API.
Currently it supports:
- Retrieving the list of instruments.
- Getting an instrument by URN.
"""

@abstractmethod
base_url = "https://task-registry.apps.digilab.network"

def __init__(self, max_retries: int = 3, timeout: int = 5) -> None:
transport = httpx.HTTPTransport(retries=max_retries)
self.client = httpx.Client(timeout=timeout, transport=transport)

@abstractmethod
def get_content(self, url: str) -> bytes:
"""
This method should implement getting the content of an instrument from given URL.
"""

@abstractmethod
def list_content(self, url: str = "") -> RepositoryContent:
"""
This method should implement getting list of instruments from given URL.
"""

def _get(self, url: str) -> httpx.Response:
"""
Private function that performs a GET request to given URL.
"""
response = self.client.get(url)
def get_instrument_list(self) -> RepositoryContent:
response = self.client.get(f"{TaskRegistryAPIClient.base_url}/instruments/")
if response.status_code != 200:
raise AMTNotFound()
return response


def get_client(repo_type: str) -> Client:
match repo_type:
case "github_pages":
return GitHubPagesClient()
case "github":
return GitHubClient()
case _:
raise AMTNotFound()


class GitHubPagesClient(Client):
def __init__(self) -> None:
super().__init__()

def get_content(self, url: str) -> bytes:
return super()._get(url).content

def list_content(self, url: str = "https://minbzk.github.io/task-registry/index.json") -> RepositoryContent:
response = super()._get(url)
return RepositoryContent.model_validate(response.json()["entries"])

def get_instrument(self, urn: str, version: str = "latest") -> Instrument:
response = self.client.get(f"{TaskRegistryAPIClient.base_url}/urns/", params={"version": version, "urn": urn})

class GitHubClient(Client):
def __init__(self) -> None:
super().__init__()
self.client.event_hooks["response"] = [self._check_rate_limit]
# TODO(Berry): add authentication headers with event_hooks

def get_content(self, url: str) -> bytes:
return super()._get(url).content
if response.status_code != 200:
raise AMTNotFound()

def list_content(
self,
url: str = "https://api.github.com/repos/MinBZK/task-registry/contents/instruments?ref=main",
) -> RepositoryContent:
response = super()._get(url)
return RepositoryContent.model_validate(response.json())
data = response.json()
if "urn" not in data:
logger.exception("Invalid instrument fetched: key 'urn' must occur in instrument.")
raise AMTInstrumentError()

def _check_rate_limit(self, response: httpx.Response) -> None:
if "x-ratelimit-remaining" in response.headers:
remaining = int(response.headers["X-RateLimit-Remaining"])
if remaining == 0:
reset_timestamp = int(response.headers["X-RateLimit-Reset"])
reset_time = datetime.fromtimestamp(reset_timestamp, timezone.utc) # noqa: UP017
wait_seconds = (reset_time - datetime.now(timezone.utc)).total_seconds() # noqa: UP017
logger.warning(
f"Rate limit exceeded. We need to wait for {wait_seconds} seconds. (not implemented yet)"
)
return Instrument(**data)
1 change: 1 addition & 0 deletions amt/schema/github.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Links(BaseModel):

class ContentItem(BaseModel):
name: str
urn: str
path: str
sha: str | None = Field(default=None)
size: int
Expand Down
1 change: 1 addition & 0 deletions amt/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
logger.info(f"Starting {PROJECT_NAME} version {VERSION}")
logger.info(f"Settings: {mask.secrets(get_settings().model_dump())}")
yield

logger.info(f"Stopping application {PROJECT_NAME} version {VERSION}")
logging.shutdown()

Expand Down
53 changes: 22 additions & 31 deletions amt/services/instruments.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,38 @@
import logging
from collections.abc import Sequence

import yaml

from amt.clients.clients import get_client
from amt.core.exceptions import AMTInstrumentError
from amt.schema.github import RepositoryContent
from amt.clients.clients import TaskRegistryAPIClient
from amt.schema.instrument import Instrument

logger = logging.getLogger(__name__)


class InstrumentsService:
def __init__(self, repo_type: str = "github_pages") -> None:
self.client = get_client(repo_type)

def fetch_github_content_list(self) -> RepositoryContent:
response = self.client.list_content()
return RepositoryContent.model_validate(response)
def __init__(self) -> None:
self.client = TaskRegistryAPIClient()

def fetch_github_content(self, url: str) -> Instrument:
bytes_data = self.client.get_content(url)
def fetch_instruments(self, urns: str | Sequence[str] | None = None) -> list[Instrument]:
"""
This functions returns instruments with given URN's. If urns contains an URN that is not a
valid URN of an instrument it is simply ignored.
# assume yaml
data = yaml.safe_load(bytes_data)
@param: URN's of instruments to fetch. If empty, function returns all instruments.
@return: List of instruments with given URN's in 'urns'.
"""

if "urn" not in data:
# todo: this is now an HTTP error, while a service can also be used from another context
logger.exception("Key 'urn' not found in instrument.")
raise AMTInstrumentError()
if isinstance(urns, str):
urns = [urns]

return Instrument(**data)
all_valid_urns = self.fetch_urns()

def fetch_instruments(self, urns: Sequence[str] | None = None) -> list[Instrument]:
content_list = self.fetch_github_content_list()
if urns is not None:
return [self.client.get_instrument(urn) for urn in urns if urn in all_valid_urns]

instruments: list[Instrument] = []
return [self.client.get_instrument(urn) for urn in all_valid_urns]

for content in content_list.root: # TODO(Berry): fix root field
instrument = self.fetch_github_content(str(content.download_url))
if urns is None:
instruments.append(instrument)
else:
if instrument.urn in set(urns):
instruments.append(instrument)
return instruments
def fetch_urns(self) -> list[str]:
"""
Fetches all valid instrument URN's.
"""
content_list = self.client.get_instrument_list()
return [content.urn for content in content_list.root]
93 changes: 28 additions & 65 deletions tests/clients/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,95 +1,58 @@
import json

import pytest
from amt.clients.clients import get_client
from amt.clients.clients import TaskRegistryAPIClient
from amt.core.exceptions import AMTNotFound
from amt.schema.github import RepositoryContent
from amt.schema.instrument import Instrument
from pytest_httpx import HTTPXMock
from tests.constants import TASK_REGISTRY_CONTENT_PAYLOAD, TASK_REGISTRY_LIST_PAYLOAD


def test_get_client_unknown_client():
with pytest.raises(AMTNotFound, match="The requested page or resource could not be found."):
get_client("unknown_client")


def test_get_content_github(httpx_mock: HTTPXMock):
# given
httpx_mock.add_response(
url="https://api.github.com/stuff/123",
content=b"somecontent",
headers={"X-RateLimit-Remaining": "7", "X-RateLimit-Reset": "200000000"},
)
github_client = get_client("github")

# when
result = github_client.get_content("https://api.github.com/stuff/123")

# then
assert result == b"somecontent"


def test_list_content_github(httpx_mock: HTTPXMock):
# given
url = "https://api.github.com/repos/MinBZK/task-registry/contents/?ref=main"
github_client = get_client("github")
repository_content = RepositoryContent(root=[])

def test_task_registry_api_client_get_instrument_list(httpx_mock: HTTPXMock):
task_registry_api_client = TaskRegistryAPIClient()
httpx_mock.add_response(
url=url,
json=repository_content.model_dump(),
url="https://task-registry.apps.digilab.network/instruments/", content=TASK_REGISTRY_LIST_PAYLOAD.encode()
)

# when
result = github_client.list_content(url)

# then
assert result == repository_content
result = task_registry_api_client.get_instrument_list()

assert result == RepositoryContent.model_validate(json.loads(TASK_REGISTRY_LIST_PAYLOAD)["entries"])

def test_github_ratelimit_exceeded(httpx_mock: HTTPXMock):
# given
httpx_mock.add_response(
url="https://api.github.com/stuff/123",
status_code=403,
headers={"X-RateLimit-Remaining": "0", "X-RateLimit-Reset": "200000000"},
)
github_client = get_client("github")

# when
with pytest.raises(AMTNotFound) as exc_info:
github_client.get_content("https://api.github.com/stuff/123")
def test_task_registry_api_client_get_instrument_list_not_succesfull(httpx_mock: HTTPXMock):
task_registry_api_client = TaskRegistryAPIClient()
httpx_mock.add_response(status_code=408, url="https://task-registry.apps.digilab.network/instruments/")

# then
assert "The requested page or resource could not be found" in str(exc_info.value)
pytest.raises(AMTNotFound, task_registry_api_client.get_instrument_list)


def test_get_content_github_pages(httpx_mock: HTTPXMock):
def test_task_registry_api_client_get_instrument(httpx_mock: HTTPXMock):
# given
task_registry_api_client = TaskRegistryAPIClient()
httpx_mock.add_response(
url="https://minbzk.github.io/stuff/123",
content=b"somecontent",
url="https://task-registry.apps.digilab.network/urns/?version=latest&urn=urn%3Anl%3Aaivt%3Atr%3Aiama%3A1.0",
content=TASK_REGISTRY_CONTENT_PAYLOAD.encode(),
)
github_client = get_client("github_pages")

# when
result = github_client.get_content("https://minbzk.github.io/stuff/123")
urn = "urn:nl:aivt:tr:iama:1.0"
result = task_registry_api_client.get_instrument(urn)

# then
assert result == b"somecontent"
assert result == Instrument(**json.loads(TASK_REGISTRY_CONTENT_PAYLOAD))


def test_list_content_github_pages(httpx_mock: HTTPXMock):
# given
url = "https://minbzk.github.io/task-registry/index.json"
github_client = get_client("github_pages")
repository_content = RepositoryContent(root=[])
input = {"entries": repository_content.model_dump()}

def test_task_registry_api_client_get_instrument_not_succesfull(httpx_mock: HTTPXMock):
task_registry_api_client = TaskRegistryAPIClient()
httpx_mock.add_response(
url=url,
json=input,
status_code=408,
url="https://task-registry.apps.digilab.network/urns/?version=latest&urn=urn%3Anl%3Aaivt%3Atr%3Aiama%3A1.0",
)

# when
result = github_client.list_content(url)
urn = "urn:nl:aivt:tr:iama:1.0"

# then
assert result == repository_content
with pytest.raises(AMTNotFound):
task_registry_api_client.get_instrument(urn)
Loading

0 comments on commit 2e14ca2

Please sign in to comment.