diff --git a/CHANGELOG.md b/CHANGELOG.md index f8d107fe..2fda8c64 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,10 @@ Keep it human-readable, your future self will thank you! ## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28 +### Fixed +- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) +- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87) + - Enable longer validation rollout than training ### Added - Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70) @@ -29,7 +33,9 @@ Keep it human-readable, your future self will thank you! - Feature: New `Boolean1DMask` class. Enables rollout training for limited area models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79) ### Fixed - +- Mlflow-sync to handle creation of new experiments in the remote server [#83] (https://github.com/ecmwf/anemoi-training/pull/83) +- Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99] (https://github.com/ecmwf/anemoi-training/pull/99) +- ci: fix pyshtools install error (#100) https://github.com/ecmwf/anemoi-training/pull/100 - Mlflow-sync to handle creation of new experiments in the remote server [#83](https://github.com/ecmwf/anemoi-training/pull/83) - Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99](https://github.com/ecmwf/anemoi-training/pull/99) - ci: fix pyshtools install error [#100](https://github.com/ecmwf/anemoi-training/pull/100) diff --git a/docs/modules/diagnostics.rst b/docs/modules/diagnostics.rst index c46f201c..28eac7c7 100644 --- a/docs/modules/diagnostics.rst +++ b/docs/modules/diagnostics.rst @@ -21,23 +21,94 @@ functionality to use both Weights & Biases and Tensorboard. The callbacks can also be used to evaluate forecasts over longer rollouts beyond the forecast time that the model is trained on. The -number of rollout steps (or forecast iteration steps) is set using -``config.eval.rollout = *num_of_rollout_steps*``. - -Note the user has the option to evaluate the callbacks asynchronously -(using the following config option -``config.diagnostics.plot.asynchronous``, which means that the model -training doesn't stop whilst the callbacks are being evaluated). -However, note that callbacks can still be slow, and therefore the -plotting callbacks can be switched off by setting -``config.diagnostics.plot.enabled`` to ``False`` or all the callbacks -can be completely switched off by setting -``config.diagnostics.eval.enabled`` to ``False``. +number of rollout steps for verification (or forecast iteration steps) +is set using ``config.dataloader.validation_rollout = +*num_of_rollout_steps*``. + +Callbacks are configured in the config file under the +``config.diagnostics`` key. + +For regular callbacks, they can be provided as a list of dictionaries +underneath the ``config.diagnostics.callbacks`` key. Each dictionary +must have a ``_target`` key which is used by hydra to instantiate the +callback, any other kwarg is passed to the callback's constructor. + +.. code:: yaml + + callbacks: + - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval + rollout: ${dataloader.validation_rollout} + frequency: 20 + +Plotting callbacks are configured in a similar way, but they are +specified underneath the ``config.diagnostics.plot.callbacks`` key. + +This is done to ensure seperation and ease of configuration between +experiments. + +``config.diagnostics.plot`` is a broader config file specifying the +parameters to plot, as well as the plotting frequency, and +asynchronosity. + +Setting ``config.diagnostics.plot.asynchronous``, means that the model +training doesn't stop whilst the callbacks are being evaluated) + +.. code:: yaml + + plot: + asynchronous: True # Whether to plot asynchronously + frequency: # Frequency of the plotting + batch: 750 + epoch: 5 + + # Parameters to plot + parameters: + - z_500 + - t_850 + - u_850 + + # Sample index + sample_idx: 0 + + # Precipitation and related fields + precip_and_related_fields: [tp, cp] + + callbacks: + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: ${diagnostics.plot.sample_idx} + per_sample : 6 + parameters: ${diagnostics.plot.parameters} Below is the documentation for the default callbacks provided, but it is also possible for users to add callbacks using the same structure: -.. automodule:: anemoi.training.diagnostics.callbacks +.. automodule:: anemoi.training.diagnostics.callbacks.checkpoint + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.evaluation + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.optimiser + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.plot + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.provenance :members: :no-undoc-members: :show-inheritance: diff --git a/docs/user-guide/configuring.rst b/docs/user-guide/configuring.rst index 35efec3f..307cebf0 100644 --- a/docs/user-guide/configuring.rst +++ b/docs/user-guide/configuring.rst @@ -21,7 +21,7 @@ settings at the top as follows: defaults: - data: zarr - dataloader: native_grid - - diagnostics: eval_rollout + - diagnostics: evaluation - hardware: example - graph: multi_scale - model: gnn @@ -100,7 +100,7 @@ match the dataset you provide. defaults: - data: zarr - dataloader: native_grid - - diagnostics: eval_rollout + - diagnostics: evaluation - hardware: example - graph: multi_scale - model: transformer # Change from default group diff --git a/docs/user-guide/tracking.rst b/docs/user-guide/tracking.rst index cab5e851..f97182d7 100644 --- a/docs/user-guide/tracking.rst +++ b/docs/user-guide/tracking.rst @@ -33,7 +33,7 @@ the same experiment. Within the MLflow experiments tab, it is possible to define different namespaces. To create a new namespace, the user just needs to pass an 'experiment_name' -(``config.diagnostics.eval_rollout.log.mlflow.experiment_name``) to the +(``config.diagnostics.evaluation.log.mlflow.experiment_name``) to the mlflow logger. **Parent-Child Runs** diff --git a/src/anemoi/training/config/config.yaml b/src/anemoi/training/config/config.yaml index 2045da93..a379acfd 100644 --- a/src/anemoi/training/config/config.yaml +++ b/src/anemoi/training/config/config.yaml @@ -1,7 +1,7 @@ defaults: - data: zarr - dataloader: native_grid -- diagnostics: eval_rollout +- diagnostics: evaluation - hardware: example - graph: multi_scale - model: gnn diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index e6d50801..d7aa4f6d 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -45,6 +45,8 @@ training: frequency: ${data.frequency} drop: [] +validation_rollout: 1 # number of rollouts to use for validation, must be equal or greater than rollout expected by callbacks + validation: dataset: ${dataloader.dataset} start: 2021 diff --git a/src/anemoi/training/config/debug.yaml b/src/anemoi/training/config/debug.yaml index 5be3e9f4..a6143bb6 100644 --- a/src/anemoi/training/config/debug.yaml +++ b/src/anemoi/training/config/debug.yaml @@ -1,7 +1,7 @@ defaults: - data: zarr - dataloader: native_grid -- diagnostics: eval_rollout +- diagnostics: evaluation - hardware: example - graph: multi_scale - model: gnn @@ -18,7 +18,7 @@ defaults: diagnostics: plot: - enabled: False + callbacks: [] hardware: files: graph: ??? diff --git a/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml b/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml new file mode 100644 index 00000000..1eb35f69 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml @@ -0,0 +1 @@ +# Add callbacks here diff --git a/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml new file mode 100644 index 00000000..d7daf8d0 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml @@ -0,0 +1,4 @@ +# Add callbacks here +- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval + rollout: ${dataloader.validation_rollout} + frequency: 20 diff --git a/src/anemoi/training/config/diagnostics/eval_rollout.yaml b/src/anemoi/training/config/diagnostics/evaluation.yaml similarity index 53% rename from src/anemoi/training/config/diagnostics/eval_rollout.yaml rename to src/anemoi/training/config/diagnostics/evaluation.yaml index 50e9a647..9647fe1c 100644 --- a/src/anemoi/training/config/diagnostics/eval_rollout.yaml +++ b/src/anemoi/training/config/diagnostics/evaluation.yaml @@ -1,53 +1,8 @@ --- -eval: - enabled: False - # use this to evaluate the model over longer rollouts, every so many validation batches - rollout: 12 - frequency: 20 -plot: - enabled: True - asynchronous: True - frequency: 750 - sample_idx: 0 - per_sample: 6 - parameters: - - z_500 - - t_850 - - u_850 - - v_850 - - 2t - - 10u - - 10v - - sp - - tp - - cp - #Defining the accumulation levels for precipitation related fields and the colormap - accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm - cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] - precip_and_related_fields: [tp, cp] - # Histogram and Spectrum plots - parameters_histogram: - - z_500 - - tp - - 2t - - 10u - - 10v - parameters_spectrum: - - z_500 - - tp - - 2t - - 10u - - 10v - # group parameters by categories when visualizing contributions to the loss - # one-parameter groups are possible to highlight individual parameters - parameter_groups: - moisture: [tp, cp, tcw] - sfc_wind: [10u, 10v] - learned_features: False - longrollout: - enabled: False - rollout: [60] - frequency: 20 # every X epochs +defaults: + - plot: detailed + - callbacks: pretraining + debug: # this will detect and trace back NaNs / Infs etc. but will slow down training @@ -57,6 +12,7 @@ debug: # remember to also activate the tensorboard logger (below) profiler: False +enable_checkpointing: True checkpoint: every_n_minutes: save_frequency: 30 # Approximate, as this is checked at the end of training steps diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml new file mode 100644 index 00000000..6ff7875e --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -0,0 +1,62 @@ +asynchronous: True # Whether to plot asynchronously +frequency: # Frequency of the plotting + batch: 750 + epoch: 5 + +# Parameters to plot +parameters: +- z_500 +- t_850 +- u_850 +- v_850 +- 2t +- 10u +- 10v +- sp +- tp +- cp + +# Sample index +sample_idx: 0 + +# Precipitation and related fields +precip_and_related_fields: [tp, cp] + +callbacks: + # Add plot callbacks here + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot + epoch_frequency: 5 + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: ${diagnostics.plot.sample_idx} + per_sample : 6 + parameters: ${diagnostics.plot.parameters} + #Defining the accumulation levels for precipitation related fields and the colormap + accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm + cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum + # batch_frequency: 100 # Override for batch frequency + sample_idx: ${diagnostics.plot.sample_idx} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram + sample_idx: ${diagnostics.plot.sample_idx} + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v diff --git a/src/anemoi/training/config/diagnostics/plot/none.yaml b/src/anemoi/training/config/diagnostics/plot/none.yaml new file mode 100644 index 00000000..3101f292 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/none.yaml @@ -0,0 +1 @@ +callbacks: [] diff --git a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml new file mode 100644 index 00000000..7c01b575 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml @@ -0,0 +1,68 @@ +asynchronous: True # Whether to plot asynchronously +frequency: # Frequency of the plotting + batch: 750 + epoch: 5 + +# Parameters to plot +parameters: +- z_500 +- t_850 +- u_850 +- v_850 +- 2t +- 10u +- 10v +- sp +- tp +- cp + +# Sample index +sample_idx: 0 + +# Precipitation and related fields +precip_and_related_fields: [tp, cp] + +callbacks: + # Add plot callbacks here + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot + epoch_frequency: 5 + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: ${diagnostics.plot.sample_idx} + per_sample : 6 + parameters: ${diagnostics.plot.parameters} + #Defining the accumulation levels for precipitation related fields and the colormap + accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm + cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum + # batch_frequency: 100 # Override for batch frequency + sample_idx: ${diagnostics.plot.sample_idx} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram + sample_idx: ${diagnostics.plot.sample_idx} + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v + - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots + rollout: + - ${dataloader.validation_rollout} + epoch_frequency: 20 + sample_idx: ${diagnostics.plot.sample_idx} + parameters: ${diagnostics.plot.parameters} diff --git a/src/anemoi/training/config/diagnostics/plot/simple.yaml b/src/anemoi/training/config/diagnostics/plot/simple.yaml new file mode 100644 index 00000000..2a987ccb --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/simple.yaml @@ -0,0 +1,40 @@ +asynchronous: True # Whether to plot asynchronously +frequency: # Frequency of the plotting + batch: 750 + epoch: 10 + +# Parameters to plot +parameters: +- z_500 +- t_850 +- u_850 +- v_850 +- 2t +- 10u +- 10v +- sp +- tp +- cp + +# Sample index +sample_idx: 0 + +# Precipitation and related fields +precip_and_related_fields: [tp, cp] + +callbacks: + # Add plot callbacks here + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: ${diagnostics.plot.sample_idx} + per_sample : 6 + parameters: ${diagnostics.plot.parameters} + #Defining the accumulation levels for precipitation related fields and the colormap + accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm + cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index f64a3091..788e75c2 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -129,10 +129,8 @@ def ds_train(self) -> NativeGridDataset: @cached_property def ds_valid(self) -> NativeGridDataset: r = self.rollout - if self.config.diagnostics.eval.enabled: - r = max(r, self.config.diagnostics.eval.rollout) - if self.config.diagnostics.plot.get("longrollout") and self.config.diagnostics.plot.longrollout.enabled: - r = max(r, max(self.config.diagnostics.plot.longrollout.rollout)) + r = max(r, self.config.dataloader.get("validation_rollout", 1)) + assert self.config.dataloader.training.end < self.config.dataloader.validation.start, ( f"Training end date {self.config.dataloader.training.end} is not before" f"validation start date {self.config.dataloader.validation.start}" diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 72d4d242..4b0921f1 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -1,1063 +1,66 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# * [WHY ARE CALLBACKS UNDER __init__.py?] -# * This functionality will be restructured in the near future -# * so for now callbacks are under __init__.py - from __future__ import annotations -import copy import logging -import sys -import time -import traceback -import uuid -from abc import ABC -from abc import abstractmethod -from concurrent.futures import ThreadPoolExecutor -from contextlib import nullcontext from datetime import timedelta -from functools import cached_property -from pathlib import Path from typing import TYPE_CHECKING -from typing import Any from typing import Callable +from typing import Iterable -import matplotlib.patches as mpatches -import matplotlib.pyplot as plt -import numpy as np -import torch -import torchinfo -from anemoi.utils.checkpoints import save_metadata -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.utilities import rank_zero_only +from hydra.utils import instantiate +from omegaconf import DictConfig -from anemoi.training.diagnostics.plots import init_plot_settings -from anemoi.training.diagnostics.plots import plot_graph_features -from anemoi.training.diagnostics.plots import plot_histogram -from anemoi.training.diagnostics.plots import plot_loss -from anemoi.training.diagnostics.plots import plot_power_spectrum -from anemoi.training.diagnostics.plots import plot_predicted_multilevel_flat_sample -from anemoi.training.losses.weightedloss import BaseWeightedLoss +from anemoi.training.diagnostics.callbacks.checkpoint import AnemoiCheckpoint +from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor +from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging +from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback if TYPE_CHECKING: - import pytorch_lightning as pl - from omegaconf import DictConfig - from omegaconf import OmegaConf + from pytorch_lightning.callbacks import Callback LOGGER = logging.getLogger(__name__) -class ParallelExecutor(ThreadPoolExecutor): - """Wraps parallel execution and provides accurate information about errors. - - Extends ThreadPoolExecutor to preserve the original traceback and line number. - - Reference: https://stackoverflow.com/questions/19309514/getting-original-line- - number-for-exception-in-concurrent-futures/24457608#24457608 +def nestedget(conf: DictConfig, key, default): """ + Get a nested key from a DictConfig object - def submit(self, fn: Any, *args, **kwargs) -> Callable: - """Submits the wrapped function instead of `fn`.""" - return super().submit(self._function_wrapper, fn, *args, **kwargs) - - def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable: - """Wraps `fn` in order to preserve the traceback of any kind of.""" - try: - return fn(*args, **kwargs) - except Exception as exc: - raise sys.exc_info()[0](traceback.format_exc()) from exc - - -class BasePlotCallback(Callback, ABC): - """Factory for creating a callback that plots data to Experiment Logging.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the BasePlotCallback abstract base class. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__() - self.config = config - self.save_basedir = config.hardware.paths.plots - self.plot_frequency = config.diagnostics.plot.frequency - self.post_processors = None - self.pre_processors = None - self.latlons = None - init_plot_settings() - - self.plot = self._plot - self._executor = None - - if self.config.diagnostics.plot.asynchronous: - self._executor = ParallelExecutor(max_workers=1) - self._error: BaseException | None = None - self.plot = self._async_plot - - @rank_zero_only - def _output_figure( - self, - logger: pl.loggers.base.LightningLoggerBase, - fig: plt.Figure, - epoch: int, - tag: str = "gnn", - exp_log_tag: str = "val_pred_sample", - ) -> None: - """Figure output: save to file and/or display in notebook.""" - if self.save_basedir is not None: - save_path = Path( - self.save_basedir, - "plots", - f"{tag}_epoch{epoch:03d}.png", - ) - - save_path.parent.mkdir(parents=True, exist_ok=True) - fig.savefig(save_path, dpi=100, bbox_inches="tight") - if self.config.diagnostics.log.wandb.enabled: - import wandb - - logger.experiment.log({exp_log_tag: wandb.Image(fig)}) - - if self.config.diagnostics.log.mlflow.enabled: - run_id = logger.run_id - logger.experiment.log_artifact(run_id, str(save_path)) - - plt.close(fig) # cleanup - - def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: - """Method is called to close the threads.""" - del trainer, pl_module, stage # unused - if self._executor is not None: - self._executor.shutdown(wait=True) - - def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor: - if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None: - # Fill with NaNs values where the mask is False - data[:, :, ~pl_module.output_mask, :] = np.nan - - return data - - @abstractmethod - @rank_zero_only - def _plot( - *args: list, - **kwargs: dict, - ) -> None: ... - - @rank_zero_only - def _async_plot( - self, - trainer: pl.Trainer, - *args: list, - **kwargs: dict, - ) -> None: - """To execute the plot function but ensuring we catch any errors.""" - future = self._executor.submit( - self._plot, - trainer, - *args, - **kwargs, - ) - # otherwise the error won't be thrown till the validation epoch is finished - try: - future.result() - except Exception: - LOGGER.exception("Critical error occurred in asynchronous plots.") - sys.exit(1) - - -class RolloutEval(Callback): - """Evaluates the model performance over a (longer) rollout window.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialize RolloutEval callback. - - Parameters - ---------- - config : dict - Dictionary with configuration settings - - """ - super().__init__() - - LOGGER.debug( - "Setting up RolloutEval callback with rollout = %d, frequency = %d ...", - config.diagnostics.eval.rollout, - config.diagnostics.eval.frequency, - ) - self.rollout = config.diagnostics.eval.rollout - self.frequency = config.diagnostics.eval.frequency - - def _eval( - self, - pl_module: pl.LightningModule, - batch: torch.Tensor, - ) -> None: - loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False) - metrics = {} - - # start rollout - batch = pl_module.model.pre_processors(batch, in_place=False) - x = batch[ - :, - 0 : pl_module.multi_step, - ..., - pl_module.data_indices.internal_data.input.full, - ] # (bs, multi_step, latlon, nvar) - assert ( - batch.shape[1] >= self.rollout + pl_module.multi_step - ), "Batch length not sufficient for requested rollout length!" - - with torch.no_grad(): - for rollout_step in range(self.rollout): - y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) - y = batch[ - :, - pl_module.multi_step + rollout_step, - ..., - pl_module.data_indices.internal_data.output.full, - ] # target, shape = (bs, latlon, nvar) - # y includes the auxiliary variables, so we must leave those out when computing the loss - loss += pl_module.loss(y_pred, y) - - x = pl_module.advance_input(x, y_pred, batch, rollout_step) - - metrics_next, _ = pl_module.calculate_val_metrics(y_pred, y, rollout_step) - metrics.update(metrics_next) - - # scale loss - loss *= 1.0 / self.rollout - self._log(pl_module, loss, metrics, batch.shape[0]) - - def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, bs: int) -> None: - pl_module.log( - f"val_r{self.rollout}_wmse", - loss, - on_epoch=True, - on_step=True, - prog_bar=False, - logger=pl_module.logger_enabled, - batch_size=bs, - sync_dist=False, - rank_zero_only=True, - ) - for mname, mvalue in metrics.items(): - pl_module.log( - f"val_r{self.rollout}_" + mname, - mvalue, - on_epoch=True, - on_step=False, - prog_bar=False, - logger=pl_module.logger_enabled, - batch_size=bs, - sync_dist=False, - rank_zero_only=True, - ) - - @rank_zero_only - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list, - batch: torch.Tensor, - batch_idx: int, - ) -> None: - del outputs # outputs are not used - if batch_idx % self.frequency == 0: - precision_mapping = { - "16-mixed": torch.float16, - "bf16-mixed": torch.bfloat16, - } - prec = trainer.precision - dtype = precision_mapping.get(prec) - context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() - - with context: - self._eval(pl_module, batch) - - -class LongRolloutPlots(BasePlotCallback): - """Evaluates the model performance over a (longer) rollout window.""" - - def __init__(self, config) -> None: - """Initialize RolloutEval callback. - - Parameters - ---------- - config : dict - Dictionary with configuration settings - """ - super().__init__(config) - - LOGGER.debug( - "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", - config.diagnostics.plot.longrollout.rollout, - config.diagnostics.plot.longrollout.frequency, - ) - self.rollout = config.diagnostics.plot.longrollout.rollout - self.eval_frequency = config.diagnostics.plot.longrollout.frequency - self.sample_idx = self.config.diagnostics.plot.sample_idx - - @rank_zero_only - def _plot( - self, - trainer, - pl_module: pl.LightningModule, - batch: torch.Tensor, - batch_idx, - epoch, - ) -> None: - - start_time = time.time() - - logger = trainer.logger - - # Build dictionary of inidicies and parameters to be plotted - plot_parameters_dict = { - pl_module.data_indices.model.output.name_to_index[name]: ( - name, - name not in self.config.data.get("diagnostic", []), - ) - for name in self.config.diagnostics.plot.parameters - } - - if self.post_processors is None: - # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() - if self.latlons is None: - self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - - batch = pl_module.model.pre_processors(batch, in_place=False) - # prepare input tensor for rollout from preprocessed batch - x = batch[ - :, - 0 : pl_module.multi_step, - ..., - pl_module.data_indices.internal_data.input.full, - ] # (bs, multi_step, latlon, nvar) - assert ( - batch.shape[1] >= max(self.rollout) + pl_module.multi_step - ), "Batch length not sufficient for requested rollout length!" - - # prepare input tensor for plotting - input_tensor_0 = batch[ - self.sample_idx, - pl_module.multi_step - 1, - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data_0 = self.post_processors(input_tensor_0).numpy() - - # start rollout - with torch.no_grad(): - for rollout_step in range(max(self.rollout)): - y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) - - x = pl_module.advance_input(x, y_pred, batch, rollout_step) - - if (rollout_step + 1) in self.rollout: - # prepare true output tensor for plotting - input_tensor_rollout_step = batch[ - self.sample_idx, - pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy() - - # prepare predicted output tensor for plotting - output_tensor = self.post_processors( - y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu() - ).numpy() - - fig = plot_predicted_multilevel_flat_sample( - plot_parameters_dict, - self.config.diagnostics.plot.per_sample, - self.latlons, - self.config.diagnostics.plot.get("accumulation_levels_plot", None), - self.config.diagnostics.plot.get("cmap_accumulation", None), - data_0.squeeze(), - data_rollout_step.squeeze(), - output_tensor[0, 0, :, :], # rolloutstep, first member - # force_global_view=self.show_entire_globe, - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_sample_rstep{rollout_step:03d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_sample_rstep{rollout_step:03d}_rank{local_rank:01d}", - ) - LOGGER.info(f"Time taken to plot samples after longer rollout: {int(time.time() - start_time)} seconds") - - @rank_zero_only - def on_validation_batch_end(self, trainer, pl_module, output, batch, batch_idx) -> None: - if (batch_idx) % self.plot_frequency == 0 and (trainer.current_epoch + 1) % self.eval_frequency == 0: - precision_mapping = { - "16-mixed": torch.float16, - "bf16-mixed": torch.bfloat16, - } - prec = trainer.precision - dtype = precision_mapping.get(prec) - context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() - - with context: - self._plot(trainer, pl_module, batch, batch_idx, epoch=trainer.current_epoch) - - -class GraphTrainableFeaturesPlot(BasePlotCallback): - """Visualize the trainable features defined at the data and hidden graph nodes. - - TODO: How best to visualize the learned edge embeddings? Offline, perhaps - using code from @Simon's notebook? + E.g. + >>> nestedget(config, "diagnostics.log.wandb.enabled", False) """ + keys = key.split(".") + for k in keys: + conf = conf.get(k, default) + if not isinstance(conf, (dict, DictConfig)): + break + return conf + + +# Callbacks to add according to flags in the config +# Can be function to check status from config +CONFIG_ENABLED_CALLBACKS: list[tuple[list[str] | str | Callable[[DictConfig], bool], type[Callback]]] = [ + ("training.swa.enabled", StochasticWeightAveraging), + ( + lambda config: nestedget(config, "diagnostics.log.wandb.enabled", False) + or nestedget(config, "diagnostics.log.mflow.enabled", False), + LearningRateMonitor, + ), +] + + +def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | None: + """Get checkpointing callback""" + if not config.diagnostics.get("enable_checkpointing", True): + return [] - def __init__(self, config: OmegaConf) -> None: - """Initialise the GraphTrainableFeaturesPlot callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__(config) - self._graph_name_data = config.graph.data - self._graph_name_hidden = config.graph.hidden - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - latlons: np.ndarray, - features: np.ndarray, - epoch: int, - tag: str, - exp_log_tag: str, - ) -> None: - fig = plot_graph_features(latlons, features) - self._output_figure(trainer.logger, fig, epoch=epoch, tag=tag, exp_log_tag=exp_log_tag) - - @rank_zero_only - def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - - model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model - graph = pl_module.graph_data.cpu().detach() - epoch = trainer.current_epoch - - if model.trainable_data is not None: - data_coords = np.rad2deg(graph[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.numpy()) - - self.plot( - trainer, - data_coords, - model.trainable_data.trainable.cpu().detach().numpy(), - epoch=epoch, - tag="trainable_data", - exp_log_tag="trainable_data", - ) - - if model.trainable_hidden is not None: - hidden_coords = np.rad2deg( - graph[(self._graph_name_hidden, "to", self._graph_name_hidden)].hcoords_rad.numpy(), - ) - - self.plot( - trainer, - hidden_coords, - model.trainable_hidden.trainable.cpu().detach().numpy(), - epoch=epoch, - tag="trainable_hidden", - exp_log_tag="trainable_hidden", - ) - - -class PlotLoss(BasePlotCallback): - """Plots the unsqueezed loss over rollouts.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the PlotLoss callback. - - Parameters - ---------- - config : OmegaConf - Object with configuration settings - - """ - super().__init__(config) - self.parameter_names = None - self.parameter_groups = self.config.diagnostics.plot.parameter_groups - if self.parameter_groups is None: - self.parameter_groups = {} - - @cached_property - def sort_and_color_by_parameter_group(self) -> tuple[np.ndarray, np.ndarray, dict, list]: - """Sort parameters by group and prepare colors.""" - - def automatically_determine_group(name: str) -> str: - # first prefix of parameter name is group name - parts = name.split("_") - return parts[0] - - # group parameters by their determined group name for > 15 parameters - if len(self.parameter_names) <= 15: - # for <= 15 parameters, keep the full name of parameters - parameters_to_groups = np.array(self.parameter_names) - sort_by_parameter_group = np.arange(len(self.parameter_names), dtype=int) - else: - parameters_to_groups = np.array( - [ - next( - ( - group_name - for group_name, group_parameters in self.parameter_groups.items() - if name in group_parameters - ), - automatically_determine_group(name), - ) - for name in self.parameter_names - ], - ) - - unique_group_list, group_inverse, group_counts = np.unique( - parameters_to_groups, - return_inverse=True, - return_counts=True, - ) - - # join parameter groups that appear only once and are not given in config-file - unique_group_list = np.array( - [ - unique_group_list[tn] if count > 1 or unique_group_list[tn] in self.parameter_groups else "other" - for tn, count in enumerate(group_counts) - ], - ) - parameters_to_groups = unique_group_list[group_inverse] - unique_group_list, group_inverse = np.unique(parameters_to_groups, return_inverse=True) - - # sort parameters by groups - sort_by_parameter_group = np.argsort(group_inverse, kind="stable") - - # apply new order to parameters - sorted_parameter_names = np.array(self.parameter_names)[sort_by_parameter_group] - parameters_to_groups = parameters_to_groups[sort_by_parameter_group] - unique_group_list, group_inverse, group_counts = np.unique( - parameters_to_groups, - return_inverse=True, - return_counts=True, - ) - - # get a color per group and project to parameter list - cmap = "tab10" if len(unique_group_list) <= 10 else "tab20" - if len(unique_group_list) > 20: - LOGGER.warning("More than 20 groups detected, but colormap has only 20 colors.") - # if all groups have count 1 use black color - bar_color_per_group = ( - np.tile("k", len(group_counts)) - if not np.any(group_counts - 1) - else plt.get_cmap(cmap)(np.linspace(0, 1, len(unique_group_list))) - ) - - # set x-ticks - x_tick_positions = np.cumsum(group_counts) - group_counts / 2 - 0.5 - xticks = dict(zip(unique_group_list, x_tick_positions)) - - legend_patches = [] - for group_idx, group in enumerate(unique_group_list): - text_label = f"{group}: " - string_length = len(text_label) - for ii in np.where(group_inverse == group_idx)[0]: - text_label += sorted_parameter_names[ii] + ", " - string_length += len(sorted_parameter_names[ii]) + 2 - if string_length > 50: - # linebreak after 50 characters - text_label += "\n" - string_length = 0 - legend_patches.append(mpatches.Patch(color=bar_color_per_group[group_idx], label=text_label[:-2])) - - return sort_by_parameter_group, bar_color_per_group[group_inverse], xticks, legend_patches - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - pl_module: pl.Lightning_module, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - epoch: int, - ) -> None: - logger = trainer.logger - - parameter_names = list(pl_module.data_indices.internal_model.output.name_to_index.keys()) - parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values()) - # reorder parameter_names by position - self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] - - if not isinstance(pl_module.loss, BaseWeightedLoss): - logging.warning( - "Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.", RuntimeWarning - ) - - batch = pl_module.model.pre_processors(batch, in_place=False) - for rollout_step in range(pl_module.rollout): - y_hat = outputs[1][rollout_step] - y_true = batch[ - :, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.internal_data.output.full - ] - loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() - - sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group - fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"loss_rstep_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", - exp_log_tag=f"loss_sample_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", - ) - - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) - - -class PlotSample(BasePlotCallback): - """Plots a post-processed sample: input, target and prediction.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the PlotSample callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__(config) - self.sample_idx = self.config.diagnostics.plot.sample_idx - self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields - LOGGER.info(f"Using defined accumulation colormap for fields: {self.precip_and_related_fields}") - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - pl_module: pl.Lightning_module, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - epoch: int, - ) -> None: - logger = trainer.logger - - # Build dictionary of indices and parameters to be plotted - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - plot_parameters_dict = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters - } - - # When running in Async mode, it might happen that in the last epoch these tensors - # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA - # but internal ones would be on the cpu), The lines below allow to address this problem - if self.post_processors is None: - # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() - if self.latlons is None: - self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - - batch = pl_module.model.pre_processors(batch, in_place=False) - input_tensor = batch[ - self.sample_idx, - pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data = self.post_processors(input_tensor) - - output_tensor = self.post_processors( - torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), - in_place=False, - ) - - output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() - data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) - data = data.numpy() - - for rollout_step in range(pl_module.rollout): - fig = plot_predicted_multilevel_flat_sample( - plot_parameters_dict, - self.config.diagnostics.plot.per_sample, - self.latlons, - self.config.diagnostics.plot.accumulation_levels_plot, - self.config.diagnostics.plot.cmap_accumulation, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - precip_and_related_fields=self.precip_and_related_fields, - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", - ) - - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.Lightning_module, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) - - -class PlotAdditionalMetrics(BasePlotCallback): - """Plots TP related metric comparing target and prediction. - - The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. - - - Power Spectrum - - Histograms - """ - - def __init__(self, config: OmegaConf) -> None: - """Initialise the PlotAdditionalMetrics callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__(config) - self.sample_idx = self.config.diagnostics.plot.sample_idx - self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields - LOGGER.info(f"Using precip histogram plotting method for fields: {self.precip_and_related_fields}") - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list, - batch: torch.Tensor, - batch_idx: int, - epoch: int, - ) -> None: - logger = trainer.logger - - # When running in Async mode, it might happen that in the last epoch these tensors - # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA - # but internal ones would be on the cpu), The lines below allow to address this problem - if self.pre_processors is None: - # Copy to be used across all the training cycle - self.pre_processors = copy.deepcopy(pl_module.model.pre_processors).cpu() - if self.post_processors is None: - # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() - if self.latlons is None: - self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - batch = pl_module.model.pre_processors(batch, in_place=False) - input_tensor = batch[ - self.sample_idx, - pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data = self.post_processors(input_tensor) - output_tensor = self.post_processors( - torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), - in_place=False, - ) - - output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() - data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) - data = data.numpy() - - for rollout_step in range(pl_module.rollout): - if self.config.diagnostics.plot.parameters_histogram is not None: - # Build dictionary of inidicies and parameters to be plotted - - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - plot_parameters_dict_histogram = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters_histogram - } - - fig = plot_histogram( - plot_parameters_dict_histogram, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - precip_and_related_fields=self.precip_and_related_fields, - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", - ) - - if self.config.diagnostics.plot.parameters_spectrum is not None: - # Build dictionary of inidicies and parameters to be plotted - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - - plot_parameters_dict_spectrum = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters_spectrum - } - - fig = plot_power_spectrum( - plot_parameters_dict_spectrum, - self.latlons, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", - ) - - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) - - -class ParentUUIDCallback(Callback): - """A callback that retrieves the parent UUID for a model, if it is a child model.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the ParentUUIDCallback callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__() - self.config = config - - def on_load_checkpoint( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - checkpoint: torch.nn.Module, - ) -> None: - del trainer # unused - pl_module.hparams["metadata"]["parent_uuid"] = checkpoint["hyper_parameters"]["metadata"]["uuid"] - - -class AnemoiCheckpoint(ModelCheckpoint): - """A checkpoint callback that saves the model after every validation epoch.""" - - def __init__(self, config: OmegaConf, **kwargs: dict) -> None: - """Initialise the AnemoiCheckpoint callback. - - Parameters - ---------- - config : OmegaConf - Config object - kwargs : dict - Additional keyword arguments for Pytorch ModelCheckpoint - - """ - super().__init__(**kwargs) - self.config = config - self.start = time.time() - self._model_metadata = None - self._tracker_metadata = None - self._tracker_name = None - - @staticmethod - def _torch_drop_down(trainer: pl.Trainer) -> torch.nn.Module: - # Get the model from the DataParallel wrapper, for single and multi-gpu cases - assert hasattr(trainer, "model"), "Trainer has no attribute 'model'! Is the Pytorch Lightning version correct?" - return trainer.model.module.model if hasattr(trainer.model, "module") else trainer.model.model - - @rank_zero_only - def model_metadata(self, model: torch.nn.Module) -> dict: - if self._model_metadata is not None: - return self._model_metadata - - self._model_metadata = { - "model": model.__class__.__name__, - "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad), - "total_parameters": sum(p.numel() for p in model.parameters()), - "summary": repr( - torchinfo.summary( - model, - depth=50, - verbose=0, - row_settings=["var_names"], - ), - ), - } - - return self._model_metadata - - def tracker_metadata(self, trainer: pl.Trainer) -> dict: - if self._tracker_metadata is not None: - return {self._tracker_name: self._tracker_metadata} - - if self.config.diagnostics.log.wandb.enabled: - self._tracker_name = "wand" - import wandb - - run = wandb.run - if run is not None: - self._tracker_metadata = { - "id": run.id, - "name": run.name, - "url": run.url, - "project": run.project, - } - return {self._tracker_name: self._tracker_metadata} - - if self.config.diagnostics.log.mlflow.enabled: - self._tracker_name = "mlflow" - - from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger - - mlflow_logger = next(logger for logger in trainer.loggers if isinstance(logger, AnemoiMLflowLogger)) - run_id = mlflow_logger.run_id - run = mlflow_logger._mlflow_client.get_run(run_id) - - if run is not None: - self._tracker_metadata = { - "id": run.info.run_id, - "name": run.info.run_name, - "url": run.info.artifact_uri, - "project": run.info.experiment_id, - } - return {self._tracker_name: self._tracker_metadata} - - return {} - - def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: - """Calls the strategy to remove the checkpoint file.""" - super()._remove_checkpoint(trainer, filepath) - trainer.strategy.remove_checkpoint(self._get_inference_checkpoint_filepath(filepath)) - - def _get_inference_checkpoint_filepath(self, filepath: str) -> str: - """Defines the filepath for the inference checkpoint.""" - return Path(filepath).parent / Path("inference-" + str(Path(filepath).name)) - - def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: str) -> None: - if trainer.is_global_zero: - model = self._torch_drop_down(trainer) - - # We want a different uuid each time we save the model - # so we can tell them apart in the catalogue (i.e. different epochs) - checkpoint_uuid = str(uuid.uuid4()) - trainer.lightning_module._hparams["metadata"]["uuid"] = checkpoint_uuid - - trainer.lightning_module._hparams["metadata"]["model"] = self.model_metadata(model) - trainer.lightning_module._hparams["metadata"]["tracker"] = self.tracker_metadata(trainer) - - trainer.lightning_module._hparams["metadata"]["training"] = { - "current_epoch": trainer.current_epoch, - "global_step": trainer.global_step, - "elapsed_time": time.time() - self.start, - } - - Path(lightning_checkpoint_filepath).parent.mkdir(parents=True, exist_ok=True) - - save_config = model.config - model.config = None - - tmp_metadata = model.metadata - model.metadata = None - - metadata = dict(**tmp_metadata) - - inference_checkpoint_filepath = self._get_inference_checkpoint_filepath(lightning_checkpoint_filepath) - - torch.save(model, inference_checkpoint_filepath) - - save_metadata(inference_checkpoint_filepath, metadata) - - model.config = save_config - model.metadata = tmp_metadata - - self._last_global_step_saved = trainer.global_step - - trainer.strategy.barrier() - - # saving checkpoint used for pytorch-lightning based training - trainer.save_checkpoint(lightning_checkpoint_filepath, self.save_weights_only) - - self._last_global_step_saved = trainer.global_step - self._last_checkpoint_saved = lightning_checkpoint_filepath - - if trainer.is_global_zero: - from weakref import proxy - - # save metadata for the training checkpoint in the same format as inference - save_metadata(lightning_checkpoint_filepath, metadata) - - # notify loggers - for logger in trainer.loggers: - logger.after_save_checkpoint(proxy(self)) - - -def get_callbacks(config: DictConfig) -> list: # noqa: C901 - """Setup callbacks for PyTorch Lightning trainer. - - Parameters - ---------- - config : DictConfig - Job configuration - - Returns - ------- - List - A list of PyTorch Lightning callbacks - - """ checkpoint_settings = { "dirpath": config.hardware.paths.checkpoints, "verbose": False, @@ -1071,6 +74,7 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901 } ckpt_frequency_save_dict = {} + for key, frequency_dict in config.diagnostics.checkpoint.items(): frequency = frequency_dict["save_frequency"] n_saved = frequency_dict["num_models_saved"] @@ -1079,14 +83,21 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901 frequency = timedelta(minutes=frequency_dict["save_frequency"]) else: target = key - ckpt_frequency_save_dict[target] = (config.hardware.files.checkpoint[key], frequency, n_saved) + ckpt_frequency_save_dict[target] = ( + config.hardware.files.checkpoint[key], + frequency, + n_saved, + ) - trainer_callbacks = [] if not config.diagnostics.profiler: - for save_key, (name, save_frequency, save_n_models) in ckpt_frequency_save_dict.items(): + for save_key, ( + name, + save_frequency, + save_n_models, + ) in ckpt_frequency_save_dict.items(): if save_frequency is not None: LOGGER.debug("Checkpoint callback at %s = %s ...", save_key, save_frequency) - trainer_callbacks.extend( + return ( # save_top_k: the save_top_k flag can either save the best or the last k checkpoints # depending on the monitor flag on ModelCheckpoint. # See https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html for reference @@ -1102,58 +113,99 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901 mode="max", **checkpoint_settings, ), - ], + ] ) else: LOGGER.debug("Not setting up a checkpoint callback with %s", save_key) else: # the tensorboard logger + pytorch profiler cause pickling errors when writing checkpoints LOGGER.warning("Profiling is enabled - will not write any training or inference model checkpoints!") + return None - if any([config.diagnostics.log.wandb.enabled, config.diagnostics.log.mlflow.enabled]): - from pytorch_lightning.callbacks import LearningRateMonitor - trainer_callbacks.append( - LearningRateMonitor( - logging_interval="step", - log_momentum=False, - ), - ) +def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: + """Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS - if config.diagnostics.eval.enabled: - trainer_callbacks.append(RolloutEval(config)) + Provides backwards compatibility + """ + callbacks = [] - if config.diagnostics.plot.enabled: - trainer_callbacks.extend( - [ - PlotLoss(config), - PlotSample(config), - ], - ) - if (config.diagnostics.plot.parameters_histogram or config.diagnostics.plot.parameters_spectrum) is not None: - trainer_callbacks.extend([PlotAdditionalMetrics(config)]) - if config.diagnostics.plot.get("longrollout") and config.diagnostics.plot.longrollout.enabled: - trainer_callbacks.extend([LongRolloutPlots(config)]) + def check_key(config, key: str | Iterable[str] | Callable[[DictConfig], bool]): + """Check key in config.""" + if isinstance(key, Callable): + return key(config) + elif isinstance(key, str): + return nestedget(config, key, False) + elif isinstance(key, Iterable): + return all(nestedget(config, k, False) for k in key) + return nestedget(config, key, False) - if config.training.swa.enabled: - from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging + for enable_key, callback_list in CONFIG_ENABLED_CALLBACKS: + if check_key(config, enable_key): + callbacks.append(callback_list(config)) - trainer_callbacks.append( - StochasticWeightAveraging( - swa_lrs=config.training.swa.lr, - swa_epoch_start=min( - int(0.75 * config.training.max_epochs), - config.training.max_epochs - 1, - ), - annealing_epochs=max(int(0.25 * config.training.max_epochs), 1), - annealing_strategy="cos", - device=None, - ), - ) + return callbacks + + +def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901 + """Setup callbacks for PyTorch Lightning trainer. + + Set `config.diagnostics.callbacks` to a list of callback configurations + in hydra form. + + E.g.: + ``` + callbacks: + - _target_: anemoi.training.diagnostics.callbacks.RolloutEval + rollout: 1 + frequency: 12 + ``` + + Set `config.diagnostics.plot.callbacks` to a list of plot callback configurations + will only be added if `config.diagnostics.plot.enabled` is set to True. + + A callback must take a `DictConfig` in its `__init__` method as the first argument, + which will be the complete configuration object. + + Some callbacks are added by default, depending on the configuration. + See CONFIG_ENABLED_CALLBACKS for more information. + + Parameters + ---------- + config : DictConfig + Job configuration + + Returns + ------- + List[Callback] + A list of PyTorch Lightning callbacks + + """ + trainer_callbacks: list[Callback] = [] + + # Get Checkpoint callback + checkpoint_callback = _get_checkpoint_callback(config) + if checkpoint_callback is not None: + trainer_callbacks.extend(checkpoint_callback) + + # Base callbacks + for callback in config.diagnostics.get("callbacks", None) or []: + # Instantiate new callbacks + trainer_callbacks.append(instantiate(callback, config)) + + # Plotting callbacks + for callback in config.diagnostics.plot.get("callbacks", None) or []: + # Instantiate new callbacks + trainer_callbacks.append(instantiate(callback, config)) + + # Extend with config enabled callbacks + trainer_callbacks.extend(_get_config_enabled_callbacks(config)) + + # Parent UUID callback trainer_callbacks.append(ParentUUIDCallback(config)) - if config.diagnostics.plot.learned_features: - LOGGER.debug("Setting up a callback to plot the trainable graph node features ...") - trainer_callbacks.append(GraphTrainableFeaturesPlot(config)) return trainer_callbacks + + +__all__ = ["get_callbacks"] diff --git a/src/anemoi/training/diagnostics/callbacks/checkpoint.py b/src/anemoi/training/diagnostics/callbacks/checkpoint.py new file mode 100644 index 00000000..cb95f5a4 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/checkpoint.py @@ -0,0 +1,181 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +import time +import uuid +from pathlib import Path +from typing import TYPE_CHECKING + +import torch +import torchinfo +from anemoi.utils.checkpoints import save_metadata +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.utilities import rank_zero_only + +if TYPE_CHECKING: + import pytorch_lightning as pl + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class AnemoiCheckpoint(ModelCheckpoint): + """A checkpoint callback that saves the model after every validation epoch.""" + + def __init__(self, config: OmegaConf, **kwargs: dict) -> None: + """Initialise the AnemoiCheckpoint callback. + + Parameters + ---------- + config : OmegaConf + Config object + kwargs : dict + Additional keyword arguments for Pytorch ModelCheckpoint + + """ + super().__init__(**kwargs) + self.config = config + self.start = time.time() + self._model_metadata = None + self._tracker_metadata = None + self._tracker_name = None + + @staticmethod + def _torch_drop_down(trainer: pl.Trainer) -> torch.nn.Module: + # Get the model from the DataParallel wrapper, for single and multi-gpu cases + assert hasattr(trainer, "model"), "Trainer has no attribute 'model'! Is the Pytorch Lightning version correct?" + return trainer.model.module.model if hasattr(trainer.model, "module") else trainer.model.model + + @rank_zero_only + def model_metadata(self, model: torch.nn.Module) -> dict: + if self._model_metadata is not None: + return self._model_metadata + + self._model_metadata = { + "model": model.__class__.__name__, + "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad), + "total_parameters": sum(p.numel() for p in model.parameters()), + "summary": repr( + torchinfo.summary( + model, + depth=50, + verbose=0, + row_settings=["var_names"], + ), + ), + } + + return self._model_metadata + + def tracker_metadata(self, trainer: pl.Trainer) -> dict: + if self._tracker_metadata is not None: + return {self._tracker_name: self._tracker_metadata} + + if self.config.diagnostics.log.wandb.enabled: + self._tracker_name = "wand" + import wandb + + run = wandb.run + if run is not None: + self._tracker_metadata = { + "id": run.id, + "name": run.name, + "url": run.url, + "project": run.project, + } + return {self._tracker_name: self._tracker_metadata} + + if self.config.diagnostics.log.mlflow.enabled: + self._tracker_name = "mlflow" + + from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger + + mlflow_logger = next(logger for logger in trainer.loggers if isinstance(logger, AnemoiMLflowLogger)) + run_id = mlflow_logger.run_id + run = mlflow_logger._mlflow_client.get_run(run_id) + + if run is not None: + self._tracker_metadata = { + "id": run.info.run_id, + "name": run.info.run_name, + "url": run.info.artifact_uri, + "project": run.info.experiment_id, + } + return {self._tracker_name: self._tracker_metadata} + + return {} + + def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: + """Calls the strategy to remove the checkpoint file.""" + super()._remove_checkpoint(trainer, filepath) + trainer.strategy.remove_checkpoint(self._get_inference_checkpoint_filepath(filepath)) + + def _get_inference_checkpoint_filepath(self, filepath: str) -> str: + """Defines the filepath for the inference checkpoint.""" + return Path(filepath).parent / Path("inference-" + str(Path(filepath).name)) + + def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: str) -> None: + if trainer.is_global_zero: + model = self._torch_drop_down(trainer) + + # We want a different uuid each time we save the model + # so we can tell them apart in the catalogue (i.e. different epochs) + checkpoint_uuid = str(uuid.uuid4()) + trainer.lightning_module._hparams["metadata"]["uuid"] = checkpoint_uuid + + trainer.lightning_module._hparams["metadata"]["model"] = self.model_metadata(model) + trainer.lightning_module._hparams["metadata"]["tracker"] = self.tracker_metadata(trainer) + + trainer.lightning_module._hparams["metadata"]["training"] = { + "current_epoch": trainer.current_epoch, + "global_step": trainer.global_step, + "elapsed_time": time.time() - self.start, + } + + Path(lightning_checkpoint_filepath).parent.mkdir(parents=True, exist_ok=True) + + save_config = model.config + model.config = None + + tmp_metadata = model.metadata + model.metadata = None + + metadata = dict(**tmp_metadata) + + inference_checkpoint_filepath = self._get_inference_checkpoint_filepath(lightning_checkpoint_filepath) + + torch.save(model, inference_checkpoint_filepath) + + save_metadata(inference_checkpoint_filepath, metadata) + + model.config = save_config + model.metadata = tmp_metadata + + self._last_global_step_saved = trainer.global_step + + trainer.strategy.barrier() + + # saving checkpoint used for pytorch-lightning based training + trainer.save_checkpoint(lightning_checkpoint_filepath, self.save_weights_only) + + self._last_global_step_saved = trainer.global_step + self._last_checkpoint_saved = lightning_checkpoint_filepath + + if trainer.is_global_zero: + from weakref import proxy + + # save metadata for the training checkpoint in the same format as inference + save_metadata(lightning_checkpoint_filepath, metadata) + + # notify loggers + for logger in trainer.loggers: + logger.after_save_checkpoint(proxy(self)) diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py new file mode 100644 index 00000000..6873918a --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -0,0 +1,126 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import torch +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities import rank_zero_only + +if TYPE_CHECKING: + import pytorch_lightning as pl + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class RolloutEval(Callback): + """Evaluates the model performance over a (longer) rollout window.""" + + def __init__(self, config: OmegaConf, rollout: int, frequency: int) -> None: + """Initialize RolloutEval callback. + + Parameters + ---------- + config : dict + Dictionary with configuration settings + rollout : int + Rollout length for evaluation + frequency : int + Frequency of evaluation, per batch + + """ + super().__init__() + self.config = config + + LOGGER.debug( + "Setting up RolloutEval callback with rollout = %d, frequency = %d ...", + rollout, + frequency, + ) + self.rollout = rollout + self.frequency = frequency + + def _eval( + self, + pl_module: pl.LightningModule, + batch: torch.Tensor, + ) -> None: + loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False) + metrics = {} + + assert batch.shape[1] >= self.rollout + pl_module.multi_step, ( + "Batch length not sufficient for requested validation rollout length! " + f"Set `dataloader.validation_rollout` to at least {max(self.rollout)}" + ) + + with torch.no_grad(): + for loss_next, metrics_next, _ in pl_module.rollout_step( + batch, + rollout=self.rollout, + validation_mode=True, + training_mode=True, + ): + loss += loss_next + metrics.update(metrics_next) + + # scale loss + loss *= 1.0 / self.rollout + self._log(pl_module, loss, metrics, batch.shape[0]) + + def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, bs: int) -> None: + pl_module.log( + f"val_r{self.rollout}_{getattr(pl_module.loss, 'name', pl_module.loss.__class__.__name__.lower())}", + loss, + on_epoch=True, + on_step=True, + prog_bar=False, + logger=pl_module.logger_enabled, + batch_size=bs, + sync_dist=False, + rank_zero_only=True, + ) + for mname, mvalue in metrics.items(): + pl_module.log( + f"val_r{self.rollout}_" + mname, + mvalue, + on_epoch=True, + on_step=False, + prog_bar=False, + logger=pl_module.logger_enabled, + batch_size=bs, + sync_dist=False, + rank_zero_only=True, + ) + + @rank_zero_only + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list, + batch: torch.Tensor, + batch_idx: int, + ) -> None: + del outputs # outputs are not used + if batch_idx % self.frequency == 0: + precision_mapping = { + "16-mixed": torch.float16, + "bf16-mixed": torch.bfloat16, + } + prec = trainer.precision + dtype = precision_mapping.get(prec) + context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() + + with context: + self._eval(pl_module, batch) diff --git a/src/anemoi/training/diagnostics/callbacks/optimiser.py b/src/anemoi/training/diagnostics/callbacks/optimiser.py new file mode 100644 index 00000000..bff82fcb --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/optimiser.py @@ -0,0 +1,77 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from pytorch_lightning.callbacks import LearningRateMonitor as pl_LearningRateMonitor +from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging as pl_StochasticWeightAveraging + +LOGGER = logging.getLogger(__name__) + +if TYPE_CHECKING: + from omegaconf import DictConfig + + +class LearningRateMonitor(pl_LearningRateMonitor): + """Provide LearningRateMonitor from pytorch_lightning as a callback.""" + + def __init__( + self, + config: DictConfig, + logging_interval: str = "step", + log_momentum: bool = False, + ) -> None: + super().__init__(logging_interval=logging_interval, log_momentum=log_momentum) + self.config = config + + +class StochasticWeightAveraging(pl_StochasticWeightAveraging): + """Provide StochasticWeightAveraging from pytorch_lightning as a callback.""" + + def __init__( + self, + config: DictConfig, + swa_lrs: int | None = None, + swa_epoch_start: int | None = None, + annealing_epoch: int | None = None, + annealing_strategy: str | None = None, + device: str | None = None, + **kwargs, + ) -> None: + """Stochastic Weight Averaging Callback. + + Parameters + ---------- + config : OmegaConf + Full configuration object + swa_lrs : int, optional + Stochastic Weight Averaging Learning Rate, by default None + swa_epoch_start : int, optional + Epoch start, by default 0.75 * config.training.max_epochs + annealing_epoch : int, optional + Annealing Epoch, by default 0.25 * config.training.max_epochs + annealing_strategy : str, optional + Annealing Strategy, by default 'cos' + device : str, optional + Device to use, by default None + """ + kwargs["swa_lrs"] = swa_lrs or config.training.swa.lr + kwargs["swa_epoch_start"] = swa_epoch_start or min( + int(0.75 * config.training.max_epochs), + config.training.max_epochs - 1, + ) + kwargs["annealing_epoch"] = annealing_epoch or max(int(0.25 * config.training.max_epochs), 1) + kwargs["annealing_strategy"] = annealing_strategy or "cos" + kwargs["device"] = device + + super().__init__(**kwargs) + self.config = config diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py new file mode 100644 index 00000000..98d16dc3 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -0,0 +1,996 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +# ruff: noqa: ANN001 + +from __future__ import annotations + +import copy +import logging +import sys +import time +import traceback +from abc import ABC +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +import torch +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities import rank_zero_only + +from anemoi.training.diagnostics.plots import init_plot_settings +from anemoi.training.diagnostics.plots import plot_graph_edge_features +from anemoi.training.diagnostics.plots import plot_graph_node_features +from anemoi.training.diagnostics.plots import plot_histogram +from anemoi.training.diagnostics.plots import plot_loss +from anemoi.training.diagnostics.plots import plot_power_spectrum +from anemoi.training.diagnostics.plots import plot_predicted_multilevel_flat_sample +from anemoi.training.losses.weightedloss import BaseWeightedLoss + +if TYPE_CHECKING: + import pytorch_lightning as pl + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class ParallelExecutor(ThreadPoolExecutor): + """Wraps parallel execution and provides accurate information about errors. + + Extends ThreadPoolExecutor to preserve the original traceback and line number. + + Reference: https://stackoverflow.com/questions/19309514/getting-original-line- + number-for-exception-in-concurrent-futures/24457608#24457608 + """ + + def submit(self, fn: Any, *args, **kwargs) -> Callable: + """Submits the wrapped function instead of `fn`.""" + return super().submit(self._function_wrapper, fn, *args, **kwargs) + + def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable: + """Wraps `fn` in order to preserve the traceback of any kind of.""" + try: + return fn(*args, **kwargs) + except Exception as exc: + raise sys.exc_info()[0](traceback.format_exc()) from exc + + +class BasePlotCallback(Callback, ABC): + """Factory for creating a callback that plots data to Experiment Logging.""" + + def __init__(self, config: OmegaConf) -> None: + """Initialise the BasePlotCallback abstract base class. + + Parameters + ---------- + config : OmegaConf + Config object + + """ + super().__init__() + self.config = config + self.save_basedir = config.hardware.paths.plots + + self.post_processors = None + self.pre_processors = None + self.latlons = None + init_plot_settings() + + self.plot = self._plot + self._executor = None + + if self.config.diagnostics.plot.asynchronous: + self._executor = ParallelExecutor(max_workers=1) + self._error: BaseException | None = None + self.plot = self._async_plot + + @rank_zero_only + def _output_figure( + self, + logger: pl.loggers.base.LightningLoggerBase, + fig: plt.Figure, + epoch: int, + tag: str = "gnn", + exp_log_tag: str = "val_pred_sample", + ) -> None: + """Figure output: save to file and/or display in notebook.""" + if self.save_basedir is not None: + save_path = Path( + self.save_basedir, + "plots", + f"{tag}_epoch{epoch:03d}.png", + ) + + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=100, bbox_inches="tight") + if self.config.diagnostics.log.wandb.enabled: + import wandb + + logger.experiment.log({exp_log_tag: wandb.Image(fig)}) + + if self.config.diagnostics.log.mlflow.enabled: + run_id = logger.run_id + logger.experiment.log_artifact(run_id, str(save_path)) + + plt.close(fig) # cleanup + + def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + """Method is called to close the threads.""" + del trainer, pl_module, stage # unused + if self._executor is not None: + self._executor.shutdown(wait=True) + + def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor: + if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None: + # Fill with NaNs values where the mask is False + data[:, :, ~pl_module.output_mask, :] = np.nan + return data + + @abstractmethod + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + *args, + **kwargs, + ) -> None: + """Plotting function to be implemented by subclasses.""" + + @rank_zero_only + def _async_plot( + self, + trainer: pl.Trainer, + *args: list, + **kwargs: dict, + ) -> None: + """To execute the plot function but ensuring we catch any errors.""" + future = self._executor.submit( + self._plot, + trainer, + *args, + **kwargs, + ) + # otherwise the error won't be thrown till the validation epoch is finished + try: + future.result() + except Exception: + LOGGER.exception("Critical error occurred in asynchronous plots.") + sys.exit(1) + + +class BasePerBatchPlotCallback(BasePlotCallback): + """Base Callback for plotting at the end of each batch.""" + + def __init__(self, config: OmegaConf, batch_frequency: int | None = None): + """Initialise the BasePerBatchPlotCallback. + + Parameters + ---------- + config : OmegaConf + Config object + batch_frequency : int, optional + Batch Frequency to plot at, by default None + If not given, uses default from config at `diagnostics.plot.frequency.batch` + + """ + super().__init__(config) + self.batch_frequency = batch_frequency or self.config.diagnostics.plot.frequency.batch + + @abstractmethod + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + epoch: int, + **kwargs, + ) -> None: + """Plotting function to be implemented by subclasses.""" + + @rank_zero_only + def on_validation_batch_end( + self, + trainer, + pl_module, + output, + batch: torch.Tensor, + batch_idx: int, + **kwargs, + ) -> None: + if batch_idx % self.batch_frequency == 0: + self.plot( + trainer, + pl_module, + output, + batch, + batch_idx, + epoch=trainer.current_epoch, + **kwargs, + ) + + +class BasePerEpochPlotCallback(BasePlotCallback): + """Base Callback for plotting at the end of each epoch.""" + + def __init__(self, config: OmegaConf, epoch_frequency: int | None = None): + """Initialise the BasePerEpochPlotCallback. + + Parameters + ---------- + config : OmegaConf + Config object + epoch_frequency : int, optional + Epoch frequency to plot at, by default None + If not given, uses default from config at `diagnostics.plot.frequency.epoch` + """ + super().__init__(config) + self.epoch_frequency = epoch_frequency or self.config.diagnostics.plot.frequency.epoch + + @rank_zero_only + def on_validation_epoch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + **kwargs, + ) -> None: + if trainer.current_epoch % self.epoch_frequency == 0: + self.plot(trainer, pl_module, epoch=trainer.current_epoch, **kwargs) + + +class LongRolloutPlots(BasePlotCallback): + """Evaluates the model performance over a (longer) rollout window.""" + + def __init__( + self, + config: OmegaConf, + rollout: list[int], + sample_idx: int, + parameters: list[str], + accumulation_levels_plot: list[float] | None = None, + cmap_accumulation: list[str] | None = None, + per_sample: int = 6, + epoch_frequency: int = 1, + ) -> None: + """Initialise LongRolloutPlots callback. + + Parameters + ---------- + config : OmegaConf + Config object + rollout : list[int] + Rollout steps to plot at + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + accumulation_levels_plot : list[float] | None + Accumulation levels to plot, by default None + cmap_accumulation : list[str] | None + Colors of the accumulation levels, by default None + per_sample : int, optional + Number of plots per sample, by default 6 + epoch_frequency : int, optional + Epoch frequency to plot at, by default 1 + """ + super().__init__(config) + + self.epoch_frequency = epoch_frequency + + LOGGER.debug( + "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", + rollout, + epoch_frequency, + ) + self.rollout = rollout + self.sample_idx = sample_idx + self.accumulation_levels_plot = accumulation_levels_plot + self.cmap_accumulation = cmap_accumulation + self.per_sample = per_sample + self.parameters = parameters + + @rank_zero_only + def _plot( + self, + trainer, + pl_module: pl.LightningModule, + output: list[torch.Tensor], + batch: torch.Tensor, + batch_idx, + epoch, + ) -> None: + _ = output + + start_time = time.time() + + logger = trainer.logger + + # Build dictionary of inidicies and parameters to be plotted + plot_parameters_dict = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in self.config.data.get("diagnostic", []), + ) + for name in self.parameters + } + + if self.post_processors is None: + # Copy to be used across all the training cycle + self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + local_rank = pl_module.local_rank + + assert batch.shape[1] >= max(self.rollout) + pl_module.multi_step, ( + "Batch length not sufficient for requested validation rollout length! " + f"Set `dataloader.validation_rollout` to at least {max(self.rollout)}" + ) + + # prepare input tensor for plotting + input_batch = pl_module.model.pre_processors(batch, in_place=False) + input_tensor_0 = input_batch[ + self.sample_idx, + pl_module.multi_step - 1, + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data_0 = self.post_processors(input_tensor_0).numpy() + + # start rollout + with torch.no_grad(): + for rollout_step, (_, _, y_pred) in enumerate( + pl_module.rollout_step( + batch, + rollout=max(self.rollout), + validation_mode=False, + training_mode=False, + ), + ): + + if (rollout_step + 1) in self.rollout: + # prepare true output tensor for plotting + input_tensor_rollout_step = input_batch[ + self.sample_idx, + pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy() + + # prepare predicted output tensor for plotting + output_tensor = self.post_processors( + y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu(), + ).numpy() + + fig = plot_predicted_multilevel_flat_sample( + plot_parameters_dict, + self.per_sample, + self.latlons, + self.accumulation_levels_plot, + self.cmap_accumulation, + data_0.squeeze(), + data_rollout_step.squeeze(), + output_tensor[0, 0, :, :], # rolloutstep, first member + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_sample_rstep{rollout_step + 1:03d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_rstep{rollout_step + 1:03d}_rank{local_rank:01d}", + ) + LOGGER.info( + "Time taken to plot samples after longer rollout: %s seconds", + int(time.time() - start_time), + ) + + @rank_zero_only + def on_validation_batch_end( + self, + trainer, + pl_module, + output, + batch: torch.Tensor, + batch_idx: int, + ) -> None: + if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.epoch_frequency == 0: + precision_mapping = { + "16-mixed": torch.float16, + "bf16-mixed": torch.bfloat16, + } + prec = trainer.precision + dtype = precision_mapping.get(prec) + context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() + + if self.config.diagnostics.plot.asynchronous: + LOGGER.warning("Asynchronous plotting not supported for long rollout plots.") + + with context: + # Issue with running asyncronously, so call the plot function directly + self._plot(trainer, pl_module, output, batch, batch_idx, trainer.current_epoch) + + +class GraphNodeTrainableFeaturesPlot(BasePerEpochPlotCallback): + """Visualize the node trainable features defined.""" + + def __init__(self, config: OmegaConf, epoch_frequency: int | None = None) -> None: + """Initialise the GraphTrainableFeaturesPlot callback. + + Parameters + ---------- + config : OmegaConf + Config object + epoch_frequency: int | None, optional + Override for frequency to plot at, by default None + """ + super().__init__(config, epoch_frequency=epoch_frequency) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + epoch: int, + ) -> None: + _ = epoch + model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model + + fig = plot_graph_node_features(model) + + tag = "node_trainable_params" + exp_log_tag = "node_trainable_params" + + self._output_figure( + trainer.logger, + fig, + epoch=trainer.current_epoch, + tag=tag, + exp_log_tag=exp_log_tag, + ) + + +class GraphEdgeTrainableFeaturesPlot(BasePerEpochPlotCallback): + """Trainable edge features plot. + + Visualize the trainable features defined at the edges between meshes. + """ + + def __init__(self, config: OmegaConf, epoch_frequency: int | None = None) -> None: + """Plot trainable edge features. + + Parameters + ---------- + config : OmegaConf + Config object + epoch_frequency : int | None, optional + Override for frequency to plot at, by default None + """ + super().__init__(config, epoch_frequency=epoch_frequency) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + epoch: int, + ) -> None: + _ = epoch + + model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model + fig = plot_graph_edge_features(model) + + tag = "edge_trainable_params" + exp_log_tag = "edge_trainable_params" + + self._output_figure( + trainer.logger, + fig, + epoch=trainer.current_epoch, + tag=tag, + exp_log_tag=exp_log_tag, + ) + + +class PlotLoss(BasePerBatchPlotCallback): + """Plots the unsqueezed loss over rollouts.""" + + def __init__( + self, + config: OmegaConf, + parameter_groups: dict[dict[str, list[str]]], + batch_frequency: int | None = None, + ) -> None: + """Initialise the PlotLoss callback. + + Parameters + ---------- + config : OmegaConf + Object with configuration settings + parameter_groups : dict + Dictionary with parameter groups with parameter names as keys + batch_frequency : int, optional + Override for batch frequency, by default None + + """ + super().__init__(config, batch_frequency=batch_frequency) + self.parameter_names = None + self.parameter_groups = parameter_groups + if self.parameter_groups is None: + self.parameter_groups = {} + + @cached_property + def sort_and_color_by_parameter_group( + self, + ) -> tuple[np.ndarray, np.ndarray, dict, list]: + """Sort parameters by group and prepare colors.""" + + def automatically_determine_group(name: str) -> str: + # first prefix of parameter name is group name + parts = name.split("_") + return parts[0] + + # group parameters by their determined group name for > 15 parameters + if len(self.parameter_names) <= 15: + # for <= 15 parameters, keep the full name of parameters + parameters_to_groups = np.array(self.parameter_names) + sort_by_parameter_group = np.arange(len(self.parameter_names), dtype=int) + else: + parameters_to_groups = np.array( + [ + next( + ( + group_name + for group_name, group_parameters in self.parameter_groups.items() + if name in group_parameters + ), + automatically_determine_group(name), + ) + for name in self.parameter_names + ], + ) + + unique_group_list, group_inverse, group_counts = np.unique( + parameters_to_groups, + return_inverse=True, + return_counts=True, + ) + + # join parameter groups that appear only once and are not given in config-file + unique_group_list = np.array( + [ + (unique_group_list[tn] if count > 1 or unique_group_list[tn] in self.parameter_groups else "other") + for tn, count in enumerate(group_counts) + ], + ) + parameters_to_groups = unique_group_list[group_inverse] + unique_group_list, group_inverse = np.unique(parameters_to_groups, return_inverse=True) + + # sort parameters by groups + sort_by_parameter_group = np.argsort(group_inverse, kind="stable") + + # apply new order to parameters + sorted_parameter_names = np.array(self.parameter_names)[sort_by_parameter_group] + parameters_to_groups = parameters_to_groups[sort_by_parameter_group] + unique_group_list, group_inverse, group_counts = np.unique( + parameters_to_groups, + return_inverse=True, + return_counts=True, + ) + + # get a color per group and project to parameter list + cmap = "tab10" if len(unique_group_list) <= 10 else "tab20" + if len(unique_group_list) > 20: + LOGGER.warning("More than 20 groups detected, but colormap has only 20 colors.") + # if all groups have count 1 use black color + bar_color_per_group = ( + np.tile("k", len(group_counts)) + if not np.any(group_counts - 1) + else plt.get_cmap(cmap)(np.linspace(0, 1, len(unique_group_list))) + ) + + # set x-ticks + x_tick_positions = np.cumsum(group_counts) - group_counts / 2 - 0.5 + xticks = dict(zip(unique_group_list, x_tick_positions)) + + legend_patches = [] + for group_idx, group in enumerate(unique_group_list): + text_label = f"{group}: " + string_length = len(text_label) + for ii in np.where(group_inverse == group_idx)[0]: + text_label += sorted_parameter_names[ii] + ", " + string_length += len(sorted_parameter_names[ii]) + 2 + if string_length > 50: + # linebreak after 50 characters + text_label += "\n" + string_length = 0 + legend_patches.append(mpatches.Patch(color=bar_color_per_group[group_idx], label=text_label[:-2])) + + return ( + sort_by_parameter_group, + bar_color_per_group[group_inverse], + xticks, + legend_patches, + ) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.Lightning_module, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + _ = batch_idx + + parameter_names = list(pl_module.data_indices.internal_model.output.name_to_index.keys()) + parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values()) + # reorder parameter_names by position + self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] + if not isinstance(pl_module.loss, BaseWeightedLoss): + logging.warning( + "Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.", + RuntimeWarning, + ) + + batch = pl_module.model.pre_processors(batch, in_place=False) + for rollout_step in range(pl_module.rollout): + y_hat = outputs[1][rollout_step] + y_true = batch[ + :, + pl_module.multi_step + rollout_step, + ..., + pl_module.data_indices.internal_data.output.full, + ] + loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() + + sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group + fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"loss_rstep_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", + exp_log_tag=f"loss_sample_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", + ) + + +class PlotSample(BasePerBatchPlotCallback): + """Plots a post-processed sample: input, target and prediction.""" + + def __init__( + self, + config: OmegaConf, + sample_idx: int, + parameters: list[str], + accumulation_levels_plot: list[float], + cmap_accumulation: list[str], + precip_and_related_fields: list[str] | None = None, + per_sample: int = 6, + batch_frequency: int | None = None, + ) -> None: + """Initialise the PlotSample callback. + + Parameters + ---------- + config : OmegaConf + Config object + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + accumulation_levels_plot : list[float] + Accumulation levels to plot + cmap_accumulation : list[str] + Colors of the accumulation levels + precip_and_related_fields : list[str] | None, optional + Precip variable names, by default None + per_sample : int, optional + Number of plots per sample, by default 6 + batch_frequency : int, optional + Batch frequency to plot at, by default None + """ + super().__init__(config, batch_frequency=batch_frequency) + self.sample_idx = sample_idx + self.parameters = parameters + + self.precip_and_related_fields = precip_and_related_fields + self.accumulation_levels_plot = accumulation_levels_plot + self.cmap_accumulation = cmap_accumulation + self.per_sample = per_sample + + LOGGER.info( + "Using defined accumulation colormap for fields: %s", + self.precip_and_related_fields, + ) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + + # Build dictionary of indices and parameters to be plotted + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + plot_parameters_dict = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in diagnostics, + ) + for name in self.parameters + } + + # When running in Async mode, it might happen that in the last epoch these tensors + # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA + # but internal ones would be on the cpu), The lines below allow to address this problem + if self.post_processors is None: + # Copy to be used across all the training cycle + self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + local_rank = pl_module.local_rank + + batch = pl_module.model.pre_processors(batch, in_place=False) + input_tensor = batch[ + self.sample_idx, + pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data = self.post_processors(input_tensor) + + output_tensor = self.post_processors( + torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), + in_place=False, + ) + output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) + data = data.numpy() + + for rollout_step in range(pl_module.rollout): + fig = plot_predicted_multilevel_flat_sample( + plot_parameters_dict, + self.per_sample, + self.latlons, + self.accumulation_levels_plot, + self.cmap_accumulation, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + precip_and_related_fields=self.precip_and_related_fields, + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", + ) + + +class BasePlotAdditionalMetrics(BasePerBatchPlotCallback): + """Base processing class for additional metrics.""" + + def process( + self, + pl_module: pl.LightningModule, + outputs: list, + batch: torch.Tensor, + ) -> tuple[np.ndarray, np.ndarray]: + # When running in Async mode, it might happen that in the last epoch these tensors + # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA + # but internal ones would be on the cpu), The lines below allow to address this problem + if self.pre_processors is None: + # Copy to be used across all the training cycle + self.pre_processors = copy.deepcopy(pl_module.model.pre_processors).cpu() + if self.post_processors is None: + # Copy to be used across all the training cycle + self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + + batch = pl_module.model.pre_processors(batch, in_place=False) + input_tensor = batch[ + self.sample_idx, + pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + + data = self.post_processors(input_tensor) + output_tensor = self.post_processors( + torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), + in_place=False, + ) + output_tensor = pl_module.output_mask.apply(output_tensor, dim=2, fill_value=np.nan).numpy() + data[1:, ...] = pl_module.output_mask.apply(data[1:, ...], dim=2, fill_value=np.nan) + data = data.numpy() + return data, output_tensor + + +class PlotSpectrum(BasePlotAdditionalMetrics): + """Plots TP related metric comparing target and prediction. + + The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. + + - Power Spectrum + """ + + def __init__( + self, + config: OmegaConf, + sample_idx: int, + parameters: list[str], + batch_frequency: int | None = None, + ) -> None: + """Initialise the PlotSpectrum callback. + + Parameters + ---------- + config : OmegaConf + Config object + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + batch_frequency : int | None, optional + Override for batch frequency, by default None + """ + super().__init__(config, batch_frequency=batch_frequency) + self.sample_idx = sample_idx + self.parameters = parameters + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list, + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + + local_rank = pl_module.local_rank + data, output_tensor = self.process(pl_module, outputs, batch) + + for rollout_step in range(pl_module.rollout): + # Build dictionary of inidicies and parameters to be plotted + + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + plot_parameters_dict_spectrum = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in diagnostics, + ) + for name in self.parameters + } + + fig = plot_power_spectrum( + plot_parameters_dict_spectrum, + self.latlons, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", + ) + + +class PlotHistogram(BasePlotAdditionalMetrics): + """Plots histograms comparing target and prediction. + + The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. + """ + + def __init__( + self, + config: OmegaConf, + sample_idx: int, + parameters: list[str], + precip_and_related_fields: list[str] | None = None, + batch_frequency: int | None = None, + ) -> None: + """Initialise the PlotHistogram callback. + + Parameters + ---------- + config : OmegaConf + Config object + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + precip_and_related_fields : list[str] | None, optional + Precip variable names, by default None + batch_frequency : int | None, optional + Override for batch frequency, by default None + """ + super().__init__(config, batch_frequency=batch_frequency) + self.sample_idx = sample_idx + self.parameters = parameters + self.precip_and_related_fields = precip_and_related_fields + LOGGER.info( + "Using precip histogram plotting method for fields: %s.", + self.precip_and_related_fields, + ) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list, + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + + local_rank = pl_module.local_rank + data, output_tensor = self.process(pl_module, outputs, batch) + + for rollout_step in range(pl_module.rollout): + + # Build dictionary of inidicies and parameters to be plotted + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + + plot_parameters_dict_histogram = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in diagnostics, + ) + for name in self.parameters + } + + fig = plot_histogram( + plot_parameters_dict_histogram, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + self.precip_and_related_fields, + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", + ) diff --git a/src/anemoi/training/diagnostics/callbacks/provenance.py b/src/anemoi/training/diagnostics/callbacks/provenance.py new file mode 100644 index 00000000..414f0311 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/provenance.py @@ -0,0 +1,47 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from pytorch_lightning.callbacks import Callback + +if TYPE_CHECKING: + import pytorch_lightning as pl + import torch + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class ParentUUIDCallback(Callback): + """A callback that retrieves the parent UUID for a model, if it is a child model.""" + + def __init__(self, config: OmegaConf) -> None: + """Initialise the ParentUUIDCallback callback. + + Parameters + ---------- + config : OmegaConf + Config object + + """ + super().__init__() + self.config = config + + def on_load_checkpoint( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + checkpoint: torch.nn.Module, + ) -> None: + del trainer # unused + pl_module.hparams["metadata"]["parent_uuid"] = checkpoint["hyper_parameters"]["metadata"]["uuid"] diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 7b4ba711..6bf72637 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -16,6 +16,9 @@ import matplotlib.pyplot as plt import matplotlib.style as mplstyle import numpy as np +import torch +from anemoi.models.layers.mapper import GraphEdgeMixin +from matplotlib.collections import LineCollection from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap from matplotlib.colors import TwoSlopeNorm @@ -587,7 +590,7 @@ def scatter_plot( Parameters ---------- - fig : _type_ + fig : Figure Figure object handle ax : matplotlib.axes Axis object handle @@ -628,36 +631,163 @@ def scatter_plot( fig.colorbar(psc, ax=ax) -def plot_graph_features( - latlons: np.ndarray, - features: np.ndarray, -) -> Figure: - """Plot trainable graph features. +def edge_plot( + fig: Figure, + ax: plt.Axes, + src_coords: np.ndarray, + dst_coords: np.ndarray, + data: np.ndarray, + cmap: str = "coolwarm", + title: str | None = None, +) -> None: + """Lat-lon line plot. Parameters ---------- - latlons : np.ndarray - Latitudes and longitudes - features : np.ndarray - Trainable Features + fig : _type_ + Figure object handle + ax : _type_ + Axis object handle + src_coords : np.ndarray of shape (num_edges, 2) + Source latitudes and longitudes. + dst_coords : np.ndarray of shape (num_edges, 2) + Destination latitudes and longitudes. + data : np.ndarray of shape (num_edges, 1) + Data to plot + cmap : str, optional + Colormap string from matplotlib, by default "viridis". + title : str, optional + Title for plot, by default None + """ + edge_lines = np.stack([src_coords, dst_coords], axis=1) + lc = LineCollection(edge_lines, cmap=cmap, linewidths=1) + lc.set_array(data) + + psc = ax.add_collection(lc) + + xmin, xmax = edge_lines[:, 0, 0].min(), edge_lines[:, 0, 0].max() + ymin, ymax = edge_lines[:, 1, 1].min(), edge_lines[:, 1, 1].max() + ax.set_xlim((xmin - 0.1, xmax + 0.1)) + ax.set_ylim((ymin - 0.1, ymax + 0.1)) + + continents.plot_continents(ax) + + if title is not None: + ax.set_title(title) + + ax.set_aspect("auto", adjustable=None) + _hide_axes_ticks(ax) + fig.colorbar(psc, ax=ax) + + +def sincos_to_latlon(sincos_coords: torch.Tensor) -> torch.Tensor: + """Get the lat/lon coordinates from the model. + + Parameters + ---------- + sincos_coords: torch.Tensor of shape (N, 4) + Sine and cosine of latitude and longitude coordinates. + + Returns + ------- + torch.Tensor of shape (N, 2) + Lat/lon coordinates. + """ + ndim = sincos_coords.shape[1] // 2 + sin_y, cos_y = sincos_coords[:, :ndim], sincos_coords[:, ndim:] + return torch.atan2(sin_y, cos_y) + + +def plot_graph_node_features(model: torch.nn.Module) -> Figure: + """Plot trainable graph node features. + + Parameters + ---------- + model: AneomiModelEncProcDec + Model object Returns ------- Figure Figure object handle + """ + nrows = len(nodes_name := model._graph_data.node_types) + ncols = min(getattr(model, f"trainable_{m}").trainable.shape[1] for m in nodes_name) + figsize = (ncols * 4, nrows * 3) + fig, ax = plt.subplots(nrows, ncols, figsize=figsize) + + for row, mesh in enumerate(nodes_name): + sincos_coords = getattr(model, f"latlons_{mesh}") + latlons = sincos_to_latlon(sincos_coords).cpu().numpy() + features = getattr(model, f"trainable_{mesh}").trainable.cpu().detach().numpy() + + lat, lon = latlons[:, 0], latlons[:, 1] + + for i in range(ncols): + ax_ = ax[row, i] if ncols > 1 else ax[row] + scatter_plot( + fig, + ax_, + lon=lon, + lat=lat, + data=features[..., i], + title=f"{mesh} trainable feature #{i + 1}", + ) + + return fig + +def plot_graph_edge_features(model: torch.nn.Module, q_extreme_limit: float = 0.05) -> Figure: + """Plot trainable graph edge features. + + Parameters + ---------- + model: AneomiModelEncProcDec + Model object + q_extreme_limit : float, optional + Plot top & bottom quantile of edges trainable values, by default 0.05 (5%). + + Returns + ------- + Figure + Figure object handle """ - nplots = features.shape[-1] - figsize = (nplots * 4, 3) - fig, ax = plt.subplots(1, nplots, figsize=figsize) + trainable_modules = { + (model._graph_name_data, model._graph_name_hidden): model.encoder, + (model._graph_name_hidden, model._graph_name_data): model.decoder, + } - lat, lon = latlons[:, 0], latlons[:, 1] + if isinstance(model.processor, GraphEdgeMixin): + trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor - pc = EquirectangularProjection() - pc_lon, pc_lat = pc(lon, lat) + ncols = min(module.trainable.trainable.shape[1] for module in trainable_modules.values()) + nrows = len(trainable_modules) + figsize = (ncols * 4, nrows * 3) + fig, ax = plt.subplots(nrows, ncols, figsize=figsize) + + for row, ((src, dst), graph_mapper) in enumerate(trainable_modules.items()): + src_coords = sincos_to_latlon(getattr(model, f"latlons_{src}")).cpu().numpy() + dst_coords = sincos_to_latlon(getattr(model, f"latlons_{dst}")).cpu().numpy() + edge_index = graph_mapper.edge_index_base.cpu().numpy() + edge_features = graph_mapper.trainable.trainable.cpu().detach().numpy() + + for i in range(ncols): + ax_ = ax[row, i] if ncols > 1 else ax[row] + feature = edge_features[..., i] - for i in range(nplots): - ax_ = ax[i] if nplots > 1 else ax - scatter_plot(fig, ax_, lon=pc_lon, lat=pc_lat, data=features[..., i]) + # Get mask of feature values over top and bottom percentiles + top_perc = np.quantile(feature, 1 - q_extreme_limit) + bottom_perc = np.quantile(feature, q_extreme_limit) + + mask = (feature >= top_perc) | (feature <= bottom_perc) + + edge_plot( + fig, + ax_, + src_coords[edge_index[0, mask]][:, ::-1], + dst_coords[edge_index[1, mask]][:, ::-1], + feature[mask], + title=f"{src} -> {dst} trainable feature #{i + 1}", + ) return fig diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 8fae16b3..6a084432 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -12,7 +12,9 @@ import math import os from collections import defaultdict +from collections.abc import Generator from collections.abc import Mapping +from typing import Optional from typing import Union import numpy as np @@ -139,8 +141,6 @@ def __init__( LOGGER.debug("Rollout max : %d", self.rollout_max) LOGGER.debug("Multistep: %d", self.multi_step) - self.enable_plot = config.diagnostics.plot.enabled - self.model_comm_group_id = int(os.environ.get("SLURM_PROCID", "0")) // config.hardware.num_gpus_per_model self.model_comm_group_rank = int(os.environ.get("SLURM_PROCID", "0")) % config.hardware.num_gpus_per_model self.model_comm_num_groups = math.ceil( @@ -322,17 +322,44 @@ def advance_input( ] return x - def _step( + def rollout_step( self, batch: torch.Tensor, - batch_idx: int, + rollout: Optional[int] = None, # noqa: FA100 + training_mode: bool = True, validation_mode: bool = False, - ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: - del batch_idx - loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) + ) -> Generator[tuple[Union[torch.Tensor, None], dict, list], None, None]: # noqa: FA100 + """ + Rollout step for the forecaster. + + Will run pre_processors on batch, but not post_processors on predictions. + + Parameters + ---------- + batch : torch.Tensor + Batch to use for rollout + rollout : Optional[int], optional + Number of times to rollout for, by default None + If None, will use self.rollout + training_mode : bool, optional + Whether in training mode and to calculate the loss, by default True + If False, loss will be None + validation_mode : bool, optional + Whether in validation mode, and to calculate validation metrics, by default False + If False, metrics will be empty + + Yields + ------ + Generator[tuple[Union[torch.Tensor, None], dict, list], None, None] + Loss value, metrics, and predictions (per step) + + Returns + ------- + None + None + """ # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) - metrics = {} # start rollout of preprocessed batch x = batch[ @@ -341,29 +368,52 @@ def _step( ..., self.data_indices.internal_data.input.full, ] # (bs, multi_step, latlon, nvar) + msg = ( + "Batch length not sufficient for requested multi_step length!" + f", {batch.shape[1]} !>= {rollout + self.multi_step}" + ) + assert batch.shape[1] >= rollout + self.multi_step, msg - y_preds = [] - for rollout_step in range(self.rollout): + for rollout_step in range(rollout or self.rollout): # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) y_pred = self(x) y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.internal_data.output.full] # y includes the auxiliary variables, so we must leave those out when computing the loss - loss += checkpoint(self.loss, y_pred, y, use_reentrant=False) + loss = checkpoint(self.loss, y_pred, y, use_reentrant=False) if training_mode else None x = self.advance_input(x, y_pred, batch, rollout_step) + metrics_next = {} if validation_mode: - metrics_next, y_preds_next = self.calculate_val_metrics( + metrics_next = self.calculate_val_metrics( y_pred, y, rollout_step, - enable_plot=self.enable_plot, ) - metrics.update(metrics_next) - y_preds.extend(y_preds_next) + yield loss, metrics_next, y_pred + + def _step( + self, + batch: torch.Tensor, + batch_idx: int, + validation_mode: bool = False, + ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: + del batch_idx + loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) + metrics = {} + y_preds = [] + + for loss_next, metrics_next, y_preds_next in self.rollout_step( + batch, + rollout=self.rollout, + training_mode=True, + validation_mode=validation_mode, + ): + loss += loss_next + metrics.update(metrics_next) + y_preds.extend(y_preds_next) - # scale loss loss *= 1.0 / self.rollout return loss, metrics, y_preds @@ -372,7 +422,6 @@ def calculate_val_metrics( y_pred: torch.Tensor, y: torch.Tensor, rollout_step: int, - enable_plot: bool = False, ) -> tuple[dict, list[torch.Tensor]]: """Calculate metrics on the validation output. @@ -384,8 +433,6 @@ def calculate_val_metrics( Ground truth (target). rollout_step: int Rollout step - enable_plot: bool, defaults to False - Generate plots Returns ------- @@ -393,7 +440,6 @@ def calculate_val_metrics( validation metrics and predictions """ metrics = {} - y_preds = [] y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) @@ -416,9 +462,7 @@ def calculate_val_metrics( feature_scale=mkey == "all", ) - if enable_plot: - y_preds.append(y_pred) - return metrics, y_preds + return metrics def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: train_loss, _, _ = self._step(batch, batch_idx) diff --git a/tests/diagnostics/test_callbacks.py b/tests/diagnostics/test_callbacks.py new file mode 100644 index 00000000..dd1d63b4 --- /dev/null +++ b/tests/diagnostics/test_callbacks.py @@ -0,0 +1,73 @@ +# ruff: noqa: ANN001, ANN201 +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import omegaconf +import yaml + +from anemoi.training.diagnostics.callbacks import get_callbacks + +default_config = """ +diagnostics: + callbacks: [] + + plot: + enabled: False + callbacks: [] + + debug: + # this will detect and trace back NaNs / Infs etc. but will slow down training + anomaly_detection: False + + profiler: False + + enable_checkpointing: False + checkpoint: + + log: {} +""" + + +def test_no_extra_callbacks_set(): + # No extra callbacks set + config = omegaconf.OmegaConf.create(yaml.safe_load(default_config)) + callbacks = get_callbacks(config) + assert len(callbacks) == 1 # ParentUUIDCallback + + +def test_add_config_enabled_callback(): + # Add logging callback + config = omegaconf.OmegaConf.create(default_config) + config.diagnostics.callbacks.append({"log": {"mlflow": {"enabled": True}}}) + callbacks = get_callbacks(config) + assert len(callbacks) == 2 + + +def test_add_callback(): + config = omegaconf.OmegaConf.create(default_config) + config.diagnostics.callbacks.append( + {"_target_": "anemoi.training.diagnostics.callbacks.provenance.ParentUUIDCallback"}, + ) + callbacks = get_callbacks(config) + assert len(callbacks) == 2 + + +def test_add_plotting_callback(monkeypatch): + # Add plotting callback + import anemoi.training.diagnostics.callbacks.plot as plot + + class PlotLoss: + def __init__(self, config: omegaconf.DictConfig): + pass + + monkeypatch.setattr(plot, "PlotLoss", PlotLoss) + + config = omegaconf.OmegaConf.create(default_config) + config.diagnostics.plot.enabled = True + config.diagnostics.plot.callbacks = [{"_target_": "anemoi.training.diagnostics.callbacks.plot.PlotLoss"}] + callbacks = get_callbacks(config) + assert len(callbacks) == 2