Skip to content

Commit

Permalink
better support for hindcasts data
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Aug 8, 2024
1 parent cfe6384 commit f2c6739
Show file tree
Hide file tree
Showing 6 changed files with 152 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/anemoi/datasets/compute/recentre.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def recentre(

keys = ["param", "level", "valid_datetime", "date", "time", "step", "number"]

number_list = members.unique_values("number")["number"]
number_list = members.unique_values("number", progress_bar=False)["number"]
n_numbers = len(number_list)

assert None not in number_list
Expand Down
39 changes: 27 additions & 12 deletions src/anemoi/datasets/create/functions/sources/hindcasts.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,16 @@
# nor does it submit to any jurisdiction.
#
import datetime
import logging

from earthkit.data.core.fieldlist import MultiFieldList

from anemoi.datasets.create.functions.sources.mars import mars

LOGGER = logging.getLogger(__name__)
DEBUG = True


def _member(field):
# Bug in eccodes has number=0 randomly
number = field.metadata("number")
if number is None:
number = 0
return number


def _to_list(x):
if isinstance(x, (list, tuple)):
return x
Expand Down Expand Up @@ -63,9 +59,19 @@ def compute_hindcast(self, date):
def use_reference_year(reference_year, request):
request = request.copy()
hdate = request.pop("date")
date = datetime.datetime(reference_year, hdate.month, hdate.day)

if hdate.year >= reference_year:
return None, False

try:
date = datetime.datetime(reference_year, hdate.month, hdate.day)
except ValueError:
if hdate.month == 2 and hdate.day == 29:
return None, False
raise

request.update(date=date.strftime("%Y-%m-%d"), hdate=hdate.strftime("%Y-%m-%d"))
return request
return request, True


def hindcasts(context, dates, **request):
Expand All @@ -89,9 +95,18 @@ def hindcasts(context, dates, **request):
requests = []
for d in dates:
req = c.compute_hindcast(d)
req = use_reference_year(reference_year, req)
req, ok = use_reference_year(reference_year, req)
if ok:
requests.append(req)

# print("HINDCASTS requests", reference_year, base_times, available_steps)
# print("HINDCASTS dates", compress_dates(dates))

if len(requests) == 0:
# print("HINDCASTS no requests")
return MultiFieldList([])

requests.append(req)
# print("HINDCASTS requests", requests)

return mars(
context,
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/datasets/create/functions/sources/tendencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def tendencies(dates, time_increment, **kwargs):

ds = mars(dates=all_dates, **kwargs)

dates_in_data = ds.unique_values("valid_datetime")["valid_datetime"]
dates_in_data = ds.unique_values("valid_datetime", progress_bar=False)["valid_datetime"]
for d in all_dates:
assert d.isoformat() in dates_in_data, d

Expand Down
42 changes: 39 additions & 3 deletions src/anemoi/datasets/create/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,9 @@ def explain(self, ds, *args, remapping, patches):
if len(args) == 1 and isinstance(args[0], (list, tuple)):
args = args[0]

# print("Executing", self.action_path)
# print("Dates:", compress_dates(self.dates))

names = []
for a in args:
if isinstance(a, str):
Expand All @@ -287,12 +290,12 @@ def explain(self, ds, *args, remapping, patches):
print(f"Building a {len(names)}D hypercube using", names)

ds = ds.order_by(*args, remapping=remapping, patches=patches)
user_coords = ds.unique_values(*names, remapping=remapping, patches=patches)
user_coords = ds.unique_values(*names, remapping=remapping, patches=patches, progress_bar=False)

print()
print("Number of unique values found for each coordinate:")
for k, v in user_coords.items():
print(f" {k:20}:", len(v))
print(f" {k:20}:", len(v), shorten_list(v, max_length=10))
print()
user_shape = tuple(len(v) for k, v in user_coords.items())
print("Shape of the hypercube :", user_shape)
Expand All @@ -305,13 +308,18 @@ def explain(self, ds, *args, remapping, patches):

remapping = build_remapping(remapping, patches)
expected = set(itertools.product(*user_coords.values()))
extra = set()

if math.prod(user_shape) > len(ds):
print(f"This means that all the fields in the datasets do not exists for all combinations of {names}.")

for f in ds:
metadata = remapping(f.metadata)
expected.remove(tuple(metadata(n) for n in names))
key = tuple(metadata(n, default=None) for n in names)
if key in expected:
expected.remove(key)
else:
extra.add(key)

print("Missing fields:")
print()
Expand All @@ -321,7 +329,35 @@ def explain(self, ds, *args, remapping, patches):
print("...", len(expected) - i - 1, "more")
break

print("Extra fields:")
print()
for i, f in enumerate(sorted(extra)):
print(" ", f)
if i >= 9 and len(extra) > 10:
print("...", len(extra) - i - 1, "more")
break

print()
print("Missing values:")
per_name = defaultdict(set)
for e in expected:
for n, v in zip(names, e):
per_name[n].add(v)

for n, v in per_name.items():
print(" ", n, len(v), shorten_list(sorted(v), max_length=10))
print()

print("Extra values:")
per_name = defaultdict(set)
for e in extra:
for n, v in zip(names, e):
per_name[n].add(v)

for n, v in per_name.items():
print(" ", n, len(v), shorten_list(sorted(v), max_length=10))
print()

print("To solve this issue, you can:")
print(
" - Provide a better selection, like 'step: 0' or 'level: 1000' to "
Expand Down
22 changes: 17 additions & 5 deletions src/anemoi/datasets/create/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
import tqdm
import zarr
from anemoi.utils.config import DotDict
from anemoi.utils.dates import as_datetime
from anemoi.utils.humanize import seconds_to_human

from anemoi.datasets import MissingDateError
from anemoi.datasets import open_dataset
from anemoi.datasets.create.persistent import build_storage
from anemoi.datasets.data.misc import as_first_date
from anemoi.datasets.data.misc import as_last_date
from anemoi.datasets.dates import compress_dates
from anemoi.datasets.dates.groups import Groups

from .check import DatasetName
Expand Down Expand Up @@ -442,14 +444,24 @@ def load_result(self, result):
dates = result.dates

cube = result.get_cube()
assert cube.extended_user_shape[0] == len(dates), (
cube.extended_user_shape[0],
len(dates),
)

shape = cube.extended_user_shape
dates_in_data = cube.user_coords["valid_datetime"]

if cube.extended_user_shape[0] != len(dates):
print(f"Cube shape does not match the number of dates {cube.extended_user_shape[0]}, {len(dates)}")
print("Requested dates", compress_dates(dates))
print("Cube dates", compress_dates(dates_in_data))

a = set(as_datetime(_) for _ in dates)
b = set(as_datetime(_) for _ in dates_in_data)

print("Missing dates", compress_dates(a - b))
print("Extra dates", compress_dates(b - a))

raise ValueError(
f"Cube shape does not match the number of dates {cube.extended_user_shape[0]}, {len(dates)}"
)

LOG.debug(f"Loading {shape=} in {self.data_array.shape=}")

def check_dates_in_data(lst, lst2):
Expand Down
68 changes: 67 additions & 1 deletion src/anemoi/datasets/dates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,45 @@
import datetime
import warnings

from anemoi.utils.dates import as_datetime


def _compress_dates(dates):
dates = sorted(dates)
if len(dates) < 3:
yield dates
return

prev = first = dates.pop(0)
curr = dates.pop(0)
delta = curr - prev
while curr - prev == delta:
prev = curr
if not dates:
break
curr = dates.pop(0)

yield (first, prev, delta)
if dates:
yield from _compress_dates([curr] + dates)


def compress_dates(dates):
dates = [as_datetime(_) for _ in dates]
result = []

for n in _compress_dates(dates):
if isinstance(n, list):
result.extend([str(_) for _ in n])
else:
result.append(" ".join([str(n[0]), "to", str(n[1]), "by", str(n[2])]))

return result


def print_dates(dates):
print(compress_dates(dates))


def no_time_zone(date):
return date.replace(tzinfo=None)
Expand All @@ -30,6 +69,27 @@ def normalize_date(x):
return x


def extend(x):

if isinstance(x, (list, tuple)):
for y in x:
yield from extend(y)
return

if isinstance(x, str):
if "/" in x:
start, end, step = x.split("/")
start = normalize_date(start)
end = normalize_date(end)
step = frequency_to_hours(step)
while start <= end:
yield start
start += datetime.timedelta(hours=step)
return

yield normalize_date(x)


class Dates:
"""Base class for date generation.
Expand Down Expand Up @@ -59,7 +119,7 @@ class Dates:
def __init__(self, missing=None):
if not missing:
missing = []
self.missing = [normalize_date(x) for x in missing]
self.missing = list(extend(missing))
if set(self.missing) - set(self.values):
warnings.warn(f"Missing dates {self.missing} not in list.")

Expand Down Expand Up @@ -145,3 +205,9 @@ def as_dict(self):
"end": self.end.isoformat(),
"frequency": f"{self.frequency}h",
}


if __name__ == "__main__":
print_dates([datetime.datetime(2023, 1, 1, 0, 0)])
s = StartEndDates(start="2023-01-01 00:00", end="2023-01-02 00:00", frequency=1)
print_dates(list(s))

0 comments on commit f2c6739

Please sign in to comment.