Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Dec 5, 2023
1 parent 44505db commit 3ce57c0
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 3 deletions.
25 changes: 25 additions & 0 deletions ai_models/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import json
import logging
import os
import pickle
import zipfile
from typing import Any

LOG = logging.getLogger(__name__)


class FakeStorage:
def __init__(self):
Expand Down Expand Up @@ -46,6 +50,24 @@ def tidy(x):
return x


def checkpoint_metadata(path, name="ai-models.json"):
with zipfile.ZipFile(path, "r") as f:
metadata = None
for b in f.namelist():
if os.path.basename(b) == name:
if metadata is not None:
raise Exception(
f"Found two metadata.json files in {path}: {metadata} and {b}"
)
metadata = b

if metadata is not None:
with zipfile.ZipFile(path, "r") as f:
return json.load(f.open(metadata, "r"))

return None


def peek(path):
with zipfile.ZipFile(path, "r") as f:
data_pkl = None
Expand All @@ -57,6 +79,9 @@ def peek(path):
)
data_pkl = b

LOG.info(f"Found data.pkl at {data_pkl}")

with zipfile.ZipFile(path, "r") as f:
unpickler = UnpicklerWrapper(f.open(data_pkl, "r"))
x = tidy(unpickler.load())
return tidy(x)
5 changes: 4 additions & 1 deletion ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from climetlab.utils.humanize import seconds
from multiurl import download

from .checkpoint import peek
from .checkpoint import peek, checkpoint_metadata
from .inputs import get_input
from .outputs import get_output
from .stepper import Stepper
Expand Down Expand Up @@ -396,6 +396,9 @@ def patch_retrieve_request(self, request):
def peek_into_checkpoint(self, path):
return peek(path)

def checkpoint_metadata(self, path):
return checkpoint_metadata(path)

def parse_model_args(self, args):
if args:
raise NotImplementedError(f"This model does not accept arguments {args}")
Expand Down
14 changes: 12 additions & 2 deletions ai_models/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,18 @@ def __init__(self, owner, path, metadata, **kwargs):
)

def write(self, data, *args, **kwargs):
assert not np.isnan(data).any(), (args, kwargs)
return self.output.write(data, *args, **kwargs)
try:
return self.output.write(data, *args, **kwargs)
except Exception:
if np.isnan(data).any():
raise ValueError(
f"NaN values found in field. args={args} kwargs={kwargs}"
)
if np.isinf(data).any():
raise ValueError(
f"Infinite values found in field. args={args} kwargs={kwargs}"
)
raise


class HindcastReLabel:
Expand Down

0 comments on commit 3ce57c0

Please sign in to comment.