Skip to content

Commit

Permalink
Support for step zero
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Feb 11, 2024
1 parent a35ea77 commit 5b08fa1
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 25 deletions.
2 changes: 1 addition & 1 deletion ai_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

__version__ = "0.3.5"
__version__ = "0.3.6"
47 changes: 45 additions & 2 deletions ai_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,10 +319,10 @@ def print_assets_list(self):

def _print_request(self, verb, request, file=sys.stdout):
r = [verb]
for k, v in request.items():
for k, v in sorted(request.items()):
if not isinstance(v, (list, tuple, set)):
v = [v]
v = [str(_) for _ in v]
v = [str(_) for _ in sorted(v)]
v = "/".join(v)
r.append(f"{k}={v}")

Expand Down Expand Up @@ -428,6 +428,49 @@ def gridpoints(self):
def start_datetime(self):
return self.all_fields.order_by(valid_datetime="ascending")[-1].datetime()

def write_input_fields(
self,
fields,
accumulations=None,
accumulations_template=None,
accumulations_shape=None,
ignore=None,
):
if ignore is None:
ignore = []

with self.timer("Writing step 0"):
for field in fields:
if field.metadata("shortName") in ignore:
continue

if field.valid_datetime() == self.start_datetime:
self.write(
None,
template=field,
step=0,
)

if accumulations is not None:
if accumulations_template is None:
accumulations_template = fields.sel(param="2t")[0]

if accumulations_shape is None:
accumulations_shape = accumulations_template.shape

for param in accumulations:
self.write(
np.zeros(accumulations_shape, dtype=np.float32),
stepType="accum",
template=accumulations_template,
param=param,
startStep=0,
endStep=0,
date=int(self.start_datetime.strftime("%Y%m%d")),
time=int(self.start_datetime.strftime("%H%M")),
check=True,
)


def load_model(name, **kwargs):
return available_models()[name].load()(**kwargs)
Expand Down
46 changes: 24 additions & 22 deletions ai_models/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class FileOutput:
def __init__(self, owner, path, metadata, **kwargs):
self._first = True
metadata.setdefault("stream", "oper")
metadata.setdefault("expver", owner.expver)
metadata.setdefault("class", "ml")

Expand All @@ -27,7 +28,8 @@ def __init__(self, owner, path, metadata, **kwargs):
edition = metadata.pop("edition", 2)

self.grib_keys = dict(
edition=edition, generatingProcessIdentifier=self.owner.version
edition=edition,
generatingProcessIdentifier=self.owner.version,
)
self.grib_keys.update(metadata)

Expand All @@ -37,9 +39,10 @@ def __init__(self, owner, path, metadata, **kwargs):
**self.grib_keys,
)

def write(self, data, *args, **kwargs):
def write(self, data, *args, check=False, **kwargs):
try:
return self.output.write(data, *args, **kwargs)
handle, path = self.output.write(data, *args, **kwargs)

except Exception:
if np.isnan(data).any():
raise ValueError(
Expand All @@ -51,6 +54,23 @@ def write(self, data, *args, **kwargs):
)
raise

if check:
# Check that the GRIB keys are as expected
for key, value in itertools.chain(self.grib_keys.items(), kwargs.items()):
if key in ("template",):
continue

# If "param" is a string, we what to compare it to the shortName
if key == "param":
try:
float(value)
except ValueError:
key = "shortName"

assert str(handle.get(key)) == str(value), (key, handle.get(key), value)

return handle, path


class HindcastReLabel:
def __init__(self, owner, output, hindcast_reference_year, **kwargs):
Expand Down Expand Up @@ -81,25 +101,7 @@ def write(self, *args, **kwargs):
kwargs["referenceDate"] = referenceDate
kwargs["hdate"] = date

handle, path = self.output.write(*args, **kwargs)

# Check that the GRIB keys are as expected
for key, value in itertools.chain(
self.output.grib_keys.items(), kwargs.items()
):
if key in ("template",):
continue

# If "param" is a string, we what to compare it to the shortName
if key == "param":
try:
float(value)
except ValueError:
key = "shortName"

assert str(handle.get(key)) == str(value), (key, handle.get(key), value)

return handle, path
return self.output.write(*args, check=True, **kwargs)


class NoneOutput:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def read(fname):
"multiurl",
"ecmwflibs>=0.6.1",
"gputil",
"earthkit-meteo",
],
extras_require={
"provenance": [
Expand Down

0 comments on commit 5b08fa1

Please sign in to comment.