Skip to content

Commit

Permalink
use config
Browse files Browse the repository at this point in the history
  • Loading branch information
b8raoult committed Mar 29, 2024
1 parent d5c1a77 commit 2341df0
Show file tree
Hide file tree
Showing 15 changed files with 104 additions and 55 deletions.
4 changes: 2 additions & 2 deletions anemoi/datasets/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class MissingDateError(Exception):
pass


def open_dataset(*args, zarr_root=None, **kwargs):
ds = _open_dataset(*args, zarr_root=zarr_root, **kwargs)
def open_dataset(*args, **kwargs):
ds = _open_dataset(*args, **kwargs)
ds.arguments = {"args": args, "kwargs": kwargs}
return ds
4 changes: 2 additions & 2 deletions anemoi/datasets/data/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def tree(self):
return Node(self, [d.tree() for d in self.datasets])


def ensemble_factory(args, kwargs, zarr_root):
def ensemble_factory(args, kwargs):
if "grids" in kwargs:
raise NotImplementedError("Cannot use both 'ensemble' and 'grids'")

Expand All @@ -29,7 +29,7 @@ def ensemble_factory(args, kwargs, zarr_root):
assert len(args) == 0
assert isinstance(ensemble, (list, tuple))

datasets = [_open(e, zarr_root) for e in ensemble]
datasets = [_open(e) for e in ensemble]
datasets, kwargs = _auto_adjust(datasets, kwargs)

return Ensemble(datasets, axis=axis)._subset(**kwargs)
4 changes: 2 additions & 2 deletions anemoi/datasets/data/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def tree(self):
return Node(self, [d.tree() for d in self.datasets], mode="cutout")


def grids_factory(args, kwargs, zarr_root):
def grids_factory(args, kwargs):
if "ensemble" in kwargs:
raise NotImplementedError("Cannot use both 'ensemble' and 'grids'")

Expand All @@ -224,7 +224,7 @@ def grids_factory(args, kwargs, zarr_root):
if mode not in KLASSES:
raise ValueError(f"Unknown grids mode: {mode}, values are {list(KLASSES.keys())}")

datasets = [_open(e, zarr_root) for e in grids]
datasets = [_open(e) for e in grids]
datasets, kwargs = _auto_adjust(datasets, kwargs)

return KLASSES[mode](datasets, axis=axis)._subset(**kwargs)
4 changes: 2 additions & 2 deletions anemoi/datasets/data/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,15 +140,15 @@ def tree(self):
return Node(self, [d.tree() for d in self.datasets])


def join_factory(args, kwargs, zarr_root):
def join_factory(args, kwargs):

datasets = kwargs.pop("join")
assert isinstance(datasets, (list, tuple))
assert len(args) == 0

assert isinstance(datasets, (list, tuple))

datasets = [_open(e, zarr_root) for e in datasets]
datasets = [_open(e) for e in datasets]

if len(datasets) == 1:
return datasets[0]._subset(**kwargs)
Expand Down
63 changes: 40 additions & 23 deletions anemoi/datasets/data/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,39 @@
from pathlib import PurePath

import numpy as np
import yaml
import zarr

from .dataset import Dataset

LOG = logging.getLogger(__name__)

CONFIG = None

def _name_to_path(name, zarr_root):
_, ext = os.path.splitext(name)
try:
import tomllib # Only available since 3.11
except ImportError:
import tomli as tomllib

if ext in (".zarr", ".zip"):
return name

if zarr_root is None:
with open(os.path.expanduser("~/.anemoi.datasets")) as f:
zarr_root = yaml.safe_load(f)["zarr_root"]
def load_config():
global CONFIG
if CONFIG is not None:
return CONFIG

return os.path.join(zarr_root, name + ".zarr")
conf = os.path.expanduser("~/.anemoi.toml")

if os.path.exists(conf):

with open(conf, "rb") as f:
CONFIG = tomllib.load(f)
else:
CONFIG = {}

CONFIG.setdefault("datasets", {})
CONFIG["datasets"].setdefault("lookup", [])
CONFIG["datasets"].setdefault("named", {})

return CONFIG


def _frequency_to_hours(frequency):
Expand Down Expand Up @@ -182,8 +196,9 @@ def _concat_or_join(datasets, kwargs):
return Concat(datasets), kwargs


def _open(a, zarr_root):
def _open(a):
from .stores import Zarr
from .stores import zarr_lookup

if isinstance(a, Dataset):
return a
Expand All @@ -192,16 +207,16 @@ def _open(a, zarr_root):
return Zarr(a).mutate()

if isinstance(a, str):
return Zarr(_name_to_path(a, zarr_root)).mutate()
return Zarr(zarr_lookup(a)).mutate()

if isinstance(a, PurePath):
return Zarr(str(a)).mutate()
return _open(str(a))

if isinstance(a, dict):
return _open_dataset(zarr_root=zarr_root, **a)
return _open_dataset(**a)

if isinstance(a, (list, tuple)):
return _open_dataset(*a, zarr_root=zarr_root)
return _open_dataset(*a)

raise NotImplementedError()

Expand Down Expand Up @@ -276,52 +291,54 @@ def _auto_adjust(datasets, kwargs):
return datasets, kwargs


def _open_dataset(*args, zarr_root, **kwargs):
def _open_dataset(*args, **kwargs):
sets = []
for a in args:
sets.append(_open(a, zarr_root))
sets.append(_open(a))

if "zip" in kwargs:
from .unchecked import zip_factory

assert not sets, sets
return zip_factory(args, kwargs, zarr_root)
return zip_factory(args, kwargs)

if "chain" in kwargs:
from .unchecked import chain_factory

