Skip to content

Commit

Permalink
any order linkning of assets
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrine Holm committed Aug 2, 2023
1 parent 3951c54 commit 7d9e05d
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 22 deletions.
54 changes: 35 additions & 19 deletions cognite/powerops/resync/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class AssetType(ResourceType, ABC):
label: ClassVar[Union[AssetLabel, str]]
model_config: ClassVar[ConfigDict] = ConfigDict(arbitrary_types_allowed=True)
parent_description: ClassVar[Optional[str]] = None
_instantiated_assets: ClassVar[dict[str, AssetType]] = defaultdict(dict)

name: str
description: Optional[str] = None
_external_id: Optional[str] = None
Expand Down Expand Up @@ -231,19 +233,29 @@ def _from_asset(
if not additional_fields:
additional_fields = {}
metadata = cls._parse_asset_metadata(asset.metadata)

return cls(
instance = cls(
_external_id=asset.external_id,
name=asset.name,
description=asset.description,
**metadata,
**additional_fields,
)
AssetType._instantiated_assets[asset.external_id] = instance
return instance

@classmethod
def _handle_asset_relationship(cls, target_external_id: str) -> T_Asset_Type:
print("Need to find out how to handle asset relationships that dont yet exist")
...
def _find_asset_type_class(cls, field: str) -> TypingType[AssetType]:
for field_name in cls.model_fields:
class_ = cls.model_fields[field_name].annotation
if isinstance(class_, GenericAlias):
asset_resource_class = get_args(class_)[0]
if issubclass(asset_resource_class, AssetType) and field_name == field:
return asset_resource_class
elif get_origin(class_) is Union and type(None) in get_args(class_):
# Optional field `AssetType, not a list
asset_resource_class = get_args(class_)[0]
if issubclass(asset_resource_class, AssetType):
return get_args(class_)[0]

@classmethod
def from_cdf(
Expand All @@ -253,7 +265,6 @@ def from_cdf(
asset: Optional[Asset] = None,
fetch_metadata: bool = True,
fetch_content: bool = False,
instantiated_assets: Optional[dict[str, AssetType]] = None,
) -> T_Asset_Type:
"""
Fetch an asset from CDF and convert it to a model instance.
Expand All @@ -264,14 +275,17 @@ def from_cdf(
This can be enabled by setting `fetch_content=True`.
"""

if asset and external_id:
raise ValueError("Only one of asset and external_id can be provided")
if external_id:
asset = client.assets.retrieve(external_id)
# Check if asset has already been instantiated, eg. by a relationship
if external_id in AssetType._instantiated_assets:
return AssetType._instantiated_assets[external_id]
else:
asset = client.assets.retrieve(external_id=external_id)
if not asset:
raise ValueError(f"Could not retrieve asset with {external_id=}")
if not instantiated_assets:
instantiated_assets = {}

# Prepare non-asset metadata fields
additional_fields = {
Expand All @@ -295,10 +309,17 @@ def from_cdf(
relationship_target = None

if target_type == "asset":
if r.target_external_id in instantiated_assets:
relationship_target = instantiated_assets[r.target_external_id]
if r.target_external_id in AssetType._instantiated_assets:
relationship_target = AssetType._instantiated_assets[r.target_external_id]

else:
relationship_target = cls._handle_asset_relationship(target_external_id=r.target_external_id)
target_class = cls._find_asset_type_class(field=field)
relationship_target = target_class.from_cdf(
client=client,
external_id=r.target_external_id,
fetch_metadata=fetch_metadata,
fetch_content=fetch_content,
)

elif target_type == "timeseries":
relationship_target = client.time_series.retrieve(external_id=r.target_external_id)
Expand All @@ -307,7 +328,7 @@ def from_cdf(
elif target_type == "file":
relationship_target = CDFFile.from_cdf(client, r.target_external_id, fetch_content)
else:
raise ValueError(f"Cannot handle target type {r.target_type}")
raise ValueError(f"Cannot handle target type {r.target_type}")

if isinstance(additional_fields[field], list):
additional_fields[field].append(relationship_target)
Expand Down Expand Up @@ -463,11 +484,8 @@ def from_cdf(
if fetch_content and not fetch_metadata:
raise ValueError("Cannot fetch content without also fetching metadata")

# Instance of model as dict
output = defaultdict(list)

# Cache to avoid fetching the same asset multiple times
instantiated_assets: dict[str:AssetType] = {}
for field_name, asset_cls in cls._asset_types_and_field_names():
assets = client.assets.retrieve_subtree(external_id=asset_cls.parent_external_id)
for asset in assets:
Expand All @@ -478,10 +496,8 @@ def from_cdf(
asset=asset,
fetch_metadata=fetch_metadata,
fetch_content=fetch_content,
instantiated_assets=instantiated_assets,
)
output[field_name].append(instance)
instantiated_assets[asset.external_id] = instance

return cls(**output)

Expand All @@ -504,7 +520,7 @@ def difference(self: T_Asset_Model, other: T_Asset_Model) -> dict:
).to_dict():
diff_dict[model_field] = deep_diff
str_builder.extend(self._field_diff_str_builder(model_field, deep_diff, self_dump[model_field]))
# break

print("".join(str_builder))
return diff_dict

Expand Down
4 changes: 2 additions & 2 deletions cognite/powerops/resync/models/production.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,9 @@ def _parse_asset_metadata(cls, asset_metadata: dict[str, str]) -> dict[str, Any]

class ProductionModel(AssetModel):
root_asset: ClassVar[Asset] = Asset(external_id="power_ops", name="PowerOps")
reservoirs: list[Reservoir] = Field(default_factory=list)
generators: list[Generator] = Field(default_factory=list)
plants: list[Plant] = Field(default_factory=list)
generators: list[Generator] = Field(default_factory=list)
reservoirs: list[Reservoir] = Field(default_factory=list)
watercourses: list[Watercourse] = Field(default_factory=list)
price_areas: list[PriceArea] = Field(default_factory=list)

Expand Down
1 change: 0 additions & 1 deletion scripts/diff_production_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,3 @@ def main():

if __name__ == "__main__":
main()
# client = get_powerops_client().cdf

0 comments on commit 7d9e05d

Please sign in to comment.