Skip to content

Commit

Permalink
abstraction of from_cdf complete, linkning not
Browse files Browse the repository at this point in the history
  • Loading branch information
Katrine Holm committed Aug 2, 2023
1 parent 9b75c6e commit 3951c54
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 173 deletions.
76 changes: 41 additions & 35 deletions cognite/powerops/resync/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import json
from abc import ABC
from pprint import pprint
from deepdiff import DeepDiff

from pathlib import Path
Expand Down Expand Up @@ -114,14 +113,21 @@ def relationships(self) -> list[Relationship]:
if not value:
continue
if isinstance_list(value, AssetType):
for target in value:
relationships.append(self._create_relationship(target.external_id, "ASSET", target.type_))
relationships.extend(
self._create_relationship(target.external_id, "ASSET", target.type_) for target in value
)
elif isinstance(value, AssetType):
target_type = value.type_
if self.type_ == "plant" and value.type_ == "reservoir":
target_type = "inlet_reservoir"
relationships.append(self._create_relationship(value.external_id, "ASSET", target_type))
elif any(cdf_type in str(field.annotation) for cdf_type in [CDFSequence.__name__, TimeSeries.__name__]):
elif any(
cdf_type in str(field.annotation)
for cdf_type in [
CDFSequence.__name__,
TimeSeries.__name__,
]
):
if TimeSeries.__name__ in str(field.annotation):
target_type = "TIMESERIES"
elif CDFSequence.__name__ in str(field.annotation):
Expand All @@ -130,10 +136,22 @@ def relationships(self) -> list[Relationship]:
raise ValueError(f"Unexpected type {field.annotation}")

if isinstance(value, list):
for target in value:
relationships.append(self._create_relationship(target.external_id, target_type, field_name))
relationships.extend(
self._create_relationship(
target.external_id,
target_type,
field_name,
)
for target in value
)
else:
relationships.append(self._create_relationship(value.external_id, target_type, field_name))
relationships.append(
self._create_relationship(
value.external_id,
target_type,
field_name,
)
)
return relationships

def as_asset(self):
Expand Down Expand Up @@ -199,14 +217,10 @@ def _create_relationship(
target_type=target_cdf_type,
labels=[Label(external_id=label.value)],
)

