Skip to content

Commit

Permalink
feat: wrote loading
Browse files Browse the repository at this point in the history
  • Loading branch information
doctrino committed Nov 9, 2024
1 parent d68b9cc commit afce70b
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 5 deletions.
4 changes: 2 additions & 2 deletions cognite/neat/_session/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(
client: CogniteClient | None = None,
storage: Literal["memory", "oxigraph"] = "memory",
verbose: bool = True,
load_engine: bool = True,
load_engine: Literal["newest", "cache", "skip"] = "cache",
) -> None:
self._client = client
self._verbose = verbose
Expand All @@ -44,7 +44,7 @@ def __init__(
self.show = ShowAPI(self._state)
self.set = SetAPI(self._state, verbose)
self.inspect = InspectAPI(self._state)
if load_engine and (engine_version := load_neat_engine(client)):
if load_engine != "skip" and (engine_version := load_neat_engine(client, load_engine)):
print(f"Neat Engine loaded: {engine_version}")

@property
Expand Down
115 changes: 112 additions & 3 deletions cognite/neat/_session/_engine.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,116 @@
import os
import re
import shutil
import sys
import tempfile
import warnings
from collections.abc import Callable
from pathlib import Path
from typing import Literal, cast

from cognite.client import CogniteClient
from packaging.version import Version
from packaging.version import parse as parse_version

from cognite.neat._version import __engine__

ENVIRONMENT_VARIABLE = "NEATENGINE"
PACKAGE_NAME = "neatengine"
PYVERSION = f"{sys.version_info.major}{sys.version_info.minor}"


def load_neat_engine(client: CogniteClient | None, location: Literal["newest", "cache"]) -> str | None:
cache_dir = Path(tempfile.gettempdir()) / PACKAGE_NAME
cache_dir.mkdir(exist_ok=True)
pattern = re.compile(rf"{PACKAGE_NAME}-(\d+\.\d+\.\d+)-{PYVERSION}.zip")

candidates: dict[Version, Callable[[], Path]] = {}
if location == "cache" and cache_dir.exists():
candidates = _load_from_path(cache_dir, pattern)

if location == "newest" or not candidates:
# Loading in revrse order of priority
# 3. Downloads folder
candidates = _load_from_path(Path.home() / "Downloads", pattern)
# 2. CDF
if client:
candidates.update(_load_from_cdf(client, pattern, cache_dir))
# 1. Environment variable
if ENVIRONMENT_VARIABLE in os.environ:
environ_path = Path(os.environ[ENVIRONMENT_VARIABLE])
if environ_path.exists():
candidates.update(_load_from_path(environ_path, pattern))
else:
warnings.warn(
f"Environment variable {ENVIRONMENT_VARIABLE} points to non-existing path: {environ_path}",
UserWarning,
stacklevel=2,
)

if not candidates:
return None

if not __engine__.startswith("^"):
raise ValueError(f"Invalid engine version: {__engine__}")

lower_bound = parse_version(__engine__[1:])
upper_bound = Version(f"{lower_bound.major+1}.0.0")
selected_version = max(
(version for version in candidates.keys() if lower_bound <= version < upper_bound), default=None
)
if not selected_version:
return None
source_path = candidates[selected_version]()
destination_path = cache_dir / source_path.name
if not destination_path.exists():
shutil.copy(source_path, destination_path)
sys.path.append(str(destination_path))
try:
from neatengine._version import __version__ as engine_version # type: ignore[import-not-found]
except ImportError:
return None
return engine_version


def _load_from_path(path: Path, pattern) -> dict[Version, Callable[[], Path]]:
if path.is_file() and (match := pattern.match(path.name)):
return {parse_version(match.group(1)): lambda: path}
elif path.is_dir():
output: dict[Version, Callable[[], Path]] = {}
for candidate in path.iterdir():
if candidate.is_file() and (match := pattern.match(candidate.name)):
# Setting default value to ensure we use the candidate from the current iteration
# If not set, the function will use the last candidate from the loop
def return_path(the_path: Path = candidate) -> Path:
return the_path

output[parse_version(match.group(1))] = return_path

return output
return {}


def _load_from_cdf(
client: CogniteClient, pattern: re.Pattern[str], cache_dir: Path
) -> dict[Version, Callable[[], Path]]:
file_metadata = client.files.list(
limit=-1,
data_set_external_ids=PACKAGE_NAME,
external_id_prefix=PACKAGE_NAME,
metadata={"python_version": PYVERSION},
)
output: dict[Version, Callable[[], Path]] = {}
for file in file_metadata:
name = cast(str, file.name)

ENVIRONMENT_VARIABLE = "NEAT_ENGINE"
# Use function to lazily download file
# Setting default value to ensure we use the file_id from the current iteration
# If not set, the function will use the last file_id from the loop
def download_file(file_id: int = file.id, filename: str = name) -> Path:
client.files.download(cache_dir, file_id)
return cache_dir / filename

if match := pattern.match(name):
output[parse_version(match.group(1))] = download_file

def load_neat_engine(client: CogniteClient | None = None) -> str:
raise NotImplementedError()
return output

0 comments on commit afce70b

Please sign in to comment.