Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ChristopherSpelt committed Nov 14, 2024
1 parent 376c8f7 commit fd5d501
Show file tree
Hide file tree
Showing 11 changed files with 180 additions and 120 deletions.
63 changes: 63 additions & 0 deletions :w
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import logging
from collections.abc import Sequence

from amt.clients.clients import TaskRegistryAPIClient, TaskType
from amt.schema.requirement import Requirement

logger = logging.getLogger(__name__)


class RequirementsService:
def __init__(self, repository: TaskRegistryRepository) -> None:
self.repository = repository

def fetch_measures(self, urns: str | Sequence[str] | None = None) -> list[Requirement]:
"""
Fetches measures with the given URNs.
If urns contains an URN that is not a valid URN of an measure, it is simply ignored.
@param urns: URNs of instruments to fetch. If None, function returns all measures.
@return: List of measures with the given URNs in 'urns'.
"""
task_data = self.repository.fetch_tasks(TaskType.REQUIREMENTS, urns)
return [Requirement(**data) for data in task_data]


def create_requirements_service() -> RequirementsService:
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)
return RequirementsService(repository)


class RequirementsService:
def __init__(self) -> None:
self.client = TaskRegistryAPIClient()

def fetch_requirements(self, urns: str | Sequence[str] | None = None) -> list[Requirement]:
"""
This functions returns requirement with given URN's. If urns contains an URN that is not a
valid URN of an requirement it is simply ignored.

@param: URN's of requirements to fetch. If empty, function returns all requirements.
@return: List of requirements with given URN's in 'urns'.
"""

if isinstance(urns, str):
urns = [urns]

all_valid_urns = self.fetch_urns()

if urns is not None:
return [
Requirement(**self.client.get_task_by_urn(TaskType.REQUIREMENTS, urn))
for urn in urns
if urn in all_valid_urns
]

return [Requirement(**self.client.get_task_by_urn(TaskType.REQUIREMENTS, urn)) for urn in all_valid_urns]

def fetch_urns(self) -> list[str]:
"""
Fetches all valid requirement URN's.
"""
content_list = self.client.get_list_of_task(TaskType.REQUIREMENTS)
return [content.urn for content in content_list.root]
4 changes: 2 additions & 2 deletions amt/api/routes/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from amt.schema.algorithm import AlgorithmNew
from amt.schema.localized_value_item import LocalizedValueItem
from amt.services.algorithms import AlgorithmsService, get_template_files
from amt.services.instruments import InstrumentsService
from amt.services.instruments import InstrumentsService, create_instrument_service

