diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e01d6a37..4de59323 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -43,13 +43,12 @@ repos: rev: v0.6.9 hooks: - id: ruff - # Next line if for documenation cod snippets - exclude: '^[^_].*_\.py$' args: - --line-length=120 - --fix - --exit-non-zero-on-fix - --preview + - --exclude=docs/**/*_.py - repo: https://github.com/sphinx-contrib/sphinx-lint rev: v1.0.0 hooks: diff --git a/CHANGELOG.md b/CHANGELOG.md index 5e756cd6..b10243c6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,8 @@ Keep it human-readable, your future self will thank you! - Feature: New `Boolean1DMask` class. Enables rollout training for limited area models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79) ### Fixed + +- Fix pre-commit regex - Mlflow-sync to handle creation of new experiments in the remote server [#83] (https://github.com/ecmwf/anemoi-training/pull/83) - Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99] (https://github.com/ecmwf/anemoi-training/pull/99) - ci: fix pyshtools install error (#100) https://github.com/ecmwf/anemoi-training/pull/100 diff --git a/docs/conf.py b/docs/conf.py index 4756af86..dc7a24d9 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -1,3 +1,12 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full @@ -13,10 +22,11 @@ import datetime import os import sys +from pathlib import Path read_the_docs_build = os.environ.get("READTHEDOCS", None) == "True" -sys.path.insert(0, os.path.join(os.path.abspath(".."), "src")) +sys.path.insert(0, Path("..").absolute() / "src") source_suffix = ".rst" @@ -32,11 +42,10 @@ author = "Anemoi contributors" -year = datetime.datetime.now().year +year = datetime.datetime.now(tz="UTC").year years = "2024" if year == 2024 else f"2024-{year}" -copyright = f"{years}, Anemoi contributors" - +copyright = f"{years}, Anemoi contributors" # noqa: A001 try: from anemoi.training._version import __version__ @@ -64,7 +73,7 @@ ] # Add any paths that contain templates here, relative to this directory. -# templates_path = ["_templates"] +# templates_path = ["_templates"] # noqa: ERA001 # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. diff --git a/src/anemoi/training/__init__.py b/src/anemoi/training/__init__.py index 7b9efcd6..af4a3aea 100644 --- a/src/anemoi/training/__init__.py +++ b/src/anemoi/training/__init__.py @@ -11,7 +11,7 @@ try: # NOTE: the `_version.py` file must not be present in the git repository # as it is generated by setuptools at install time - from ._version import __version__ # type: ignore + from ._version import __version__ except ImportError: # pragma: no cover # Local copy or not installed with setuptools __version__ = "999" diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 3968da77..f6a7b894 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -10,10 +10,11 @@ from __future__ import annotations import logging +from collections.abc import Iterable from datetime import timedelta from typing import TYPE_CHECKING +from typing import Any from typing import Callable -from typing import Iterable from hydra.utils import instantiate from omegaconf import DictConfig @@ -29,8 +30,8 @@ LOGGER = logging.getLogger(__name__) -def nestedget(conf: DictConfig, key, default): - """Get a nested key from a DictConfig object +def nestedget(conf: DictConfig, key: str, default: Any) -> Any: + """Get a nested key from a DictConfig object. E.g. >>> nestedget(config, "diagnostics.log.wandb.enabled", False) @@ -56,7 +57,7 @@ def nestedget(conf: DictConfig, key, default): def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | None: - """Get checkpointing callback""" + """Get checkpointing callback.""" if not config.diagnostics.get("enable_checkpointing", True): return [] @@ -114,8 +115,7 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non ), ] ) - else: - LOGGER.debug("Not setting up a checkpoint callback with %s", save_key) + LOGGER.debug("Not setting up a checkpoint callback with %s", save_key) else: # the tensorboard logger + pytorch profiler cause pickling errors when writing checkpoints LOGGER.warning("Profiling is enabled - will not write any training or inference model checkpoints!") @@ -123,19 +123,16 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: - """Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS - - Provides backwards compatibility - """ + """Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS.""" callbacks = [] - def check_key(config, key: str | Iterable[str] | Callable[[DictConfig], bool]): + def check_key(config: dict, key: str | Iterable[str] | Callable[[DictConfig], bool]) -> bool: """Check key in config.""" if isinstance(key, Callable): return key(config) - elif isinstance(key, str): + if isinstance(key, str): return nestedget(config, key, False) - elif isinstance(key, Iterable): + if isinstance(key, Iterable): return all(nestedget(config, k, False) for k in key) return nestedget(config, key, False) @@ -146,7 +143,7 @@ def check_key(config, key: str | Iterable[str] | Callable[[DictConfig], bool]): return callbacks -def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901 +def get_callbacks(config: DictConfig) -> list[Callback]: """Setup callbacks for PyTorch Lightning trainer. Set `config.diagnostics.callbacks` to a list of callback configurations @@ -180,7 +177,6 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901 A list of PyTorch Lightning callbacks """ - trainer_callbacks: list[Callback] = [] # Get Checkpoint callback @@ -189,14 +185,15 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901 trainer_callbacks.extend(checkpoint_callback) # Base callbacks - for callback in config.diagnostics.get("callbacks", None) or []: - # Instantiate new callbacks - trainer_callbacks.append(instantiate(callback, config)) + trainer_callbacks.extend( + instantiate(callback, config) for callback in config.diagnostics.get("callbacks", None) or [] + ) # Plotting callbacks - for callback in config.diagnostics.plot.get("callbacks", None) or []: - # Instantiate new callbacks - trainer_callbacks.append(instantiate(callback, config)) + + trainer_callbacks.extend( + instantiate(callback, config) for callback in config.diagnostics.plot.get("callbacks", None) or [] + ) # Extend with config enabled callbacks trainer_callbacks.extend(_get_config_enabled_callbacks(config)) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 705aa2ff..ebba3c6d 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -157,8 +157,7 @@ def get_loss_function( scalars: Union[dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]], None] = None, # noqa: FA100 **kwargs, ) -> Union[torch.nn.Module, torch.nn.ModuleList]: # noqa: FA100 - """ - Get loss functions from config. + """Get loss functions from config. Can be ModuleList if multiple losses are specified. @@ -329,8 +328,7 @@ def rollout_step( training_mode: bool = True, validation_mode: bool = False, ) -> Generator[tuple[Union[torch.Tensor, None], dict, list], None, None]: # noqa: FA100 - """ - Rollout step for the forecaster. + """Rollout step for the forecaster. Will run pre_processors on batch, but not post_processors on predictions.