Skip to content

Commit

Permalink
recenter fields
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 23, 2024
1 parent e5ee8d3 commit cf4b994
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 100 deletions.
36 changes: 21 additions & 15 deletions src/ai_models/inputs/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,27 +7,33 @@

import logging

import earthkit.data as ekd
import tqdm
from earthkit.data.core.temporary import temp_file
from earthkit.data.indexing.fieldlist import FieldArray

from .transform import NewDataField
from .transform import NewMetadataField

LOG = logging.getLogger(__name__)

G = 9.80665 # Same a pgen


def make_z_from_gh(ds):

tmp = temp_file()

out = ekd.new_grib_output(tmp.path)
other = []

def make_z_from_gh(previous):
g = 9.80665 # Same a pgen
for f in tqdm.tqdm(ds, delay=0.5, desc="GH to Z", leave=False):

def _proc(ds):
if f.metadata("param") == "gh":
out.write(f.to_numpy() * G, template=f, param="z")
else:
other.append(f)

ds = previous(ds)
out.close()

result = []
for f in ds:
if f.metadata("param") == "gh":
result.append(NewMetadataField(NewDataField(f, f.to_numpy() * g), param="z"))
else:
result.append(f)
return FieldArray(result)
result = FieldArray(other) + ekd.from_source("file", tmp.path)
result._tmp = tmp

return _proc
return result
24 changes: 17 additions & 7 deletions src/ai_models/inputs/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,35 @@

import logging

import earthkit.data as ekd
import earthkit.regrid as ekr
import tqdm
from earthkit.data.indexing.fieldlist import FieldArray

from .transform import NewDataField
from earthkit.data.core.temporary import temp_file

LOG = logging.getLogger(__name__)


class Interpolate:
def __init__(self, grid, source):
def __init__(self, grid, source, metadata):
self.grid = list(grid) if isinstance(grid, tuple) else grid
self.source = list(source) if isinstance(source, tuple) else source
self.metadata = metadata

def __call__(self, ds):
tmp = temp_file()

out = ekd.new_grib_output(tmp.path)

result = []
for f in tqdm.tqdm(ds, delay=0.5, desc="Interpolating", leave=False):
data = ekr.interpolate(f.to_numpy(), dict(grid=self.source), dict(grid=self.grid))
result.append(NewDataField(f, data))
out.write(data, template=f, **self.metadata)

out.close()

result = ekd.from_source("file", tmp.path)
result._tmp = tmp

print("Interpolated data", tmp.path)

LOG.info("Interpolated %d fields. Input shape %s, output shape %s.", len(result), ds[0].shape, result[0].shape)
return FieldArray(result)
return result
181 changes: 103 additions & 78 deletions src/ai_models/inputs/opendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,23 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import datetime
import itertools
import logging
import os

import earthkit.data as ekd
from earthkit.data.core.temporary import temp_file
from earthkit.data.indexing.fieldlist import FieldArray
from multiurl import download

from .base import RequestBasedInput
from .compute import make_z_from_gh
from .interpolate import Interpolate
from .recenter import recenter
from .transform import NewMetadataField

LOG = logging.getLogger(__name__)


CONSTANTS = (
"z",
"sdor",
Expand All @@ -30,17 +30,33 @@

CONSTANTS_URL = "https://get.ecmwf.int/repository/test-data/ai-models/opendata/constants-{resol}.grib2"

RESOLS = {
(0.25, 0.25): ("0p25", (0.25, 0.25), False, False, {}),
(0.1, 0.1): (
"0p25",
(0.25, 0.25),
True,
True,
dict(
longitudeOfLastGridPointInDegrees=359.9,
iDirectionIncrementInDegrees=0.1,
jDirectionIncrementInDegrees=0.1,
Ni=3600,
Nj=1801,
),
),
# "N320": ("0p25", (0.25, 0.25), True, False, dict(gridType='reduced_gg')),
# "O96": ("0p25", (0.25, 0.25), True, False, dict(gridType='reduced_gg', )),
}


def _identity(x):
return x


class OpenDataInput(RequestBasedInput):
WHERE = "OPENDATA"

RESOLS = {
(0.25, 0.25): ("0p25", (0.25, 0.25), False, False),
"N320": ("0p25", (0.25, 0.25), True, False),
"O96": ("0p25", (0.25, 0.25), True, False),
(0.1, 0.1): ("0p25", (0.25, 0.25), True, True),
}

def __init__(self, owner, **kwargs):
self.owner = owner

Expand All @@ -56,7 +72,7 @@ def _adjust(self, kwargs):
if isinstance(grid, list):
grid = tuple(grid)

kwargs["resol"], source, interp, oversampling = self.RESOLS[grid]
kwargs["resol"], source, interp, oversampling, metadata = RESOLS[grid]
r = dict(**kwargs)
r.update(self.owner.retrieve)

Expand All @@ -65,12 +81,15 @@ def _adjust(self, kwargs):
logging.info("Interpolating input data from %s to %s.", source, grid)
if oversampling:
logging.warning("This will oversample the input data.")
return Interpolate(grid, source)
return Interpolate(grid, source, metadata)
else:
return lambda x: x
return _identity

def pl_load_source(self, **kwargs):
pproc = self._adjust(kwargs)

gh_to_z = _identity
interpolate = self._adjust(kwargs)

kwargs["levtype"] = "pl"
request = kwargs.copy()

Expand All @@ -84,13 +103,68 @@ def pl_load_source(self, **kwargs):
if "gh" not in param:
param.append("gh")
kwargs["param"] = param
pproc = make_z_from_gh(pproc)
gh_to_z = make_z_from_gh

logging.debug("load source ecmwf-open-data %s", kwargs)
return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request)

opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs))
opendata = gh_to_z(opendata)
opendata = interpolate(opendata)

return self.check_pl(opendata, request)

def constants(self, constant_params, request, kwargs):
if len(constant_params) == 1:
logging.warning(
f"Single level parameter '{constant_params[0]}' is"
" not available in ECMWF open data, using constants.grib2 instead"
)
else:
logging.warning(
f"Single level parameters {constant_params} are"
" not available in ECMWF open data, using constants.grib2 instead"
)

cachedir = os.path.expanduser("~/.cache/ai-models")
constants_url = CONSTANTS_URL.format(resol=request["resol"])
basename = os.path.basename(constants_url)

if not os.path.exists(cachedir):
os.makedirs(cachedir)

path = os.path.join(cachedir, basename)

if not os.path.exists(path):
logging.info("Downloading %s to %s", constants_url, path)
download(constants_url, path + ".tmp")
os.rename(path + ".tmp", path)

ds = ekd.from_source("file", path)
ds = ds.sel(param=constant_params)

tmp = temp_file()

out = ekd.new_grib_output(tmp.path)

for f in ds:
out.write(
f.to_numpy(),
template=f,
date=kwargs["date"],
time=kwargs["time"],
step=kwargs.get("step", 0),
)

out.close()

result = ekd.from_source("file", tmp.path)
result._tmp = tmp

return result

def sfc_load_source(self, **kwargs):
pproc = self._adjust(kwargs)
interpolate = self._adjust(kwargs)

kwargs["levtype"] = "sfc"
request = kwargs.copy()

Expand All @@ -104,81 +178,32 @@ def sfc_load_source(self, **kwargs):
param.remove(c)
constant_params.append(c)

constants = ekd.from_source("empty")

if constant_params:
if len(constant_params) == 1:
logging.warning(
f"Single level parameter '{constant_params[0]}' is"
" not available in ECMWF open data, using constants.grib2 instead"
)
else:
logging.warning(
f"Single level parameters {constant_params} are"
" not available in ECMWF open data, using constants.grib2 instead"
)
constants = []

cachedir = os.path.expanduser("~/.cache/ai-models")
constants_url = CONSTANTS_URL.format(resol=request["resol"])
basename = os.path.basename(constants_url)

if not os.path.exists(cachedir):
os.makedirs(cachedir)

path = os.path.join(cachedir, basename)

if not os.path.exists(path):
logging.info("Downloading %s to %s", constants_url, path)
download(constants_url, path + ".tmp")
os.rename(path + ".tmp", path)

ds = ekd.from_source("file", path)
ds = ds.sel(param=constant_params)

date = int(kwargs["date"])
time = int(kwargs["time"])
if time < 100:
time *= 100
step = int(kwargs.get("step", 0))
valid = datetime.datetime(
date // 10000, date // 100 % 100, date % 100, time // 100, time % 100
) + datetime.timedelta(hours=step)

for f in ds:

# assert False, (date, time, step)
constants.append(
NewMetadataField(
f,
valid_datetime=str(valid),
date=date,
time="%4d" % (time,),
step=step,
)
)

constants = FieldArray(constants)
constants = self.constants(constant_params, request, kwargs)
else:
constants = ekd.from_source("empty")

kwargs["param"] = param

logging.debug("load source ecmwf-open-data %s", kwargs)

fields = pproc(ekd.from_source("ecmwf-open-data", **kwargs) + constants)
opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs))
opendata = opendata + constants
opendata = interpolate(opendata)

# Fix grib2/eccodes bug

fields = FieldArray([NewMetadataField(f, levelist=None) for f in fields])
opendata = FieldArray([NewMetadataField(f, levelist=None) for f in opendata])

return self.check_sfc(fields, request)
return self.check_sfc(opendata, request)

def ml_load_source(self, **kwargs):
pproc = self._adjust(kwargs)
interpolate = self._adjust(kwargs)
kwargs["levtype"] = "ml"
request = kwargs.copy()

logging.debug("load source ecmwf-open-data %s", kwargs)
return self.check_ml(pproc(ekd.from_source("ecmwf-open-data", kwargs)), request)
opendata = recenter(ekd.from_source("ecmwf-open-data", **kwargs))
opendata = interpolate(opendata)

return self.check_ml(opendata, request)

def check_pl(self, ds, request):
self._check(ds, "PL", request, "param", "levelist")
Expand Down
Loading

0 comments on commit cf4b994

Please sign in to comment.