Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Sep 15, 2024
1 parent e27e6e7 commit 3d63cb0
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions src/ai_models/inputs/opendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,20 +62,20 @@ def _adjust(self, kwargs):

if interp:
logging.debug("Interpolating from %s to %s", source, grid)
return (Interpolate(grid, source), source)
return Interpolate(grid, source)
else:
return (lambda x: x, source)
return lambda x: x

def pl_load_source(self, **kwargs):
pproc, _ = self._adjust(kwargs)
pproc = self._adjust(kwargs)
kwargs["levtype"] = "pl"
request = kwargs.copy()

param = [p.lower() for p in kwargs["param"]]
assert isinstance(param, (list, tuple))

if "z" in param:
logging.warning("Parameter 'z' on pressure levels is not available in open data, using 'gh' instead")
logging.warning("Parameter 'z' on pressure levels is not available in ECMWF open data, using 'gh' instead")
param = list(param)
param.remove("z")
if "gh" not in param:
Expand All @@ -87,7 +87,7 @@ def pl_load_source(self, **kwargs):
return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request)

def sfc_load_source(self, **kwargs):
pproc, resol = self._adjust(kwargs)
pproc = self._adjust(kwargs)
kwargs["levtype"] = "sfc"
request = kwargs.copy()

Expand All @@ -107,27 +107,27 @@ def sfc_load_source(self, **kwargs):
if len(constant_params) == 1:
logging.warning(
f"Single level parameter '{constant_params[0]}' is"
" not available in open data, using constants.grib2 instead"
" not available in ECMWF open data, using constants.grib2 instead"
)
else:
logging.warning(
f"Single level parameters {constant_params} are"
" not available in open data, using constants.grib2 instead"
" not available in ECMWF open data, using constants.grib2 instead"
)
constants = []

cachedir = os.path.expanduser("~/.cache/ai-models")
constant_url = CONSTANTS_URL.format(resol=resol)
basename = os.path.basename(constant_url)
constants_url = CONSTANTS_URL.format(resol=request["resol"])
basename = os.path.basename(constants_url)

if not os.path.exists(cachedir):
os.makedirs(cachedir)

path = os.path.join(cachedir, basename)

if not os.path.exists(path):
logging.info("Downloading %s to %s", constant_url, path)
download(constant_url, path + ".tmp")
logging.info("Downloading %s to %s", constants_url, path)
download(constants_url, path + ".tmp")
os.rename(path + ".tmp", path)

ds = ekd.from_source("file", path)
Expand Down Expand Up @@ -191,6 +191,10 @@ def check_ml(self, ds, request):

def _check(self, ds, what, request, *keys):

def _(p):
if len(p) == 1:
return p[0]

expected = set()
for p in itertools.product(*[request[key] for key in keys]):
expected.add(p)
Expand All @@ -201,8 +205,14 @@ def _check(self, ds, what, request, *keys):

missing = expected - found
if missing:
raise ValueError(f"The following {what} parameters {missing} are not available in open data")
missing = [_(p) for p in missing]
if len(missing) == 1:
raise ValueError(f"The following {what} parameter '{missing[0]}' is not available in ECMWF open data")
raise ValueError(f"The following {what} parameters {missing} are not available in ECMWF open data")

extra = found - expected
if extra:
raise ValueError(f"Unexpected {what} parameters {extra} from open data")
extra = [_(p) for p in extra]
if len(extra) == 1:
raise ValueError(f"Unexpected {what} parameter '{extra[0]}' from ECMWF open data")
raise ValueError(f"Unexpected {what} parameters {extra} from ECMWF open data")

0 comments on commit 3d63cb0

Please sign in to comment.