assert not sets, sets
return chain_factory(args, kwargs, zarr_root)
return chain_factory(args, kwargs)

if "join" in kwargs:
from .join import join_factory

assert not sets, sets
return join_factory(args, kwargs, zarr_root)
return join_factory(args, kwargs)

if "concat" in kwargs:
from .concat import concat_factory

assert not sets, sets
return concat_factory(args, kwargs, zarr_root)
return concat_factory(args, kwargs)

if "ensemble" in kwargs:
from .ensemble import ensemble_factory

assert not sets, sets
return ensemble_factory(args, kwargs, zarr_root)
return ensemble_factory(args, kwargs)

if "grids" in kwargs:
from .grids import grids_factory

assert not sets, sets
return grids_factory(args, kwargs, zarr_root)
return grids_factory(args, kwargs)

for name in ("datasets", "dataset"):
if name in kwargs:
datasets = kwargs.pop(name)
if not isinstance(datasets, (list, tuple)):
datasets = [datasets]
for a in datasets:
sets.append(_open(a, zarr_root))
sets.append(_open(a))

assert len(sets) > 0, (args, kwargs)

Expand Down
34 changes: 32 additions & 2 deletions anemoi/datasets/data/stores.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from .debug import Source
from .debug import debug_indexing
from .indexing import expand_list_indexing
from .misc import load_config

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -99,7 +100,7 @@ def __contains__(self, key):
return key in self.store


def open_zarr(path):
def open_zarr(path, silent=False):
try:
store = path

Expand All @@ -116,7 +117,8 @@ def open_zarr(path):

return zarr.convenience.open(store, "r")
except Exception:
LOG.exception("Failed to open %r", path)
if not silent:
LOG.exception("Failed to open %r", path)
raise


Expand All @@ -135,6 +137,12 @@ def __init__(self, path):
self.data = self.z.data
self.missing = set()

@classmethod
def from_name(cls, name):
if name.endswith(".zip") or name.endswith(".zarr"):
return Zarr(name)
return Zarr(zarr_lookup(name))

def __len__(self):
return self.data.shape[0]

Expand Down Expand Up @@ -318,3 +326,25 @@ def _report_missing(self, n):

def tree(self):
return Node(self, [], path=self.path, missing=sorted(self.missing))


def zarr_lookup(name):
config = load_config()["datasets"]
if name in config["named"]:
return zarr_lookup(config["named"][name])

tried = []
for location in config["lookup"]:
if not location.endswith("/"):
location += "/"
full = location + name + ".zarr"
tried.append(full)
try:
open_zarr(full, silent=True)
# Cache for next time
config["named"][name] = full
return full
except zarr.errors.PathNotFoundError:
pass

raise ValueError(f"Cannot find a dataset that matched '{name}'. Tried: {tried}")
8 changes: 4 additions & 4 deletions anemoi/datasets/data/unchecked.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,25 +146,25 @@ def dates(self):
raise NotImplementedError()


def zip_factory(args, kwargs, zarr_root):
def zip_factory(args, kwargs):

zip = kwargs.pop("zip")
assert len(args) == 0
assert isinstance(zip, (list, tuple))

datasets = [_open(e, zarr_root) for e in zip]
datasets = [_open(e) for e in zip]
datasets, kwargs = _auto_adjust(datasets, kwargs)

return Zip(datasets)._subset(**kwargs)


def chain_factory(args, kwargs, zarr_root):
def chain_factory(args, kwargs):

chain = kwargs.pop("chain")
assert len(args) == 0
assert isinstance(chain, (list, tuple))

datasets = [_open(e, zarr_root) for e in chain]
datasets = [_open(e) for e in chain]
datasets, kwargs = _auto_adjust(datasets, kwargs)

return Chain(datasets)._subset(**kwargs)
2 changes: 2 additions & 0 deletions anemoi/datasets/utils/dates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@
import datetime
import warnings


def no_time_zone(date):
return date.replace(tzinfo=None)


def frequency_to_hours(frequency):
if isinstance(frequency, int):
return frequency
Expand Down
12 changes: 2 additions & 10 deletions docs/using/configuration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,5 @@
..
corresponding path.
.. code:: toml
[datasets]
dataset1 = "/path/to/dataset1.zarr"
dataset2 = "/path/to/dataset2.zarr"
"*" = "/path/to/{name}.zarr"
[dataset.path]
- name = "dataset1"
path = "/path/to/dataset1.zarr"
.. literalinclude:: configuration.toml
:language: toml
5 changes: 5 additions & 0 deletions docs/using/configuration.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[datasets]
lookup = ["a", "b", "c"]

[datasets.named]
name1 = "foo"
3 changes: 3 additions & 0 deletions docs/using/selecting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
Selecting variables
#####################

Selecting is the action of filtering the dataset by its second dimension (variables).


.. _select:

********
Expand Down
2 changes: 2 additions & 0 deletions docs/using/subsetting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
Subsetting datasets
#####################

Subsetting is the action of filtering the dataset by its first dimension (dates).

.. _start:

*******
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def read(fname):


data_requires = [
"tomli", # Only needed before 3.11
"zarr",
"pyyaml",
"numpy",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def mockup_open_zarr(func):
@wraps(func)
def wrapper(*args, **kwargs):
with patch("zarr.convenience.open", zarr_from_str):
with patch("anemoi.datasets.data.misc._name_to_path", lambda name, zarr_root: name):
with patch("anemoi.datasets.data.stores.zarr_lookup", lambda name: name):
return func(*args, **kwargs)

return wrapper
Expand Down
Loading

0 comments on commit 2341df0

Please sign in to comment.