From 3c8a1dea4f7033cd6c3fafd60ffe036b6e561966 Mon Sep 17 00:00:00 2001 From: b8raoult <53792887+b8raoult@users.noreply.github.com> Date: Thu, 3 Oct 2024 20:33:59 +0100 Subject: [PATCH] Feature/multi dates match (#64) * add support for constants * massive refactoring * update documentation --- .pre-commit-config.yaml | 2 +- CHANGELOG.md | 1 + docs/building/sources.rst | 1 + docs/building/sources/repeated_dates.rst | 25 + .../sources/yaml/repeated_dates1.yaml | 6 + .../sources/yaml/repeated_dates2.yaml | 6 + .../sources/yaml/repeated_dates3.yaml | 8 + .../sources/yaml/repeated_dates4.yaml | 9 + docs/index.rst | 6 +- src/anemoi/datasets/create/__init__.py | 6 +- src/anemoi/datasets/create/check.py | 6 + .../create/functions/sources/__init__.py | 8 +- .../create/functions/sources/accumulations.py | 1 + .../datasets/create/functions/sources/grib.py | 2 +- .../functions/sources/xarray/__init__.py | 13 +- src/anemoi/datasets/create/input.py | 1087 ----------------- src/anemoi/datasets/create/input/__init__.py | 69 ++ src/anemoi/datasets/create/input/action.py | 123 ++ src/anemoi/datasets/create/input/concat.py | 92 ++ src/anemoi/datasets/create/input/context.py | 59 + .../datasets/create/input/data_sources.py | 71 ++ src/anemoi/datasets/create/input/empty.py | 42 + src/anemoi/datasets/create/input/filter.py | 76 ++ src/anemoi/datasets/create/input/function.py | 122 ++ src/anemoi/datasets/create/input/join.py | 57 + src/anemoi/datasets/create/input/misc.py | 85 ++ src/anemoi/datasets/create/input/pipe.py | 33 + .../datasets/create/input/repeated_dates.py | 217 ++++ src/anemoi/datasets/create/input/result.py | 413 +++++++ src/anemoi/datasets/create/input/step.py | 99 ++ .../datasets/create/{ => input}/template.py | 42 - .../datasets/create/{ => input}/trace.py | 0 .../datasets/create/statistics/__init__.py | 2 +- src/anemoi/datasets/dates/__init__.py | 1 + src/anemoi/datasets/dates/groups.py | 16 +- src/anemoi/datasets/fields.py | 66 + 36 files changed, 1720 insertions(+), 1152 deletions(-) create mode 100644 docs/building/sources/repeated_dates.rst create mode 100644 docs/building/sources/yaml/repeated_dates1.yaml create mode 100644 docs/building/sources/yaml/repeated_dates2.yaml create mode 100644 docs/building/sources/yaml/repeated_dates3.yaml create mode 100644 docs/building/sources/yaml/repeated_dates4.yaml delete mode 100644 src/anemoi/datasets/create/input.py create mode 100644 src/anemoi/datasets/create/input/__init__.py create mode 100644 src/anemoi/datasets/create/input/action.py create mode 100644 src/anemoi/datasets/create/input/concat.py create mode 100644 src/anemoi/datasets/create/input/context.py create mode 100644 src/anemoi/datasets/create/input/data_sources.py create mode 100644 src/anemoi/datasets/create/input/empty.py create mode 100644 src/anemoi/datasets/create/input/filter.py create mode 100644 src/anemoi/datasets/create/input/function.py create mode 100644 src/anemoi/datasets/create/input/join.py create mode 100644 src/anemoi/datasets/create/input/misc.py create mode 100644 src/anemoi/datasets/create/input/pipe.py create mode 100644 src/anemoi/datasets/create/input/repeated_dates.py create mode 100644 src/anemoi/datasets/create/input/result.py create mode 100644 src/anemoi/datasets/create/input/step.py rename src/anemoi/datasets/create/{ => input}/template.py (58%) rename src/anemoi/datasets/create/{ => input}/trace.py (100%) create mode 100644 src/anemoi/datasets/fields.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e4341a3..4b9f40cd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,7 +44,7 @@ repos: hooks: - id: ruff # Next line if for documenation cod snippets - exclude: '^[^_].*_\.py$' + exclude: '.*/[^_].*_\.py$' args: - --line-length=120 - --fix diff --git a/CHANGELOG.md b/CHANGELOG.md index 8bfd4354..ea5ee595 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ Keep it human-readable, your future self will thank you! ### Added - Adding the user recipe in the dataset PR #59. +- Add `multi_dates_match` action in create. ### Changed diff --git a/docs/building/sources.rst b/docs/building/sources.rst index c67e6118..e327f949 100644 --- a/docs/building/sources.rst +++ b/docs/building/sources.rst @@ -28,6 +28,7 @@ The following `sources` are currently available: sources/netcdf sources/opendap sources/recentre + sources/repeated_dates sources/xarray-kerchunk sources/xarray-zarr sources/zenodo diff --git a/docs/building/sources/repeated_dates.rst b/docs/building/sources/repeated_dates.rst new file mode 100644 index 00000000..4fefb078 --- /dev/null +++ b/docs/building/sources/repeated_dates.rst @@ -0,0 +1,25 @@ +################ + repeated_dates +################ + +The generale format of the `repeated_dates` source is: + +.. literalinclude:: yaml/repeated_dates1.yaml + +********** + constant +********** + +.. literalinclude:: yaml/repeated_dates2.yaml + +************* + climatology +************* + +.. literalinclude:: yaml/repeated_dates3.yaml + +********* + closest +********* + +.. literalinclude:: yaml/repeated_dates4.yaml diff --git a/docs/building/sources/yaml/repeated_dates1.yaml b/docs/building/sources/yaml/repeated_dates1.yaml new file mode 100644 index 00000000..4d13e9bb --- /dev/null +++ b/docs/building/sources/yaml/repeated_dates1.yaml @@ -0,0 +1,6 @@ + +repeated_dates: + mode: mode + # ... parameters related to the mode ... + source: + # ... a source definition ... diff --git a/docs/building/sources/yaml/repeated_dates2.yaml b/docs/building/sources/yaml/repeated_dates2.yaml new file mode 100644 index 00000000..498312be --- /dev/null +++ b/docs/building/sources/yaml/repeated_dates2.yaml @@ -0,0 +1,6 @@ +repeated_dates: + mode: constant + source: + xarray-zarr: + url: dem.zarr + variable: dem diff --git a/docs/building/sources/yaml/repeated_dates3.yaml b/docs/building/sources/yaml/repeated_dates3.yaml new file mode 100644 index 00000000..8a1dbed3 --- /dev/null +++ b/docs/building/sources/yaml/repeated_dates3.yaml @@ -0,0 +1,8 @@ +repeated_dates: + mode: climatology + year: 2019 + day: 15 + source: + grib: + path: some/path/to/data.grib + param: [some_param] diff --git a/docs/building/sources/yaml/repeated_dates4.yaml b/docs/building/sources/yaml/repeated_dates4.yaml new file mode 100644 index 00000000..c9fb1e59 --- /dev/null +++ b/docs/building/sources/yaml/repeated_dates4.yaml @@ -0,0 +1,9 @@ +repeated_dates: + mode: closest + frequency: 24h + maximum: 30d + skip_all_nans: true + source: + grib: + path: path/to/data.grib + param: [some_param] diff --git a/docs/index.rst b/docs/index.rst index 1a5ff9cd..fbddc874 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -1,8 +1,8 @@ .. _index-page: -#################################### - Welcome to Anemoi's documentation! -#################################### +############################################# + Welcome to `anemoi-datasets` documentation! +############################################# .. warning:: diff --git a/src/anemoi/datasets/create/__init__.py b/src/anemoi/datasets/create/__init__.py index d1026de3..68f78a9e 100644 --- a/src/anemoi/datasets/create/__init__.py +++ b/src/anemoi/datasets/create/__init__.py @@ -25,9 +25,11 @@ from anemoi.utils.dates import frequency_to_timedelta from anemoi.utils.humanize import compress_dates from anemoi.utils.humanize import seconds_to_human +from earthkit.data.core.order import build_remapping from anemoi.datasets import MissingDateError from anemoi.datasets import open_dataset +from anemoi.datasets.create.input.trace import enable_trace from anemoi.datasets.create.persistent import build_storage from anemoi.datasets.data.misc import as_first_date from anemoi.datasets.data.misc import as_last_date @@ -309,7 +311,6 @@ def create_elements(self, config): def build_input_(main_config, output_config): - from earthkit.data.core.order import build_remapping builder = build_input( main_config.input, @@ -563,7 +564,7 @@ def _run(self): # assert isinstance(group[0], datetime.datetime), type(group[0]) LOG.debug(f"Building data for group {igroup}/{self.n_groups}") - result = self.input.select(dates=group) + result = self.input.select(group_of_dates=group) assert result.group_of_dates == group, (len(result.group_of_dates), len(group), group) # There are several groups. @@ -1031,7 +1032,6 @@ def run(self): def creator_factory(name, trace=None, **kwargs): if trace: - from anemoi.datasets.create.trace import enable_trace enable_trace(trace) diff --git a/src/anemoi/datasets/create/check.py b/src/anemoi/datasets/create/check.py index 58b63ae9..902f66e9 100644 --- a/src/anemoi/datasets/create/check.py +++ b/src/anemoi/datasets/create/check.py @@ -140,9 +140,15 @@ class StatisticsValueError(ValueError): def check_data_values(arr, *, name: str, log=[], allow_nans=False): + shape = arr.shape + if (isinstance(allow_nans, (set, list, tuple, dict)) and name in allow_nans) or allow_nans: arr = arr[~np.isnan(arr)] + if arr.size == 0: + warnings.warn(f"Empty array for {name} ({shape})") + return + assert arr.size > 0, (name, *log) min, max = arr.min(), arr.max() diff --git a/src/anemoi/datasets/create/functions/sources/__init__.py b/src/anemoi/datasets/create/functions/sources/__init__.py index 6192702a..d2c4c570 100644 --- a/src/anemoi/datasets/create/functions/sources/__init__.py +++ b/src/anemoi/datasets/create/functions/sources/__init__.py @@ -16,6 +16,10 @@ def _expand(paths): + + if not isinstance(paths, list): + paths = [paths] + for path in paths: if path.startswith("file://"): path = path[7:] @@ -40,8 +44,10 @@ def iterate_patterns(path, dates, **kwargs): given_paths = path if isinstance(path, list) else [path] dates = [d.isoformat() for d in dates] + if len(dates) > 0: + kwargs["date"] = dates for path in given_paths: - paths = Pattern(path, ignore_missing_keys=True).substitute(date=dates, **kwargs) + paths = Pattern(path, ignore_missing_keys=True).substitute(**kwargs) for path in _expand(paths): yield path, dates diff --git a/src/anemoi/datasets/create/functions/sources/accumulations.py b/src/anemoi/datasets/create/functions/sources/accumulations.py index 4f9b605e..b74eb33f 100644 --- a/src/anemoi/datasets/create/functions/sources/accumulations.py +++ b/src/anemoi/datasets/create/functions/sources/accumulations.py @@ -376,6 +376,7 @@ def accumulations(context, dates, **request): ("ea", "oper"): dict(data_accumulation_period=1, base_times=(6, 18)), ("ea", "enda"): dict(data_accumulation_period=3, base_times=(6, 18)), ("rr", "oper"): dict(data_accumulation_period=3, base_times=(0, 3, 6, 9, 12, 15, 18, 21)), + ("l5", "oper"): dict(data_accumulation_period=1, base_times=(0,)), } kwargs = KWARGS.get((class_, stream), {}) diff --git a/src/anemoi/datasets/create/functions/sources/grib.py b/src/anemoi/datasets/create/functions/sources/grib.py index 8b2b7e35..88a7c1e9 100644 --- a/src/anemoi/datasets/create/functions/sources/grib.py +++ b/src/anemoi/datasets/create/functions/sources/grib.py @@ -135,7 +135,7 @@ def execute(context, dates, path, latitudes=None, longitudes=None, *args, **kwar s = s.sel(valid_datetime=dates, **kwargs) ds = ds + s - if kwargs: + if kwargs and not context.partial_ok: check(ds, given_paths, valid_datetime=dates, **kwargs) if geography is not None: diff --git a/src/anemoi/datasets/create/functions/sources/xarray/__init__.py b/src/anemoi/datasets/create/functions/sources/xarray/__init__.py index 9e7cd77d..3afc899b 100644 --- a/src/anemoi/datasets/create/functions/sources/xarray/__init__.py +++ b/src/anemoi/datasets/create/functions/sources/xarray/__init__.py @@ -10,10 +10,9 @@ import logging from earthkit.data.core.fieldlist import MultiFieldList -from earthkit.data.indexing.fieldlist import FieldArray from anemoi.datasets.data.stores import name_to_zarr_store -from anemoi.datasets.utils.fields import NewMetadataField +from anemoi.datasets.utils.fields import NewMetadataField as NewMetadataField from .. import iterate_patterns from .fieldlist import XarrayFieldList @@ -31,7 +30,7 @@ def check(what, ds, paths, **kwargs): raise ValueError(f"Expected {count} fields, got {len(ds)} (kwargs={kwargs}, {what}s={paths})") -def load_one(emoji, context, dates, dataset, options={}, match_all_dates=False, flavour=None, **kwargs): +def load_one(emoji, context, dates, dataset, options={}, flavour=None, **kwargs): import xarray as xr """ @@ -52,12 +51,8 @@ def load_one(emoji, context, dates, dataset, options={}, match_all_dates=False, fs = XarrayFieldList.from_xarray(data, flavour) - if match_all_dates: - match = fs.sel(**kwargs) - result = [] - for date in dates: - result.append(FieldArray([NewMetadataField(f, valid_datetime=date) for f in match])) - result = MultiFieldList(result) + if len(dates) == 0: + return fs.sel(**kwargs) else: result = MultiFieldList([fs.sel(valid_datetime=date, **kwargs) for date in dates]) diff --git a/src/anemoi/datasets/create/input.py b/src/anemoi/datasets/create/input.py deleted file mode 100644 index 162da21c..00000000 --- a/src/anemoi/datasets/create/input.py +++ /dev/null @@ -1,1087 +0,0 @@ -# (C) Copyright 2023 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# -import datetime -import itertools -import logging -import math -import time -from collections import defaultdict -from copy import deepcopy -from functools import cached_property -from functools import wraps - -import numpy as np -from anemoi.utils.humanize import seconds_to_human -from anemoi.utils.humanize import shorten_list -from earthkit.data.core.fieldlist import FieldList -from earthkit.data.core.fieldlist import MultiFieldList -from earthkit.data.core.order import build_remapping - -from anemoi.datasets.dates import DatesProvider - -from .functions import import_function -from .template import Context -from .template import notify_result -from .template import resolve -from .template import substitute -from .trace import trace -from .trace import trace_datasource -from .trace import trace_select - -LOG = logging.getLogger(__name__) - - -def parse_function_name(name): - - if name.endswith("h") and name[:-1].isdigit(): - - if "-" in name: - name, delta = name.split("-") - sign = -1 - - elif "+" in name: - name, delta = name.split("+") - sign = 1 - - else: - return name, None - - assert delta[-1] == "h", (name, delta) - delta = sign * int(delta[:-1]) - return name, delta - - return name, None - - -def time_delta_to_string(delta): - assert isinstance(delta, datetime.timedelta), delta - seconds = delta.total_seconds() - hours = int(seconds // 3600) - assert hours * 3600 == seconds, delta - hours = abs(hours) - - if seconds > 0: - return f"plus_{hours}h" - if seconds == 0: - return "" - if seconds < 0: - return f"minus_{hours}h" - - -def is_function(name, kind): - name, _ = parse_function_name(name) - try: - import_function(name, kind) - return True - except ImportError as e: - print(e) - return False - - -def assert_fieldlist(method): - @wraps(method) - def wrapper(self, *args, **kwargs): - result = method(self, *args, **kwargs) - assert isinstance(result, FieldList), type(result) - return result - - return wrapper - - -def assert_is_fieldlist(obj): - assert isinstance(obj, FieldList), type(obj) - - -def _data_request(data): - date = None - params_levels = defaultdict(set) - params_steps = defaultdict(set) - - area = grid = None - - for field in data: - try: - if date is None: - date = field.metadata("valid_datetime") - - if field.metadata("valid_datetime") != date: - continue - - as_mars = field.metadata(namespace="mars") - if not as_mars: - continue - step = as_mars.get("step") - levtype = as_mars.get("levtype", "sfc") - param = as_mars["param"] - levelist = as_mars.get("levelist", None) - area = field.mars_area - grid = field.mars_grid - - if levelist is None: - params_levels[levtype].add(param) - else: - params_levels[levtype].add((param, levelist)) - - if step: - params_steps[levtype].add((param, step)) - except Exception: - LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True) - - def sort(old_dic): - new_dic = {} - for k, v in old_dic.items(): - new_dic[k] = sorted(list(v)) - return new_dic - - params_steps = sort(params_steps) - params_levels = sort(params_levels) - - return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) - - -class Action: - def __init__(self, context, action_path, /, *args, **kwargs): - if "args" in kwargs and "kwargs" in kwargs: - """We have: - args = [] - kwargs = {args: [...], kwargs: {...}} - move the content of kwargs to args and kwargs. - """ - assert len(kwargs) == 2, (args, kwargs) - assert not args, (args, kwargs) - args = kwargs.pop("args") - kwargs = kwargs.pop("kwargs") - - assert isinstance(context, ActionContext), type(context) - self.context = context - self.kwargs = kwargs - self.args = args - self.action_path = action_path - - @classmethod - def _short_str(cls, x): - x = str(x) - if len(x) < 1000: - return x - return x[:1000] + "..." - - def __repr__(self, *args, _indent_="\n", _inline_="", **kwargs): - more = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - more = more[:5000] - txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def select(self, dates, **kwargs): - self._raise_not_implemented() - - def _raise_not_implemented(self): - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_select(self, dates): - return f"{self.__class__.__name__}({shorten(dates)})" - - -def shorten(dates): - if isinstance(dates, (list, tuple)): - dates = [d.isoformat() for d in dates] - if len(dates) > 5: - return f"{dates[0]}...{dates[-1]}" - return dates - - -class Result: - empty = False - _coords_already_built = False - - def __init__(self, context, action_path, dates): - from anemoi.datasets.dates.groups import GroupOfDates - - assert isinstance(dates, GroupOfDates), dates - - assert isinstance(context, ActionContext), type(context) - assert isinstance(action_path, list), action_path - - self.context = context - self.group_of_dates = dates - self.action_path = action_path - - @property - @trace_datasource - def datasource(self): - self._raise_not_implemented() - - @property - def data_request(self): - """Returns a dictionary with the parameters needed to retrieve the data.""" - return _data_request(self.datasource) - - def get_cube(self): - trace("🧊", f"getting cube from {self.__class__.__name__}") - ds = self.datasource - - remapping = self.context.remapping - order_by = self.context.order_by - flatten_grid = self.context.flatten_grid - start = time.time() - LOG.debug("Sorting dataset %s %s", dict(order_by), remapping) - assert order_by, order_by - - patches = {"number": {None: 0}} - - try: - cube = ds.cube( - order_by, - remapping=remapping, - flatten_values=flatten_grid, - patches=patches, - ) - cube = cube.squeeze() - LOG.debug(f"Sorting done in {seconds_to_human(time.time()-start)}.") - except ValueError: - self.explain(ds, order_by, remapping=remapping, patches=patches) - # raise ValueError(f"Error in {self}") - exit(1) - - if LOG.isEnabledFor(logging.DEBUG): - LOG.debug("Cube shape: %s", cube) - for k, v in cube.user_coords.items(): - LOG.debug(" %s %s", k, shorten_list(v, max_length=10)) - - return cube - - def explain(self, ds, *args, remapping, patches): - - METADATA = ( - "date", - "time", - "step", - "hdate", - "valid_datetime", - "levtype", - "levelist", - "number", - "level", - "shortName", - "paramId", - "variable", - ) - - # We redo the logic here - print() - print("❌" * 40) - print() - if len(args) == 1 and isinstance(args[0], (list, tuple)): - args = args[0] - - # print("Executing", self.action_path) - # print("Dates:", compress_dates(self.dates)) - - names = [] - for a in args: - if isinstance(a, str): - names.append(a) - elif isinstance(a, dict): - names += list(a.keys()) - - print(f"Building a {len(names)}D hypercube using", names) - ds = ds.order_by(*args, remapping=remapping, patches=patches) - user_coords = ds.unique_values(*names, remapping=remapping, patches=patches, progress_bar=False) - - print() - print("Number of unique values found for each coordinate:") - for k, v in user_coords.items(): - print(f" {k:20}:", len(v), shorten_list(v, max_length=10)) - print() - user_shape = tuple(len(v) for k, v in user_coords.items()) - print("Shape of the hypercube :", user_shape) - print( - "Number of expected fields :", math.prod(user_shape), "=", " x ".join([str(i) for i in user_shape]) - ) - print("Number of fields in the dataset :", len(ds)) - print("Difference :", abs(len(ds) - math.prod(user_shape))) - print() - - remapping = build_remapping(remapping, patches) - expected = set(itertools.product(*user_coords.values())) - extra = set() - - if math.prod(user_shape) > len(ds): - print(f"This means that all the fields in the datasets do not exists for all combinations of {names}.") - - for f in ds: - metadata = remapping(f.metadata) - key = tuple(metadata(n, default=None) for n in names) - if key in expected: - expected.remove(key) - else: - extra.add(key) - - print("Missing fields:") - print() - for i, f in enumerate(sorted(expected)): - print(" ", f) - if i >= 9 and len(expected) > 10: - print("...", len(expected) - i - 1, "more") - break - - print("Extra fields:") - print() - for i, f in enumerate(sorted(extra)): - print(" ", f) - if i >= 9 and len(extra) > 10: - print("...", len(extra) - i - 1, "more") - break - - print() - print("Missing values:") - per_name = defaultdict(set) - for e in expected: - for n, v in zip(names, e): - per_name[n].add(v) - - for n, v in per_name.items(): - print(" ", n, len(v), shorten_list(sorted(v), max_length=10)) - print() - - print("Extra values:") - per_name = defaultdict(set) - for e in extra: - for n, v in zip(names, e): - per_name[n].add(v) - - for n, v in per_name.items(): - print(" ", n, len(v), shorten_list(sorted(v), max_length=10)) - print() - - print("To solve this issue, you can:") - print( - " - Provide a better selection, like 'step: 0' or 'level: 1000' to " - "reduce the number of selected fields." - ) - print( - " - Split the 'input' part in smaller sections using 'join', " - "making sure that each section represent a full hypercube." - ) - - else: - print(f"More fields in dataset that expected for {names}. " "This means that some fields are duplicated.") - duplicated = defaultdict(list) - for f in ds: - # print(f.metadata(namespace="default")) - metadata = remapping(f.metadata) - key = tuple(metadata(n, default=None) for n in names) - duplicated[key].append(f) - - print("Duplicated fields:") - print() - duplicated = {k: v for k, v in duplicated.items() if len(v) > 1} - for i, (k, v) in enumerate(sorted(duplicated.items())): - print(" ", k) - for f in v: - x = {k: f.metadata(k, default=None) for k in METADATA if f.metadata(k, default=None) is not None} - print(" ", f, x) - if i >= 9 and len(duplicated) > 10: - print("...", len(duplicated) - i - 1, "more") - break - - print() - print("To solve this issue, you can:") - print(" - Provide a better selection, like 'step: 0' or 'level: 1000'") - print(" - Change the way 'param' is computed using 'variable_naming' " "in the 'build' section.") - - print() - print("❌" * 40) - print() - exit(1) - - def __repr__(self, *args, _indent_="\n", **kwargs): - more = ",".join([str(a)[:5000] for a in args]) - more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) - - dates = " no-dates" - if self.group_of_dates is not None: - dates = f" {len(self.group_of_dates)} dates" - dates += " (" - dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.group_of_dates) - if len(dates) > 100: - dates = dates[:100] + "..." - dates += ")" - - more = more[:5000] - txt = f"{self.__class__.__name__}:{dates}{_indent_}{more}" - if _indent_: - txt = txt.replace("\n", "\n ") - return txt - - def _raise_not_implemented(self): - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - def _trace_datasource(self, *args, **kwargs): - return f"{self.__class__.__name__}({self.group_of_dates})" - - def build_coords(self): - if self._coords_already_built: - return - from_data = self.get_cube().user_coords - from_config = self.context.order_by - - keys_from_config = list(from_config.keys()) - keys_from_data = list(from_data.keys()) - assert keys_from_data == keys_from_config, f"Critical error: {keys_from_data=} != {keys_from_config=}. {self=}" - - variables_key = list(from_config.keys())[1] - ensembles_key = list(from_config.keys())[2] - - if isinstance(from_config[variables_key], (list, tuple)): - assert all([v == w for v, w in zip(from_data[variables_key], from_config[variables_key])]), ( - from_data[variables_key], - from_config[variables_key], - ) - - self._variables = from_data[variables_key] # "param_level" - self._ensembles = from_data[ensembles_key] # "number" - - first_field = self.datasource[0] - grid_points = first_field.grid_points() - - lats, lons = grid_points - - assert len(lats) == len(lons), (len(lats), len(lons), first_field) - assert len(lats) == math.prod(first_field.shape), (len(lats), first_field.shape, first_field) - - north = np.amax(lats) - south = np.amin(lats) - east = np.amax(lons) - west = np.amin(lons) - - assert -90 <= south <= north <= 90, (south, north, first_field) - assert (-180 <= west <= east <= 180) or (0 <= west <= east <= 360), ( - west, - east, - first_field, - ) - - grid_values = list(range(len(grid_points[0]))) - - self._grid_points = grid_points - self._resolution = first_field.resolution - self._grid_values = grid_values - self._field_shape = first_field.shape - self._proj_string = first_field.proj_string if hasattr(first_field, "proj_string") else None - - @property - def variables(self): - self.build_coords() - return self._variables - - @property - def ensembles(self): - self.build_coords() - return self._ensembles - - @property - def resolution(self): - self.build_coords() - return self._resolution - - @property - def grid_values(self): - self.build_coords() - return self._grid_values - - @property - def grid_points(self): - self.build_coords() - return self._grid_points - - @property - def field_shape(self): - self.build_coords() - return self._field_shape - - @property - def proj_string(self): - self.build_coords() - return self._proj_string - - @cached_property - def shape(self): - return [ - len(self.group_of_dates), - len(self.variables), - len(self.ensembles), - len(self.grid_values), - ] - - @cached_property - def coords(self): - return { - "dates": list(self.group_of_dates), - "variables": self.variables, - "ensembles": self.ensembles, - "values": self.grid_values, - } - - -class EmptyResult(Result): - empty = True - - def __init__(self, context, action_path, dates): - super().__init__(context, action_path + ["empty"], dates) - - @cached_property - @assert_fieldlist - @trace_datasource - def datasource(self): - from earthkit.data import from_source - - return from_source("empty") - - @property - def variables(self): - return [] - - -def _flatten(ds): - if isinstance(ds, MultiFieldList): - return [_tidy(f) for s in ds._indexes for f in _flatten(s)] - return [ds] - - -def _tidy(ds, indent=0): - if isinstance(ds, MultiFieldList): - - sources = [s for s in _flatten(ds) if len(s) > 0] - if len(sources) == 1: - return sources[0] - return MultiFieldList(sources) - return ds - - -class FunctionResult(Result): - def __init__(self, context, action_path, dates, action): - super().__init__(context, action_path, dates) - assert isinstance(action, Action), type(action) - self.action = action - - self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs)) - - def _trace_datasource(self, *args, **kwargs): - return f"{self.action.name}({self.group_of_dates})" - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self): - args, kwargs = resolve(self.context, (self.args, self.kwargs)) - - try: - return _tidy( - self.action.function( - FunctionContext(self), - list(self.group_of_dates), # Will provide a list of datetime objects - *args, - **kwargs, - ) - ) - except Exception: - LOG.error(f"Error in {self.action.function.__name__}", exc_info=True) - raise - - def __repr__(self): - try: - return f"{self.action.name}({self.group_of_dates})" - except Exception: - return f"{self.__class__.__name__}(unitialised)" - - @property - def function(self): - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - -class JoinResult(Result): - def __init__(self, context, action_path, dates, results, **kwargs): - super().__init__(context, action_path, dates) - self.results = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self): - ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - def __repr__(self): - content = "\n".join([str(i) for i in self.results]) - return super().__repr__(content) - - -class DateShiftAction(Action): - def __init__(self, context, action_path, delta, **kwargs): - super().__init__(context, action_path, **kwargs) - - if isinstance(delta, str): - if delta[0] == "-": - delta, sign = int(delta[1:]), -1 - else: - delta, sign = int(delta), 1 - delta = datetime.timedelta(hours=sign * delta) - assert isinstance(delta, int), delta - delta = datetime.timedelta(hours=delta) - self.delta = delta - - self.content = action_factory(kwargs, context, self.action_path + ["shift"]) - - @trace_select - def select(self, dates): - shifted_dates = [d + self.delta for d in dates] - result = self.content.select(shifted_dates) - return UnShiftResult(self.context, self.action_path, dates, result, action=self) - - def __repr__(self): - return super().__repr__(f"{self.delta}\n{self.content}") - - -class UnShiftResult(Result): - def __init__(self, context, action_path, dates, result, action): - super().__init__(context, action_path, dates) - # dates are the actual requested dates - # result does not have the same dates - self.action = action - self.result = result - - def _trace_datasource(self, *args, **kwargs): - return f"{self.action.delta}({shorten(self.dates)})" - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self): - from earthkit.data.indexing.fieldlist import FieldArray - - class DateShiftedField: - def __init__(self, field, delta): - self.field = field - self.delta = delta - - def metadata(self, key): - value = self.field.metadata(key) - if key == "param": - return value + "_" + time_delta_to_string(self.delta) - if key == "valid_datetime": - dt = datetime.datetime.fromisoformat(value) - new_dt = dt - self.delta - new_value = new_dt.isoformat() - return new_value - if key in ["date", "time", "step", "hdate"]: - raise NotImplementedError(f"metadata {key} not implemented when shifting dates") - return value - - def __getattr__(self, name): - return getattr(self.field, name) - - ds = self.result.datasource - ds = FieldArray([DateShiftedField(fs, self.action.delta) for fs in ds]) - return _tidy(ds) - - -class FunctionAction(Action): - def __init__(self, context, action_path, _name, **kwargs): - super().__init__(context, action_path, **kwargs) - self.name = _name - - @trace_select - def select(self, dates): - return FunctionResult(self.context, self.action_path, dates, action=self) - - @property - def function(self): - # name, delta = parse_function_name(self.name) - return import_function(self.name, "sources") - - def __repr__(self): - content = "" - content += ",".join([self._short_str(a) for a in self.args]) - content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()]) - content = self._short_str(content) - return super().__repr__(_inline_=content, _indent_=" ") - - def _trace_select(self, dates): - return f"{self.name}({shorten(dates)})" - - -class PipeAction(Action): - def __init__(self, context, action_path, *configs): - super().__init__(context, action_path, *configs) - assert len(configs) > 1, configs - current = action_factory(configs[0], context, action_path + ["0"]) - for i, c in enumerate(configs[1:]): - current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current) - self.last_step = current - - @trace_select - def select(self, dates): - return self.last_step.select(dates) - - def __repr__(self): - return super().__repr__(self.last_step) - - -class StepResult(Result): - def __init__(self, context, action_path, dates, action, upstream_result): - super().__init__(context, action_path, dates) - assert isinstance(upstream_result, Result), type(upstream_result) - self.upstream_result = upstream_result - self.action = action - - @property - @notify_result - @trace_datasource - def datasource(self): - raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") - - -class StepAction(Action): - result_class = None - - def __init__(self, context, action_path, previous_step, *args, **kwargs): - super().__init__(context, action_path, *args, **kwargs) - self.previous_step = previous_step - - @trace_select - def select(self, dates): - return self.result_class( - self.context, - self.action_path, - dates, - self, - self.previous_step.select(dates), - ) - - def __repr__(self): - return super().__repr__(self.previous_step, _inline_=str(self.kwargs)) - - -class StepFunctionResult(StepResult): - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self): - try: - return _tidy( - self.action.function( - FunctionContext(self), - self.upstream_result.datasource, - *self.action.args[1:], - **self.action.kwargs, - ) - ) - - except Exception: - LOG.error(f"Error in {self.action.name}", exc_info=True) - raise - - def _trace_datasource(self, *args, **kwargs): - return f"{self.action.name}({shorten(self.group_of_dates)})" - - -class FilterStepResult(StepResult): - @property - @notify_result - @assert_fieldlist - @trace_datasource - def datasource(self): - ds = self.upstream_result.datasource - ds = ds.sel(**self.action.kwargs) - return _tidy(ds) - - -class FilterStepAction(StepAction): - result_class = FilterStepResult - - -class FunctionStepAction(StepAction): - result_class = StepFunctionResult - - def __init__(self, context, action_path, previous_step, *args, **kwargs): - super().__init__(context, action_path, previous_step, *args, **kwargs) - self.name = args[0] - self.function = import_function(self.name, "filters") - - -class ConcatResult(Result): - def __init__(self, context, action_path, dates, results, **kwargs): - super().__init__(context, action_path, dates) - self.results = [r for r in results if not r.empty] - - @cached_property - @assert_fieldlist - @notify_result - @trace_datasource - def datasource(self): - ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource - for i in self.results: - ds += i.datasource - return _tidy(ds) - - @property - def variables(self): - """Check that all the results objects have the same variables.""" - variables = None - for f in self.results: - if f.empty: - continue - if variables is None: - variables = f.variables - assert variables == f.variables, (variables, f.variables) - assert variables is not None, self.results - return variables - - def __repr__(self): - content = "\n".join([str(i) for i in self.results]) - return super().__repr__(content) - - -class DataSourcesResult(Result): - def __init__(self, context, action_path, dates, input_result, sources_results): - super().__init__(context, action_path, dates) - # result is the main input result - self.input_result = input_result - # sources_results is the list of the sources_results - self.sources_results = sources_results - - @cached_property - def datasource(self): - for i in self.sources_results: - # for each result trigger the datasource to be computed - # and saved in context - self.context.notify_result(i.action_path[:-1], i.datasource) - # then return the input result - # which can use the datasources of the included results - return _tidy(self.input_result.datasource) - - -class DataSourcesAction(Action): - def __init__(self, context, action_path, sources, input): - super().__init__(context, ["data_sources"], *sources) - if isinstance(sources, dict): - configs = [(str(k), c) for k, c in sources.items()] - elif isinstance(sources, list): - configs = [(str(i), c) for i, c in enumerate(sources)] - else: - raise ValueError(f"Invalid data_sources, expecting list or dict, got {type(sources)}: {sources}") - - self.sources = [action_factory(config, context, ["data_sources"] + [a_path]) for a_path, config in configs] - self.input = action_factory(input, context, ["input"]) - - def select(self, dates): - sources_results = [a.select(dates) for a in self.sources] - return DataSourcesResult( - self.context, - self.action_path, - dates, - self.input.select(dates), - sources_results, - ) - - def __repr__(self): - content = "\n".join([str(i) for i in self.sources]) - return super().__repr__(content) - - -class ConcatAction(Action): - def __init__(self, context, action_path, *configs): - super().__init__(context, action_path, *configs) - parts = [] - for i, cfg in enumerate(configs): - if "dates" not in cfg: - raise ValueError(f"Missing 'dates' in {cfg}") - cfg = deepcopy(cfg) - dates_cfg = cfg.pop("dates") - assert isinstance(dates_cfg, dict), dates_cfg - filtering_dates = DatesProvider.from_config(**dates_cfg) - action = action_factory(cfg, context, action_path + [str(i)]) - parts.append((filtering_dates, action)) - self.parts = parts - - def __repr__(self): - content = "\n".join([str(i) for i in self.parts]) - return super().__repr__(content) - - @trace_select - def select(self, dates): - from anemoi.datasets.dates.groups import GroupOfDates - - results = [] - for filtering_dates, action in self.parts: - newdates = GroupOfDates(sorted(set(dates) & set(filtering_dates)), dates.provider) - if newdates: - results.append(action.select(newdates)) - if not results: - return EmptyResult(self.context, self.action_path, dates) - - return ConcatResult(self.context, self.action_path, dates, results) - - -class JoinAction(Action): - def __init__(self, context, action_path, *configs): - super().__init__(context, action_path, *configs) - self.actions = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] - - def __repr__(self): - content = "\n".join([str(i) for i in self.actions]) - return super().__repr__(content) - - @trace_select - def select(self, dates): - results = [a.select(dates) for a in self.actions] - return JoinResult(self.context, self.action_path, dates, results) - - -def action_factory(config, context, action_path): - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - if len(config) != 1: - raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") - - config = deepcopy(config) - key = list(config.keys())[0] - - if isinstance(config[key], list): - args, kwargs = config[key], {} - elif isinstance(config[key], dict): - args, kwargs = [], config[key] - else: - raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}") - - cls = { - # "date_shift": DateShiftAction, - # "date_filter": DateFilterAction, - "data_sources": DataSourcesAction, - "concat": ConcatAction, - "join": JoinAction, - "pipe": PipeAction, - "function": FunctionAction, - }.get(key) - - if cls is None: - if not is_function(key, "sources"): - raise ValueError(f"Unknown action '{key}' in {config}") - cls = FunctionAction - args = [key] + args - - return cls(context, action_path + [key], *args, **kwargs) - - -def step_factory(config, context, action_path, previous_step): - assert isinstance(context, Context), (type, context) - if not isinstance(config, dict): - raise ValueError(f"Invalid input config {config}") - - config = deepcopy(config) - assert len(config) == 1, config - - key = list(config.keys())[0] - cls = dict( - filter=FilterStepAction, - # rename=RenameAction, - # remapping=RemappingAction, - ).get(key) - - if isinstance(config[key], list): - args, kwargs = config[key], {} - - if isinstance(config[key], dict): - args, kwargs = [], config[key] - - if isinstance(config[key], str): - args, kwargs = [config[key]], {} - - if cls is None: - if not is_function(key, "filters"): - raise ValueError(f"Unknown step {key}") - cls = FunctionStepAction - args = [key] + args - # print("========", args) - - return cls(context, action_path, previous_step, *args, **kwargs) - - -class FunctionContext: - """A FunctionContext is passed to all functions, it will be used to pass information - to the functions from the other actions and filters and results. - """ - - def __init__(self, owner): - self.owner = owner - self.use_grib_paramid = owner.context.use_grib_paramid - - def trace(self, emoji, *args): - trace(emoji, *args) - - def info(self, *args, **kwargs): - LOG.info(*args, **kwargs) - - @property - def dates_provider(self): - return self.owner.group_of_dates.provider - - -class ActionContext(Context): - def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid): - super().__init__() - self.order_by = order_by - self.flatten_grid = flatten_grid - self.remapping = build_remapping(remapping) - self.use_grib_paramid = use_grib_paramid - - -class InputBuilder: - def __init__(self, config, data_sources, **kwargs): - self.kwargs = kwargs - - config = deepcopy(config) - if data_sources: - config = dict( - data_sources=dict( - sources=data_sources, - input=config, - ) - ) - self.config = config - self.action_path = ["input"] - - @trace_select - def select(self, dates): - """This changes the context.""" - context = ActionContext(**self.kwargs) - action = action_factory(self.config, context, self.action_path) - return action.select(dates) - - def __repr__(self): - context = ActionContext(**self.kwargs) - a = action_factory(self.config, context, self.action_path) - return repr(a) - - def _trace_select(self, dates): - return f"InputBuilder({shorten(dates)})" - - -build_input = InputBuilder diff --git a/src/anemoi/datasets/create/input/__init__.py b/src/anemoi/datasets/create/input/__init__.py new file mode 100644 index 00000000..d23f038d --- /dev/null +++ b/src/anemoi/datasets/create/input/__init__.py @@ -0,0 +1,69 @@ +# (C) Copyright 2023 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import datetime +import itertools +import logging +import math +import time +from collections import defaultdict +from copy import deepcopy +from functools import cached_property +from functools import wraps + +import numpy as np +from anemoi.utils.dates import as_datetime as as_datetime +from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta + +from anemoi.datasets.dates import DatesProvider as DatesProvider +from anemoi.datasets.fields import FieldArray as FieldArray +from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField + +from .trace import trace_select + +LOG = logging.getLogger(__name__) + + +class InputBuilder: + def __init__(self, config, data_sources, **kwargs): + self.kwargs = kwargs + + config = deepcopy(config) + if data_sources: + config = dict( + data_sources=dict( + sources=data_sources, + input=config, + ) + ) + self.config = config + self.action_path = ["input"] + + @trace_select + def select(self, group_of_dates): + from .action import ActionContext + from .action import action_factory + + """This changes the context.""" + context = ActionContext(**self.kwargs) + action = action_factory(self.config, context, self.action_path) + return action.select(group_of_dates) + + def __repr__(self): + from .action import ActionContext + from .action import action_factory + + context = ActionContext(**self.kwargs) + a = action_factory(self.config, context, self.action_path) + return repr(a) + + def _trace_select(self, group_of_dates): + return f"InputBuilder({group_of_dates})" + + +build_input = InputBuilder diff --git a/src/anemoi/datasets/create/input/action.py b/src/anemoi/datasets/create/input/action.py new file mode 100644 index 00000000..e486d561 --- /dev/null +++ b/src/anemoi/datasets/create/input/action.py @@ -0,0 +1,123 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +from copy import deepcopy + +from anemoi.utils.dates import as_datetime as as_datetime +from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta +from earthkit.data.core.order import build_remapping + +from anemoi.datasets.dates import DatesProvider as DatesProvider +from anemoi.datasets.fields import FieldArray as FieldArray +from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField + +from .context import Context +from .misc import is_function + +LOG = logging.getLogger(__name__) + + +class Action: + def __init__(self, context, action_path, /, *args, **kwargs): + if "args" in kwargs and "kwargs" in kwargs: + """We have: + args = [] + kwargs = {args: [...], kwargs: {...}} + move the content of kwargs to args and kwargs. + """ + assert len(kwargs) == 2, (args, kwargs) + assert not args, (args, kwargs) + args = kwargs.pop("args") + kwargs = kwargs.pop("kwargs") + + assert isinstance(context, ActionContext), type(context) + self.context = context + self.kwargs = kwargs + self.args = args + self.action_path = action_path + + @classmethod + def _short_str(cls, x): + x = str(x) + if len(x) < 1000: + return x + return x[:1000] + "..." + + def __repr__(self, *args, _indent_="\n", _inline_="", **kwargs): + more = ",".join([str(a)[:5000] for a in args]) + more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) + + more = more[:5000] + txt = f"{self.__class__.__name__}: {_inline_}{_indent_}{more}" + if _indent_: + txt = txt.replace("\n", "\n ") + return txt + + def select(self, dates, **kwargs): + self._raise_not_implemented() + + def _raise_not_implemented(self): + raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") + + def _trace_select(self, group_of_dates): + return f"{self.__class__.__name__}({group_of_dates})" + + +class ActionContext(Context): + def __init__(self, /, order_by, flatten_grid, remapping, use_grib_paramid): + super().__init__() + self.order_by = order_by + self.flatten_grid = flatten_grid + self.remapping = build_remapping(remapping) + self.use_grib_paramid = use_grib_paramid + + +def action_factory(config, context, action_path): + + from .concat import ConcatAction + from .data_sources import DataSourcesAction + from .function import FunctionAction + from .join import JoinAction + from .pipe import PipeAction + from .repeated_dates import RepeatedDatesAction + + # from .data_sources import DataSourcesAction + + assert isinstance(context, Context), (type, context) + if not isinstance(config, dict): + raise ValueError(f"Invalid input config {config}") + if len(config) != 1: + raise ValueError(f"Invalid input config. Expecting dict with only one key, got {list(config.keys())}") + + config = deepcopy(config) + key = list(config.keys())[0] + + if isinstance(config[key], list): + args, kwargs = config[key], {} + elif isinstance(config[key], dict): + args, kwargs = [], config[key] + else: + raise ValueError(f"Invalid input config {config[key]} ({type(config[key])}") + + cls = { + "data_sources": DataSourcesAction, + "concat": ConcatAction, + "join": JoinAction, + "pipe": PipeAction, + "function": FunctionAction, + "repeated_dates": RepeatedDatesAction, + }.get(key) + + if cls is None: + if not is_function(key, "sources"): + raise ValueError(f"Unknown action '{key}' in {config}") + cls = FunctionAction + args = [key] + args + + return cls(context, action_path + [key], *args, **kwargs) diff --git a/src/anemoi/datasets/create/input/concat.py b/src/anemoi/datasets/create/input/concat.py new file mode 100644 index 00000000..600dcb5b --- /dev/null +++ b/src/anemoi/datasets/create/input/concat.py @@ -0,0 +1,92 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +from copy import deepcopy +from functools import cached_property + +from anemoi.datasets.dates import DatesProvider + +from .action import Action +from .action import action_factory +from .empty import EmptyResult +from .misc import _tidy +from .misc import assert_fieldlist +from .result import Result +from .template import notify_result +from .trace import trace_datasource +from .trace import trace_select + +LOG = logging.getLogger(__name__) + + +class ConcatResult(Result): + def __init__(self, context, action_path, group_of_dates, results, **kwargs): + super().__init__(context, action_path, group_of_dates) + self.results = [r for r in results if not r.empty] + + @cached_property + @assert_fieldlist + @notify_result + @trace_datasource + def datasource(self): + ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource + for i in self.results: + ds += i.datasource + return _tidy(ds) + + @property + def variables(self): + """Check that all the results objects have the same variables.""" + variables = None + for f in self.results: + if f.empty: + continue + if variables is None: + variables = f.variables + assert variables == f.variables, (variables, f.variables) + assert variables is not None, self.results + return variables + + def __repr__(self): + content = "\n".join([str(i) for i in self.results]) + return super().__repr__(content) + + +class ConcatAction(Action): + def __init__(self, context, action_path, *configs): + super().__init__(context, action_path, *configs) + parts = [] + for i, cfg in enumerate(configs): + if "dates" not in cfg: + raise ValueError(f"Missing 'dates' in {cfg}") + cfg = deepcopy(cfg) + dates_cfg = cfg.pop("dates") + assert isinstance(dates_cfg, dict), dates_cfg + filtering_dates = DatesProvider.from_config(**dates_cfg) + action = action_factory(cfg, context, action_path + [str(i)]) + parts.append((filtering_dates, action)) + self.parts = parts + + def __repr__(self): + content = "\n".join([str(i) for i in self.parts]) + return super().__repr__(content) + + @trace_select + def select(self, group_of_dates): + from anemoi.datasets.dates.groups import GroupOfDates + + results = [] + for filtering_dates, action in self.parts: + newdates = GroupOfDates(sorted(set(group_of_dates) & set(filtering_dates)), group_of_dates.provider) + if newdates: + results.append(action.select(newdates)) + if not results: + return EmptyResult(self.context, self.action_path, group_of_dates) + + return ConcatResult(self.context, self.action_path, group_of_dates, results) diff --git a/src/anemoi/datasets/create/input/context.py b/src/anemoi/datasets/create/input/context.py new file mode 100644 index 00000000..81217006 --- /dev/null +++ b/src/anemoi/datasets/create/input/context.py @@ -0,0 +1,59 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +import textwrap + +from anemoi.utils.dates import as_datetime as as_datetime +from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta +from anemoi.utils.humanize import plural + +from anemoi.datasets.dates import DatesProvider as DatesProvider +from anemoi.datasets.fields import FieldArray as FieldArray +from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField + +from .trace import step +from .trace import trace + +LOG = logging.getLogger(__name__) + + +class Context: + def __init__(self): + # used_references is a set of reference paths that will be needed + self.used_references = set() + # results is a dictionary of reference path -> obj + self.results = {} + + def will_need_reference(self, key): + assert isinstance(key, (list, tuple)), key + key = tuple(key) + self.used_references.add(key) + + def notify_result(self, key, result): + trace( + "🎯", + step(key), + "notify result", + textwrap.shorten(repr(result).replace(",", ", "), width=40), + plural(len(result), "field"), + ) + assert isinstance(key, (list, tuple)), key + key = tuple(key) + if key in self.used_references: + if key in self.results: + raise ValueError(f"Duplicate result {key}") + self.results[key] = result + + def get_result(self, key): + assert isinstance(key, (list, tuple)), key + key = tuple(key) + if key in self.results: + return self.results[key] + all_keys = sorted(list(self.results.keys())) + raise ValueError(f"Cannot find result {key} in {all_keys}") diff --git a/src/anemoi/datasets/create/input/data_sources.py b/src/anemoi/datasets/create/input/data_sources.py new file mode 100644 index 00000000..e977eea2 --- /dev/null +++ b/src/anemoi/datasets/create/input/data_sources.py @@ -0,0 +1,71 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +from functools import cached_property + +from anemoi.utils.dates import as_datetime as as_datetime +from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta + +from anemoi.datasets.dates import DatesProvider as DatesProvider +from anemoi.datasets.fields import FieldArray as FieldArray +from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField + +from .action import Action +from .action import action_factory +from .misc import _tidy +from .result import Result + +LOG = logging.getLogger(__name__) + + +class DataSourcesAction(Action): + def __init__(self, context, action_path, sources, input): + super().__init__(context, ["data_sources"], *sources) + if isinstance(sources, dict): + configs = [(str(k), c) for k, c in sources.items()] + elif isinstance(sources, list): + configs = [(str(i), c) for i, c in enumerate(sources)] + else: + raise ValueError(f"Invalid data_sources, expecting list or dict, got {type(sources)}: {sources}") + + self.sources = [action_factory(config, context, ["data_sources"] + [a_path]) for a_path, config in configs] + self.input = action_factory(input, context, ["input"]) + + def select(self, group_of_dates): + sources_results = [a.select(group_of_dates) for a in self.sources] + return DataSourcesResult( + self.context, + self.action_path, + group_of_dates, + self.input.select(group_of_dates), + sources_results, + ) + + def __repr__(self): + content = "\n".join([str(i) for i in self.sources]) + return super().__repr__(content) + + +class DataSourcesResult(Result): + def __init__(self, context, action_path, dates, input_result, sources_results): + super().__init__(context, action_path, dates) + # result is the main input result + self.input_result = input_result + # sources_results is the list of the sources_results + self.sources_results = sources_results + + @cached_property + def datasource(self): + for i in self.sources_results: + # for each result trigger the datasource to be computed + # and saved in context + self.context.notify_result(i.action_path[:-1], i.datasource) + # then return the input result + # which can use the datasources of the included results + return _tidy(self.input_result.datasource) diff --git a/src/anemoi/datasets/create/input/empty.py b/src/anemoi/datasets/create/input/empty.py new file mode 100644 index 00000000..4d3e4058 --- /dev/null +++ b/src/anemoi/datasets/create/input/empty.py @@ -0,0 +1,42 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +from functools import cached_property + +from anemoi.utils.dates import as_datetime as as_datetime +from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta + +from anemoi.datasets.dates import DatesProvider as DatesProvider +from anemoi.datasets.fields import FieldArray as FieldArray +from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField + +from .misc import assert_fieldlist +from .result import Result +from .trace import trace_datasource + +LOG = logging.getLogger(__name__) + + +class EmptyResult(Result): + empty = True + + def __init__(self, context, action_path, dates): + super().__init__(context, action_path + ["empty"], dates) + + @cached_property + @assert_fieldlist + @trace_datasource + def datasource(self): + from earthkit.data import from_source + + return from_source("empty") + + @property + def variables(self): + return [] diff --git a/src/anemoi/datasets/create/input/filter.py b/src/anemoi/datasets/create/input/filter.py new file mode 100644 index 00000000..ec971737 --- /dev/null +++ b/src/anemoi/datasets/create/input/filter.py @@ -0,0 +1,76 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +from functools import cached_property + +from anemoi.utils.dates import as_datetime as as_datetime +from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta + +from anemoi.datasets.dates import DatesProvider as DatesProvider +from anemoi.datasets.fields import FieldArray as FieldArray +from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField + +from ..functions import import_function +from .function import FunctionContext +from .misc import _tidy +from .misc import assert_fieldlist +from .step import StepAction +from .step import StepResult +from .template import notify_result +from .trace import trace_datasource + +LOG = logging.getLogger(__name__) + + +class FilterStepResult(StepResult): + @property + @notify_result + @assert_fieldlist + @trace_datasource + def datasource(self): + ds = self.upstream_result.datasource + ds = ds.sel(**self.action.kwargs) + return _tidy(ds) + + +class FilterStepAction(StepAction): + result_class = FilterStepResult + + +class StepFunctionResult(StepResult): + @cached_property + @assert_fieldlist + @notify_result + @trace_datasource + def datasource(self): + try: + return _tidy( + self.action.function( + FunctionContext(self), + self.upstream_result.datasource, + *self.action.args[1:], + **self.action.kwargs, + ) + ) + + except Exception: + LOG.error(f"Error in {self.action.name}", exc_info=True) + raise + + def _trace_datasource(self, *args, **kwargs): + return f"{self.action.name}({self.group_of_dates})" + + +class FunctionStepAction(StepAction): + result_class = StepFunctionResult + + def __init__(self, context, action_path, previous_step, *args, **kwargs): + super().__init__(context, action_path, previous_step, *args, **kwargs) + self.name = args[0] + self.function = import_function(self.name, "filters") diff --git a/src/anemoi/datasets/create/input/function.py b/src/anemoi/datasets/create/input/function.py new file mode 100644 index 00000000..bc7ec28f --- /dev/null +++ b/src/anemoi/datasets/create/input/function.py @@ -0,0 +1,122 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +from functools import cached_property + +from anemoi.utils.dates import as_datetime as as_datetime +from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta + +from anemoi.datasets.dates import DatesProvider as DatesProvider +from anemoi.datasets.fields import FieldArray as FieldArray +from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField + +from ..functions import import_function +from .action import Action +from .misc import _tidy +from .misc import assert_fieldlist +from .result import Result +from .template import notify_result +from .template import resolve +from .template import substitute +from .trace import trace +from .trace import trace_datasource +from .trace import trace_select + +LOG = logging.getLogger(__name__) + + +class FunctionContext: + """A FunctionContext is passed to all functions, it will be used to pass information + to the functions from the other actions and filters and results. + """ + + def __init__(self, owner): + self.owner = owner + self.use_grib_paramid = owner.context.use_grib_paramid + + def trace(self, emoji, *args): + trace(emoji, *args) + + def info(self, *args, **kwargs): + LOG.info(*args, **kwargs) + + @property + def dates_provider(self): + return self.owner.group_of_dates.provider + + @property + def partial_ok(self): + return self.owner.group_of_dates.partial_ok + + +class FunctionAction(Action): + def __init__(self, context, action_path, _name, **kwargs): + super().__init__(context, action_path, **kwargs) + self.name = _name + + @trace_select + def select(self, group_of_dates): + return FunctionResult(self.context, self.action_path, group_of_dates, action=self) + + @property + def function(self): + # name, delta = parse_function_name(self.name) + return import_function(self.name, "sources") + + def __repr__(self): + content = "" + content += ",".join([self._short_str(a) for a in self.args]) + content += " ".join([self._short_str(f"{k}={v}") for k, v in self.kwargs.items()]) + content = self._short_str(content) + return super().__repr__(_inline_=content, _indent_=" ") + + def _trace_select(self, group_of_dates): + return f"{self.name}({group_of_dates})" + + +class FunctionResult(Result): + def __init__(self, context, action_path, group_of_dates, action): + super().__init__(context, action_path, group_of_dates) + assert isinstance(action, Action), type(action) + self.action = action + + self.args, self.kwargs = substitute(context, (self.action.args, self.action.kwargs)) + + def _trace_datasource(self, *args, **kwargs): + return f"{self.action.name}({self.group_of_dates})" + + @cached_property + @assert_fieldlist + @notify_result + @trace_datasource + def datasource(self): + args, kwargs = resolve(self.context, (self.args, self.kwargs)) + + try: + return _tidy( + self.action.function( + FunctionContext(self), + list(self.group_of_dates), # Will provide a list of datetime objects + *args, + **kwargs, + ) + ) + except Exception: + LOG.error(f"Error in {self.action.function.__name__}", exc_info=True) + raise + + def __repr__(self): + try: + return f"{self.action.name}({self.group_of_dates})" + except Exception: + return f"{self.__class__.__name__}(unitialised)" + + @property + def function(self): + raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") diff --git a/src/anemoi/datasets/create/input/join.py b/src/anemoi/datasets/create/input/join.py new file mode 100644 index 00000000..9a1c1b7d --- /dev/null +++ b/src/anemoi/datasets/create/input/join.py @@ -0,0 +1,57 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +from functools import cached_property + +from .action import Action +from .action import action_factory +from .empty import EmptyResult +from .misc import _tidy +from .misc import assert_fieldlist +from .result import Result +from .template import notify_result +from .trace import trace_datasource +from .trace import trace_select + +LOG = logging.getLogger(__name__) + + +class JoinResult(Result): + def __init__(self, context, action_path, group_of_dates, results, **kwargs): + super().__init__(context, action_path, group_of_dates) + self.results = [r for r in results if not r.empty] + + @cached_property + @assert_fieldlist + @notify_result + @trace_datasource + def datasource(self): + ds = EmptyResult(self.context, self.action_path, self.group_of_dates).datasource + for i in self.results: + ds += i.datasource + return _tidy(ds) + + def __repr__(self): + content = "\n".join([str(i) for i in self.results]) + return super().__repr__(content) + + +class JoinAction(Action): + def __init__(self, context, action_path, *configs): + super().__init__(context, action_path, *configs) + self.actions = [action_factory(c, context, action_path + [str(i)]) for i, c in enumerate(configs)] + + def __repr__(self): + content = "\n".join([str(i) for i in self.actions]) + return super().__repr__(content) + + @trace_select + def select(self, group_of_dates): + results = [a.select(group_of_dates) for a in self.actions] + return JoinResult(self.context, self.action_path, group_of_dates, results) diff --git a/src/anemoi/datasets/create/input/misc.py b/src/anemoi/datasets/create/input/misc.py new file mode 100644 index 00000000..4f5bbc20 --- /dev/null +++ b/src/anemoi/datasets/create/input/misc.py @@ -0,0 +1,85 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +from functools import wraps + +from anemoi.utils.dates import as_datetime as as_datetime +from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta +from earthkit.data.core.fieldlist import MultiFieldList +from earthkit.data.indexing.fieldlist import FieldList + +from anemoi.datasets.dates import DatesProvider as DatesProvider +from anemoi.datasets.fields import FieldArray as FieldArray +from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField + +from ..functions import import_function + +LOG = logging.getLogger(__name__) + + +def parse_function_name(name): + + if name.endswith("h") and name[:-1].isdigit(): + + if "-" in name: + name, delta = name.split("-") + sign = -1 + + elif "+" in name: + name, delta = name.split("+") + sign = 1 + + else: + return name, None + + assert delta[-1] == "h", (name, delta) + delta = sign * int(delta[:-1]) + return name, delta + + return name, None + + +def is_function(name, kind): + name, _ = parse_function_name(name) + try: + import_function(name, kind) + return True + except ImportError as e: + print(e) + return False + + +def assert_fieldlist(method): + @wraps(method) + def wrapper(self, *args, **kwargs): + result = method(self, *args, **kwargs) + assert isinstance(result, FieldList), type(result) + return result + + return wrapper + + +def assert_is_fieldlist(obj): + assert isinstance(obj, FieldList), type(obj) + + +def _flatten(ds): + if isinstance(ds, MultiFieldList): + return [_tidy(f) for s in ds._indexes for f in _flatten(s)] + return [ds] + + +def _tidy(ds, indent=0): + if isinstance(ds, MultiFieldList): + + sources = [s for s in _flatten(ds) if len(s) > 0] + if len(sources) == 1: + return sources[0] + return MultiFieldList(sources) + return ds diff --git a/src/anemoi/datasets/create/input/pipe.py b/src/anemoi/datasets/create/input/pipe.py new file mode 100644 index 00000000..7cf34cd6 --- /dev/null +++ b/src/anemoi/datasets/create/input/pipe.py @@ -0,0 +1,33 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging + +from .action import Action +from .action import action_factory +from .step import step_factory +from .trace import trace_select + +LOG = logging.getLogger(__name__) + + +class PipeAction(Action): + def __init__(self, context, action_path, *configs): + super().__init__(context, action_path, *configs) + assert len(configs) > 1, configs + current = action_factory(configs[0], context, action_path + ["0"]) + for i, c in enumerate(configs[1:]): + current = step_factory(c, context, action_path + [str(i + 1)], previous_step=current) + self.last_step = current + + @trace_select + def select(self, group_of_dates): + return self.last_step.select(group_of_dates) + + def __repr__(self): + return super().__repr__(self.last_step) diff --git a/src/anemoi/datasets/create/input/repeated_dates.py b/src/anemoi/datasets/create/input/repeated_dates.py new file mode 100644 index 00000000..e0e9e9f8 --- /dev/null +++ b/src/anemoi/datasets/create/input/repeated_dates.py @@ -0,0 +1,217 @@ +# (C) Copyright 2023 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from collections import defaultdict + +import numpy as np +from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_timedelta + +from anemoi.datasets.fields import FieldArray +from anemoi.datasets.fields import NewValidDateTimeField + +from .action import Action +from .action import action_factory +from .join import JoinResult +from .result import Result +from .trace import trace_select + +LOG = logging.getLogger(__name__) + + +class DateMapper: + + @staticmethod + def from_mode(mode, source, config): + + MODES = dict( + closest=DateMapperClosest, + climatology=DateMapperClimatology, + constant=DateMapperConstant, + ) + + if mode not in MODES: + raise ValueError(f"Invalid mode for DateMapper: {mode}") + + return MODES[mode](source, **config) + + +class DateMapperClosest(DateMapper): + def __init__(self, source, frequency="1h", maximum="30d", skip_all_nans=False): + self.source = source + self.maximum = frequency_to_timedelta(maximum) + self.frequency = frequency_to_timedelta(frequency) + self.skip_all_nans = skip_all_nans + self.tried = set() + self.found = set() + + def transform(self, group_of_dates): + from anemoi.datasets.dates.groups import GroupOfDates + + asked_dates = list(group_of_dates) + if not asked_dates: + return [] + + to_try = set() + for date in asked_dates: + start = date + while start >= date - self.maximum: + to_try.add(start) + start -= self.frequency + + end = date + while end <= date + self.maximum: + to_try.add(end) + end += self.frequency + + to_try = sorted(to_try - self.tried) + + if to_try: + result = self.source.select( + GroupOfDates( + sorted(to_try), + group_of_dates.provider, + partial_ok=True, + ) + ) + + for f in result.datasource: + # We could keep the fields in a dictionary, but we don't want to keep the fields in memory + date = as_datetime(f.metadata("valid_datetime")) + + if self.skip_all_nans: + if np.isnan(f.to_numpy()).all(): + LOG.warning(f"Skipping {date} because all values are NaN") + continue + + self.found.add(date) + + self.tried.update(to_try) + + new_dates = defaultdict(list) + + for date in asked_dates: + best = None + for found_date in sorted(self.found): + delta = abs(date - found_date) + # With < we prefer the first date + # With <= we prefer the last date + if best is None or delta <= best[0]: + best = delta, found_date + new_dates[best[1]].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperClimatology(DateMapper): + def __init__(self, source, year, day): + self.year = year + self.day = day + + def transform(self, group_of_dates): + from anemoi.datasets.dates.groups import GroupOfDates + + dates = list(group_of_dates) + if not dates: + return [] + + new_dates = defaultdict(list) + for date in dates: + new_date = date.replace(year=self.year, day=self.day) + new_dates[new_date].append(date) + + for date, dates in new_dates.items(): + yield ( + GroupOfDates([date], group_of_dates.provider), + GroupOfDates(dates, group_of_dates.provider), + ) + + +class DateMapperConstant(DateMapper): + def __init__(self, source, date=None): + self.source = source + self.date = date + + def transform(self, group_of_dates): + from anemoi.datasets.dates.groups import GroupOfDates + + if self.date is None: + return [ + ( + GroupOfDates([], group_of_dates.provider), + group_of_dates, + ) + ] + + return [ + ( + GroupOfDates([self.date], group_of_dates.provider), + group_of_dates, + ) + ] + + +class DateMapperResult(Result): + def __init__( + self, + context, + action_path, + group_of_dates, + source_result, + mapper, + original_group_of_dates, + ): + super().__init__(context, action_path, group_of_dates) + + self.source_results = source_result + self.mapper = mapper + self.original_group_of_dates = original_group_of_dates + + @property + def datasource(self): + result = [] + + for field in self.source_results.datasource: + for date in self.original_group_of_dates: + result.append(NewValidDateTimeField(field, date)) + + return FieldArray(result) + + +class RepeatedDatesAction(Action): + def __init__(self, context, action_path, source, mode, **kwargs): + super().__init__(context, action_path, source, mode, **kwargs) + + self.source = action_factory(source, context, action_path + ["source"]) + self.mapper = DateMapper.from_mode(mode, self.source, kwargs) + + @trace_select + def select(self, group_of_dates): + results = [] + for one_date_group, many_dates_group in self.mapper.transform(group_of_dates): + results.append( + DateMapperResult( + self.context, + self.action_path, + one_date_group, + self.source.select(one_date_group), + self.mapper, + many_dates_group, + ) + ) + + return JoinResult(self.context, self.action_path, group_of_dates, results) + + def __repr__(self): + return f"MultiDateMatchAction({self.source}, {self.mapper})" diff --git a/src/anemoi/datasets/create/input/result.py b/src/anemoi/datasets/create/input/result.py new file mode 100644 index 00000000..ff3ca17f --- /dev/null +++ b/src/anemoi/datasets/create/input/result.py @@ -0,0 +1,413 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import itertools +import logging +import math +import time +from collections import defaultdict +from functools import cached_property + +import numpy as np +from anemoi.utils.dates import as_datetime as as_datetime +from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta +from anemoi.utils.humanize import seconds_to_human +from anemoi.utils.humanize import shorten_list +from earthkit.data.core.order import build_remapping + +from anemoi.datasets.dates import DatesProvider as DatesProvider +from anemoi.datasets.fields import FieldArray as FieldArray +from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField + +from .trace import trace +from .trace import trace_datasource + +LOG = logging.getLogger(__name__) + + +def _data_request(data): + date = None + params_levels = defaultdict(set) + params_steps = defaultdict(set) + + area = grid = None + + for field in data: + try: + if date is None: + date = field.metadata("valid_datetime") + + if field.metadata("valid_datetime") != date: + continue + + as_mars = field.metadata(namespace="mars") + if not as_mars: + continue + step = as_mars.get("step") + levtype = as_mars.get("levtype", "sfc") + param = as_mars["param"] + levelist = as_mars.get("levelist", None) + area = field.mars_area + grid = field.mars_grid + + if levelist is None: + params_levels[levtype].add(param) + else: + params_levels[levtype].add((param, levelist)) + + if step: + params_steps[levtype].add((param, step)) + except Exception: + LOG.error(f"Error in retrieving metadata (cannot build data request info) for {field}", exc_info=True) + + def sort(old_dic): + new_dic = {} + for k, v in old_dic.items(): + new_dic[k] = sorted(list(v)) + return new_dic + + params_steps = sort(params_steps) + params_levels = sort(params_levels) + + return dict(param_level=params_levels, param_step=params_steps, area=area, grid=grid) + + +class Result: + empty = False + _coords_already_built = False + + def __init__(self, context, action_path, dates): + from anemoi.datasets.dates.groups import GroupOfDates + + from .action import ActionContext + + assert isinstance(dates, GroupOfDates), dates + + assert isinstance(context, ActionContext), type(context) + assert isinstance(action_path, list), action_path + + self.context = context + self.group_of_dates = dates + self.action_path = action_path + + @property + @trace_datasource + def datasource(self): + self._raise_not_implemented() + + @property + def data_request(self): + """Returns a dictionary with the parameters needed to retrieve the data.""" + return _data_request(self.datasource) + + def get_cube(self): + trace("🧊", f"getting cube from {self.__class__.__name__}") + ds = self.datasource + + remapping = self.context.remapping + order_by = self.context.order_by + flatten_grid = self.context.flatten_grid + start = time.time() + LOG.debug("Sorting dataset %s %s", dict(order_by), remapping) + assert order_by, order_by + + patches = {"number": {None: 0}} + + try: + cube = ds.cube( + order_by, + remapping=remapping, + flatten_values=flatten_grid, + patches=patches, + ) + cube = cube.squeeze() + LOG.debug(f"Sorting done in {seconds_to_human(time.time()-start)}.") + except ValueError: + self.explain(ds, order_by, remapping=remapping, patches=patches) + # raise ValueError(f"Error in {self}") + exit(1) + + if LOG.isEnabledFor(logging.DEBUG): + LOG.debug("Cube shape: %s", cube) + for k, v in cube.user_coords.items(): + LOG.debug(" %s %s", k, shorten_list(v, max_length=10)) + + return cube + + def explain(self, ds, *args, remapping, patches): + + METADATA = ( + "date", + "time", + "step", + "hdate", + "valid_datetime", + "levtype", + "levelist", + "number", + "level", + "shortName", + "paramId", + "variable", + ) + + # We redo the logic here + print() + print("❌" * 40) + print() + if len(args) == 1 and isinstance(args[0], (list, tuple)): + args = args[0] + + # print("Executing", self.action_path) + # print("Dates:", compress_dates(self.dates)) + + names = [] + for a in args: + if isinstance(a, str): + names.append(a) + elif isinstance(a, dict): + names += list(a.keys()) + + print(f"Building a {len(names)}D hypercube using", names) + ds = ds.order_by(*args, remapping=remapping, patches=patches) + user_coords = ds.unique_values(*names, remapping=remapping, patches=patches, progress_bar=False) + + print() + print("Number of unique values found for each coordinate:") + for k, v in user_coords.items(): + print(f" {k:20}:", len(v), shorten_list(v, max_length=10)) + print() + user_shape = tuple(len(v) for k, v in user_coords.items()) + print("Shape of the hypercube :", user_shape) + print( + "Number of expected fields :", math.prod(user_shape), "=", " x ".join([str(i) for i in user_shape]) + ) + print("Number of fields in the dataset :", len(ds)) + print("Difference :", abs(len(ds) - math.prod(user_shape))) + print() + + remapping = build_remapping(remapping, patches) + expected = set(itertools.product(*user_coords.values())) + extra = set() + + if math.prod(user_shape) > len(ds): + print(f"This means that all the fields in the datasets do not exists for all combinations of {names}.") + + for f in ds: + metadata = remapping(f.metadata) + key = tuple(metadata(n, default=None) for n in names) + if key in expected: + expected.remove(key) + else: + extra.add(key) + + print("Missing fields:") + print() + for i, f in enumerate(sorted(expected)): + print(" ", f) + if i >= 9 and len(expected) > 10: + print("...", len(expected) - i - 1, "more") + break + + print("Extra fields:") + print() + for i, f in enumerate(sorted(extra)): + print(" ", f) + if i >= 9 and len(extra) > 10: + print("...", len(extra) - i - 1, "more") + break + + print() + print("Missing values:") + per_name = defaultdict(set) + for e in expected: + for n, v in zip(names, e): + per_name[n].add(v) + + for n, v in per_name.items(): + print(" ", n, len(v), shorten_list(sorted(v), max_length=10)) + print() + + print("Extra values:") + per_name = defaultdict(set) + for e in extra: + for n, v in zip(names, e): + per_name[n].add(v) + + for n, v in per_name.items(): + print(" ", n, len(v), shorten_list(sorted(v), max_length=10)) + print() + + print("To solve this issue, you can:") + print( + " - Provide a better selection, like 'step: 0' or 'level: 1000' to " + "reduce the number of selected fields." + ) + print( + " - Split the 'input' part in smaller sections using 'join', " + "making sure that each section represent a full hypercube." + ) + + else: + print(f"More fields in dataset that expected for {names}. " "This means that some fields are duplicated.") + duplicated = defaultdict(list) + for f in ds: + # print(f.metadata(namespace="default")) + metadata = remapping(f.metadata) + key = tuple(metadata(n, default=None) for n in names) + duplicated[key].append(f) + + print("Duplicated fields:") + print() + duplicated = {k: v for k, v in duplicated.items() if len(v) > 1} + for i, (k, v) in enumerate(sorted(duplicated.items())): + print(" ", k) + for f in v: + x = {k: f.metadata(k, default=None) for k in METADATA if f.metadata(k, default=None) is not None} + print(" ", f, x) + if i >= 9 and len(duplicated) > 10: + print("...", len(duplicated) - i - 1, "more") + break + + print() + print("To solve this issue, you can:") + print(" - Provide a better selection, like 'step: 0' or 'level: 1000'") + print(" - Change the way 'param' is computed using 'variable_naming' " "in the 'build' section.") + + print() + print("❌" * 40) + print() + exit(1) + + def __repr__(self, *args, _indent_="\n", **kwargs): + more = ",".join([str(a)[:5000] for a in args]) + more += ",".join([f"{k}={v}"[:5000] for k, v in kwargs.items()]) + + dates = " no-dates" + if self.group_of_dates is not None: + dates = f" {len(self.group_of_dates)} dates" + dates += " (" + dates += "/".join(d.strftime("%Y-%m-%d:%H") for d in self.group_of_dates) + if len(dates) > 100: + dates = dates[:100] + "..." + dates += ")" + + more = more[:5000] + txt = f"{self.__class__.__name__}:{dates}{_indent_}{more}" + if _indent_: + txt = txt.replace("\n", "\n ") + return txt + + def _raise_not_implemented(self): + raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") + + def _trace_datasource(self, *args, **kwargs): + return f"{self.__class__.__name__}({self.group_of_dates})" + + def build_coords(self): + if self._coords_already_built: + return + from_data = self.get_cube().user_coords + from_config = self.context.order_by + + keys_from_config = list(from_config.keys()) + keys_from_data = list(from_data.keys()) + assert keys_from_data == keys_from_config, f"Critical error: {keys_from_data=} != {keys_from_config=}. {self=}" + + variables_key = list(from_config.keys())[1] + ensembles_key = list(from_config.keys())[2] + + if isinstance(from_config[variables_key], (list, tuple)): + assert all([v == w for v, w in zip(from_data[variables_key], from_config[variables_key])]), ( + from_data[variables_key], + from_config[variables_key], + ) + + self._variables = from_data[variables_key] # "param_level" + self._ensembles = from_data[ensembles_key] # "number" + + first_field = self.datasource[0] + grid_points = first_field.grid_points() + + lats, lons = grid_points + + assert len(lats) == len(lons), (len(lats), len(lons), first_field) + assert len(lats) == math.prod(first_field.shape), (len(lats), first_field.shape, first_field) + + north = np.amax(lats) + south = np.amin(lats) + east = np.amax(lons) + west = np.amin(lons) + + assert -90 <= south <= north <= 90, (south, north, first_field) + assert (-180 <= west <= east <= 180) or (0 <= west <= east <= 360), ( + west, + east, + first_field, + ) + + grid_values = list(range(len(grid_points[0]))) + + self._grid_points = grid_points + self._resolution = first_field.resolution + self._grid_values = grid_values + self._field_shape = first_field.shape + self._proj_string = first_field.proj_string if hasattr(first_field, "proj_string") else None + + @property + def variables(self): + self.build_coords() + return self._variables + + @property + def ensembles(self): + self.build_coords() + return self._ensembles + + @property + def resolution(self): + self.build_coords() + return self._resolution + + @property + def grid_values(self): + self.build_coords() + return self._grid_values + + @property + def grid_points(self): + self.build_coords() + return self._grid_points + + @property + def field_shape(self): + self.build_coords() + return self._field_shape + + @property + def proj_string(self): + self.build_coords() + return self._proj_string + + @cached_property + def shape(self): + return [ + len(self.group_of_dates), + len(self.variables), + len(self.ensembles), + len(self.grid_values), + ] + + @cached_property + def coords(self): + return { + "dates": list(self.group_of_dates), + "variables": self.variables, + "ensembles": self.ensembles, + "values": self.grid_values, + } diff --git a/src/anemoi/datasets/create/input/step.py b/src/anemoi/datasets/create/input/step.py new file mode 100644 index 00000000..3eb2917c --- /dev/null +++ b/src/anemoi/datasets/create/input/step.py @@ -0,0 +1,99 @@ +# (C) Copyright 2024 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# +import logging +from copy import deepcopy + +from anemoi.utils.dates import as_datetime as as_datetime +from anemoi.utils.dates import frequency_to_timedelta as frequency_to_timedelta + +from anemoi.datasets.dates import DatesProvider as DatesProvider +from anemoi.datasets.fields import FieldArray as FieldArray +from anemoi.datasets.fields import NewValidDateTimeField as NewValidDateTimeField + +from .action import Action +from .context import Context +from .misc import is_function +from .result import Result +from .template import notify_result +from .trace import trace_datasource +from .trace import trace_select + +LOG = logging.getLogger(__name__) + + +class StepResult(Result): + def __init__(self, context, action_path, group_of_dates, action, upstream_result): + super().__init__(context, action_path, group_of_dates) + assert isinstance(upstream_result, Result), type(upstream_result) + self.upstream_result = upstream_result + self.action = action + + @property + @notify_result + @trace_datasource + def datasource(self): + raise NotImplementedError(f"Not implemented in {self.__class__.__name__}") + + +class StepAction(Action): + result_class = None + + def __init__(self, context, action_path, previous_step, *args, **kwargs): + super().__init__(context, action_path, *args, **kwargs) + self.previous_step = previous_step + + @trace_select + def select(self, group_of_dates): + return self.result_class( + self.context, + self.action_path, + group_of_dates, + self, + self.previous_step.select(group_of_dates), + ) + + def __repr__(self): + return super().__repr__(self.previous_step, _inline_=str(self.kwargs)) + + +def step_factory(config, context, action_path, previous_step): + + from .filter import FilterStepAction + from .filter import FunctionStepAction + + assert isinstance(context, Context), (type, context) + if not isinstance(config, dict): + raise ValueError(f"Invalid input config {config}") + + config = deepcopy(config) + assert len(config) == 1, config + + key = list(config.keys())[0] + cls = dict( + filter=FilterStepAction, + # rename=RenameAction, + # remapping=RemappingAction, + ).get(key) + + if isinstance(config[key], list): + args, kwargs = config[key], {} + + if isinstance(config[key], dict): + args, kwargs = [], config[key] + + if isinstance(config[key], str): + args, kwargs = [config[key]], {} + + if cls is None: + if not is_function(key, "filters"): + raise ValueError(f"Unknown step {key}") + cls = FunctionStepAction + args = [key] + args + + return cls(context, action_path, previous_step, *args, **kwargs) diff --git a/src/anemoi/datasets/create/template.py b/src/anemoi/datasets/create/input/template.py similarity index 58% rename from src/anemoi/datasets/create/template.py rename to src/anemoi/datasets/create/input/template.py index ec1c1a41..d7a50d1d 100644 --- a/src/anemoi/datasets/create/template.py +++ b/src/anemoi/datasets/create/input/template.py @@ -9,14 +9,8 @@ import logging import re -import textwrap from functools import wraps -from anemoi.utils.humanize import plural - -from .trace import step -from .trace import trace - LOG = logging.getLogger(__name__) @@ -30,42 +24,6 @@ def wrapper(self, *args, **kwargs): return wrapper -class Context: - def __init__(self): - # used_references is a set of reference paths that will be needed - self.used_references = set() - # results is a dictionary of reference path -> obj - self.results = {} - - def will_need_reference(self, key): - assert isinstance(key, (list, tuple)), key - key = tuple(key) - self.used_references.add(key) - - def notify_result(self, key, result): - trace( - "🎯", - step(key), - "notify result", - textwrap.shorten(repr(result).replace(",", ", "), width=40), - plural(len(result), "field"), - ) - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.used_references: - if key in self.results: - raise ValueError(f"Duplicate result {key}") - self.results[key] = result - - def get_result(self, key): - assert isinstance(key, (list, tuple)), key - key = tuple(key) - if key in self.results: - return self.results[key] - all_keys = sorted(list(self.results.keys())) - raise ValueError(f"Cannot find result {key} in {all_keys}") - - class Substitution: pass diff --git a/src/anemoi/datasets/create/trace.py b/src/anemoi/datasets/create/input/trace.py similarity index 100% rename from src/anemoi/datasets/create/trace.py rename to src/anemoi/datasets/create/input/trace.py diff --git a/src/anemoi/datasets/create/statistics/__init__.py b/src/anemoi/datasets/create/statistics/__init__.py index d788c203..e5fcf460 100644 --- a/src/anemoi/datasets/create/statistics/__init__.py +++ b/src/anemoi/datasets/create/statistics/__init__.py @@ -155,7 +155,7 @@ def compute_statistics(array, check_variables_names=None, allow_nans=False): check_data_values(values[j, :], name=name, allow_nans=allow_nans) if np.isnan(values[j, :]).all(): # LOG.warning(f"All NaN values for {name} ({j}) for date {i}") - raise ValueError(f"All NaN values for {name} ({j}) for date {i}") + LOG.warning(f"All NaN values for {name} ({j}) for date {i}") # Ignore NaN values minimum[i] = np.nanmin(values, axis=1) diff --git a/src/anemoi/datasets/dates/__init__.py b/src/anemoi/datasets/dates/__init__.py index fe1054ee..84d49636 100644 --- a/src/anemoi/datasets/dates/__init__.py +++ b/src/anemoi/datasets/dates/__init__.py @@ -12,6 +12,7 @@ # from anemoi.utils.dates import as_datetime from anemoi.utils.dates import DateTimes from anemoi.utils.dates import as_datetime +from anemoi.utils.dates import frequency_to_string from anemoi.utils.dates import frequency_to_timedelta from anemoi.utils.hindcasts import HindcastDatesTimes from anemoi.utils.humanize import print_dates diff --git a/src/anemoi/datasets/dates/groups.py b/src/anemoi/datasets/dates/groups.py index 624f308e..934a823f 100644 --- a/src/anemoi/datasets/dates/groups.py +++ b/src/anemoi/datasets/dates/groups.py @@ -9,18 +9,26 @@ import itertools from functools import cached_property -from anemoi.datasets.create.input import shorten from anemoi.datasets.dates import DatesProvider from anemoi.datasets.dates import as_datetime +def _shorten(dates): + if isinstance(dates, (list, tuple)): + dates = [d.isoformat() for d in dates] + if len(dates) > 5: + return f"{dates[0]}...{dates[-1]}" + return dates + + class GroupOfDates: - def __init__(self, dates, provider): + def __init__(self, dates, provider, partial_ok=False): assert isinstance(provider, DatesProvider), type(provider) assert isinstance(dates, list) self.dates = dates self.provider = provider + self.partial_ok = partial_ok def __len__(self): return len(self.dates) @@ -29,7 +37,7 @@ def __iter__(self): return iter(self.dates) def __repr__(self) -> str: - return f"GroupOfDates(dates={shorten(self.dates)})" + return f"GroupOfDates(dates={_shorten(self.dates)})" def __eq__(self, other: object) -> bool: return isinstance(other, GroupOfDates) and self.dates == other.dates @@ -93,7 +101,7 @@ def _len(self): return n def __repr__(self): - return f"{self.__class__.__name__}(dates={len(self)},{shorten(self._dates)})" + return f"{self.__class__.__name__}(dates={len(self)},{_shorten(self._dates)})" def describe(self): return self.dates.summary diff --git a/src/anemoi/datasets/fields.py b/src/anemoi/datasets/fields.py new file mode 100644 index 00000000..b341d401 --- /dev/null +++ b/src/anemoi/datasets/fields.py @@ -0,0 +1,66 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +from earthkit.data.indexing.fieldlist import FieldArray + + +def list_to_fieldlist(fields): + return FieldArray(fields) + + +def empty_fieldlist(): + return FieldArray([]) + + +class WrappedField: + def __init__(self, field): + self._field = field + + def __getattr__(self, name): + return getattr(self._field, name) + + def __repr__(self) -> str: + return repr(self._field) + + +class NewDataField(WrappedField): + def __init__(self, field, data): + super().__init__(field) + self._data = data + self.shape = data.shape + + def to_numpy(self, flatten=False, dtype=None, index=None): + data = self._data + if dtype is not None: + data = data.astype(dtype) + if flatten: + data = data.flatten() + if index is not None: + data = data[index] + return data + + +class NewMetadataField(WrappedField): + def __init__(self, field, **kwargs): + super().__init__(field) + self._metadata = kwargs + + def metadata(self, *args, **kwargs): + if len(args) == 1 and args[0] in self._metadata: + return self._metadata[args[0]] + return self._field.metadata(*args, **kwargs) + + +class NewValidDateTimeField(NewMetadataField): + def __init__(self, field, valid_datetime): + date = valid_datetime.date().strftime("%Y%m%d") + time = valid_datetime.time().strftime("%H%M") + + self.valid_datetime = valid_datetime + + super().__init__(field, date=date, time=time, step=0, valid_datetime=valid_datetime.isoformat())