diff --git a/ai_models/checkpoint.py b/ai_models/checkpoint.py index 9f8e7a7..eef0a56 100644 --- a/ai_models/checkpoint.py +++ b/ai_models/checkpoint.py @@ -5,11 +5,15 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import json +import logging import os import pickle import zipfile from typing import Any +LOG = logging.getLogger(__name__) + class FakeStorage: def __init__(self): @@ -46,6 +50,24 @@ def tidy(x): return x +def checkpoint_metadata(path, name="ai-models.json"): + with zipfile.ZipFile(path, "r") as f: + metadata = None + for b in f.namelist(): + if os.path.basename(b) == name: + if metadata is not None: + raise Exception( + f"Found two metadata.json files in {path}: {metadata} and {b}" + ) + metadata = b + + if metadata is not None: + with zipfile.ZipFile(path, "r") as f: + return json.load(f.open(metadata, "r")) + + return None + + def peek(path): with zipfile.ZipFile(path, "r") as f: data_pkl = None @@ -57,6 +79,9 @@ def peek(path): ) data_pkl = b + LOG.info(f"Found data.pkl at {data_pkl}") + + with zipfile.ZipFile(path, "r") as f: unpickler = UnpicklerWrapper(f.open(data_pkl, "r")) x = tidy(unpickler.load()) return tidy(x) diff --git a/ai_models/model.py b/ai_models/model.py index e52e2cc..5f89416 100644 --- a/ai_models/model.py +++ b/ai_models/model.py @@ -20,7 +20,7 @@ from climetlab.utils.humanize import seconds from multiurl import download -from .checkpoint import peek +from .checkpoint import peek, checkpoint_metadata from .inputs import get_input from .outputs import get_output from .stepper import Stepper @@ -396,6 +396,9 @@ def patch_retrieve_request(self, request): def peek_into_checkpoint(self, path): return peek(path) + def checkpoint_metadata(self, path): + return checkpoint_metadata(path) + def parse_model_args(self, args): if args: raise NotImplementedError(f"This model does not accept arguments {args}") diff --git a/ai_models/outputs/__init__.py b/ai_models/outputs/__init__.py index eeaa79c..d9af007 100644 --- a/ai_models/outputs/__init__.py +++ b/ai_models/outputs/__init__.py @@ -38,8 +38,18 @@ def __init__(self, owner, path, metadata, **kwargs): ) def write(self, data, *args, **kwargs): - assert not np.isnan(data).any(), (args, kwargs) - return self.output.write(data, *args, **kwargs) + try: + return self.output.write(data, *args, **kwargs) + except Exception: + if np.isnan(data).any(): + raise ValueError( + f"NaN values found in field. args={args} kwargs={kwargs}" + ) + if np.isinf(data).any(): + raise ValueError( + f"Infinite values found in field. args={args} kwargs={kwargs}" + ) + raise class HindcastReLabel: