From d989ee1c35011aec710fa3bf7fb73441cfaccb3b Mon Sep 17 00:00:00 2001 From: Jesse Date: Sun, 24 Mar 2024 06:55:01 -0700 Subject: [PATCH] type hints, docstrings; RSRC class inherits from Node --- xerparser/schemas/rsrc.py | 22 ++++++++++++---------- xerparser/src/xer.py | 10 +++++++++- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/xerparser/schemas/rsrc.py b/xerparser/schemas/rsrc.py index e3dd7ae..a205f9e 100644 --- a/xerparser/schemas/rsrc.py +++ b/xerparser/schemas/rsrc.py @@ -2,30 +2,32 @@ # rsrc.py from typing import Any + +from xerparser.schemas._node import Node from xerparser.schemas.udftype import UDFTYPE -class RSRC: +class RSRC(Node): """ A class to represent a Resource. """ - def __init__(self, **data) -> None: + def __init__(self, **data: str) -> None: + super().__init__() self.uid: str = data["rsrc_id"] self.clndr_id: str = data["clndr_id"] self.name: str = data["rsrc_name"] + self.parent_rsrc_id: str = data["parent_rsrc_id"] self.short_name: str = data["rsrc_short_name"] self.type: str = data["rsrc_type"] self.user_defined_fields: dict[UDFTYPE, Any] = {} def __eq__(self, __o: "RSRC") -> bool: - return all( - ( - self.name == __o.name, - self.short_name == __o.short_name, - self.type == __o.type, - ) - ) + self.full_code == __o.full_code def __hash__(self) -> int: - return hash((self.name, self.short_name, self.type)) + return hash(self.full_code) + + @property + def full_code(self) -> str: + return ".".join(reversed([node.short_name for node in self.lineage])) diff --git a/xerparser/src/xer.py b/xerparser/src/xer.py index b698fc3..19053e5 100644 --- a/xerparser/src/xer.py +++ b/xerparser/src/xer.py @@ -56,7 +56,7 @@ def __init__(self, xer_file_contents: str) -> None: self.notebook_topics: dict[str, MEMOTYPE] = self._get_attr("MEMOTYPE") self.project_code_types: dict[str, PCATTYPE] = self._get_attr("PCATTYPE") self.project_code_values: dict[str, PCATVAL] = self._get_proj_code_values() - self.resources: dict[str, RSRC] = self._get_attr("RSRC") + self.resources: dict[str, RSRC] = self._get_rsrcs() self.sched_options: dict[str, SCHEDOPTIONS] = self._get_attr("SCHEDOPTIONS") self.udf_types: dict[str, UDFTYPE] = self._get_attr("UDFTYPE") self.projects = self._get_projects() @@ -147,6 +147,14 @@ def _get_relationships(self) -> dict[str, TASKPRED]: for rel in self.tables.get("TASKPRED", []) } + def _get_rsrcs(self) -> dict[str, RSRC]: + rsrcs: dict[str, RSRC] = self._get_attr("RSRC") + for rsrc in rsrcs.values(): + if rsrc.parent_rsrc_id: + rsrc.parent = rsrcs[rsrc.parent_rsrc_id] + rsrc.parent.addChild(rsrc) + return rsrcs + def _get_tasks(self) -> dict[str, TASK]: return { task["task_id"]: self._set_task(**task)