Skip to content

Commit

Permalink
Fix pre-commit regex (#97)
Browse files Browse the repository at this point in the history
* Fix pre-commit regex

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix new ruff issues

* Fix more ruff issues

* Move to Pathlib

* noqa

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Harrison Cook <[email protected]>
Co-authored-by: Jesper Dramsch <[email protected]>
  • Loading branch information
4 people authored Nov 5, 2024
1 parent 49ce37e commit 5ade906
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 33 deletions.
3 changes: 1 addition & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ repos:
rev: v0.6.9
hooks:
- id: ruff
# Next line if for documenation cod snippets
exclude: '^[^_].*_\.py$'
args:
- --line-length=120
- --fix
- --exit-non-zero-on-fix
- --preview
- --exclude=docs/**/*_.py
- repo: https://github.com/sphinx-contrib/sphinx-lint
rev: v1.0.0
hooks:
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ Keep it human-readable, your future self will thank you!
- Feature: New `Boolean1DMask` class. Enables rollout training for limited area models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79)

### Fixed

- Fix pre-commit regex
- Mlflow-sync to handle creation of new experiments in the remote server [#83] (https://github.com/ecmwf/anemoi-training/pull/83)
- Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99] (https://github.com/ecmwf/anemoi-training/pull/99)
- ci: fix pyshtools install error (#100) https://github.com/ecmwf/anemoi-training/pull/100
Expand Down
19 changes: 14 additions & 5 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,12 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# Configuration file for the Sphinx documentation builder.
#
# This file only contains a selection of the most common options. For a full
Expand All @@ -13,10 +22,11 @@
import datetime
import os
import sys
from pathlib import Path

read_the_docs_build = os.environ.get("READTHEDOCS", None) == "True"

sys.path.insert(0, os.path.join(os.path.abspath(".."), "src"))
sys.path.insert(0, Path("..").absolute() / "src")


source_suffix = ".rst"
Expand All @@ -32,11 +42,10 @@

author = "Anemoi contributors"

year = datetime.datetime.now().year
year = datetime.datetime.now(tz="UTC").year
years = "2024" if year == 2024 else f"2024-{year}"

copyright = f"{years}, Anemoi contributors"

copyright = f"{years}, Anemoi contributors" # noqa: A001

try:
from anemoi.training._version import __version__
Expand Down Expand Up @@ -64,7 +73,7 @@
]

# Add any paths that contain templates here, relative to this directory.
# templates_path = ["_templates"]
# templates_path = ["_templates"] # noqa: ERA001

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/training/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
try:
# NOTE: the `_version.py` file must not be present in the git repository
# as it is generated by setuptools at install time
from ._version import __version__ # type: ignore
from ._version import __version__
except ImportError: # pragma: no cover
# Local copy or not installed with setuptools
__version__ = "999"
39 changes: 18 additions & 21 deletions src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@
from __future__ import annotations

import logging
from collections.abc import Iterable
from datetime import timedelta
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable

from hydra.utils import instantiate
from omegaconf import DictConfig
Expand All @@ -29,8 +30,8 @@
LOGGER = logging.getLogger(__name__)


def nestedget(conf: DictConfig, key, default):
"""Get a nested key from a DictConfig object
def nestedget(conf: DictConfig, key: str, default: Any) -> Any:
"""Get a nested key from a DictConfig object.
E.g.
>>> nestedget(config, "diagnostics.log.wandb.enabled", False)
Expand All @@ -56,7 +57,7 @@ def nestedget(conf: DictConfig, key, default):


def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | None:
"""Get checkpointing callback"""
"""Get checkpointing callback."""
if not config.diagnostics.get("enable_checkpointing", True):
return []

Expand Down Expand Up @@ -114,28 +115,24 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non
),
]
)
else:
LOGGER.debug("Not setting up a checkpoint callback with %s", save_key)
LOGGER.debug("Not setting up a checkpoint callback with %s", save_key)
else:
# the tensorboard logger + pytorch profiler cause pickling errors when writing checkpoints
LOGGER.warning("Profiling is enabled - will not write any training or inference model checkpoints!")
return None


def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]:
"""Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS
Provides backwards compatibility
"""
"""Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS."""
callbacks = []

def check_key(config, key: str | Iterable[str] | Callable[[DictConfig], bool]):
def check_key(config: dict, key: str | Iterable[str] | Callable[[DictConfig], bool]) -> bool:
"""Check key in config."""
if isinstance(key, Callable):
return key(config)
elif isinstance(key, str):
if isinstance(key, str):
return nestedget(config, key, False)
elif isinstance(key, Iterable):
if isinstance(key, Iterable):
return all(nestedget(config, k, False) for k in key)
return nestedget(config, key, False)

Expand All @@ -146,7 +143,7 @@ def check_key(config, key: str | Iterable[str] | Callable[[DictConfig], bool]):
return callbacks


def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901
def get_callbacks(config: DictConfig) -> list[Callback]:
"""Setup callbacks for PyTorch Lightning trainer.
Set `config.diagnostics.callbacks` to a list of callback configurations
Expand Down Expand Up @@ -180,7 +177,6 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901
A list of PyTorch Lightning callbacks
"""

trainer_callbacks: list[Callback] = []

# Get Checkpoint callback
Expand All @@ -189,14 +185,15 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901
trainer_callbacks.extend(checkpoint_callback)

# Base callbacks
for callback in config.diagnostics.get("callbacks", None) or []:
# Instantiate new callbacks
trainer_callbacks.append(instantiate(callback, config))
trainer_callbacks.extend(
instantiate(callback, config) for callback in config.diagnostics.get("callbacks", None) or []
)

# Plotting callbacks
for callback in config.diagnostics.plot.get("callbacks", None) or []:
# Instantiate new callbacks
trainer_callbacks.append(instantiate(callback, config))

trainer_callbacks.extend(
instantiate(callback, config) for callback in config.diagnostics.plot.get("callbacks", None) or []
)

# Extend with config enabled callbacks
trainer_callbacks.extend(_get_config_enabled_callbacks(config))
Expand Down
6 changes: 2 additions & 4 deletions src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,7 @@ def get_loss_function(
scalars: Union[dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]], None] = None, # noqa: FA100
**kwargs,
) -> Union[torch.nn.Module, torch.nn.ModuleList]: # noqa: FA100
"""
Get loss functions from config.
"""Get loss functions from config.
Can be ModuleList if multiple losses are specified.
Expand Down Expand Up @@ -329,8 +328,7 @@ def rollout_step(
training_mode: bool = True,
validation_mode: bool = False,
) -> Generator[tuple[Union[torch.Tensor, None], dict, list], None, None]: # noqa: FA100
"""
Rollout step for the forecaster.
"""Rollout step for the forecaster.
Will run pre_processors on batch, but not post_processors on predictions.
Expand Down

0 comments on commit 5ade906

Please sign in to comment.