router = APIRouter()
logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -125,7 +125,7 @@ async def get_root(
@router.get("/new")
async def get_new(
request: Request,
instrument_service: Annotated[InstrumentsService, Depends(InstrumentsService)],
instrument_service: Annotated[InstrumentsService, Depends(create_instrument_service)],
) -> HTMLResponse:
sub_menu_items = resolve_navigation_items([Navigation.ALGORITHMS_OVERVIEW], request) # pyright: ignore [reportUnusedVariable] # noqa
breadcrumbs = resolve_base_navigation_items([Navigation.ALGORITHMS_ROOT, Navigation.ALGORITHM_NEW], request)
Expand Down
4 changes: 2 additions & 2 deletions amt/cli/check_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from amt.schema.instrument import Instrument
from amt.schema.system_card import SystemCard
from amt.services.instruments import InstrumentsService
from amt.services.instruments import create_instrument_service
from amt.services.instruments_and_requirements_state import all_lifecycles, get_all_next_tasks
from amt.services.storage import StorageFactory

Expand All @@ -29,7 +29,7 @@ def get_requested_instruments(all_instruments: list[Instrument], urns: list[str]
def get_tasks_by_priority(urns: list[str], system_card_path: Path) -> None:
try:
system_card = get_system_card(system_card_path)
instruments_service = InstrumentsService()
instruments_service = create_instrument_service()
all_instruments = instruments_service.fetch_instruments()
instruments = get_requested_instruments(all_instruments, urns)
next_tasks = get_all_next_tasks(instruments, system_card)
Expand Down
54 changes: 54 additions & 0 deletions amt/repositories/task_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import logging
from collections.abc import Sequence
from typing import Any

from amt.clients.clients import TaskRegistryAPIClient, TaskType
from amt.core.exceptions import AMTNotFound

logger = logging.getLogger(__name__)


class TaskRegistryRepository:
"""
Responsible for fetching tasks (instruments, measures, etc.) from the Task Registry API.
"""

def __init__(self, client: TaskRegistryAPIClient) -> None:
self.client = client

def fetch_tasks(self, task_type: TaskType, urns: str | Sequence[str] | None = None) -> list[dict[str, Any]]:
"""
Fetches tasks (instruments, measures, etc.) with the given URNs.
If urns contains an URN that is not a valid URN of a task, it is simply ignored.
@param task_type: The type of task to fetch (e.g. TaskType.INSTRUMENTS, TaskType.MEASURES).
@param urns: URNs of tasks to fetch. If None, function returns all tasks of the given type.
@return: List of task data dictionaries with the given URNs in 'urns'.
"""
if isinstance(urns, str):
urns = [urns]

all_valid_urns: list[str] = self.fetch_urns(task_type)

if urns is None:
return [self.client.get_task_by_urn(task_type, urn) for urn in all_valid_urns]

tasks: list[dict[str, Any]] = []
for urn in urns:
# For backward compatibilty of this method we now simply ignore invalid URN's.
# We might want to refactor this later to throw exceptions when task with URN is not
# found.
try:
tasks.append(self.client.get_task_by_urn(task_type, urn))
except AMTNotFound:
logger.warning(f"Cannot find {task_type.value} with URN {urn}")

return tasks

def fetch_urns(self, task_type: TaskType) -> list[str]:
"""
Fetches all valid URNs for the given task type.
"""
content_list = self.client.get_list_of_task(task_type)
return [content.urn for content in content_list.root]
4 changes: 2 additions & 2 deletions amt/services/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from amt.schema.algorithm import AlgorithmNew
from amt.schema.instrument import InstrumentBase
from amt.schema.system_card import AiActProfile, SystemCard
from amt.services.instruments import InstrumentsService
from amt.services.instruments import InstrumentsService, create_instrument_service
from amt.services.task_registry import get_requirements_and_measures
from amt.services.tasks import TasksService

Expand All @@ -29,7 +29,7 @@ def __init__(
self,
repository: Annotated[AlgorithmsRepository, Depends(AlgorithmsRepository)],
task_service: Annotated[TasksService, Depends(TasksService)],
instrument_service: Annotated[InstrumentsService, Depends(InstrumentsService)],
instrument_service: Annotated[InstrumentsService, Depends(create_instrument_service)],
) -> None:
self.repository = repository
self.instrument_service = instrument_service
Expand Down
43 changes: 13 additions & 30 deletions amt/services/instruments.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,28 @@
from collections.abc import Sequence

from amt.clients.clients import TaskRegistryAPIClient, TaskType
from amt.core.exceptions import AMTNotFound
from amt.repositories.task_registry import TaskRegistryRepository
from amt.schema.instrument import Instrument

logger = logging.getLogger(__name__)


class InstrumentsService:
def __init__(self) -> None:
self.client = TaskRegistryAPIClient()
def __init__(self, repository: TaskRegistryRepository) -> None:
self.repository = repository

def fetch_instruments(self, urns: str | Sequence[str] | None = None) -> list[Instrument]:
"""
This functions returns instruments with given URN's. If urns contains an URN that is not a
valid URN of an instrument it is simply ignored.
@param: URN's of instruments to fetch. If empty, function returns all instruments.
@return: List of instruments with given URN's in 'urns'.
Fetches instruments with the given URNs.
If urns contains an URN that is not a valid URN of an instrument, it is simply ignored.
@param urns: URNs of instruments to fetch. If None, function returns all instruments.
@return: List of instruments with the given URNs in 'urns'.
"""
task_data = self.repository.fetch_tasks(TaskType.INSTRUMENTS, urns)
return [Instrument(**data) for data in task_data]

if isinstance(urns, str):
urns = [urns]

all_valid_urns = self.fetch_urns()

if urns is None:
return [Instrument(**self.client.get_task_by_urn(TaskType.INSTRUMENTS, urn)) for urn in all_valid_urns]

instruments: list[Instrument] = []
for urn in urns:
try:
instruments.append(Instrument(**self.client.get_task_by_urn(TaskType.INSTRUMENTS, urn)))
except AMTNotFound:
logger.warning(f"cannot find instrument with URN {urn}")

return instruments

def fetch_urns(self) -> list[str]:
"""
Fetches all valid instrument URN's.
"""
content_list = self.client.get_list_of_task(task=TaskType.INSTRUMENTS)
return [content.urn for content in content_list.root]
def create_instrument_service() -> InstrumentsService:
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)
return InstrumentsService(repository)
4 changes: 2 additions & 2 deletions amt/services/instruments_and_requirements_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from amt.schema.instrument import Instrument, InstrumentTask
from amt.schema.requirement import Requirement
from amt.schema.system_card import SystemCard
from amt.services.instruments import InstrumentsService
from amt.services.instruments import create_instrument_service

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -132,7 +132,7 @@ def get_state_per_instrument(self) -> list[dict[str, int]]:
# Otherwise the instrument is completed as there are not any tasks left.

urns = [instrument.urn for instrument in self.system_card.instruments]
instruments_service = InstrumentsService()
instruments_service = create_instrument_service()
instruments = instruments_service.fetch_instruments(urns)
# TODO: refactor this data structure in 3 lines below (also change in get_all_next_tasks + check_state.py)
instruments_dict = {}
Expand Down
44 changes: 13 additions & 31 deletions amt/services/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,28 @@
from collections.abc import Sequence

from amt.clients.clients import TaskRegistryAPIClient, TaskType
from amt.core.exceptions import AMTNotFound
from amt.repositories.task_registry import TaskRegistryRepository
from amt.schema.measure import Measure

logger = logging.getLogger(__name__)


class MeasuresService:
def __init__(self) -> None:
self.client = TaskRegistryAPIClient()
def __init__(self, repository: TaskRegistryRepository) -> None:
self.repository = repository

def fetch_measures(self, urns: str | Sequence[str] | None = None) -> list[Measure]:
"""
This functions returns measures with given URN's. If urns contains an URN that is not a
valid URN of an measure it is simply ignored.
@param: URN's of measures to fetch. If empty, function returns all measures.
@return: List of measures with given URN's in 'urns'.
Fetches measures with the given URNs.
If urns contains an URN that is not a valid URN of an measure, it is simply ignored.
@param urns: URNs of instruments to fetch. If None, function returns all measures.
@return: List of measures with the given URNs in 'urns'.
"""
task_data = self.repository.fetch_tasks(TaskType.MEASURES, urns)
return [Measure(**data) for data in task_data]

if isinstance(urns, str):
urns = [urns]

all_valid_urns = self.fetch_urns()

if urns is None:
return [Measure(**self.client.get_task_by_urn(TaskType.MEASURES, urn)) for urn in all_valid_urns]

measures: list[Measure] = []
for urn in urns:
try:
logger.info(f"adding measure with URN {urn}")
measures.append(Measure(**self.client.get_task_by_urn(TaskType.MEASURES, urn)))
except AMTNotFound:
logger.warning(f"cannot find measure with URN {urn}")

return measures

def fetch_urns(self) -> list[str]:
"""
Fetches all valid measure URN's.
"""
content_list = self.client.get_list_of_task(TaskType.MEASURES)
return [content.urn for content in content_list.root]
def create_measures_service() -> MeasuresService:
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)
return MeasuresService(repository)
44 changes: 13 additions & 31 deletions amt/services/requirements.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,28 @@
from collections.abc import Sequence

