diff --git a/src/ai_models/inputs/opendata.py b/src/ai_models/inputs/opendata.py index 9c3a1d9..16634c0 100644 --- a/src/ai_models/inputs/opendata.py +++ b/src/ai_models/inputs/opendata.py @@ -62,12 +62,12 @@ def _adjust(self, kwargs): if interp: logging.debug("Interpolating from %s to %s", source, grid) - return (Interpolate(grid, source), source) + return Interpolate(grid, source) else: - return (lambda x: x, source) + return lambda x: x def pl_load_source(self, **kwargs): - pproc, _ = self._adjust(kwargs) + pproc = self._adjust(kwargs) kwargs["levtype"] = "pl" request = kwargs.copy() @@ -75,7 +75,7 @@ def pl_load_source(self, **kwargs): assert isinstance(param, (list, tuple)) if "z" in param: - logging.warning("Parameter 'z' on pressure levels is not available in open data, using 'gh' instead") + logging.warning("Parameter 'z' on pressure levels is not available in ECMWF open data, using 'gh' instead") param = list(param) param.remove("z") if "gh" not in param: @@ -87,7 +87,7 @@ def pl_load_source(self, **kwargs): return self.check_pl(pproc(ekd.from_source("ecmwf-open-data", **kwargs)), request) def sfc_load_source(self, **kwargs): - pproc, resol = self._adjust(kwargs) + pproc = self._adjust(kwargs) kwargs["levtype"] = "sfc" request = kwargs.copy() @@ -107,18 +107,18 @@ def sfc_load_source(self, **kwargs): if len(constant_params) == 1: logging.warning( f"Single level parameter '{constant_params[0]}' is" - " not available in open data, using constants.grib2 instead" + " not available in ECMWF open data, using constants.grib2 instead" ) else: logging.warning( f"Single level parameters {constant_params} are" - " not available in open data, using constants.grib2 instead" + " not available in ECMWF open data, using constants.grib2 instead" ) constants = [] cachedir = os.path.expanduser("~/.cache/ai-models") - constant_url = CONSTANTS_URL.format(resol=resol) - basename = os.path.basename(constant_url) + constants_url = CONSTANTS_URL.format(resol=request["resol"]) + basename = os.path.basename(constants_url) if not os.path.exists(cachedir): os.makedirs(cachedir) @@ -126,8 +126,8 @@ def sfc_load_source(self, **kwargs): path = os.path.join(cachedir, basename) if not os.path.exists(path): - logging.info("Downloading %s to %s", constant_url, path) - download(constant_url, path + ".tmp") + logging.info("Downloading %s to %s", constants_url, path) + download(constants_url, path + ".tmp") os.rename(path + ".tmp", path) ds = ekd.from_source("file", path) @@ -191,6 +191,10 @@ def check_ml(self, ds, request): def _check(self, ds, what, request, *keys): + def _(p): + if len(p) == 1: + return p[0] + expected = set() for p in itertools.product(*[request[key] for key in keys]): expected.add(p) @@ -201,8 +205,14 @@ def _check(self, ds, what, request, *keys): missing = expected - found if missing: - raise ValueError(f"The following {what} parameters {missing} are not available in open data") + missing = [_(p) for p in missing] + if len(missing) == 1: + raise ValueError(f"The following {what} parameter '{missing[0]}' is not available in ECMWF open data") + raise ValueError(f"The following {what} parameters {missing} are not available in ECMWF open data") extra = found - expected if extra: - raise ValueError(f"Unexpected {what} parameters {extra} from open data") + extra = [_(p) for p in extra] + if len(extra) == 1: + raise ValueError(f"Unexpected {what} parameter '{extra[0]}' from ECMWF open data") + raise ValueError(f"Unexpected {what} parameters {extra} from ECMWF open data")