diff --git a/src/ai_models/inputs/compute.py b/src/ai_models/inputs/compute.py index 41ece36..bd656b9 100644 --- a/src/ai_models/inputs/compute.py +++ b/src/ai_models/inputs/compute.py @@ -7,27 +7,33 @@ import logging +import earthkit.data as ekd +import tqdm +from earthkit.data.core.temporary import temp_file from earthkit.data.indexing.fieldlist import FieldArray -from .transform import NewDataField -from .transform import NewMetadataField - LOG = logging.getLogger(__name__) +G = 9.80665 # Same a pgen + + +def make_z_from_gh(ds): + + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + other = [] -def make_z_from_gh(previous): - g = 9.80665 # Same a pgen + for f in tqdm.tqdm(ds, delay=0.5, desc="GH to Z", leave=False): - def _proc(ds): + if f.metadata("param") == "gh": + out.write(f.to_numpy() * G, template=f, param="z") + else: + other.append(f) - ds = previous(ds) + out.close() - result = [] - for f in ds: - if f.metadata("param") == "gh": - result.append(NewMetadataField(NewDataField(f, f.to_numpy() * g), param="z")) - else: - result.append(f) - return FieldArray(result) + result = FieldArray(other) + ekd.from_source("file", tmp.path) + result._tmp = tmp - return _proc + return result diff --git a/src/ai_models/inputs/interpolate.py b/src/ai_models/inputs/interpolate.py index 3fec238..c5e3145 100644 --- a/src/ai_models/inputs/interpolate.py +++ b/src/ai_models/inputs/interpolate.py @@ -7,25 +7,35 @@ import logging +import earthkit.data as ekd import earthkit.regrid as ekr import tqdm -from earthkit.data.indexing.fieldlist import FieldArray - -from .transform import NewDataField +from earthkit.data.core.temporary import temp_file LOG = logging.getLogger(__name__) class Interpolate: - def __init__(self, grid, source): + def __init__(self, grid, source, metadata): self.grid = list(grid) if isinstance(grid, tuple) else grid self.source = list(source) if isinstance(source, tuple) else source + self.metadata = metadata def __call__(self, ds): + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + result = [] for f in tqdm.tqdm(ds, delay=0.5, desc="Interpolating", leave=False): data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid)) - result.append(NewDataField(f, data)) + out.write(data, template=f, **self.metadata) + + out.close() + + result = ekd.from_source("file", tmp.path) + result._tmp = tmp + + print("Interpolated data", tmp.path) - LOG.info("Interpolated %d fields. Input shape %s, output shape %s.", len(result), ds[0].shape, result[0].shape) - return FieldArray(result) + return result diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index 61d4e7b..5ea9dbb 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -5,23 +5,23 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -import datetime import itertools import logging import os import earthkit.data as ekd +from earthkit.data.core.temporary import temp_file from earthkit.data.indexing.fieldlist import FieldArray from multiurl import download from .base import RequestBasedInput from .compute import make_z_from_gh from .interpolate import Interpolate +from .recenter import recenter from .transform import NewMetadataField LOG = logging.getLogger(__name__) - CONSTANTS = ( "z", "sdor", @@ -30,17 +30,33 @@ CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants-{resol}.grib2" +RESOLS = { + (0.25, 0.25): ("0p25", (0.25, 0.25), False, False, {}), + (0.1, 0.1): ( + "0p25", + (0.25, 0.25), + True, + True, + dict( + longitudeOfLastGridPointInDegrees=359.9, + iDirectionIncrementInDegrees=0.1, + jDirectionIncrementInDegrees=0.1, + Ni=3600, + Nj=1801, + ), + ), + # "N320": ("0p25", (0.25, 0.25), True, False, dict(gridType='reduced_gg')), + # "O96": ("0p25", (0.25, 0.25), True, False, dict(gridType='reduced_gg', )), +} + + +def _identity(x): + return x + class OpenDataInput(RequestBasedInput): WHERE = "OPENDATA" - RESOLS = { - (0.25, 0.25): ("0p25", (0.25, 0.25), False, False), - "N320": ("0p25", (0.25, 0.25), True, False), - "O96": ("0p25", (0.25, 0.25), True, False), - (0.1, 0.1): ("0p25", (0.25, 0.25), True, True), - } - def __init__(self, owner, **kwargs): self.owner = owner @@ -56,7 +72,7 @@ def _adjust(self, kwargs): if isinstance(grid, list): grid = tuple(grid) - kwargs["resol"], source, interp, oversampling = self.RESOLS[grid] + kwargs["resol"], source, interp, oversampling, metadata = RESOLS[grid] r = dict(**kwargs) r.update(self.owner.retrieve) @@ -65,12 +81,15 @@ def _adjust(self, kwargs): logging.info("Interpolating input data from %s to %s.", source, grid) if oversampling: logging.warning("This will oversample the input data.") - return Interpolate(grid, source) + return Interpolate(grid, source, metadata) else: - return lambda x: x + return _identity def pl_load_source(self, **kwargs): - pproc = self._adjust(kwargs) + + gh_to_z = _identity + interpolate = self._adjust(kwargs) + kwargs["levtype"] = "pl" request = kwargs.copy() @@ -84,13 +103,68 @@ def pl_load_source(self, **kwargs): if "gh" not in param: param.append("gh") kwargs["param"] = param - pproc = make_z_from_gh(pproc) + gh_to_z = make_z_from_gh logging.debug("load source ecmwf-open-data %s", kwargs) - return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request) + + opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs)) + opendata = gh_to_z(opendata) + opendata = interpolate(opendata) + + return self.check_pl(opendata, request) + + def constants(self, constant_params, request, kwargs): + if len(constant_params) == 1: + logging.warning( + f"Single level parameter '{constant_params[0]}' is" + " not available in ECMWF open data, using constants.grib2 instead" + ) + else: + logging.warning( + f"Single level parameters {constant_params} are" + " not available in ECMWF open data, using constants.grib2 instead" + ) + + cachedir = os.path.expanduser("~/.cache/ai-models") + constants_url = CONSTANTS_URL.format(resol=request["resol"]) + basename = os.path.basename(constants_url) + + if not os.path.exists(cachedir): + os.makedirs(cachedir) + + path = os.path.join(cachedir, basename) + + if not os.path.exists(path): + logging.info("Downloading %s to %s", constants_url, path) + download(constants_url, path + ".tmp") + os.rename(path + ".tmp", path) + + ds = ekd.from_source("file", path) + ds = ds.sel(param=constant_params) + + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + + for f in ds: + out.write( + f.to_numpy(), + template=f, + date=kwargs["date"], + time=kwargs["time"], + step=kwargs.get("step", 0), + ) + + out.close() + + result = ekd.from_source("file", tmp.path) + result._tmp = tmp + + return result def sfc_load_source(self, **kwargs): - pproc = self._adjust(kwargs) + interpolate = self._adjust(kwargs) + kwargs["levtype"] = "sfc" request = kwargs.copy() @@ -104,81 +178,32 @@ def sfc_load_source(self, **kwargs): param.remove(c) constant_params.append(c) - constants = ekd.from_source("empty") - if constant_params: - if len(constant_params) == 1: - logging.warning( - f"Single level parameter '{constant_params[0]}' is" - " not available in ECMWF open data, using constants.grib2 instead" - ) - else: - logging.warning( - f"Single level parameters {constant_params} are" - " not available in ECMWF open data, using constants.grib2 instead" - ) - constants = [] - - cachedir = os.path.expanduser("~/.cache/ai-models") - constants_url = CONSTANTS_URL.format(resol=request["resol"]) - basename = os.path.basename(constants_url) - - if not os.path.exists(cachedir): - os.makedirs(cachedir) - - path = os.path.join(cachedir, basename) - - if not os.path.exists(path): - logging.info("Downloading %s to %s", constants_url, path) - download(constants_url, path + ".tmp") - os.rename(path + ".tmp", path) - - ds = ekd.from_source("file", path) - ds = ds.sel(param=constant_params) - - date = int(kwargs["date"]) - time = int(kwargs["time"]) - if time < 100: - time *= 100 - step = int(kwargs.get("step", 0)) - valid = datetime.datetime( - date // 10000, date // 100 % 100, date % 100, time // 100, time % 100 - ) + datetime.timedelta(hours=step) - - for f in ds: - - # assert False, (date, time, step) - constants.append( - NewMetadataField( - f, - valid_datetime=str(valid), - date=date, - time="%4d" % (time,), - step=step, - ) - ) - - constants = FieldArray(constants) + constants = self.constants(constant_params, request, kwargs) + else: + constants = ekd.from_source("empty") kwargs["param"] = param - logging.debug("load source ecmwf-open-data %s", kwargs) - - fields = pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants) + opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs)) + opendata = opendata + constants + opendata = interpolate(opendata) # Fix grib2/eccodes bug - fields = FieldArray([NewMetadataField(f, levelist=None) for f in fields]) + opendata = FieldArray([NewMetadataField(f, levelist=None) for f in opendata]) - return self.check_sfc(fields, request) + return self.check_sfc(opendata, request) def ml_load_source(self, **kwargs): - pproc = self._adjust(kwargs) + interpolate = self._adjust(kwargs) kwargs["levtype"] = "ml" request = kwargs.copy() - logging.debug("load source ecmwf-open-data %s", kwargs) - return self.check_ml(pproc(ekd.from_source("ecmwf-open-data", kwargs)), request) + opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs)) + opendata = interpolate(opendata) + + return self.check_ml(opendata, request) def check_pl(self, ds, request): self._check(ds, "PL", request, "param", "levelist") diff --git a/src/ai_models/inputs/recenter.py b/src/ai_models/inputs/recenter.py new file mode 100644 index 0000000..33bba3e --- /dev/null +++ b/src/ai_models/inputs/recenter.py @@ -0,0 +1,92 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging + +import earthkit.data as ekd +import numpy as np +import tqdm +from earthkit.data.core.temporary import temp_file + +LOG = logging.getLogger(__name__) + +CHECKED = set() + + +def _init_recenter(ds, f): + + # For now, we only support the 0.25x0.25 grid from OPENDATA (centered on the greenwich meridian) + + latitudeOfFirstGridPointInDegrees = f.metadata("latitudeOfFirstGridPointInDegrees") + longitudeOfFirstGridPointInDegrees = f.metadata("longitudeOfFirstGridPointInDegrees") + latitudeOfLastGridPointInDegrees = f.metadata("latitudeOfLastGridPointInDegrees") + longitudeOfLastGridPointInDegrees = f.metadata("longitudeOfLastGridPointInDegrees") + iDirectionIncrementInDegrees = f.metadata("iDirectionIncrementInDegrees") + jDirectionIncrementInDegrees = f.metadata("jDirectionIncrementInDegrees") + scanningMode = f.metadata("scanningMode") + Ni = f.metadata("Ni") + Nj = f.metadata("Nj") + + assert scanningMode == 0 + assert latitudeOfFirstGridPointInDegrees == 90 + assert longitudeOfFirstGridPointInDegrees == 180 + assert latitudeOfLastGridPointInDegrees == -90 + assert longitudeOfLastGridPointInDegrees == 179.75 + assert iDirectionIncrementInDegrees == 0.25 + assert jDirectionIncrementInDegrees == 0.25 + + assert Ni == 1440 + assert Nj == 721 + + shape = (Nj, Ni) + roll = -Ni // 2 + axis = 1 + + key = ( + latitudeOfFirstGridPointInDegrees, + longitudeOfFirstGridPointInDegrees, + latitudeOfLastGridPointInDegrees, + longitudeOfLastGridPointInDegrees, + iDirectionIncrementInDegrees, + jDirectionIncrementInDegrees, + Ni, + Nj, + ) + + ############################ + + if key not in CHECKED: + lon = ekd.from_source("forcings", ds, param=["longitude"], date=f.metadata("date"))[0] + assert np.all(np.roll(lon.to_numpy(), roll, axis=axis)[:, 0] == 0) + CHECKED.add(key) + + return (shape, roll, axis, dict(longitudeOfFirstGridPointInDegrees=0, longitudeOfLastGridPointInDegrees=359.75)) + + +def recenter(ds): + + tmp = temp_file() + + out = ekd.new_grib_output(tmp.path) + + for f in tqdm.tqdm(ds, delay=0.5, desc="Recentering", leave=False): + + shape, roll, axis, metadata = _init_recenter(ds, f) + + data = f.to_numpy() + assert data.shape == shape, (data.shape, shape) + + data = np.roll(data, roll, axis=axis) + + out.write(data, template=f, **metadata) + + out.close() + + result = ekd.from_source("file", tmp.path) + result._tmp = tmp + + return result