From 410654f827c57125bc19648c4f5a7de6d1ef02bf Mon Sep 17 00:00:00 2001 From: anders-albert Date: Sat, 16 Nov 2024 02:38:51 +0100 Subject: [PATCH] refactor: lookup system containers --- cognite/neat/_rules/importers/_dms2rules.py | 10 ++++++++- cognite/neat/_rules/models/dms/_schema.py | 7 ++++++ cognite/neat/_session/_read.py | 25 +++++++++++++++++---- 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/cognite/neat/_rules/importers/_dms2rules.py b/cognite/neat/_rules/importers/_dms2rules.py index 122030b9e..79a7f3c05 100644 --- a/cognite/neat/_rules/importers/_dms2rules.py +++ b/cognite/neat/_rules/importers/_dms2rules.py @@ -1,5 +1,5 @@ from collections import Counter -from collections.abc import Collection, Sequence +from collections.abc import Collection, Iterable, Sequence from datetime import datetime, timezone from pathlib import Path from typing import Literal, cast @@ -86,6 +86,14 @@ def __init__( self._all_containers_by_id.update(schema.reference.containers.items()) self._all_views_by_id.update(schema.reference.views.items()) + def update_referenced_containers(self, containers: Iterable[dm.ContainerApply]) -> None: + """Update the referenced containers. This is useful to add Cognite containers identified after the root schema + is read""" + for container in containers: + if container.as_id() in self._all_containers_by_id: + continue + self._all_containers_by_id[container.as_id()] = container + @classmethod def from_data_model_id( cls, diff --git a/cognite/neat/_rules/models/dms/_schema.py b/cognite/neat/_rules/models/dms/_schema.py index 0aff93372..efae07916 100644 --- a/cognite/neat/_rules/models/dms/_schema.py +++ b/cognite/neat/_rules/models/dms/_schema.py @@ -708,6 +708,13 @@ def referenced_spaces(self, include_indirect_references: bool = True) -> set[str referenced_spaces |= {s.space for s in self.spaces.values()} return referenced_spaces + def referenced_container(self) -> set[dm.ContainerId]: + referenced_containers = { + container for view in self.views.values() for container in view.referenced_containers() + } + referenced_containers |= set(self.containers.keys()) + return referenced_containers + def as_read_model(self) -> dm.DataModel[dm.View]: if self.data_model is None: raise ValueError("Data model is not defined") diff --git a/cognite/neat/_session/_read.py b/cognite/neat/_session/_read.py index d5977de55..4d47b0973 100644 --- a/cognite/neat/_session/_read.py +++ b/cognite/neat/_session/_read.py @@ -6,6 +6,7 @@ from cognite.client import CogniteClient from cognite.client.data_classes.data_modeling import DataModelId, DataModelIdentifier +from cognite.neat._constants import COGNITE_SPACES from cognite.neat._graph import examples as instances_examples from cognite.neat._graph import extractors from cognite.neat._issues import IssueList @@ -158,10 +159,26 @@ def __call__(self, io: Any, format: Literal["neat", "toolkit"] = "neat") -> Issu importer: BaseImporter if format == "neat": importer = importers.YAMLImporter.from_file(reader.path) - elif format == "toolkit" and reader.path.is_file(): - importer = importers.DMSImporter.from_directory(reader.path) - elif format == "toolkit" and reader.path.is_dir(): - importer = importers.DMSImporter.from_directory(reader.path) + elif format == "toolkit": + if reader.path.is_file(): + dms_importer = importers.DMSImporter.from_zip_file(reader.path) + elif reader.path.is_dir(): + dms_importer = importers.DMSImporter.from_directory(reader.path) + else: + raise NeatValueError(f"Unsupported YAML format: {format}") + ref_containers = dms_importer.root_schema.referenced_container() + if system_container_ids := [ + container_id for container_id in ref_containers if container_id.space in COGNITE_SPACES + ]: + if self._client is None: + raise NeatValueError( + "No client provided. You are referencing Cognite containers in the data model," + "and a client is required to lookup the container definitions." + ) + system_containers = self._client.data_modeling.containers.retrieve(system_container_ids) + dms_importer.update_referenced_containers(system_containers) + + importer = dms_importer else: raise NeatValueError(f"Unsupported YAML format: {format}") input_rules: ReadRules = importer.to_rules()