@classmethod
def _parse_asset_metadata(
cls,
metadata: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
def _parse_asset_metadata(cls, metadata: dict[str, Any] = None) -> dict[str, Any]:
raise NotImplementedError()


@classmethod
def _from_asset(
Expand All @@ -225,13 +239,12 @@ def _from_asset(
**metadata,
**additional_fields,
)

@classmethod
def _handle_asset_relationship(cls, target_external_id:str) -> T_Asset_Type:
def _handle_asset_relationship(cls, target_external_id: str) -> T_Asset_Type:
print("Need to find out how to handle asset relationships that dont yet exist")
...


@classmethod
def from_cdf(
cls,
Expand All @@ -240,7 +253,7 @@ def from_cdf(
asset: Optional[Asset] = None,
fetch_metadata: bool = True,
fetch_content: bool = False,
instantiated_assets: Optional[dict[str, AssetType]]=None,
instantiated_assets: Optional[dict[str, AssetType]] = None,
) -> T_Asset_Type:
"""
Fetch an asset from CDF and convert it to a model instance.
Expand All @@ -259,12 +272,13 @@ def from_cdf(
raise ValueError(f"Could not retrieve asset with {external_id=}")
if not instantiated_assets:
instantiated_assets = {}



# Prepare non-asset metadata fields
additional_fields = {
field: [] if 'list' in str(field_info.annotation) else None
field: [] if "list" in str(field_info.annotation) else None
for field, field_info in cls.model_fields.items()
if field in cls._asset_type_fields()
or any(cdf_type in str(field_info.annotation) for cdf_type in [CDFSequence.__name__, TimeSeries.__name__])
}

# Populate non-asset metadata fields according to relationships/flags
Expand Down Expand Up @@ -294,19 +308,14 @@ def from_cdf(
relationship_target = CDFFile.from_cdf(client, r.target_external_id, fetch_content)
else:
raise ValueError(f"Cannot handle target type {r.target_type}")



if isinstance(additional_fields[field], list):
additional_fields[field].append(relationship_target)
additional_fields[field].append(relationship_target)
else:
additional_fields[field] = relationship_target




return cls._from_asset(asset, additional_fields)


@classmethod
def _asset_type_fields(cls) -> Iterable[str]:
# Exclude fom model_dump in diff (ext_id only)
Expand Down Expand Up @@ -451,15 +460,15 @@ def from_cdf(
fetch_metadata: bool = True,
fetch_content: bool = False,
) -> T_Asset_Model:

if fetch_content and not fetch_metadata:
raise ValueError("Cannot fetch content without also fetching metadata")

# Instance of model as dict
output = defaultdict(list)
instantiated_assets: dict[str: AssetType ]= {}

# Cache to avoid fetching the same asset multiple times
instantiated_assets: dict[str:AssetType] = {}
for field_name, asset_cls in cls._asset_types_and_field_names():
# if asset_cls.parent_external_id not in ("plants"):
# continue
assets = client.assets.retrieve_subtree(external_id=asset_cls.parent_external_id)
for asset in assets:
if asset.external_id == asset_cls.parent_external_id:
Expand All @@ -479,13 +488,10 @@ def from_cdf(
def _prepare_for_diff(self: T_Asset_Model) -> dict[str:dict]:
raise NotImplementedError()

def difference(self: T_Asset_Model, other: T_Asset_Model, debug: bool = False) -> dict:
def difference(self: T_Asset_Model, other: T_Asset_Model) -> dict:
if type(self) != type(other):
raise ValueError("Cannot compare these models of different types.")

if debug:
return

self_dump = self._prepare_for_diff()
other_dump = other._prepare_for_diff()
str_builder = []
Expand Down
138 changes: 7 additions & 131 deletions cognite/powerops/resync/models/production.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,13 @@
from pathlib import Path
from typing import Any, ClassVar, Optional, Union

from cognite.client import CogniteClient
from cognite.client.data_classes import Asset, TimeSeries
from pydantic import ConfigDict, Field

from cognite.powerops.cdf_labels import AssetLabel
from cognite.powerops.resync.models.base import AssetModel, AssetType, NonAssetType
from cognite.powerops.resync.models.cdf_resources import CDFSequence
from cognite.powerops.resync.models.helpers import isinstance_list, match_field_from_relationship
from cognite.powerops.resync.models.helpers import isinstance_list


class Generator(AssetType):
Expand All @@ -31,22 +30,21 @@ def _parse_asset_metadata(cls, asset_metadata: dict[str, str]) -> dict[str, Any]
"penstock": asset_metadata.get("penstock", ""),
"startcost": float(asset_metadata.get("startcost", 0.0)),
}


class Reservoir(AssetType):
parent_external_id: ClassVar[str] = "reservoirs"
label: ClassVar[Union[AssetLabel, str]] = AssetLabel.RESERVOIR
display_name: str
ordering: str


@classmethod
def _parse_asset_metadata(cls, asset_metadata: dict[str, str]) -> dict[str, Any]:
return {
"display_name": asset_metadata.get("display_name", ""),
"ordering": asset_metadata.get("ordering", ""),
}


class Plant(AssetType):
parent_external_id: ClassVar[str] = "plants"
Expand All @@ -68,8 +66,6 @@ class Plant(AssetType):
inlet_level_time_series: Optional[TimeSeries] = None
head_direct_time_series: Optional[TimeSeries] = None



@classmethod
def _parse_asset_metadata(cls, asset_metadata: dict[str, str]) -> dict[str, Any]:
penstock_head_loss_factors_raw: str = asset_metadata.get("penstock_head_loss_factors", "")
Expand All @@ -87,54 +83,9 @@ def _parse_asset_metadata(cls, asset_metadata: dict[str, str]) -> dict[str, Any]
"outlet_level": float(asset_metadata.get("outlet_level", 0.0)),
"p_min": float(asset_metadata.get("p_min", 0.0)),
"p_max": float(asset_metadata.get("p_max", 0.0)),
"penstock_head_loss_factors": penstock_head_loss_factors
"penstock_head_loss_factors": penstock_head_loss_factors,
}

@classmethod
def _from_cdf(
cls,
client: CogniteClient,
external_id: Optional[str] = "",
asset: Optional[Asset] = None,
fetch_metadata: bool = True,
fetch_content: bool = False,
) -> Plant:
if asset and external_id:
raise ValueError("Only one of asset and external_id can be provided")
if external_id:
asset = client.assets.retrieve(external_id)
if not asset:
raise ValueError(f"Could not retrieve asset with {external_id=}")
cdf_fields = {
"generators": [],
"inlet_reservoir": None,
"p_min_time_series": None,
"p_max_time_series": None,
"water_value_time_series": None,
"feeding_fee_time_series": None,
"outlet_level_time_series": None,
"inlet_level_time_series": None,
"head_direct_time_series": None,
}
if fetch_metadata:
relationships = client.relationships.list(
source_external_ids=[asset.external_id],
source_types=["asset"],
target_types=["timeseries", "asset"],
limit=-1,
)
for r in relationships:
field = match_field_from_relationship(cls.model_fields.keys(), r)
if r.target_type.lower() == "asset":
# todo: handle later -- we only want to instantiate a class
# one per ext id. Probably a dict when re-written to high-level.
# finding the field is still its own challenge
cdf_fields[field] = None if field == "inlet_reservoir" else []

if r.target_type.lower() == "timeseries":
cdf_fields[field] = client.time_series.retrieve(external_id=r.target_external_id)
return cls._from_asset(asset, cdf_fields)


class WaterCourseShop(NonAssetType):
penalty_limit: str
Expand All @@ -154,96 +105,24 @@ class Watercourse(AssetType):
@classmethod
def _parse_asset_metadata(cls, asset_metadata: dict[str, str]) -> dict[str, Any]:
return {
"config_version": "",
"shop": WaterCourseShop(penalty_limit=asset_metadata.get("shop:penalty_limit", "")),
"model_file": None,
"processed_model_file": None,
}

@classmethod
def _from_cdf(
cls,
client: CogniteClient,
external_id: Optional[str] = "",
asset: Optional[Asset] = None,
fetch_metadata: bool = True,
fetch_content: bool = False,
) -> Watercourse:
if asset and external_id:
raise ValueError("Only one of asset and external_id can be provided")
if external_id:
asset = client.assets.retrieve(external_id)
if not asset:
raise ValueError(f"Could not retrieve asset with {external_id=}")
cdf_fields = {
"config_version": None,
"plants": [],
"production_obligation_time_series": [],
}
if fetch_metadata:
relationships = client.relationships.list(
source_external_ids=[asset.external_id],
source_types=["asset"],
target_types=["timeseries", "asset"],
limit=-1,
)
for r in relationships:
field = match_field_from_relationship(cls.model_fields.keys(), r)
if r.target_type.lower() == "asset":
cdf_fields[field] = []
if r.target_type.lower() == "timeseries":
cdf_fields[field] = client.time_series.retrieve(external_id=r.target_external_id)

return cls._from_asset(asset, cdf_fields)


class PriceArea(AssetType):
parent_external_id: ClassVar[str] = "price_areas"
label: ClassVar[Union[AssetLabel, str]] = AssetLabel.PRICE_AREA
dayahead_price_time_series: Optional[TimeSeries] = None
plants: list[Plant] = Field(default_factory=list)
watercourses: list[Watercourse] = Field(default_factory=list)

@classmethod
def _parse_asset_metadata(cls, asset_metadata: dict[str, str]) -> dict[str, Any]:
# In order to maintain the AssetType structure
# Maintain the AssetType structure
return {}


@classmethod
def _from_cdf(
cls,
client: CogniteClient,
external_id: Optional[str] = "",
asset: Optional[Asset] = None,
fetch_metadata: bool = True,
fetch_content: bool = False,
) -> PriceArea:
if asset and external_id:
raise ValueError("Only one of asset and external_id can be provided")
if external_id:
asset = client.assets.retrieve(external_id)
if not asset:
raise ValueError(f"Could not retrieve asset with {external_id=}")
cdf_fields = {
"dayahead_price_time_series": None,
"plants": [],
"watercourses": [],
}
if fetch_metadata:
relationships = client.relationships.list(
source_external_ids=[asset.external_id],
source_types=["asset"],
target_types=["timeseries", "asset"],
limit=-1,
)
for r in relationships:
field = match_field_from_relationship(cls.model_fields.keys(), r)
if r.target_type.lower() == "asset":
cdf_fields[field] = []
if r.target_type.lower() == "timeseries":
cdf_fields[field] = client.time_series.retrieve(external_id=r.target_external_id)

return cls._from_asset(asset, cdf_fields)


class ProductionModel(AssetModel):
Expand All @@ -260,14 +139,11 @@ def _prepare_for_diff(self: ProductionModel) -> dict:
for model_field in clone.model_fields:
field_value = getattr(clone, model_field)
if isinstance_list(field_value, AssetType):
# if isinstance(field_value, list) and field_value and isinstance(field_value[0], AssetType):
# Sort the asset types to have comparable order for diff
_sorted = sorted(field_value, key=lambda x: x.external_id)
# Prepare each asset type for diff
_prepared = map(lambda x: x._asset_type_prepare_for_diff(), _sorted)
setattr(clone, model_field, list(_prepared))
elif isinstance(field_value, AssetType):
# does not apply to this model, but
# might be used in a higher level of abstraction
field_value._asset_type_prepare_for_diff()
return clone.model_dump()
Loading

0 comments on commit 3951c54

Please sign in to comment.