diff --git a/ai_models/__init__.py b/ai_models/__init__.py index a9f10ab..c51499b 100644 --- a/ai_models/__init__.py +++ b/ai_models/__init__.py @@ -5,4 +5,4 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -__version__ = "0.3.5" +__version__ = "0.3.6" diff --git a/ai_models/model.py b/ai_models/model.py index 1ec5e1e..25ce083 100644 --- a/ai_models/model.py +++ b/ai_models/model.py @@ -319,10 +319,10 @@ def print_assets_list(self): def _print_request(self, verb, request, file=sys.stdout): r = [verb] - for k, v in request.items(): + for k, v in sorted(request.items()): if not isinstance(v, (list, tuple, set)): v = [v] - v = [str(_) for _ in v] + v = [str(_) for _ in sorted(v)] v = "/".join(v) r.append(f"{k}={v}") @@ -428,6 +428,49 @@ def gridpoints(self): def start_datetime(self): return self.all_fields.order_by(valid_datetime="ascending")[-1].datetime() + def write_input_fields( + self, + fields, + accumulations=None, + accumulations_template=None, + accumulations_shape=None, + ignore=None, + ): + if ignore is None: + ignore = [] + + with self.timer("Writing step 0"): + for field in fields: + if field.metadata("shortName") in ignore: + continue + + if field.valid_datetime() == self.start_datetime: + self.write( + None, + template=field, + step=0, + ) + + if accumulations is not None: + if accumulations_template is None: + accumulations_template = fields.sel(param="2t")[0] + + if accumulations_shape is None: + accumulations_shape = accumulations_template.shape + + for param in accumulations: + self.write( + np.zeros(accumulations_shape, dtype=np.float32), + stepType="accum", + template=accumulations_template, + param=param, + startStep=0, + endStep=0, + date=int(self.start_datetime.strftime("%Y%m%d")), + time=int(self.start_datetime.strftime("%H%M")), + check=True, + ) + def load_model(name, **kwargs): return available_models()[name].load()(**kwargs) diff --git a/ai_models/outputs/__init__.py b/ai_models/outputs/__init__.py index 377a54a..bc9dcc9 100644 --- a/ai_models/outputs/__init__.py +++ b/ai_models/outputs/__init__.py @@ -17,6 +17,7 @@ class FileOutput: def __init__(self, owner, path, metadata, **kwargs): self._first = True + metadata.setdefault("stream", "oper") metadata.setdefault("expver", owner.expver) metadata.setdefault("class", "ml") @@ -27,7 +28,8 @@ def __init__(self, owner, path, metadata, **kwargs): edition = metadata.pop("edition", 2) self.grib_keys = dict( - edition=edition, generatingProcessIdentifier=self.owner.version + edition=edition, + generatingProcessIdentifier=self.owner.version, ) self.grib_keys.update(metadata) @@ -37,9 +39,10 @@ def __init__(self, owner, path, metadata, **kwargs): **self.grib_keys, ) - def write(self, data, *args, **kwargs): + def write(self, data, *args, check=False, **kwargs): try: - return self.output.write(data, *args, **kwargs) + handle, path = self.output.write(data, *args, **kwargs) + except Exception: if np.isnan(data).any(): raise ValueError( @@ -51,6 +54,23 @@ def write(self, data, *args, **kwargs): ) raise + if check: + # Check that the GRIB keys are as expected + for key, value in itertools.chain(self.grib_keys.items(), kwargs.items()): + if key in ("template",): + continue + + # If "param" is a string, we what to compare it to the shortName + if key == "param": + try: + float(value) + except ValueError: + key = "shortName" + + assert str(handle.get(key)) == str(value), (key, handle.get(key), value) + + return handle, path + class HindcastReLabel: def __init__(self, owner, output, hindcast_reference_year, **kwargs): @@ -81,25 +101,7 @@ def write(self, *args, **kwargs): kwargs["referenceDate"] = referenceDate kwargs["hdate"] = date - handle, path = self.output.write(*args, **kwargs) - - # Check that the GRIB keys are as expected - for key, value in itertools.chain( - self.output.grib_keys.items(), kwargs.items() - ): - if key in ("template",): - continue - - # If "param" is a string, we what to compare it to the shortName - if key == "param": - try: - float(value) - except ValueError: - key = "shortName" - - assert str(handle.get(key)) == str(value), (key, handle.get(key), value) - - return handle, path + return self.output.write(*args, check=True, **kwargs) class NoneOutput: diff --git a/setup.py b/setup.py index c5329f9..2cc71d9 100644 --- a/setup.py +++ b/setup.py @@ -47,6 +47,7 @@ def read(fname): "multiurl", "ecmwflibs>=0.6.1", "gputil", + "earthkit-meteo", ], extras_require={ "provenance": [