from amt.clients.clients import TaskRegistryAPIClient, TaskType
from amt.core.exceptions import AMTNotFound
from amt.repositories.task_registry import TaskRegistryRepository
from amt.schema.requirement import Requirement

logger = logging.getLogger(__name__)


class RequirementsService:
def __init__(self) -> None:
self.client = TaskRegistryAPIClient()
def __init__(self, repository: TaskRegistryRepository) -> None:
self.repository = repository

def fetch_requirements(self, urns: str | Sequence[str] | None = None) -> list[Requirement]:
"""
This functions returns requirement with given URN's. If urns contains an URN that is not a
valid URN of an requirement it is simply ignored.
@param: URN's of requirements to fetch. If empty, function returns all requirements.
@return: List of requirements with given URN's in 'urns'.
Fetches measures with the given URNs.
If urns contains an URN that is not a valid URN of an measure, it is simply ignored.
@param urns: URNs of instruments to fetch. If None, function returns all measures.
@return: List of measures with the given URNs in 'urns'.
"""
task_data = self.repository.fetch_tasks(TaskType.REQUIREMENTS, urns)
return [Requirement(**data) for data in task_data]

if isinstance(urns, str):
urns = [urns]

all_valid_urns = self.fetch_urns()

if urns is None:
return [Requirement(**self.client.get_task_by_urn(TaskType.REQUIREMENTS, urn)) for urn in all_valid_urns]

requirements: list[Requirement] = []
for urn in urns:
try:
requirements.append(Requirement(**self.client.get_task_by_urn(TaskType.REQUIREMENTS, urn)))
except AMTNotFound:
logger.warning(f"cannot find instrument with URN {urn}")

return requirements


def fetch_urns(self) -> list[str]:
"""
Fetches all valid requirement URN's.
"""
content_list = self.client.get_list_of_task(TaskType.REQUIREMENTS)
return [content.urn for content in content_list.root]
def create_requirements_service() -> RequirementsService:
client = TaskRegistryAPIClient()
repository = TaskRegistryRepository(client)
return RequirementsService(repository)
Loading

0 comments on commit fd5d501

Please sign in to comment.