Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/scaling #36

Merged
merged 15 commits into from
Sep 21, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ Keep it human-readable, your future self will thank you!
## [Unreleased]

### Added

- New `rescale` keyword in `open_dataset` to change units of variables #36

### Changed

- Added incremental building of datasets
Expand Down
30 changes: 30 additions & 0 deletions docs/using/code/rescale_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Scale and offset can be passed as a dictionnary...

ds = open_dataset(
dataset,
rescale={"2t": {"scale": 1.0, "offset": -273.15}},
)

# ... a tuple of floating points ....

ds = open_dataset(
dataset,
rescale={"2t": (1.0, -273.15)},
)

# ... or a tuple of strings representing units.

ds = open_dataset(
dataset,
rescale={"2t": ("K", "degC")},
)

# Several variables can be rescaled at once.

ds = open_dataset(
dataset,
rescale={
"2t": ("K", "degC"),
"tp": ("m", "mm"),
},
)
25 changes: 25 additions & 0 deletions docs/using/selecting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,28 @@ You can also rename variables:

This will be useful when you join datasets and do not want variables
from one dataset to override the ones from the other.

*********
rescale
*********

When combining datasets, you may want to rescale the variables so that
their have matching units. This can be done with the `rescale` option:

.. literalinclude:: code/rescale_.py
:language: python

The `rescale` option will also rescale the statistics. The rescaling is
currently limited to simple linear conversions.

When provided with units, the `rescale` option uses the cfunits_ package
find the `scale` and `offset` attributes of the units and uses these to
rescale the data.

.. warning::

When providing units, the library assumes that the mapping between
them is a linear transformation. No check is does to ensure this is
the case.

.. _cfunits: https://github.com/NCAS-CMS/cfunits
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ dynamic = [
"version",
]
dependencies = [
"anemoi-utils[provenance]>=0.3.9",
"anemoi-utils[provenance]>=0.3.15",
"cfunits",
"numpy",
"pyyaml",
"semantic-version",
Expand Down
8 changes: 5 additions & 3 deletions src/anemoi/datasets/create/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -926,7 +926,7 @@ def __init__(self, *args, **kwargs):
self.actors.append(cls(*args, delta=k, **kwargs))

if not self.actors:
LOG.warning("No delta found in kwargs, no addtions will be computed.")
LOG.warning("No delta found in kwargs, no additions will be computed.")

def run(self):
for actor in self.actors:
Expand Down Expand Up @@ -954,7 +954,9 @@ def run(self):
)
start, end = np.datetime64(start), np.datetime64(end)
dates = self.dataset.anemoi_dataset.dates
assert type(dates[0]) == type(start), (type(dates[0]), type(start))

assert type(dates[0]) is type(start), (type(dates[0]), type(start))

dates = [d for d in dates if d >= start and d <= end]
dates = [d for i, d in enumerate(dates) if i not in self.dataset.anemoi_dataset.missing]
variables = self.dataset.anemoi_dataset.variables
Expand All @@ -963,7 +965,7 @@ def run(self):
LOG.info(stats)

if not all(self.registry.get_flags(sync=False)):
raise Exception(f"❗Zarr {self.path} is not fully built, not writting statistics into dataset.")
raise Exception(f"❗Zarr {self.path} is not fully built, not writing statistics into dataset.")

for k in ["mean", "stdev", "minimum", "maximum", "sums", "squares", "count", "has_nans"]:
self.dataset.add_dataset(name=k, array=stats[k], dimensions=("variable",))
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/datasets/create/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def set_to_test_mode(cfg):
dates = cfg["dates"]
LOG.warning(f"Running in test mode. Changing the list of dates to use only {NUMBER_OF_DATES}.")
groups = Groups(**LoadersConfig(cfg).dates)

dates = groups.dates
cfg["dates"] = dict(
start=dates[0],
Expand Down
12 changes: 11 additions & 1 deletion src/anemoi/datasets/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@
class Dataset:
arguments = {}

def mutate(self):
def mutate(self) -> "Dataset":
"""
Give an opportunity to a subclass to return a new Dataset
object of a different class, if needed.
"""
return self

def swap_with_parent(self, parent):
Expand Down Expand Up @@ -90,6 +94,12 @@ def _subset(self, **kwargs):
rename = kwargs.pop("rename")
return Rename(self, rename)._subset(**kwargs).mutate()

if "rescale" in kwargs:
from .rescale import Rescale

rescale = kwargs.pop("rescale")
return Rescale(self, rescale)._subset(**kwargs).mutate()

if "statistics" in kwargs:
from ..data import open_dataset
from .statistics import Statistics
Expand Down
147 changes: 147 additions & 0 deletions src/anemoi/datasets/data/rescale.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

import logging
from functools import cached_property

import numpy as np

from .debug import Node
from .debug import debug_indexing
from .forwards import Forwards
from .indexing import apply_index_to_slices_changes
from .indexing import expand_list_indexing
from .indexing import index_to_slices
from .indexing import update_tuple

LOG = logging.getLogger(__name__)


def make_rescale(variable, rescale):

if isinstance(rescale, (tuple, list)):

assert len(rescale) == 2, rescale

if isinstance(rescale[0], (int, float)):
return rescale

from cfunits import Units

u0 = Units(rescale[0])
u1 = Units(rescale[1])

x1, x2 = 0.0, 1.0
y1, y2 = Units.conform([x1, x2], u0, u1)

a = (y2 - y1) / (x2 - x1)
b = y1 - a * x1

return a, b

return rescale

if isinstance(rescale, dict):
assert "scale" in rescale, rescale
assert "offset" in rescale, rescale
return rescale["scale"], rescale["offset"]

assert False


class Rescale(Forwards):
def __init__(self, dataset, rescale):
super().__init__(dataset)
for n in rescale:
assert n in dataset.variables, n

variables = dataset.variables

self._a = np.ones(len(variables))
self._b = np.zeros(len(variables))

self.rescale = {}
for i, v in enumerate(variables):
if v in rescale:
a, b = make_rescale(v, rescale[v])
self.rescale[v] = a, b
self._a[i], self._b[i] = a, b

self._a = self._a[np.newaxis, :, np.newaxis, np.newaxis]
self._b = self._b[np.newaxis, :, np.newaxis, np.newaxis]

self._a = self._a.astype(self.forward.dtype)
self._b = self._b.astype(self.forward.dtype)

def tree(self):
return Node(self, [self.forward.tree()], rescale=self.rescale)

def subclass_metadata_specific(self):
return dict(rescale=self.rescale)

@debug_indexing
@expand_list_indexing
def _get_tuple(self, index):
index, changes = index_to_slices(index, self.shape)
index, previous = update_tuple(index, 1, slice(None))
result = self.forward[index]
result = result * self._a + self._b
result = result[:, previous]
result = apply_index_to_slices_changes(result, changes)
return result

@debug_indexing
def __get_slice_(self, n):
data = self.forward[n]
return data * self._a + self._b

@debug_indexing
def __getitem__(self, n):

if isinstance(n, tuple):
return self._get_tuple(n)

if isinstance(n, slice):
return self.__get_slice_(n)

data = self.forward[n]

return data * self._a[0] + self._b[0]

@cached_property
def statistics(self):
result = {}
a = self._a.squeeze()
assert np.all(a >= 0)

b = self._b.squeeze()
for k, v in self.forward.statistics.items():
if k in ("maximum", "minimum", "mean"):
result[k] = v * a + b
continue

if k in ("stdev",):
result[k] = v * a
continue

raise NotImplementedError("rescale statistics", k)

return result

def statistics_tendencies(self, delta=None):
result = {}
a = self._a.squeeze()
assert np.all(a >= 0)

for k, v in self.forward.statistics_tendencies(delta).items():
if k in ("maximum", "minimum", "mean", "stdev"):
result[k] = v * a
continue

raise NotImplementedError("rescale tendencies statistics", k)

return result
Loading