From b12fac854761c4ff852da25feb27c9b14f5c21b5 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Tue, 24 Sep 2024 09:44:21 +0000 Subject: [PATCH 01/40] Refactor Callbacks - Split into seperate files - Use list in config to add callbacks - Provide legacy config enabled approach - Fix ruff issues --- .../diagnostics/callbacks/__init__.py | 1145 ++--------------- .../diagnostics/callbacks/checkpointing.py | 179 +++ .../diagnostics/callbacks/evaluation.py | 133 ++ .../training/diagnostics/callbacks/id.py | 45 + .../diagnostics/callbacks/learning_rate.py | 26 + .../diagnostics/callbacks/plotting.py | 737 +++++++++++ .../training/diagnostics/callbacks/weights.py | 35 + 7 files changed, 1243 insertions(+), 1057 deletions(-) create mode 100644 src/anemoi/training/diagnostics/callbacks/checkpointing.py create mode 100644 src/anemoi/training/diagnostics/callbacks/evaluation.py create mode 100644 src/anemoi/training/diagnostics/callbacks/id.py create mode 100644 src/anemoi/training/diagnostics/callbacks/learning_rate.py create mode 100644 src/anemoi/training/diagnostics/callbacks/plotting.py create mode 100644 src/anemoi/training/diagnostics/callbacks/weights.py diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f2195b5f..b4c73ad4 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -5,1038 +5,55 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# * [WHY ARE CALLBACKS UNDER __init__.py?] -# * This functionality will be restructured in the near future -# * so for now callbacks are under __init__.py - from __future__ import annotations -import copy import logging -import sys -import time -import traceback -import uuid -from abc import ABC -from abc import abstractmethod -from concurrent.futures import ThreadPoolExecutor -from contextlib import nullcontext from datetime import timedelta -from functools import cached_property -from pathlib import Path from typing import TYPE_CHECKING -from typing import Any -from typing import Callable - -import matplotlib.patches as mpatches -import matplotlib.pyplot as plt -import numpy as np -import torch -import torchinfo -from anemoi.utils.checkpoints import save_metadata -from pytorch_lightning.callbacks import Callback -from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint -from pytorch_lightning.utilities import rank_zero_only -from anemoi.training.diagnostics.plots import init_plot_settings -from anemoi.training.diagnostics.plots import plot_graph_features -from anemoi.training.diagnostics.plots import plot_histogram -from anemoi.training.diagnostics.plots import plot_loss -from anemoi.training.diagnostics.plots import plot_power_spectrum -from anemoi.training.diagnostics.plots import plot_predicted_multilevel_flat_sample +from anemoi.training.diagnostics.callbacks.checkpointing import AnemoiCheckpoint +from anemoi.training.diagnostics.callbacks.evaluation import RolloutEval +from anemoi.training.diagnostics.callbacks.id import ParentUUIDCallback +from anemoi.training.diagnostics.callbacks.learning_rate import LearningRateMonitor +from anemoi.training.diagnostics.callbacks.plotting import GraphTrainableFeaturesPlot +from anemoi.training.diagnostics.callbacks.plotting import LongRolloutPlots +from anemoi.training.diagnostics.callbacks.plotting import PlotAdditionalMetrics +from anemoi.training.diagnostics.callbacks.plotting import PlotLoss +from anemoi.training.diagnostics.callbacks.plotting import PlotSample +from anemoi.training.diagnostics.callbacks.weights import StochasticWeightAveraging if TYPE_CHECKING: - import pytorch_lightning as pl from omegaconf import DictConfig - from omegaconf import OmegaConf + from pytorch_lightning.callbacks import Callback LOGGER = logging.getLogger(__name__) - -class ParallelExecutor(ThreadPoolExecutor): - """Wraps parallel execution and provides accurate information about errors. - - Extends ThreadPoolExecutor to preserve the original traceback and line number. - - Reference: https://stackoverflow.com/questions/19309514/getting-original-line- - number-for-exception-in-concurrent-futures/24457608#24457608 - """ - - def submit(self, fn: Any, *args, **kwargs) -> Callable: - """Submits the wrapped function instead of `fn`.""" - return super().submit(self._function_wrapper, fn, *args, **kwargs) - - def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable: - """Wraps `fn` in order to preserve the traceback of any kind of.""" - try: - return fn(*args, **kwargs) - except Exception as exc: - raise sys.exc_info()[0](traceback.format_exc()) from exc - - -class BasePlotCallback(Callback, ABC): - """Factory for creating a callback that plots data to Experiment Logging.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the BasePlotCallback abstract base class. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__() - self.config = config - self.save_basedir = config.hardware.paths.plots - self.plot_frequency = config.diagnostics.plot.frequency - self.post_processors = None - self.pre_processors = None - self.latlons = None - init_plot_settings() - - self.plot = self._plot - self._executor = None - - if self.config.diagnostics.plot.asynchronous: - self._executor = ParallelExecutor(max_workers=1) - self._error: BaseException | None = None - self.plot = self._async_plot - - @rank_zero_only - def _output_figure( - self, - logger: pl.loggers.base.LightningLoggerBase, - fig: plt.Figure, - epoch: int, - tag: str = "gnn", - exp_log_tag: str = "val_pred_sample", - ) -> None: - """Figure output: save to file and/or display in notebook.""" - if self.save_basedir is not None: - save_path = Path( - self.save_basedir, - "plots", - f"{tag}_epoch{epoch:03d}.png", - ) - - save_path.parent.mkdir(parents=True, exist_ok=True) - fig.savefig(save_path, dpi=100, bbox_inches="tight") - if self.config.diagnostics.log.wandb.enabled: - import wandb - - logger.experiment.log({exp_log_tag: wandb.Image(fig)}) - - if self.config.diagnostics.log.mlflow.enabled: - run_id = logger.run_id - logger.experiment.log_artifact(run_id, str(save_path)) - - plt.close(fig) # cleanup - - def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: - """Method is called to close the threads.""" - del trainer, pl_module, stage # unused - if self._executor is not None: - self._executor.shutdown(wait=True) - - @abstractmethod - @rank_zero_only - def _plot( - *args: list, - **kwargs: dict, - ) -> None: ... - - @rank_zero_only - def _async_plot( - self, - trainer: pl.Trainer, - *args: list, - **kwargs: dict, - ) -> None: - """To execute the plot function but ensuring we catch any errors.""" - future = self._executor.submit( - self._plot, - trainer, - *args, - **kwargs, - ) - # otherwise the error won't be thrown till the validation epoch is finished - try: - future.result() - except Exception: - LOGGER.exception("Critical error occurred in asynchronous plots.") - sys.exit(1) - - -class RolloutEval(Callback): - """Evaluates the model performance over a (longer) rollout window.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialize RolloutEval callback. - - Parameters - ---------- - config : dict - Dictionary with configuration settings - - """ - super().__init__() - - LOGGER.debug( - "Setting up RolloutEval callback with rollout = %d, frequency = %d ...", - config.diagnostics.eval.rollout, - config.diagnostics.eval.frequency, - ) - self.rollout = config.diagnostics.eval.rollout - self.frequency = config.diagnostics.eval.frequency - - def _eval( - self, - pl_module: pl.LightningModule, - batch: torch.Tensor, - ) -> None: - loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False) - metrics = {} - - # start rollout - batch = pl_module.model.pre_processors(batch, in_place=False) - x = batch[ - :, - 0 : pl_module.multi_step, - ..., - pl_module.data_indices.internal_data.input.full, - ] # (bs, multi_step, latlon, nvar) - assert ( - batch.shape[1] >= self.rollout + pl_module.multi_step - ), "Batch length not sufficient for requested rollout length!" - - with torch.no_grad(): - for rollout_step in range(self.rollout): - y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) - y = batch[ - :, - pl_module.multi_step + rollout_step, - ..., - pl_module.data_indices.internal_data.output.full, - ] # target, shape = (bs, latlon, nvar) - # y includes the auxiliary variables, so we must leave those out when computing the loss - loss += pl_module.loss(y_pred, y) - - x = pl_module.advance_input(x, y_pred, batch, rollout_step) - - metrics_next, _ = pl_module.calculate_val_metrics(y_pred, y, rollout_step) - metrics.update(metrics_next) - - # scale loss - loss *= 1.0 / self.rollout - self._log(pl_module, loss, metrics, batch.shape[0]) - - def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, bs: int) -> None: - pl_module.log( - f"val_r{self.rollout}_wmse", - loss, - on_epoch=True, - on_step=True, - prog_bar=False, - logger=pl_module.logger_enabled, - batch_size=bs, - sync_dist=False, - rank_zero_only=True, - ) - for mname, mvalue in metrics.items(): - pl_module.log( - f"val_r{self.rollout}_" + mname, - mvalue, - on_epoch=True, - on_step=False, - prog_bar=False, - logger=pl_module.logger_enabled, - batch_size=bs, - sync_dist=False, - rank_zero_only=True, - ) - - @rank_zero_only - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list, - batch: torch.Tensor, - batch_idx: int, - ) -> None: - del outputs # outputs are not used - if batch_idx % self.frequency == 0: - precision_mapping = { - "16-mixed": torch.float16, - "bf16-mixed": torch.bfloat16, - } - prec = trainer.precision - dtype = precision_mapping.get(prec) - context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() - - with context: - self._eval(pl_module, batch) - - -class LongRolloutPlots(BasePlotCallback): - """Evaluates the model performance over a (longer) rollout window.""" - - def __init__(self, config) -> None: - """Initialize RolloutEval callback. - - Parameters - ---------- - config : dict - Dictionary with configuration settings - """ - super().__init__(config) - - LOGGER.debug( - "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", - config.diagnostics.plot.longrollout.rollout, - config.diagnostics.plot.longrollout.frequency, - ) - self.rollout = config.diagnostics.plot.longrollout.rollout - self.eval_frequency = config.diagnostics.plot.longrollout.frequency - self.sample_idx = self.config.diagnostics.plot.sample_idx - - @rank_zero_only - def _plot( - self, - trainer, - pl_module: pl.LightningModule, - batch: torch.Tensor, - batch_idx, - epoch, - ) -> None: - - start_time = time.time() - - logger = trainer.logger - - # Build dictionary of inidicies and parameters to be plotted - plot_parameters_dict = { - pl_module.data_indices.model.output.name_to_index[name]: ( - name, - name not in self.config.data.get("diagnostic", []), - ) - for name in self.config.diagnostics.plot.parameters - } - - if self.post_processors is None: - # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() - if self.latlons is None: - self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - - batch = pl_module.model.pre_processors(batch, in_place=False) - # prepare input tensor for rollout from preprocessed batch - x = batch[ - :, - 0 : pl_module.multi_step, - ..., - pl_module.data_indices.internal_data.input.full, - ] # (bs, multi_step, latlon, nvar) - assert ( - batch.shape[1] >= max(self.rollout) + pl_module.multi_step - ), "Batch length not sufficient for requested rollout length!" - - # prepare input tensor for plotting - input_tensor_0 = batch[ - self.sample_idx, - pl_module.multi_step - 1, - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data_0 = self.post_processors(input_tensor_0).numpy() - - # start rollout - with torch.no_grad(): - for rollout_step in range(max(self.rollout)): - y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) - - x = pl_module.advance_input(x, y_pred, batch, rollout_step) - - if (rollout_step + 1) in self.rollout: - # prepare true output tensor for plotting - input_tensor_rollout_step = batch[ - self.sample_idx, - pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy() - - # prepare predicted output tensor for plotting - output_tensor = self.post_processors( - y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu() - ).numpy() - - fig = plot_predicted_multilevel_flat_sample( - plot_parameters_dict, - self.config.diagnostics.plot.per_sample, - self.latlons, - self.config.diagnostics.plot.get("accumulation_levels_plot", None), - self.config.diagnostics.plot.get("cmap_accumulation", None), - data_0.squeeze(), - data_rollout_step.squeeze(), - output_tensor[0, 0, :, :], # rolloutstep, first member - # force_global_view=self.show_entire_globe, - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_sample_rstep{rollout_step:03d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_sample_rstep{rollout_step:03d}_rank{local_rank:01d}", - ) - LOGGER.info(f"Time taken to plot samples after longer rollout: {int(time.time() - start_time)} seconds") - - @rank_zero_only - def on_validation_batch_end(self, trainer, pl_module, output, batch, batch_idx) -> None: - if (batch_idx) % self.plot_frequency == 0 and (trainer.current_epoch + 1) % self.eval_frequency == 0: - precision_mapping = { - "16-mixed": torch.float16, - "bf16-mixed": torch.bfloat16, - } - prec = trainer.precision - dtype = precision_mapping.get(prec) - context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() - - with context: - self._plot(trainer, pl_module, batch, batch_idx, epoch=trainer.current_epoch) - - -class GraphTrainableFeaturesPlot(BasePlotCallback): - """Visualize the trainable features defined at the data and hidden graph nodes. - - TODO: How best to visualize the learned edge embeddings? Offline, perhaps - using code from @Simon's notebook? - """ - - def __init__(self, config: OmegaConf) -> None: - """Initialise the GraphTrainableFeaturesPlot callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__(config) - self._graph_name_data = config.graph.data - self._graph_name_hidden = config.graph.hidden - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - latlons: np.ndarray, - features: np.ndarray, - epoch: int, - tag: str, - exp_log_tag: str, - ) -> None: - fig = plot_graph_features(latlons, features) - self._output_figure(trainer.logger, fig, epoch=epoch, tag=tag, exp_log_tag=exp_log_tag) - - @rank_zero_only - def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - - model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model - graph = pl_module.graph_data.cpu().detach() - epoch = trainer.current_epoch - - if model.trainable_data is not None: - data_coords = np.rad2deg(graph[(self._graph_name_data, "to", self._graph_name_data)].ecoords_rad.numpy()) - - self.plot( - trainer, - data_coords, - model.trainable_data.trainable.cpu().detach().numpy(), - epoch=epoch, - tag="trainable_data", - exp_log_tag="trainable_data", - ) - - if model.trainable_hidden is not None: - hidden_coords = np.rad2deg( - graph[(self._graph_name_hidden, "to", self._graph_name_hidden)].hcoords_rad.numpy(), - ) - - self.plot( - trainer, - hidden_coords, - model.trainable_hidden.trainable.cpu().detach().numpy(), - epoch=epoch, - tag="trainable_hidden", - exp_log_tag="trainable_hidden", - ) - - -class PlotLoss(BasePlotCallback): - """Plots the unsqueezed loss over rollouts.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the PlotLoss callback. - - Parameters - ---------- - config : OmegaConf - Object with configuration settings - - """ - super().__init__(config) - self.parameter_names = None - self.parameter_groups = self.config.diagnostics.plot.parameter_groups - if self.parameter_groups is None: - self.parameter_groups = {} - - @cached_property - def sort_and_color_by_parameter_group(self) -> tuple[np.ndarray, np.ndarray, dict, list]: - """Sort parameters by group and prepare colors.""" - - def automatically_determine_group(name: str) -> str: - # first prefix of parameter name is group name - parts = name.split("_") - return parts[0] - - # group parameters by their determined group name for > 15 parameters - if len(self.parameter_names) <= 15: - # for <= 15 parameters, keep the full name of parameters - parameters_to_groups = np.array(self.parameter_names) - sort_by_parameter_group = np.arange(len(self.parameter_names), dtype=int) - else: - parameters_to_groups = np.array( - [ - next( - ( - group_name - for group_name, group_parameters in self.parameter_groups.items() - if name in group_parameters - ), - automatically_determine_group(name), - ) - for name in self.parameter_names - ], - ) - - unique_group_list, group_inverse, group_counts = np.unique( - parameters_to_groups, - return_inverse=True, - return_counts=True, - ) - - # join parameter groups that appear only once and are not given in config-file - unique_group_list = np.array( - [ - unique_group_list[tn] if count > 1 or unique_group_list[tn] in self.parameter_groups else "other" - for tn, count in enumerate(group_counts) - ], - ) - parameters_to_groups = unique_group_list[group_inverse] - unique_group_list, group_inverse = np.unique(parameters_to_groups, return_inverse=True) - - # sort parameters by groups - sort_by_parameter_group = np.argsort(group_inverse, kind="stable") - - # apply new order to parameters - sorted_parameter_names = np.array(self.parameter_names)[sort_by_parameter_group] - parameters_to_groups = parameters_to_groups[sort_by_parameter_group] - unique_group_list, group_inverse, group_counts = np.unique( - parameters_to_groups, - return_inverse=True, - return_counts=True, - ) - - # get a color per group and project to parameter list - cmap = "tab10" if len(unique_group_list) <= 10 else "tab20" - if len(unique_group_list) > 20: - LOGGER.warning("More than 20 groups detected, but colormap has only 20 colors.") - # if all groups have count 1 use black color - bar_color_per_group = ( - np.tile("k", len(group_counts)) - if not np.any(group_counts - 1) - else plt.get_cmap(cmap)(np.linspace(0, 1, len(unique_group_list))) - ) - - # set x-ticks - x_tick_positions = np.cumsum(group_counts) - group_counts / 2 - 0.5 - xticks = dict(zip(unique_group_list, x_tick_positions)) - - legend_patches = [] - for group_idx, group in enumerate(unique_group_list): - text_label = f"{group}: " - string_length = len(text_label) - for ii in np.where(group_inverse == group_idx)[0]: - text_label += sorted_parameter_names[ii] + ", " - string_length += len(sorted_parameter_names[ii]) + 2 - if string_length > 50: - # linebreak after 50 characters - text_label += "\n" - string_length = 0 - legend_patches.append(mpatches.Patch(color=bar_color_per_group[group_idx], label=text_label[:-2])) - - return sort_by_parameter_group, bar_color_per_group[group_inverse], xticks, legend_patches - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - pl_module: pl.Lightning_module, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - epoch: int, - ) -> None: - logger = trainer.logger - - parameter_names = list(pl_module.data_indices.internal_model.output.name_to_index.keys()) - parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values()) - # reorder parameter_names by position - self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] - - batch = pl_module.model.pre_processors(batch, in_place=False) - for rollout_step in range(pl_module.rollout): - y_hat = outputs[1][rollout_step] - y_true = batch[ - :, pl_module.multi_step + rollout_step, ..., pl_module.data_indices.internal_data.output.full - ] - loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() - - sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group - fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"loss_rstep_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", - exp_log_tag=f"loss_sample_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", - ) - - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) - - -class PlotSample(BasePlotCallback): - """Plots a post-processed sample: input, target and prediction.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the PlotSample callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__(config) - self.sample_idx = self.config.diagnostics.plot.sample_idx - self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields - LOGGER.info(f"Using defined accumulation colormap for fields: {self.precip_and_related_fields}") - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - pl_module: pl.Lightning_module, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - epoch: int, - ) -> None: - logger = trainer.logger - - # Build dictionary of indices and parameters to be plotted - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - plot_parameters_dict = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters - } - - # When running in Async mode, it might happen that in the last epoch these tensors - # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA - # but internal ones would be on the cpu), The lines below allow to address this problem - if self.post_processors is None: - # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() - if self.latlons is None: - self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - - batch = pl_module.model.pre_processors(batch, in_place=False) - input_tensor = batch[ - self.sample_idx, - pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data = self.post_processors(input_tensor).numpy() - - output_tensor = self.post_processors( - torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), - in_place=False, - ).numpy() - - for rollout_step in range(pl_module.rollout): - fig = plot_predicted_multilevel_flat_sample( - plot_parameters_dict, - self.config.diagnostics.plot.per_sample, - self.latlons, - self.config.diagnostics.plot.accumulation_levels_plot, - self.config.diagnostics.plot.cmap_accumulation, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - precip_and_related_fields=self.precip_and_related_fields, - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", - ) - - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.Lightning_module, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) - - -class PlotAdditionalMetrics(BasePlotCallback): - """Plots TP related metric comparing target and prediction. - - The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. - - - Power Spectrum - - Histograms - """ - - def __init__(self, config: OmegaConf) -> None: - """Initialise the PlotAdditionalMetrics callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__(config) - self.sample_idx = self.config.diagnostics.plot.sample_idx - self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields - LOGGER.info(f"Using precip histogram plotting method for fields: {self.precip_and_related_fields}") - - @rank_zero_only - def _plot( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list, - batch: torch.Tensor, - batch_idx: int, - epoch: int, - ) -> None: - logger = trainer.logger - - # When running in Async mode, it might happen that in the last epoch these tensors - # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA - # but internal ones would be on the cpu), The lines below allow to address this problem - if self.pre_processors is None: - # Copy to be used across all the training cycle - self.pre_processors = copy.deepcopy(pl_module.model.pre_processors).cpu() - if self.post_processors is None: - # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() - if self.latlons is None: - self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) - local_rank = pl_module.local_rank - batch = pl_module.model.pre_processors(batch, in_place=False) - input_tensor = batch[ - self.sample_idx, - pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data = self.post_processors(input_tensor).numpy() - output_tensor = self.post_processors( - torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), - in_place=False, - ).numpy() - - for rollout_step in range(pl_module.rollout): - if self.config.diagnostics.plot.parameters_histogram is not None: - # Build dictionary of inidicies and parameters to be plotted - - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - plot_parameters_dict_histogram = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters_histogram - } - - fig = plot_histogram( - plot_parameters_dict_histogram, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - precip_and_related_fields=self.precip_and_related_fields, - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", - ) - - if self.config.diagnostics.plot.parameters_spectrum is not None: - # Build dictionary of inidicies and parameters to be plotted - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - - plot_parameters_dict_spectrum = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters_spectrum - } - - fig = plot_power_spectrum( - plot_parameters_dict_spectrum, - self.latlons, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - ) - - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", - ) - - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) - - -class ParentUUIDCallback(Callback): - """A callback that retrieves the parent UUID for a model, if it is a child model.""" - - def __init__(self, config: OmegaConf) -> None: - """Initialise the ParentUUIDCallback callback. - - Parameters - ---------- - config : OmegaConf - Config object - - """ - super().__init__() - self.config = config - - def on_load_checkpoint( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - checkpoint: torch.nn.Module, - ) -> None: - del trainer # unused - pl_module.hparams["metadata"]["parent_uuid"] = checkpoint["hyper_parameters"]["metadata"]["uuid"] - - -class AnemoiCheckpoint(ModelCheckpoint): - """A checkpoint callback that saves the model after every validation epoch.""" - - def __init__(self, config: OmegaConf, **kwargs: dict) -> None: - """Initialise the AnemoiCheckpoint callback. - - Parameters - ---------- - config : OmegaConf - Config object - kwargs : dict - Additional keyword arguments for Pytorch ModelCheckpoint - - """ - super().__init__(**kwargs) - self.config = config - self.start = time.time() - self._model_metadata = None - self._tracker_metadata = None - self._tracker_name = None - - @staticmethod - def _torch_drop_down(trainer: pl.Trainer) -> torch.nn.Module: - # Get the model from the DataParallel wrapper, for single and multi-gpu cases - assert hasattr(trainer, "model"), "Trainer has no attribute 'model'! Is the Pytorch Lightning version correct?" - return trainer.model.module.model if hasattr(trainer.model, "module") else trainer.model.model - - @rank_zero_only - def model_metadata(self, model: torch.nn.Module) -> dict: - if self._model_metadata is not None: - return self._model_metadata - - self._model_metadata = { - "model": model.__class__.__name__, - "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad), - "total_parameters": sum(p.numel() for p in model.parameters()), - "summary": repr( - torchinfo.summary( - model, - depth=50, - verbose=0, - row_settings=["var_names"], - ), - ), - } - - return self._model_metadata - - def tracker_metadata(self, trainer: pl.Trainer) -> dict: - if self._tracker_metadata is not None: - return {self._tracker_name: self._tracker_metadata} - - if self.config.diagnostics.log.wandb.enabled: - self._tracker_name = "wand" - import wandb - - run = wandb.run - if run is not None: - self._tracker_metadata = { - "id": run.id, - "name": run.name, - "url": run.url, - "project": run.project, - } - return {self._tracker_name: self._tracker_metadata} - - if self.config.diagnostics.log.mlflow.enabled: - self._tracker_name = "mlflow" - - from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger - - mlflow_logger = next(logger for logger in trainer.loggers if isinstance(logger, AnemoiMLflowLogger)) - run_id = mlflow_logger.run_id - run = mlflow_logger._mlflow_client.get_run(run_id) - - if run is not None: - self._tracker_metadata = { - "id": run.info.run_id, - "name": run.info.run_name, - "url": run.info.artifact_uri, - "project": run.info.experiment_id, - } - return {self._tracker_name: self._tracker_metadata} - - return {} - - def _remove_checkpoint(self, trainer: "pl.Trainer", filepath: str) -> None: - """Calls the strategy to remove the checkpoint file.""" - super()._remove_checkpoint(trainer, filepath) - trainer.strategy.remove_checkpoint(self._get_inference_checkpoint_filepath(filepath)) - - def _get_inference_checkpoint_filepath(self, filepath: str) -> str: - """Defines the filepath for the inference checkpoint.""" - return Path(filepath).parent / Path("inference-" + str(Path(filepath).name)) - - def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: str) -> None: - if trainer.is_global_zero: - model = self._torch_drop_down(trainer) - - # We want a different uuid each time we save the model - # so we can tell them apart in the catalogue (i.e. different epochs) - checkpoint_uuid = str(uuid.uuid4()) - trainer.lightning_module._hparams["metadata"]["uuid"] = checkpoint_uuid - - trainer.lightning_module._hparams["metadata"]["model"] = self.model_metadata(model) - trainer.lightning_module._hparams["metadata"]["tracker"] = self.tracker_metadata(trainer) - - trainer.lightning_module._hparams["metadata"]["training"] = { - "current_epoch": trainer.current_epoch, - "global_step": trainer.global_step, - "elapsed_time": time.time() - self.start, - } - - Path(lightning_checkpoint_filepath).parent.mkdir(parents=True, exist_ok=True) - - save_config = model.config - model.config = None - - tmp_metadata = model.metadata - model.metadata = None - - metadata = dict(**tmp_metadata) - - inference_checkpoint_filepath = self._get_inference_checkpoint_filepath(lightning_checkpoint_filepath) - - torch.save(model, inference_checkpoint_filepath) - - save_metadata(inference_checkpoint_filepath, metadata) - - model.config = save_config - model.metadata = tmp_metadata - - self._last_global_step_saved = trainer.global_step - - trainer.strategy.barrier() - - # saving checkpoint used for pytorch-lightning based training - trainer.save_checkpoint(lightning_checkpoint_filepath, self.save_weights_only) - - self._last_global_step_saved = trainer.global_step - self._last_checkpoint_saved = lightning_checkpoint_filepath - - if trainer.is_global_zero: - from weakref import proxy - - # save metadata for the training checkpoint in the same format as inference - save_metadata(lightning_checkpoint_filepath, metadata) - - # notify loggers - for logger in trainer.loggers: - logger.after_save_checkpoint(proxy(self)) - - -def get_callbacks(config: DictConfig) -> list: # noqa: C901 - """Setup callbacks for PyTorch Lightning trainer. - - Parameters - ---------- - config : DictConfig - Job configuration - - Returns - ------- - List - A list of PyTorch Lightning callbacks - - """ +# Dictionary of available callbacks +CALLBACK_DICT: dict[str, type[Callback]] = { + "RolloutEval": RolloutEval, + "LongRolloutPlots": LongRolloutPlots, + "GraphTrainableFeaturesPlot": GraphTrainableFeaturesPlot, + "PlotLoss": PlotLoss, + "PlotSample": PlotSample, + "PlotAdditionalMetrics": PlotAdditionalMetrics, + "ParentUUIDCallback": ParentUUIDCallback, +} + +# Callbacks to add according to flags in the config +CONFIG_ENABLED_CALLBACKS: dict[list[str] | str, list[type[Callback]] | type[Callback]] = { + ["diagnostics.log.wandb.enabled", "diagnostics.log.mlflow.enabled"]: LearningRateMonitor, + "diagnostics.eval.enabled": RolloutEval, + "diagnostics.plot.enabled": [ + PlotLoss, + PlotSample, + ], + "training.swa.enabled": StochasticWeightAveraging, + "diagnostics.plot.learned_features": GraphTrainableFeaturesPlot, +} + + +def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | None: + """Get checkpointing callback""" checkpoint_settings = { "dirpath": config.hardware.paths.checkpoints, "verbose": False, @@ -1060,12 +77,11 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901 target = key ckpt_frequency_save_dict[target] = (config.hardware.files.checkpoint[key], frequency, n_saved) - trainer_callbacks = [] if not config.diagnostics.profiler: for save_key, (name, save_frequency, save_n_models) in ckpt_frequency_save_dict.items(): if save_frequency is not None: LOGGER.debug("Checkpoint callback at %s = %s ...", save_key, save_frequency) - trainer_callbacks.extend( + return ( # save_top_k: the save_top_k flag can either save the best or the last k checkpoints # depending on the monitor flag on ModelCheckpoint. # See https://lightning.ai/docs/pytorch/stable/common/checkpointing_intermediate.html for reference @@ -1088,51 +104,66 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901 else: # the tensorboard logger + pytorch profiler cause pickling errors when writing checkpoints LOGGER.warning("Profiling is enabled - will not write any training or inference model checkpoints!") + return None - if any([config.diagnostics.log.wandb.enabled, config.diagnostics.log.mlflow.enabled]): - from pytorch_lightning.callbacks import LearningRateMonitor - trainer_callbacks.append( - LearningRateMonitor( - logging_interval="step", - log_momentum=False, - ), - ) +def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: + """Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS""" + callbacks = [] - if config.diagnostics.eval.enabled: - trainer_callbacks.append(RolloutEval(config)) + for enable_key, callback_list in CONFIG_ENABLED_CALLBACKS.items(): + if isinstance(enable_key, list): + if not any(config.get(key, False) for key in enable_key): + continue + elif not config.get(enable_key, False): + continue + + if isinstance(callback_list, list): + callbacks.extend(map(lambda x: x(config), callback_list)) + else: + callbacks.append(callback_list(config)) + + return callbacks + + +def get_callbacks(config: DictConfig) -> list: # noqa: C901 + """Setup callbacks for PyTorch Lightning trainer. + + Set config.diagnostics.callbacks to a list of callback names to enable them. + + Parameters + ---------- + config : DictConfig + Job configuration + + Returns + ------- + List + A list of PyTorch Lightning callbacks + + """ + + trainer_callbacks: list[Callback] = [] + checkpoint_callback = _get_checkpoint_callback(config) + if checkpoint_callback is not None: + trainer_callbacks.extend(checkpoint_callback) + + requested_callbacks = config.diagnostics.get("callbacks", []) + + for callback in requested_callbacks: + if callback in CALLBACK_DICT: + trainer_callbacks.append(CALLBACK_DICT[callback](config)) + else: + LOGGER.error(f"Callback {callback} not found in CALLBACK_DICT\n{list(CALLBACK_DICT.keys())}") + + trainer_callbacks.extend(_get_config_enabled_callbacks(config)) if config.diagnostics.plot.enabled: - trainer_callbacks.extend( - [ - PlotLoss(config), - PlotSample(config), - ], - ) if (config.diagnostics.plot.parameters_histogram or config.diagnostics.plot.parameters_spectrum) is not None: trainer_callbacks.extend([PlotAdditionalMetrics(config)]) if config.diagnostics.plot.get("longrollout") and config.diagnostics.plot.longrollout.enabled: trainer_callbacks.extend([LongRolloutPlots(config)]) - if config.training.swa.enabled: - from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging - - trainer_callbacks.append( - StochasticWeightAveraging( - swa_lrs=config.training.swa.lr, - swa_epoch_start=min( - int(0.75 * config.training.max_epochs), - config.training.max_epochs - 1, - ), - annealing_epochs=max(int(0.25 * config.training.max_epochs), 1), - annealing_strategy="cos", - device=None, - ), - ) - trainer_callbacks.append(ParentUUIDCallback(config)) - if config.diagnostics.plot.learned_features: - LOGGER.debug("Setting up a callback to plot the trainable graph node features ...") - trainer_callbacks.append(GraphTrainableFeaturesPlot(config)) return trainer_callbacks diff --git a/src/anemoi/training/diagnostics/callbacks/checkpointing.py b/src/anemoi/training/diagnostics/callbacks/checkpointing.py new file mode 100644 index 00000000..6953d8fb --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/checkpointing.py @@ -0,0 +1,179 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +import time +import uuid +from pathlib import Path +from typing import TYPE_CHECKING + +import torch +import torchinfo +from anemoi.utils.checkpoints import save_metadata +from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint +from pytorch_lightning.utilities import rank_zero_only + +if TYPE_CHECKING: + import pytorch_lightning as pl + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class AnemoiCheckpoint(ModelCheckpoint): + """A checkpoint callback that saves the model after every validation epoch.""" + + def __init__(self, config: OmegaConf, **kwargs: dict) -> None: + """Initialise the AnemoiCheckpoint callback. + + Parameters + ---------- + config : OmegaConf + Config object + kwargs : dict + Additional keyword arguments for Pytorch ModelCheckpoint + + """ + super().__init__(**kwargs) + self.config = config + self.start = time.time() + self._model_metadata = None + self._tracker_metadata = None + self._tracker_name = None + + @staticmethod + def _torch_drop_down(trainer: pl.Trainer) -> torch.nn.Module: + # Get the model from the DataParallel wrapper, for single and multi-gpu cases + assert hasattr(trainer, "model"), "Trainer has no attribute 'model'! Is the Pytorch Lightning version correct?" + return trainer.model.module.model if hasattr(trainer.model, "module") else trainer.model.model + + @rank_zero_only + def model_metadata(self, model: torch.nn.Module) -> dict: + if self._model_metadata is not None: + return self._model_metadata + + self._model_metadata = { + "model": model.__class__.__name__, + "trainable_parameters": sum(p.numel() for p in model.parameters() if p.requires_grad), + "total_parameters": sum(p.numel() for p in model.parameters()), + "summary": repr( + torchinfo.summary( + model, + depth=50, + verbose=0, + row_settings=["var_names"], + ), + ), + } + + return self._model_metadata + + def tracker_metadata(self, trainer: pl.Trainer) -> dict: + if self._tracker_metadata is not None: + return {self._tracker_name: self._tracker_metadata} + + if self.config.diagnostics.log.wandb.enabled: + self._tracker_name = "wand" + import wandb + + run = wandb.run + if run is not None: + self._tracker_metadata = { + "id": run.id, + "name": run.name, + "url": run.url, + "project": run.project, + } + return {self._tracker_name: self._tracker_metadata} + + if self.config.diagnostics.log.mlflow.enabled: + self._tracker_name = "mlflow" + + from anemoi.training.diagnostics.mlflow.logger import AnemoiMLflowLogger + + mlflow_logger = next(logger for logger in trainer.loggers if isinstance(logger, AnemoiMLflowLogger)) + run_id = mlflow_logger.run_id + run = mlflow_logger._mlflow_client.get_run(run_id) + + if run is not None: + self._tracker_metadata = { + "id": run.info.run_id, + "name": run.info.run_name, + "url": run.info.artifact_uri, + "project": run.info.experiment_id, + } + return {self._tracker_name: self._tracker_metadata} + + return {} + + def _remove_checkpoint(self, trainer: pl.Trainer, filepath: str) -> None: + """Calls the strategy to remove the checkpoint file.""" + super()._remove_checkpoint(trainer, filepath) + trainer.strategy.remove_checkpoint(self._get_inference_checkpoint_filepath(filepath)) + + def _get_inference_checkpoint_filepath(self, filepath: str) -> str: + """Defines the filepath for the inference checkpoint.""" + return Path(filepath).parent / Path("inference-" + str(Path(filepath).name)) + + def _save_checkpoint(self, trainer: pl.Trainer, lightning_checkpoint_filepath: str) -> None: + if trainer.is_global_zero: + model = self._torch_drop_down(trainer) + + # We want a different uuid each time we save the model + # so we can tell them apart in the catalogue (i.e. different epochs) + checkpoint_uuid = str(uuid.uuid4()) + trainer.lightning_module._hparams["metadata"]["uuid"] = checkpoint_uuid + + trainer.lightning_module._hparams["metadata"]["model"] = self.model_metadata(model) + trainer.lightning_module._hparams["metadata"]["tracker"] = self.tracker_metadata(trainer) + + trainer.lightning_module._hparams["metadata"]["training"] = { + "current_epoch": trainer.current_epoch, + "global_step": trainer.global_step, + "elapsed_time": time.time() - self.start, + } + + Path(lightning_checkpoint_filepath).parent.mkdir(parents=True, exist_ok=True) + + save_config = model.config + model.config = None + + tmp_metadata = model.metadata + model.metadata = None + + metadata = dict(**tmp_metadata) + + inference_checkpoint_filepath = self._get_inference_checkpoint_filepath(lightning_checkpoint_filepath) + + torch.save(model, inference_checkpoint_filepath) + + save_metadata(inference_checkpoint_filepath, metadata) + + model.config = save_config + model.metadata = tmp_metadata + + self._last_global_step_saved = trainer.global_step + + trainer.strategy.barrier() + + # saving checkpoint used for pytorch-lightning based training + trainer.save_checkpoint(lightning_checkpoint_filepath, self.save_weights_only) + + self._last_global_step_saved = trainer.global_step + self._last_checkpoint_saved = lightning_checkpoint_filepath + + if trainer.is_global_zero: + from weakref import proxy + + # save metadata for the training checkpoint in the same format as inference + save_metadata(lightning_checkpoint_filepath, metadata) + + # notify loggers + for logger in trainer.loggers: + logger.after_save_checkpoint(proxy(self)) diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py new file mode 100644 index 00000000..cbc6b854 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -0,0 +1,133 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import torch +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities import rank_zero_only + +if TYPE_CHECKING: + import pytorch_lightning as pl + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class RolloutEval(Callback): + """Evaluates the model performance over a (longer) rollout window.""" + + def __init__(self, config: OmegaConf) -> None: + """Initialize RolloutEval callback. + + Parameters + ---------- + config : dict + Dictionary with configuration settings + + """ + super().__init__() + + LOGGER.debug( + "Setting up RolloutEval callback with rollout = %d, frequency = %d ...", + config.diagnostics.eval.rollout, + config.diagnostics.eval.frequency, + ) + self.rollout = config.diagnostics.eval.rollout + self.frequency = config.diagnostics.eval.frequency + + def _eval( + self, + pl_module: pl.LightningModule, + batch: torch.Tensor, + ) -> None: + loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False) + metrics = {} + + # start rollout + batch = pl_module.model.pre_processors(batch, in_place=False) + x = batch[ + :, + 0 : pl_module.multi_step, + ..., + pl_module.data_indices.internal_data.input.full, + ] # (bs, multi_step, latlon, nvar) + assert ( + batch.shape[1] >= self.rollout + pl_module.multi_step + ), "Batch length not sufficient for requested rollout length!" + + with torch.no_grad(): + for rollout_step in range(self.rollout): + y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) + y = batch[ + :, + pl_module.multi_step + rollout_step, + ..., + pl_module.data_indices.internal_data.output.full, + ] # target, shape = (bs, latlon, nvar) + # y includes the auxiliary variables, so we must leave those out when computing the loss + loss += pl_module.loss(y_pred, y) + + x = pl_module.advance_input(x, y_pred, batch, rollout_step) + + metrics_next, _ = pl_module.calculate_val_metrics(y_pred, y, rollout_step) + metrics.update(metrics_next) + + # scale loss + loss *= 1.0 / self.rollout + self._log(pl_module, loss, metrics, batch.shape[0]) + + def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, bs: int) -> None: + pl_module.log( + f"val_r{self.rollout}_wmse", + loss, + on_epoch=True, + on_step=True, + prog_bar=False, + logger=pl_module.logger_enabled, + batch_size=bs, + sync_dist=False, + rank_zero_only=True, + ) + for mname, mvalue in metrics.items(): + pl_module.log( + f"val_r{self.rollout}_" + mname, + mvalue, + on_epoch=True, + on_step=False, + prog_bar=False, + logger=pl_module.logger_enabled, + batch_size=bs, + sync_dist=False, + rank_zero_only=True, + ) + + @rank_zero_only + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list, + batch: torch.Tensor, + batch_idx: int, + ) -> None: + del outputs # outputs are not used + if batch_idx % self.frequency == 0: + precision_mapping = { + "16-mixed": torch.float16, + "bf16-mixed": torch.bfloat16, + } + prec = trainer.precision + dtype = precision_mapping.get(prec) + context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() + + with context: + self._eval(pl_module, batch) diff --git a/src/anemoi/training/diagnostics/callbacks/id.py b/src/anemoi/training/diagnostics/callbacks/id.py new file mode 100644 index 00000000..3b5da7a6 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/id.py @@ -0,0 +1,45 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from pytorch_lightning.callbacks import Callback + +if TYPE_CHECKING: + import pytorch_lightning as pl + import torch + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class ParentUUIDCallback(Callback): + """A callback that retrieves the parent UUID for a model, if it is a child model.""" + + def __init__(self, config: OmegaConf) -> None: + """Initialise the ParentUUIDCallback callback. + + Parameters + ---------- + config : OmegaConf + Config object + + """ + super().__init__() + self.config = config + + def on_load_checkpoint( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + checkpoint: torch.nn.Module, + ) -> None: + del trainer # unused + pl_module.hparams["metadata"]["parent_uuid"] = checkpoint["hyper_parameters"]["metadata"]["uuid"] diff --git a/src/anemoi/training/diagnostics/callbacks/learning_rate.py b/src/anemoi/training/diagnostics/callbacks/learning_rate.py new file mode 100644 index 00000000..5e1b84f9 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/learning_rate.py @@ -0,0 +1,26 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from pytorch_lightning.callbacks import LearningRateMonitor as pl_LearningRateMonitor + +LOGGER = logging.getLogger(__name__) + +if TYPE_CHECKING: + from omegaconf import DictConfig + + +class LearningRateMonitor(pl_LearningRateMonitor): + """Provide LearningRateMonitor from pytorch_lightning as a callback.""" + + def __init__(self, config: DictConfig): + super().__init__(logging_interval="step", log_momentum=False) + self.config = config diff --git a/src/anemoi/training/diagnostics/callbacks/plotting.py b/src/anemoi/training/diagnostics/callbacks/plotting.py new file mode 100644 index 00000000..67a90016 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/plotting.py @@ -0,0 +1,737 @@ +# ruff: noqa: ANN001 + +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import copy +import logging +import sys +import time +import traceback +from abc import ABC +from abc import abstractmethod +from concurrent.futures import ThreadPoolExecutor +from contextlib import nullcontext +from functools import cached_property +from pathlib import Path +from typing import TYPE_CHECKING +from typing import Any +from typing import Callable + +import matplotlib.patches as mpatches +import matplotlib.pyplot as plt +import numpy as np +import torch +from pytorch_lightning.callbacks import Callback +from pytorch_lightning.utilities import rank_zero_only + +from anemoi.training.diagnostics.plots import init_plot_settings +from anemoi.training.diagnostics.plots import plot_graph_features +from anemoi.training.diagnostics.plots import plot_histogram +from anemoi.training.diagnostics.plots import plot_loss +from anemoi.training.diagnostics.plots import plot_power_spectrum +from anemoi.training.diagnostics.plots import plot_predicted_multilevel_flat_sample + +if TYPE_CHECKING: + import pytorch_lightning as pl + from omegaconf import OmegaConf + +LOGGER = logging.getLogger(__name__) + + +class ParallelExecutor(ThreadPoolExecutor): + """Wraps parallel execution and provides accurate information about errors. + + Extends ThreadPoolExecutor to preserve the original traceback and line number. + + Reference: https://stackoverflow.com/questions/19309514/getting-original-line- + number-for-exception-in-concurrent-futures/24457608#24457608 + """ + + def submit(self, fn: Any, *args, **kwargs) -> Callable: + """Submits the wrapped function instead of `fn`.""" + return super().submit(self._function_wrapper, fn, *args, **kwargs) + + def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable: + """Wraps `fn` in order to preserve the traceback of any kind of.""" + try: + return fn(*args, **kwargs) + except Exception as exc: + raise sys.exc_info()[0](traceback.format_exc()) from exc + + +class BasePlotCallback(Callback, ABC): + """Factory for creating a callback that plots data to Experiment Logging.""" + + def __init__(self, config: OmegaConf) -> None: + """Initialise the BasePlotCallback abstract base class. + + Parameters + ---------- + config : OmegaConf + Config object + + """ + super().__init__() + self.config = config + self.save_basedir = config.hardware.paths.plots + self.plot_frequency = config.diagnostics.plot.frequency + self.post_processors = None + self.pre_processors = None + self.latlons = None + init_plot_settings() + + self.plot = self._plot + self._executor = None + + if self.config.diagnostics.plot.asynchronous: + self._executor = ParallelExecutor(max_workers=1) + self._error: BaseException | None = None + self.plot = self._async_plot + + @rank_zero_only + def _output_figure( + self, + logger: pl.loggers.base.LightningLoggerBase, + fig: plt.Figure, + epoch: int, + tag: str = "gnn", + exp_log_tag: str = "val_pred_sample", + ) -> None: + """Figure output: save to file and/or display in notebook.""" + if self.save_basedir is not None: + save_path = Path( + self.save_basedir, + "plots", + f"{tag}_epoch{epoch:03d}.png", + ) + + save_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(save_path, dpi=100, bbox_inches="tight") + if self.config.diagnostics.log.wandb.enabled: + import wandb + + logger.experiment.log({exp_log_tag: wandb.Image(fig)}) + + if self.config.diagnostics.log.mlflow.enabled: + run_id = logger.run_id + logger.experiment.log_artifact(run_id, str(save_path)) + + plt.close(fig) # cleanup + + def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: + """Method is called to close the threads.""" + del trainer, pl_module, stage # unused + if self._executor is not None: + self._executor.shutdown(wait=True) + + @abstractmethod + @rank_zero_only + def _plot( + *args: list, + **kwargs: dict, + ) -> None: ... + + @rank_zero_only + def _async_plot( + self, + trainer: pl.Trainer, + *args: list, + **kwargs: dict, + ) -> None: + """To execute the plot function but ensuring we catch any errors.""" + future = self._executor.submit( + self._plot, + trainer, + *args, + **kwargs, + ) + # otherwise the error won't be thrown till the validation epoch is finished + try: + future.result() + except Exception: + LOGGER.exception("Critical error occurred in asynchronous plots.") + sys.exit(1) + + +class LongRolloutPlots(BasePlotCallback): + """Evaluates the model performance over a (longer) rollout window.""" + + def __init__(self, config) -> None: + """Initialize RolloutEval callback. + + Parameters + ---------- + config : dict + Dictionary with configuration settings + """ + super().__init__(config) + + LOGGER.debug( + "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", + config.diagnostics.plot.longrollout.rollout, + config.diagnostics.plot.longrollout.frequency, + ) + self.rollout = config.diagnostics.plot.longrollout.rollout + self.eval_frequency = config.diagnostics.plot.longrollout.frequency + self.sample_idx = self.config.diagnostics.plot.sample_idx + + @rank_zero_only + def _plot( + self, + trainer, + pl_module: pl.LightningModule, + batch: torch.Tensor, + batch_idx, + epoch, + ) -> None: + + start_time = time.time() + + logger = trainer.logger + + # Build dictionary of inidicies and parameters to be plotted + plot_parameters_dict = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in self.config.data.get("diagnostic", []), + ) + for name in self.config.diagnostics.plot.parameters + } + + if self.post_processors is None: + # Copy to be used across all the training cycle + self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + local_rank = pl_module.local_rank + + batch = pl_module.model.pre_processors(batch, in_place=False) + # prepare input tensor for rollout from preprocessed batch + x = batch[ + :, + 0 : pl_module.multi_step, + ..., + pl_module.data_indices.internal_data.input.full, + ] # (bs, multi_step, latlon, nvar) + assert ( + batch.shape[1] >= max(self.rollout) + pl_module.multi_step + ), "Batch length not sufficient for requested rollout length!" + + # prepare input tensor for plotting + input_tensor_0 = batch[ + self.sample_idx, + pl_module.multi_step - 1, + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data_0 = self.post_processors(input_tensor_0).numpy() + + # start rollout + with torch.no_grad(): + for rollout_step in range(max(self.rollout)): + y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) + + x = pl_module.advance_input(x, y_pred, batch, rollout_step) + + if (rollout_step + 1) in self.rollout: + # prepare true output tensor for plotting + input_tensor_rollout_step = batch[ + self.sample_idx, + pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data_rollout_step = self.post_processors(input_tensor_rollout_step).numpy() + + # prepare predicted output tensor for plotting + output_tensor = self.post_processors( + y_pred[self.sample_idx : self.sample_idx + 1, ...].cpu(), + ).numpy() + + fig = plot_predicted_multilevel_flat_sample( + plot_parameters_dict, + self.config.diagnostics.plot.per_sample, + self.latlons, + self.config.diagnostics.plot.get("accumulation_levels_plot", None), + self.config.diagnostics.plot.get("cmap_accumulation", None), + data_0.squeeze(), + data_rollout_step.squeeze(), + output_tensor[0, 0, :, :], # rolloutstep, first member + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_sample_rstep{rollout_step:03d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_rstep{rollout_step:03d}_rank{local_rank:01d}", + ) + LOGGER.info("Time taken to plot samples after longer rollout: %s seconds", int(time.time() - start_time)) + + @rank_zero_only + def on_validation_batch_end( + self, + trainer, + pl_module, + output, + batch: torch.Tensor, + batch_idx: int, + ) -> None: + _ = output + if (batch_idx) % self.plot_frequency == 0 and (trainer.current_epoch + 1) % self.eval_frequency == 0: + precision_mapping = { + "16-mixed": torch.float16, + "bf16-mixed": torch.bfloat16, + } + prec = trainer.precision + dtype = precision_mapping.get(prec) + context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() + + with context: + self._plot(trainer, pl_module, batch, batch_idx, epoch=trainer.current_epoch) + + +class GraphTrainableFeaturesPlot(BasePlotCallback): + """Visualize the trainable features defined at the data and hidden graph nodes. + + TODO: How best to visualize the learned edge embeddings? Offline, perhaps - using code from @Simon's notebook? + """ + + def __init__(self, config: OmegaConf) -> None: + """Initialise the GraphTrainableFeaturesPlot callback. + + Parameters + ---------- + config : OmegaConf + Config object + + """ + super().__init__(config) + self._graph_name_data = config.graph.data + self._graph_name_hidden = config.graph.hidden + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + latlons: np.ndarray, + features: np.ndarray, + epoch: int, + tag: str, + exp_log_tag: str, + ) -> None: + fig = plot_graph_features(latlons, features) + self._output_figure(trainer.logger, fig, epoch=epoch, tag=tag, exp_log_tag=exp_log_tag) + + @rank_zero_only + def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + + model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model + graph = pl_module.graph_data.cpu().detach() + epoch = trainer.current_epoch + + if model.trainable_data is not None: + data_coords = np.rad2deg(graph[self._graph_name_data, "to", self._graph_name_data].ecoords_rad.numpy()) + + self.plot( + trainer, + data_coords, + model.trainable_data.trainable.cpu().detach().numpy(), + epoch=epoch, + tag="trainable_data", + exp_log_tag="trainable_data", + ) + + if model.trainable_hidden is not None: + hidden_coords = np.rad2deg( + graph[self._graph_name_hidden, "to", self._graph_name_hidden].hcoords_rad.numpy(), + ) + + self.plot( + trainer, + hidden_coords, + model.trainable_hidden.trainable.cpu().detach().numpy(), + epoch=epoch, + tag="trainable_hidden", + exp_log_tag="trainable_hidden", + ) + + +class PlotLoss(BasePlotCallback): + """Plots the unsqueezed loss over rollouts.""" + + def __init__(self, config: OmegaConf) -> None: + """Initialise the PlotLoss callback. + + Parameters + ---------- + config : OmegaConf + Object with configuration settings + + """ + super().__init__(config) + self.parameter_names = None + self.parameter_groups = self.config.diagnostics.plot.parameter_groups + if self.parameter_groups is None: + self.parameter_groups = {} + + @cached_property + def sort_and_color_by_parameter_group(self) -> tuple[np.ndarray, np.ndarray, dict, list]: + """Sort parameters by group and prepare colors.""" + + def automatically_determine_group(name: str) -> str: + # first prefix of parameter name is group name + parts = name.split("_") + return parts[0] + + # group parameters by their determined group name for > 15 parameters + if len(self.parameter_names) <= 15: + # for <= 15 parameters, keep the full name of parameters + parameters_to_groups = np.array(self.parameter_names) + sort_by_parameter_group = np.arange(len(self.parameter_names), dtype=int) + else: + parameters_to_groups = np.array( + [ + next( + ( + group_name + for group_name, group_parameters in self.parameter_groups.items() + if name in group_parameters + ), + automatically_determine_group(name), + ) + for name in self.parameter_names + ], + ) + + unique_group_list, group_inverse, group_counts = np.unique( + parameters_to_groups, + return_inverse=True, + return_counts=True, + ) + + # join parameter groups that appear only once and are not given in config-file + unique_group_list = np.array( + [ + unique_group_list[tn] if count > 1 or unique_group_list[tn] in self.parameter_groups else "other" + for tn, count in enumerate(group_counts) + ], + ) + parameters_to_groups = unique_group_list[group_inverse] + unique_group_list, group_inverse = np.unique(parameters_to_groups, return_inverse=True) + + # sort parameters by groups + sort_by_parameter_group = np.argsort(group_inverse, kind="stable") + + # apply new order to parameters + sorted_parameter_names = np.array(self.parameter_names)[sort_by_parameter_group] + parameters_to_groups = parameters_to_groups[sort_by_parameter_group] + unique_group_list, group_inverse, group_counts = np.unique( + parameters_to_groups, + return_inverse=True, + return_counts=True, + ) + + # get a color per group and project to parameter list + cmap = "tab10" if len(unique_group_list) <= 10 else "tab20" + if len(unique_group_list) > 20: + LOGGER.warning("More than 20 groups detected, but colormap has only 20 colors.") + # if all groups have count 1 use black color + bar_color_per_group = ( + np.tile("k", len(group_counts)) + if not np.any(group_counts - 1) + else plt.get_cmap(cmap)(np.linspace(0, 1, len(unique_group_list))) + ) + + # set x-ticks + x_tick_positions = np.cumsum(group_counts) - group_counts / 2 - 0.5 + xticks = dict(zip(unique_group_list, x_tick_positions)) + + legend_patches = [] + for group_idx, group in enumerate(unique_group_list): + text_label = f"{group}: " + string_length = len(text_label) + for ii in np.where(group_inverse == group_idx)[0]: + text_label += sorted_parameter_names[ii] + ", " + string_length += len(sorted_parameter_names[ii]) + 2 + if string_length > 50: + # linebreak after 50 characters + text_label += "\n" + string_length = 0 + legend_patches.append(mpatches.Patch(color=bar_color_per_group[group_idx], label=text_label[:-2])) + + return sort_by_parameter_group, bar_color_per_group[group_inverse], xticks, legend_patches + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.Lightning_module, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + _ = batch_idx + + parameter_names = list(pl_module.data_indices.internal_model.output.name_to_index.keys()) + parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values()) + # reorder parameter_names by position + self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)] + + batch = pl_module.model.pre_processors(batch, in_place=False) + for rollout_step in range(pl_module.rollout): + y_hat = outputs[1][rollout_step] + y_true = batch[ + :, + pl_module.multi_step + rollout_step, + ..., + pl_module.data_indices.internal_data.output.full, + ] + loss = pl_module.loss(y_hat, y_true, squash=False).cpu().numpy() + + sort_by_parameter_group, colors, xticks, legend_patches = self.sort_and_color_by_parameter_group + fig = plot_loss(loss[sort_by_parameter_group], colors, xticks, legend_patches) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"loss_rstep_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", + exp_log_tag=f"loss_sample_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", + ) + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + ) -> None: + if batch_idx % self.plot_frequency == 0: + self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) + + +class PlotSample(BasePlotCallback): + """Plots a post-processed sample: input, target and prediction.""" + + def __init__(self, config: OmegaConf) -> None: + """Initialise the PlotSample callback. + + Parameters + ---------- + config : OmegaConf + Config object + + """ + super().__init__(config) + self.sample_idx = self.config.diagnostics.plot.sample_idx + self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields + LOGGER.info("Using defined accumulation colormap for fields: %s", self.precip_and_related_fields) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.Lightning_module, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + + # Build dictionary of indices and parameters to be plotted + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + plot_parameters_dict = { + pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) + for name in self.config.diagnostics.plot.parameters + } + + # When running in Async mode, it might happen that in the last epoch these tensors + # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA + # but internal ones would be on the cpu), The lines below allow to address this problem + if self.post_processors is None: + # Copy to be used across all the training cycle + self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + local_rank = pl_module.local_rank + + batch = pl_module.model.pre_processors(batch, in_place=False) + input_tensor = batch[ + self.sample_idx, + pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data = self.post_processors(input_tensor).numpy() + + output_tensor = self.post_processors( + torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), + in_place=False, + ).numpy() + + for rollout_step in range(pl_module.rollout): + fig = plot_predicted_multilevel_flat_sample( + plot_parameters_dict, + self.config.diagnostics.plot.per_sample, + self.latlons, + self.config.diagnostics.plot.accumulation_levels_plot, + self.config.diagnostics.plot.cmap_accumulation, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + precip_and_related_fields=self.precip_and_related_fields, + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_sample_rstep{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", + ) + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.Lightning_module, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + ) -> None: + if batch_idx % self.plot_frequency == 0: + self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) + + +class PlotAdditionalMetrics(BasePlotCallback): + """Plots TP related metric comparing target and prediction. + + The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. + + - Power Spectrum + - Histograms + """ + + def __init__(self, config: OmegaConf) -> None: + """Initialise the PlotAdditionalMetrics callback. + + Parameters + ---------- + config : OmegaConf + Config object + + """ + super().__init__(config) + self.sample_idx = self.config.diagnostics.plot.sample_idx + self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields + LOGGER.info("Using precip histogram plotting method for fields: %s.", self.precip_and_related_fields) + + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list, + batch: torch.Tensor, + batch_idx: int, + epoch: int, + ) -> None: + logger = trainer.logger + + # When running in Async mode, it might happen that in the last epoch these tensors + # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA + # but internal ones would be on the cpu), The lines below allow to address this problem + if self.pre_processors is None: + # Copy to be used across all the training cycle + self.pre_processors = copy.deepcopy(pl_module.model.pre_processors).cpu() + if self.post_processors is None: + # Copy to be used across all the training cycle + self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + local_rank = pl_module.local_rank + batch = pl_module.model.pre_processors(batch, in_place=False) + input_tensor = batch[ + self.sample_idx, + pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data = self.post_processors(input_tensor).numpy() + output_tensor = self.post_processors( + torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), + in_place=False, + ).numpy() + + for rollout_step in range(pl_module.rollout): + if self.config.diagnostics.plot.parameters_histogram is not None: + # Build dictionary of inidicies and parameters to be plotted + + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + plot_parameters_dict_histogram = { + pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) + for name in self.config.diagnostics.plot.parameters_histogram + } + + fig = plot_histogram( + plot_parameters_dict_histogram, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + precip_and_related_fields=self.precip_and_related_fields, + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", + ) + + if self.config.diagnostics.plot.parameters_spectrum is not None: + # Build dictionary of inidicies and parameters to be plotted + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + + plot_parameters_dict_spectrum = { + pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) + for name in self.config.diagnostics.plot.parameters_spectrum + } + + fig = plot_power_spectrum( + plot_parameters_dict_spectrum, + self.latlons, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", + ) + + def on_validation_batch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + ) -> None: + if batch_idx % self.plot_frequency == 0: + self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) diff --git a/src/anemoi/training/diagnostics/callbacks/weights.py b/src/anemoi/training/diagnostics/callbacks/weights.py new file mode 100644 index 00000000..9e327e38 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/weights.py @@ -0,0 +1,35 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging as pl_StochasticWeightAveraging + +LOGGER = logging.getLogger(__name__) + +if TYPE_CHECKING: + from omegaconf import DictConfig + + +class StochasticWeightAveraging(pl_StochasticWeightAveraging): + """Provide StochasticWeightAveraging from pytorch_lightning as a callback using config.""" + + def __init__(self, config: DictConfig): + super().__init__( + swa_lrs=config.training.swa.lr, + swa_epoch_start=min( + int(0.75 * config.training.max_epochs), + config.training.max_epochs - 1, + ), + annealing_epochs=max(int(0.25 * config.training.max_epochs), 1), + annealing_strategy="cos", + device=None, + ) + self.config = config From 29a84773432ebb688f3f3f23d795ed83384385ae Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Tue, 24 Sep 2024 09:49:33 +0000 Subject: [PATCH 02/40] Update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f3caa1c9..e734e9a1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ Keep it human-readable, your future self will thank you! ### Changed - Updated configuration examples in documentation and corrected links - [#46](https://github.com/ecmwf/anemoi-training/pull/46) +- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) ## [0.1.0 - Anemoi training - First release](https://github.com/ecmwf/anemoi-training/releases/tag/0.1.0) - 2024-08-16 From 15824bea069b5b08ea1eb148f52c8212ab906562 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Tue, 24 Sep 2024 09:56:39 +0000 Subject: [PATCH 03/40] Fix TypeError --- .../diagnostics/callbacks/__init__.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index b4c73ad4..60cb03a2 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -40,16 +40,19 @@ } # Callbacks to add according to flags in the config -CONFIG_ENABLED_CALLBACKS: dict[list[str] | str, list[type[Callback]] | type[Callback]] = { - ["diagnostics.log.wandb.enabled", "diagnostics.log.mlflow.enabled"]: LearningRateMonitor, - "diagnostics.eval.enabled": RolloutEval, - "diagnostics.plot.enabled": [ - PlotLoss, - PlotSample, - ], - "training.swa.enabled": StochasticWeightAveraging, - "diagnostics.plot.learned_features": GraphTrainableFeaturesPlot, -} +CONFIG_ENABLED_CALLBACKS: list[tuple[list[str] | str, list[type[Callback]] | type[Callback]]] = [ + (["diagnostics.log.wandb.enabled", "diagnostics.log.mlflow.enabled"], LearningRateMonitor), + ("diagnostics.eval.enabled", RolloutEval), + ( + "diagnostics.plot.enabled", + [ + PlotLoss, + PlotSample, + ], + ), + ("training.swa.enabled", StochasticWeightAveraging), + ("diagnostics.plot.learned_features", GraphTrainableFeaturesPlot), +] def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | None: @@ -111,7 +114,7 @@ def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: """Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS""" callbacks = [] - for enable_key, callback_list in CONFIG_ENABLED_CALLBACKS.items(): + for enable_key, callback_list in CONFIG_ENABLED_CALLBACKS: if isinstance(enable_key, list): if not any(config.get(key, False) for key in enable_key): continue From 4077bf430f322df4a84c1b104f042ae47425f25b Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 25 Sep 2024 08:13:31 +0000 Subject: [PATCH 04/40] Move to hydra.instantiate --- .../diagnostics/callbacks/__init__.py | 68 ++++++++++++------- .../training/diagnostics/callbacks/weights.py | 35 ---------- 2 files changed, 42 insertions(+), 61 deletions(-) delete mode 100644 src/anemoi/training/diagnostics/callbacks/weights.py diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 60cb03a2..9624159f 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -11,6 +11,8 @@ from datetime import timedelta from typing import TYPE_CHECKING +from hydra.utils import instantiate + from anemoi.training.diagnostics.callbacks.checkpointing import AnemoiCheckpoint from anemoi.training.diagnostics.callbacks.evaluation import RolloutEval from anemoi.training.diagnostics.callbacks.id import ParentUUIDCallback @@ -20,7 +22,6 @@ from anemoi.training.diagnostics.callbacks.plotting import PlotAdditionalMetrics from anemoi.training.diagnostics.callbacks.plotting import PlotLoss from anemoi.training.diagnostics.callbacks.plotting import PlotSample -from anemoi.training.diagnostics.callbacks.weights import StochasticWeightAveraging if TYPE_CHECKING: from omegaconf import DictConfig @@ -28,16 +29,6 @@ LOGGER = logging.getLogger(__name__) -# Dictionary of available callbacks -CALLBACK_DICT: dict[str, type[Callback]] = { - "RolloutEval": RolloutEval, - "LongRolloutPlots": LongRolloutPlots, - "GraphTrainableFeaturesPlot": GraphTrainableFeaturesPlot, - "PlotLoss": PlotLoss, - "PlotSample": PlotSample, - "PlotAdditionalMetrics": PlotAdditionalMetrics, - "ParentUUIDCallback": ParentUUIDCallback, -} # Callbacks to add according to flags in the config CONFIG_ENABLED_CALLBACKS: list[tuple[list[str] | str, list[type[Callback]] | type[Callback]]] = [ @@ -50,7 +41,6 @@ PlotSample, ], ), - ("training.swa.enabled", StochasticWeightAveraging), ("diagnostics.plot.learned_features", GraphTrainableFeaturesPlot), ] @@ -111,7 +101,11 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: - """Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS""" + """Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS + + Provides backwards compatibility + + """ callbacks = [] for enable_key, callback_list in CONFIG_ENABLED_CALLBACKS: @@ -126,13 +120,34 @@ def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: else: callbacks.append(callback_list(config)) + if config.diagnostics.plot.enabled: + if (config.diagnostics.plot.parameters_histogram or config.diagnostics.plot.parameters_spectrum) is not None: + callbacks.extend([PlotAdditionalMetrics(config)]) + if config.diagnostics.plot.get("longrollout") and config.diagnostics.plot.longrollout.enabled: + callbacks.extend([LongRolloutPlots(config)]) + return callbacks def get_callbacks(config: DictConfig) -> list: # noqa: C901 """Setup callbacks for PyTorch Lightning trainer. - Set config.diagnostics.callbacks to a list of callback names to enable them. + Set `config.diagnostics.callbacks` to a list of callback configurations + in hydra form. + + E.g.: + ``` + callbacks: + swa: _target_: pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging + swa_lr: 1e-4 + swa_epoch_start: 123 + annealing_epochs: 5 + annealing_strategy: cos + device: null + ``` + + Set `config.diagnostics.plot_callbacks` to a list of plotting callback configurations + will only be added if `config.diagnostics.plot.enabled` is set to True. Parameters ---------- @@ -147,26 +162,27 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901 """ trainer_callbacks: list[Callback] = [] + + # Get Checkpoint callback checkpoint_callback = _get_checkpoint_callback(config) if checkpoint_callback is not None: trainer_callbacks.extend(checkpoint_callback) - requested_callbacks = config.diagnostics.get("callbacks", []) + # Base callbacks + for callback in config.diagnostics.get("callbacks", []): + # Instantiate new callbacks + trainer_callbacks.append(instantiate(callback)) - for callback in requested_callbacks: - if callback in CALLBACK_DICT: - trainer_callbacks.append(CALLBACK_DICT[callback](config)) - else: - LOGGER.error(f"Callback {callback} not found in CALLBACK_DICT\n{list(CALLBACK_DICT.keys())}") + # Plotting callbacks + if config.diagnostics.plot.enabled: + for callback in config.diagnostics.get("plot_callbacks", []): + # Instantiate new callbacks + trainer_callbacks.append(instantiate(callback)) + # Extend with backward compatible callbacks trainer_callbacks.extend(_get_config_enabled_callbacks(config)) - if config.diagnostics.plot.enabled: - if (config.diagnostics.plot.parameters_histogram or config.diagnostics.plot.parameters_spectrum) is not None: - trainer_callbacks.extend([PlotAdditionalMetrics(config)]) - if config.diagnostics.plot.get("longrollout") and config.diagnostics.plot.longrollout.enabled: - trainer_callbacks.extend([LongRolloutPlots(config)]) - + # Parent UUID callback trainer_callbacks.append(ParentUUIDCallback(config)) return trainer_callbacks diff --git a/src/anemoi/training/diagnostics/callbacks/weights.py b/src/anemoi/training/diagnostics/callbacks/weights.py deleted file mode 100644 index 9e327e38..00000000 --- a/src/anemoi/training/diagnostics/callbacks/weights.py +++ /dev/null @@ -1,35 +0,0 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging as pl_StochasticWeightAveraging - -LOGGER = logging.getLogger(__name__) - -if TYPE_CHECKING: - from omegaconf import DictConfig - - -class StochasticWeightAveraging(pl_StochasticWeightAveraging): - """Provide StochasticWeightAveraging from pytorch_lightning as a callback using config.""" - - def __init__(self, config: DictConfig): - super().__init__( - swa_lrs=config.training.swa.lr, - swa_epoch_start=min( - int(0.75 * config.training.max_epochs), - config.training.max_epochs - 1, - ), - annealing_epochs=max(int(0.25 * config.training.max_epochs), 1), - annealing_strategy="cos", - device=None, - ) - self.config = config From fe37c02385259f0d527ec3c3d5bf09ce95f764f3 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 25 Sep 2024 09:14:23 +0000 Subject: [PATCH 05/40] Add __all__ --- .../diagnostics/callbacks/__init__.py | 23 +++++++++---------- 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 9624159f..8fc53ba7 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -13,15 +13,11 @@ from hydra.utils import instantiate +from anemoi.training.diagnostics.callbacks import plotting from anemoi.training.diagnostics.callbacks.checkpointing import AnemoiCheckpoint from anemoi.training.diagnostics.callbacks.evaluation import RolloutEval from anemoi.training.diagnostics.callbacks.id import ParentUUIDCallback from anemoi.training.diagnostics.callbacks.learning_rate import LearningRateMonitor -from anemoi.training.diagnostics.callbacks.plotting import GraphTrainableFeaturesPlot -from anemoi.training.diagnostics.callbacks.plotting import LongRolloutPlots -from anemoi.training.diagnostics.callbacks.plotting import PlotAdditionalMetrics -from anemoi.training.diagnostics.callbacks.plotting import PlotLoss -from anemoi.training.diagnostics.callbacks.plotting import PlotSample if TYPE_CHECKING: from omegaconf import DictConfig @@ -37,11 +33,11 @@ ( "diagnostics.plot.enabled", [ - PlotLoss, - PlotSample, + plotting.PlotLoss, + plotting.PlotSample, ], ), - ("diagnostics.plot.learned_features", GraphTrainableFeaturesPlot), + ("diagnostics.plot.learned_features", plotting.GraphTrainableFeaturesPlot), ] @@ -122,14 +118,14 @@ def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: if config.diagnostics.plot.enabled: if (config.diagnostics.plot.parameters_histogram or config.diagnostics.plot.parameters_spectrum) is not None: - callbacks.extend([PlotAdditionalMetrics(config)]) + callbacks.extend([plotting.PlotAdditionalMetrics(config)]) if config.diagnostics.plot.get("longrollout") and config.diagnostics.plot.longrollout.enabled: - callbacks.extend([LongRolloutPlots(config)]) + callbacks.extend([plotting.LongRolloutPlots(config)]) return callbacks -def get_callbacks(config: DictConfig) -> list: # noqa: C901 +def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901 """Setup callbacks for PyTorch Lightning trainer. Set `config.diagnostics.callbacks` to a list of callback configurations @@ -156,7 +152,7 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901 Returns ------- - List + List[Callback] A list of PyTorch Lightning callbacks """ @@ -186,3 +182,6 @@ def get_callbacks(config: DictConfig) -> list: # noqa: C901 trainer_callbacks.append(ParentUUIDCallback(config)) return trainer_callbacks + + +__all__ = ["get_callbacks", "RolloutEval", "LearningRateMonitor", "plotting"] From 2d8275ca765a355694a0192918566e7c90b6f176 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 25 Sep 2024 09:40:46 +0000 Subject: [PATCH 06/40] Add to base config --- src/anemoi/training/config/diagnostics/eval_rollout.yaml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/anemoi/training/config/diagnostics/eval_rollout.yaml b/src/anemoi/training/config/diagnostics/eval_rollout.yaml index 50e9a647..db7dd7c9 100644 --- a/src/anemoi/training/config/diagnostics/eval_rollout.yaml +++ b/src/anemoi/training/config/diagnostics/eval_rollout.yaml @@ -70,6 +70,12 @@ checkpoint: save_frequency: null # Does not scale with rollout num_models_saved: 0 +callbacks: + # Add callbacks here + +plot_callbacks: + # Add extra plot callbacks here + log: wandb: enabled: False From 230eb0e3932f5a0e266f97887c705988b6880414 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 25 Sep 2024 09:40:59 +0000 Subject: [PATCH 07/40] Fix nested list --- src/anemoi/training/diagnostics/callbacks/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 8fc53ba7..69fc6ea1 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -86,7 +86,7 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non mode="max", **checkpoint_settings, ), - ], + ] ) else: LOGGER.debug("Not setting up a checkpoint callback with %s", save_key) From 5547b20b42f7af47b0826cc87eb3553a8f754c62 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Thu, 26 Sep 2024 14:31:46 +0000 Subject: [PATCH 08/40] Fix nested get issue --- .../training/diagnostics/callbacks/__init__.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 69fc6ea1..648d67d0 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -100,15 +100,22 @@ def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: """Get callbacks that are enabled in the config as according to CONFIG_ENABLED_CALLBACKS Provides backwards compatibility - """ callbacks = [] + def nestedget(conf: DictConfig, key, default): + keys = key.split(".") + for k in keys: + conf = conf.get(k, default) + if not isinstance(conf, dict): + break + return conf + for enable_key, callback_list in CONFIG_ENABLED_CALLBACKS: if isinstance(enable_key, list): - if not any(config.get(key, False) for key in enable_key): + if not any(nestedget(config, key, False) for key in enable_key): continue - elif not config.get(enable_key, False): + elif not nestedget(config, enable_key, False): continue if isinstance(callback_list, list): @@ -118,9 +125,9 @@ def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: if config.diagnostics.plot.enabled: if (config.diagnostics.plot.parameters_histogram or config.diagnostics.plot.parameters_spectrum) is not None: - callbacks.extend([plotting.PlotAdditionalMetrics(config)]) + callbacks.append(plotting.PlotAdditionalMetrics(config)) if config.diagnostics.plot.get("longrollout") and config.diagnostics.plot.longrollout.enabled: - callbacks.extend([plotting.LongRolloutPlots(config)]) + callbacks.append(plotting.LongRolloutPlots(config)) return callbacks From 1d80cfb6d1418bc685b0ef60463f6f14e363775e Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 27 Sep 2024 12:10:22 +0000 Subject: [PATCH 09/40] Fix type checking --- src/anemoi/training/diagnostics/callbacks/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 648d67d0..eaba0b75 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING from hydra.utils import instantiate +from omegaconf import DictConfig from anemoi.training.diagnostics.callbacks import plotting from anemoi.training.diagnostics.callbacks.checkpointing import AnemoiCheckpoint @@ -20,7 +21,6 @@ from anemoi.training.diagnostics.callbacks.learning_rate import LearningRateMonitor if TYPE_CHECKING: - from omegaconf import DictConfig from pytorch_lightning.callbacks import Callback LOGGER = logging.getLogger(__name__) @@ -107,7 +107,7 @@ def nestedget(conf: DictConfig, key, default): keys = key.split(".") for k in keys: conf = conf.get(k, default) - if not isinstance(conf, dict): + if not isinstance(conf, DictConfig): break return conf From 96ab74c512278a0f1aa9585ad60fc1ac8c90f594 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 1 Oct 2024 14:37:39 +0000 Subject: [PATCH 10/40] feat: edge plot in callbacks --- .../diagnostics/callbacks/__init__.py | 5 +- .../diagnostics/callbacks/plotting.py | 67 ++++--- src/anemoi/training/diagnostics/plots.py | 170 ++++++++++++++++-- 3 files changed, 186 insertions(+), 56 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index eaba0b75..c634785b 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -37,7 +37,10 @@ plotting.PlotSample, ], ), - ("diagnostics.plot.learned_features", plotting.GraphTrainableFeaturesPlot), + ("diagnostics.plot.learned_features", [ + plotting.GraphNodeTrainableFeaturesPlot, + plotting.GraphEdgeTrainableFeaturesPlot, + ]), ] diff --git a/src/anemoi/training/diagnostics/callbacks/plotting.py b/src/anemoi/training/diagnostics/callbacks/plotting.py index 67a90016..9a562ca2 100644 --- a/src/anemoi/training/diagnostics/callbacks/plotting.py +++ b/src/anemoi/training/diagnostics/callbacks/plotting.py @@ -32,7 +32,7 @@ from pytorch_lightning.utilities import rank_zero_only from anemoi.training.diagnostics.plots import init_plot_settings -from anemoi.training.diagnostics.plots import plot_graph_features +from anemoi.training.diagnostics.plots import plot_graph_node_features, plot_graph_edge_features from anemoi.training.diagnostics.plots import plot_histogram from anemoi.training.diagnostics.plots import plot_loss from anemoi.training.diagnostics.plots import plot_power_spectrum @@ -298,11 +298,8 @@ def on_validation_batch_end( self._plot(trainer, pl_module, batch, batch_idx, epoch=trainer.current_epoch) -class GraphTrainableFeaturesPlot(BasePlotCallback): - """Visualize the trainable features defined at the data and hidden graph nodes. - - TODO: How best to visualize the learned edge embeddings? Offline, perhaps - using code from @Simon's notebook? - """ +class GraphNodeTrainableFeaturesPlot(BasePlotCallback): + """Visualize the node trainable features defined.""" def __init__(self, config: OmegaConf) -> None: """Initialise the GraphTrainableFeaturesPlot callback. @@ -316,52 +313,50 @@ def __init__(self, config: OmegaConf) -> None: super().__init__(config) self._graph_name_data = config.graph.data self._graph_name_hidden = config.graph.hidden + self.epoch_freq = 5 @rank_zero_only def _plot( self, trainer: pl.Trainer, - latlons: np.ndarray, - features: np.ndarray, - epoch: int, + model: torch.nn.Module, tag: str, exp_log_tag: str, ) -> None: - fig = plot_graph_features(latlons, features) - self._output_figure(trainer.logger, fig, epoch=epoch, tag=tag, exp_log_tag=exp_log_tag) + fig = plot_graph_node_features(model, [self._graph_name_data, self._graph_name_hidden]) + self._output_figure(trainer.logger, fig, epoch=trainer.current_epoch, tag=tag, exp_log_tag=exp_log_tag) @rank_zero_only def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model - graph = pl_module.graph_data.cpu().detach() - epoch = trainer.current_epoch - if model.trainable_data is not None: - data_coords = np.rad2deg(graph[self._graph_name_data, "to", self._graph_name_data].ecoords_rad.numpy()) + self.plot(trainer, model, tag="node_trainable_params", exp_log_tag="node_trainable_params") - self.plot( - trainer, - data_coords, - model.trainable_data.trainable.cpu().detach().numpy(), - epoch=epoch, - tag="trainable_data", - exp_log_tag="trainable_data", - ) - if model.trainable_hidden is not None: - hidden_coords = np.rad2deg( - graph[self._graph_name_hidden, "to", self._graph_name_hidden].hcoords_rad.numpy(), - ) +class GraphEdgeTrainableFeaturesPlot(BasePlotCallback): + """Trainable edge features plot. - self.plot( - trainer, - hidden_coords, - model.trainable_hidden.trainable.cpu().detach().numpy(), - epoch=epoch, - tag="trainable_hidden", - exp_log_tag="trainable_hidden", - ) + Visualize the trainable features defined at the edges between meshes. + """ + + def __init__(self, config): + super().__init__(config) + self.epoch_freq = 5 + + def _plot( + self, + trainer: pl.Trainer, + model: torch.nn.Module, + tag: str, + exp_log_tag: str, + ) -> None: + fig = plot_graph_edge_features(model) + self._output_figure(trainer.logger, fig, epoch=trainer.current_epoch, tag=tag, exp_log_tag=exp_log_tag) + + def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: + if trainer.current_epoch % self.epoch_freq == 0 and pl_module.global_rank == 0: + model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model + self._plot(trainer, model, tag="edge_trainable_params", exp_log_tag="edge_trainable_params") class PlotLoss(BasePlotCallback): diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index b2004cf4..8a037858 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -13,15 +13,18 @@ import matplotlib.pyplot as plt import matplotlib.style as mplstyle import numpy as np +import torch from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap from matplotlib.colors import TwoSlopeNorm +from matplotlib.collections import LineCollection from pyshtools.expand import SHGLQ from pyshtools.expand import SHExpandGLQ from scipy.interpolate import griddata from anemoi.training.diagnostics.maps import Coastlines from anemoi.training.diagnostics.maps import EquirectangularProjection +from anemoi.models.layers.mapper import GraphEdgeMixin if TYPE_CHECKING: from matplotlib.figure import Figure @@ -625,36 +628,165 @@ def scatter_plot( fig.colorbar(psc, ax=ax) -def plot_graph_features( - latlons: np.ndarray, - features: np.ndarray, -) -> Figure: - """Plot trainable graph features. +def edge_plot( + fig, + ax, + src_coords: np.ndarray, + dst_coords: np.ndarray, + data: np.ndarray, + cmap: str = "coolwarm", + title: str | None = None, +) -> None: + """Lat-lon line plot. Parameters ---------- - latlons : np.ndarray - Latitudes and longitudes - features : np.ndarray - Trainable Features + fig : _type_ + Figure object handle + ax : _type_ + Axis object handle + src_coords : np.ndarray of shape (num_edges, 2) + Source latitudes and longitudes. + dst_coords : np.ndarray of shape (num_edges, 2) + Destination latitudes and longitudes. + data : np.ndarray of shape (num_edges, 1) + Data to plot + cmap : str, optional + Colormap string from matplotlib, by default "viridis". + title : str, optional + Title for plot, by default None + """ + edge_lines = np.stack([src_coords, dst_coords], axis=1) + lc = LineCollection(edge_lines, cmap=cmap, linewidths=1) + lc.set_array(data) + + psc = ax.add_collection(lc) + + xmin, xmax = edge_lines[:, 0, 0].min(), edge_lines[:, 0, 0].max() + ymin, ymax = edge_lines[:, 1, 1].min(), edge_lines[:, 1, 1].max() + ax.set_xlim((xmin - 0.1, xmax + 0.1)) + ax.set_ylim((ymin - 0.1, ymax + 0.1)) + + continents.plot_continents(ax) + + if title is not None: + ax.set_title(title) + + ax.set_aspect("auto", adjustable=None) + _hide_axes_ticks(ax) + fig.colorbar(psc, ax=ax) + + +def sincos_to_latlon(sincos_coords: torch.Tensor) -> torch.Tensor: + """Get the lat/lon coordinates from the model. + + Parameters + ---------- + sincos_coords: torch.Tensor of shape (N, 4) + Sine and cosine of latitude and longitude coordinates. + + Returns + ------- + torch.Tensor of shape (N, 2) + Lat/lon coordinates. + """ + ndim = sincos_coords.shape[1] // 2 + sin_y, cos_y = sincos_coords[:, :ndim], sincos_coords[:, ndim:] + return torch.atan2(sin_y, cos_y) + + +def plot_graph_node_features(model, nodes_name: list[str]) -> Figure: + """Plot trainable graph node features. + + Parameters + ---------- + model: + Model object + force_global_view : bool, optional + Show the entire globe, by default True. Returns ------- Figure Figure object handle - """ - nplots = features.shape[-1] - figsize = (nplots * 4, 3) - fig, ax = plt.subplots(1, nplots, figsize=figsize) + nrows = len(nodes_name) + ncols = min([getattr(model, f"trainable_{m}").trainable.shape[1] for m in nodes_name]) + figsize = (ncols * 4, nrows * 3) + fig, ax = plt.subplots(nrows, ncols, figsize=figsize) - lat, lon = latlons[:, 0], latlons[:, 1] + for row, mesh in enumerate(nodes_name): + sincos_coords = getattr(model, f"latlons_{mesh}") + latlons = sincos_to_latlon(sincos_coords).cpu().numpy() + features = getattr(model, f"trainable_{mesh}").trainable.cpu().detach().numpy() - pc = EquirectangularProjection() - pc_lon, pc_lat = pc(lon, lat) + lat, lon = latlons[:, 0], latlons[:, 1] - for i in range(nplots): - ax_ = ax[i] if nplots > 1 else ax - scatter_plot(fig, ax_, lon=pc_lon, lat=pc_lat, data=features[..., i]) + for i in range(ncols): + ax_ = ax[row, i] if ncols > 1 else ax[row] + scatter_plot( + fig, + ax_, + lon=lon, + lat=lat, + data=features[..., i], + title=f"{mesh} trainable feature #{i + 1}", + ) return fig + + +def plot_graph_edge_features(model, q_extreme_limit: float = 0.05) -> Figure: + """Plot trainable graph edge features. + + Parameters + ---------- + model: AneomiModelEncProcDec + Model object + q_extreme_limit : float, optional + Plot top & bottom quantile of edges trainable values, by default 0.05 (5%). + + Returns + ------- + Figure + Figure object handle + """ + trainable_modules = { + (model._graph_name_data, model._graph_name_hidden): model.encoder, + (model._graph_name_hidden, model._graph_name_data): model.decoder, + } + + if isinstance(model.processor, GraphEdgeMixin): + trainable_modules[(model._graph_name_hidden, model._graph_name_hidden)] = model.processor + + ncols = min([module.trainable.trainable.shape[1] for module in trainable_modules.values()]) + nrows = len(trainable_modules) + figsize = (ncols * 4, nrows * 3) + fig, ax = plt.subplots(nrows, ncols, figsize=figsize) + + for row, ((src, dst), graph_mapper) in enumerate(trainable_modules.items()): + src_coords = sincos_to_latlon(getattr(model, f"latlons_{src}")).cpu().numpy() + dst_coords = sincos_to_latlon(getattr(model, f"latlons_{dst}")).cpu().numpy() + edge_index = graph_mapper.edge_index_base.cpu().numpy() + edge_features = graph_mapper.trainable.trainable.cpu().detach().numpy() + + for i in range(ncols): + ax_ = ax[row, i] if ncols > 1 else ax[row] + feature = edge_features[..., i] + + # Get mask of feature values over top and bottom percentiles + top_perc = np.quantile(feature, 1 - q_extreme_limit) + bottom_perc = np.quantile(feature, q_extreme_limit) + + mask = (feature >= top_perc) | (feature <= bottom_perc) + + edge_plot( + fig, + ax_, + src_coords[edge_index[0, mask]][:, ::-1], + dst_coords[edge_index[1, mask]][:, ::-1], + feature[mask], + title=f"{src} -> {dst} trainable feature #{i + 1}", + ) + + return fig \ No newline at end of file From 4aeb1a5e52de87da4ebed93aa15ac7eb902021d4 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Tue, 1 Oct 2024 14:45:08 +0000 Subject: [PATCH 11/40] feat: set default extra callbacks --- src/anemoi/training/config/diagnostics/eval_rollout.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/config/diagnostics/eval_rollout.yaml b/src/anemoi/training/config/diagnostics/eval_rollout.yaml index db7dd7c9..8d2ac2fb 100644 --- a/src/anemoi/training/config/diagnostics/eval_rollout.yaml +++ b/src/anemoi/training/config/diagnostics/eval_rollout.yaml @@ -70,10 +70,10 @@ checkpoint: save_frequency: null # Does not scale with rollout num_models_saved: 0 -callbacks: +callbacks: [] # Add callbacks here -plot_callbacks: +plot_callbacks: [] # Add extra plot callbacks here log: From 816b3af1e9d96b38dcdb4dd5e39256f5294297c0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Oct 2024 10:54:27 +0000 Subject: [PATCH 12/40] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/training/diagnostics/callbacks/__init__.py | 11 +++++++---- src/anemoi/training/diagnostics/callbacks/plotting.py | 3 ++- src/anemoi/training/diagnostics/plots.py | 8 ++++---- 3 files changed, 13 insertions(+), 9 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index c634785b..e1e85c81 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -37,10 +37,13 @@ plotting.PlotSample, ], ), - ("diagnostics.plot.learned_features", [ - plotting.GraphNodeTrainableFeaturesPlot, - plotting.GraphEdgeTrainableFeaturesPlot, - ]), + ( + "diagnostics.plot.learned_features", + [ + plotting.GraphNodeTrainableFeaturesPlot, + plotting.GraphEdgeTrainableFeaturesPlot, + ], + ), ] diff --git a/src/anemoi/training/diagnostics/callbacks/plotting.py b/src/anemoi/training/diagnostics/callbacks/plotting.py index 9a562ca2..b1e982a1 100644 --- a/src/anemoi/training/diagnostics/callbacks/plotting.py +++ b/src/anemoi/training/diagnostics/callbacks/plotting.py @@ -32,7 +32,8 @@ from pytorch_lightning.utilities import rank_zero_only from anemoi.training.diagnostics.plots import init_plot_settings -from anemoi.training.diagnostics.plots import plot_graph_node_features, plot_graph_edge_features +from anemoi.training.diagnostics.plots import plot_graph_edge_features +from anemoi.training.diagnostics.plots import plot_graph_node_features from anemoi.training.diagnostics.plots import plot_histogram from anemoi.training.diagnostics.plots import plot_loss from anemoi.training.diagnostics.plots import plot_power_spectrum diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 8a037858..6d2d84f9 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -14,17 +14,17 @@ import matplotlib.style as mplstyle import numpy as np import torch +from anemoi.models.layers.mapper import GraphEdgeMixin +from matplotlib.collections import LineCollection from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap from matplotlib.colors import TwoSlopeNorm -from matplotlib.collections import LineCollection from pyshtools.expand import SHGLQ from pyshtools.expand import SHExpandGLQ from scipy.interpolate import griddata from anemoi.training.diagnostics.maps import Coastlines from anemoi.training.diagnostics.maps import EquirectangularProjection -from anemoi.models.layers.mapper import GraphEdgeMixin if TYPE_CHECKING: from matplotlib.figure import Figure @@ -757,7 +757,7 @@ def plot_graph_edge_features(model, q_extreme_limit: float = 0.05) -> Figure: } if isinstance(model.processor, GraphEdgeMixin): - trainable_modules[(model._graph_name_hidden, model._graph_name_hidden)] = model.processor + trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor ncols = min([module.trainable.trainable.shape[1] for module in trainable_modules.values()]) nrows = len(trainable_modules) @@ -789,4 +789,4 @@ def plot_graph_edge_features(model, q_extreme_limit: float = 0.05) -> Figure: title=f"{src} -> {dst} trainable feature #{i + 1}", ) - return fig \ No newline at end of file + return fig From 644038fcc46582dfa75bdd3537fff83457d84ee7 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 2 Oct 2024 11:05:37 +0000 Subject: [PATCH 13/40] fix: typing & refactoring --- .../training/diagnostics/callbacks/plotting.py | 4 +--- src/anemoi/training/diagnostics/plots.py | 16 +++++++--------- 2 files changed, 8 insertions(+), 12 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/plotting.py b/src/anemoi/training/diagnostics/callbacks/plotting.py index b1e982a1..c086f9d8 100644 --- a/src/anemoi/training/diagnostics/callbacks/plotting.py +++ b/src/anemoi/training/diagnostics/callbacks/plotting.py @@ -312,8 +312,6 @@ def __init__(self, config: OmegaConf) -> None: """ super().__init__(config) - self._graph_name_data = config.graph.data - self._graph_name_hidden = config.graph.hidden self.epoch_freq = 5 @rank_zero_only @@ -324,7 +322,7 @@ def _plot( tag: str, exp_log_tag: str, ) -> None: - fig = plot_graph_node_features(model, [self._graph_name_data, self._graph_name_hidden]) + fig = plot_graph_node_features(model) self._output_figure(trainer.logger, fig, epoch=trainer.current_epoch, tag=tag, exp_log_tag=exp_log_tag) @rank_zero_only diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 6d2d84f9..62cd3d53 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -587,7 +587,7 @@ def scatter_plot( Parameters ---------- - fig : _type_ + fig : Figure Figure object handle ax : matplotlib.axes Axis object handle @@ -629,8 +629,8 @@ def scatter_plot( def edge_plot( - fig, - ax, + fig: Figure, + ax: plt.Axes, src_coords: np.ndarray, dst_coords: np.ndarray, data: np.ndarray, @@ -695,22 +695,20 @@ def sincos_to_latlon(sincos_coords: torch.Tensor) -> torch.Tensor: return torch.atan2(sin_y, cos_y) -def plot_graph_node_features(model, nodes_name: list[str]) -> Figure: +def plot_graph_node_features(model: torch.nn.Module) -> Figure: """Plot trainable graph node features. Parameters ---------- - model: + model: AneomiModelEncProcDec Model object - force_global_view : bool, optional - Show the entire globe, by default True. Returns ------- Figure Figure object handle """ - nrows = len(nodes_name) + nrows = len(nodes_name := model._graph_data.node_types) ncols = min([getattr(model, f"trainable_{m}").trainable.shape[1] for m in nodes_name]) figsize = (ncols * 4, nrows * 3) fig, ax = plt.subplots(nrows, ncols, figsize=figsize) @@ -736,7 +734,7 @@ def plot_graph_node_features(model, nodes_name: list[str]) -> Figure: return fig -def plot_graph_edge_features(model, q_extreme_limit: float = 0.05) -> Figure: +def plot_graph_edge_features(model: torch.nn.Module, q_extreme_limit: float = 0.05) -> Figure: """Plot trainable graph edge features. Parameters From 8356cd47204952f013a4fb40b63ae64aed337cd6 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 2 Oct 2024 11:10:52 +0000 Subject: [PATCH 14/40] fix: remove list comprehension --- src/anemoi/training/diagnostics/plots.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index 62cd3d53..e667ec0e 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -709,7 +709,7 @@ def plot_graph_node_features(model: torch.nn.Module) -> Figure: Figure object handle """ nrows = len(nodes_name := model._graph_data.node_types) - ncols = min([getattr(model, f"trainable_{m}").trainable.shape[1] for m in nodes_name]) + ncols = min(getattr(model, f"trainable_{m}").trainable.shape[1] for m in nodes_name) figsize = (ncols * 4, nrows * 3) fig, ax = plt.subplots(nrows, ncols, figsize=figsize) @@ -757,7 +757,7 @@ def plot_graph_edge_features(model: torch.nn.Module, q_extreme_limit: float = 0. if isinstance(model.processor, GraphEdgeMixin): trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor - ncols = min([module.trainable.trainable.shape[1] for module in trainable_modules.values()]) + ncols = min(module.trainable.trainable.shape[1] for module in trainable_modules.values()) nrows = len(trainable_modules) figsize = (ncols * 4, nrows * 3) fig, ax = plt.subplots(nrows, ncols, figsize=figsize) From 930e4d285c0355870060a76f924b71afbbbf11f6 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 2 Oct 2024 12:29:58 +0000 Subject: [PATCH 15/40] Refactor according to PR - Prefill config with callbacks - Warn on deprecations for old config - Expand config enabled - Add back SWA - Fix logging callback - Add flag to disable checkpointing - Add testing --- .../config/diagnostics/eval_rollout.yaml | 25 ++-- .../diagnostics/callbacks/__init__.py | 128 ++++++++++++------ .../{checkpointing.py => checkpoint.py} | 0 .../diagnostics/callbacks/evaluation.py | 11 +- .../diagnostics/callbacks/learning_rate.py | 4 +- .../callbacks/{plotting.py => plot.py} | 0 .../training/diagnostics/callbacks/swa.py | 63 +++++++++ tests/diagnostics/test_callbacks.py | 81 +++++++++++ 8 files changed, 250 insertions(+), 62 deletions(-) rename src/anemoi/training/diagnostics/callbacks/{checkpointing.py => checkpoint.py} (100%) rename src/anemoi/training/diagnostics/callbacks/{plotting.py => plot.py} (100%) create mode 100644 src/anemoi/training/diagnostics/callbacks/swa.py create mode 100644 tests/diagnostics/test_callbacks.py diff --git a/src/anemoi/training/config/diagnostics/eval_rollout.yaml b/src/anemoi/training/config/diagnostics/eval_rollout.yaml index db7dd7c9..ebddc4f1 100644 --- a/src/anemoi/training/config/diagnostics/eval_rollout.yaml +++ b/src/anemoi/training/config/diagnostics/eval_rollout.yaml @@ -1,11 +1,20 @@ --- -eval: - enabled: False - # use this to evaluate the model over longer rollouts, every so many validation batches - rollout: 12 - frequency: 20 + +callbacks: + # Add callbacks here + - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval + rollout: 12 + frequency: 20 + plot: enabled: True + + callbacks: + # Add extra plot callbacks here + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + asynchronous: True frequency: 750 sample_idx: 0 @@ -43,7 +52,6 @@ plot: parameter_groups: moisture: [tp, cp, tcw] sfc_wind: [10u, 10v] - learned_features: False longrollout: enabled: False rollout: [60] @@ -58,6 +66,7 @@ debug: profiler: False checkpoint: + enabled: True every_n_minutes: save_frequency: 30 # Approximate, as this is checked at the end of training steps num_models_saved: 3 # If set to k, saves the 'last' k model weights in the training. @@ -70,11 +79,7 @@ checkpoint: save_frequency: null # Does not scale with rollout num_models_saved: 0 -callbacks: - # Add callbacks here -plot_callbacks: - # Add extra plot callbacks here log: wandb: diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index eaba0b75..2a0e80d2 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -8,17 +8,21 @@ from __future__ import annotations import logging +import warnings from datetime import timedelta from typing import TYPE_CHECKING +from typing import Callable +from typing import Iterable from hydra.utils import instantiate from omegaconf import DictConfig -from anemoi.training.diagnostics.callbacks import plotting -from anemoi.training.diagnostics.callbacks.checkpointing import AnemoiCheckpoint +from anemoi.training.diagnostics.callbacks import plot +from anemoi.training.diagnostics.callbacks.checkpoint import AnemoiCheckpoint from anemoi.training.diagnostics.callbacks.evaluation import RolloutEval from anemoi.training.diagnostics.callbacks.id import ParentUUIDCallback from anemoi.training.diagnostics.callbacks.learning_rate import LearningRateMonitor +from anemoi.training.diagnostics.callbacks.swa import StochasticWeightAveraging if TYPE_CHECKING: from pytorch_lightning.callbacks import Callback @@ -26,23 +30,58 @@ LOGGER = logging.getLogger(__name__) +def nestedget(conf: DictConfig, key, default): + """ + Get a nested key from a DictConfig object + + E.g. + >>> nestedget(config, "diagnostics.log.wandb.enabled", False) + """ + keys = key.split(".") + for k in keys: + conf = conf.get(k, default) + if not isinstance(conf, (dict, DictConfig)): + break + return conf + + # Callbacks to add according to flags in the config -CONFIG_ENABLED_CALLBACKS: list[tuple[list[str] | str, list[type[Callback]] | type[Callback]]] = [ - (["diagnostics.log.wandb.enabled", "diagnostics.log.mlflow.enabled"], LearningRateMonitor), - ("diagnostics.eval.enabled", RolloutEval), +# Can be function to check status from config +CONFIG_ENABLED_CALLBACKS: list[tuple[list[str] | str | Callable[[DictConfig], bool], type[Callback]]] = [ + ("training.swa.enabled", StochasticWeightAveraging), + ( + lambda config: nestedget(config, "diagnostics.log.wandb.enabled", False) + or nestedget(config, "diagnostics.log.mflow.enabled", False), + LearningRateMonitor, + ), + ( + lambda config: config.diagnostics.plot.enabled + and ( + nestedget(config, "diagnostics.plot.parameters_histogram", None) + or nestedget(config, "diagnostics.plot.parameters_spectrum", None) + ) + is not None, + plot.PlotAdditionalMetrics, + ), +] + +DEPRECATED_CONFIGS: list[tuple[list[str] | str, type[Callback]]] = [ ( - "diagnostics.plot.enabled", - [ - plotting.PlotLoss, - plotting.PlotSample, - ], + "diagnostics.eval.enabled", + lambda config: RolloutEval( + config, rollout=config.diagnostics.eval.rollout, frequency=config.diagnostics.eval.frequency + ), ), - ("diagnostics.plot.learned_features", plotting.GraphTrainableFeaturesPlot), + ("diagnostics.plot.learned_features", plot.GraphTrainableFeaturesPlot), + (["diagnostics.plot.enabled", "diagnostics.plot.longrollout.enabled"], plot.LongRolloutPlots), ] def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | None: """Get checkpointing callback""" + if not config.diagnostics.checkpoint.get("enabled", True): + return [] + checkpoint_settings = { "dirpath": config.hardware.paths.checkpoints, "verbose": False, @@ -103,32 +142,28 @@ def _get_config_enabled_callbacks(config: DictConfig) -> list[Callback]: """ callbacks = [] - def nestedget(conf: DictConfig, key, default): - keys = key.split(".") - for k in keys: - conf = conf.get(k, default) - if not isinstance(conf, DictConfig): - break - return conf + def check_key(config, key: str | Iterable[str] | Callable[[DictConfig], bool]): + """Check key in config.""" + if isinstance(key, Callable): + return key(config) + elif isinstance(key, str): + return nestedget(config, key, False) + elif isinstance(key, Iterable): + return all(nestedget(config, k, False) for k in key) + return nestedget(config, key, False) + + for deprecated_key, callback_list in DEPRECATED_CONFIGS: + if check_key(config, deprecated_key): + warnings.warn( + f"Deprecated config {deprecated_key} found. Please update your config file to use the new callback initialisation method.", + DeprecationWarning, + ) + callbacks.append(callback_list(config)) for enable_key, callback_list in CONFIG_ENABLED_CALLBACKS: - if isinstance(enable_key, list): - if not any(nestedget(config, key, False) for key in enable_key): - continue - elif not nestedget(config, enable_key, False): - continue - - if isinstance(callback_list, list): - callbacks.extend(map(lambda x: x(config), callback_list)) - else: + if check_key(config, enable_key): callbacks.append(callback_list(config)) - if config.diagnostics.plot.enabled: - if (config.diagnostics.plot.parameters_histogram or config.diagnostics.plot.parameters_spectrum) is not None: - callbacks.append(plotting.PlotAdditionalMetrics(config)) - if config.diagnostics.plot.get("longrollout") and config.diagnostics.plot.longrollout.enabled: - callbacks.append(plotting.LongRolloutPlots(config)) - return callbacks @@ -141,17 +176,20 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901 E.g.: ``` callbacks: - swa: _target_: pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging - swa_lr: 1e-4 - swa_epoch_start: 123 - annealing_epochs: 5 - annealing_strategy: cos - device: null + - _target_: anemoi.training.diagnostics.callbacks.RolloutEval + rollout: 1 + frequency: 12 ``` - Set `config.diagnostics.plot_callbacks` to a list of plotting callback configurations + Set `config.diagnostics.plot.callbacks` to a list of plot callback configurations will only be added if `config.diagnostics.plot.enabled` is set to True. + A callback must take a `DictConfig` in its `__init__` method as the first argument, + which will be the complete configuration object. + + Some callbacks are added by default, depending on the configuration. + See CONFIG_ENABLED_CALLBACKS for more information. + Parameters ---------- config : DictConfig @@ -174,15 +212,15 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901 # Base callbacks for callback in config.diagnostics.get("callbacks", []): # Instantiate new callbacks - trainer_callbacks.append(instantiate(callback)) + trainer_callbacks.append(instantiate(callback, config)) # Plotting callbacks if config.diagnostics.plot.enabled: - for callback in config.diagnostics.get("plot_callbacks", []): + for callback in config.diagnostics.plot.get("callbacks", []): # Instantiate new callbacks - trainer_callbacks.append(instantiate(callback)) + trainer_callbacks.append(instantiate(callback, config)) - # Extend with backward compatible callbacks + # Extend with config enabled callbacks trainer_callbacks.extend(_get_config_enabled_callbacks(config)) # Parent UUID callback @@ -191,4 +229,4 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901 return trainer_callbacks -__all__ = ["get_callbacks", "RolloutEval", "LearningRateMonitor", "plotting"] +__all__ = ["get_callbacks"] diff --git a/src/anemoi/training/diagnostics/callbacks/checkpointing.py b/src/anemoi/training/diagnostics/callbacks/checkpoint.py similarity index 100% rename from src/anemoi/training/diagnostics/callbacks/checkpointing.py rename to src/anemoi/training/diagnostics/callbacks/checkpoint.py diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index cbc6b854..a91de1f5 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -25,7 +25,7 @@ class RolloutEval(Callback): """Evaluates the model performance over a (longer) rollout window.""" - def __init__(self, config: OmegaConf) -> None: + def __init__(self, config: OmegaConf, rollout: int, frequency: int) -> None: """Initialize RolloutEval callback. Parameters @@ -35,14 +35,15 @@ def __init__(self, config: OmegaConf) -> None: """ super().__init__() + self.config = config LOGGER.debug( "Setting up RolloutEval callback with rollout = %d, frequency = %d ...", - config.diagnostics.eval.rollout, - config.diagnostics.eval.frequency, + rollout, + frequency, ) - self.rollout = config.diagnostics.eval.rollout - self.frequency = config.diagnostics.eval.frequency + self.rollout = rollout + self.frequency = frequency def _eval( self, diff --git a/src/anemoi/training/diagnostics/callbacks/learning_rate.py b/src/anemoi/training/diagnostics/callbacks/learning_rate.py index 5e1b84f9..9839a120 100644 --- a/src/anemoi/training/diagnostics/callbacks/learning_rate.py +++ b/src/anemoi/training/diagnostics/callbacks/learning_rate.py @@ -21,6 +21,6 @@ class LearningRateMonitor(pl_LearningRateMonitor): """Provide LearningRateMonitor from pytorch_lightning as a callback.""" - def __init__(self, config: DictConfig): - super().__init__(logging_interval="step", log_momentum=False) + def __init__(self, config: DictConfig, logging_interval: str = "step", log_momentum: bool = False) -> None: + super().__init__(logging_interval=logging_interval, log_momentum=log_momentum) self.config = config diff --git a/src/anemoi/training/diagnostics/callbacks/plotting.py b/src/anemoi/training/diagnostics/callbacks/plot.py similarity index 100% rename from src/anemoi/training/diagnostics/callbacks/plotting.py rename to src/anemoi/training/diagnostics/callbacks/plot.py diff --git a/src/anemoi/training/diagnostics/callbacks/swa.py b/src/anemoi/training/diagnostics/callbacks/swa.py new file mode 100644 index 00000000..8e3b26f5 --- /dev/null +++ b/src/anemoi/training/diagnostics/callbacks/swa.py @@ -0,0 +1,63 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging as pl_StochasticWeightAveraging + +if TYPE_CHECKING: + from omegaconf import OmegaConf + + +LOGGER = logging.getLogger(__name__) + + +class StochasticWeightAveraging(pl_StochasticWeightAveraging): + """Provide StochasticWeightAveraging from pytorch_lightning as a callback.""" + + def __init__( + self, + config: OmegaConf, + swa_lrs: int | None = None, + swa_epoch_start: int | None = None, + annealing_epoch: int | None = None, + annealing_strategy: str | None = None, + device: str | None = None, + **kwargs, + ) -> None: + """ + Stochastic Weight Averaging Callback. + + Parameters + ---------- + config : OmegaConf + Full configuration object + swa_lrs : int, optional + Stochastic Weight Averaging Learning Rate, by default None + swa_epoch_start : int, optional + Epoch start, by default 0.75 * config.training.max_epochs + annealing_epoch : int, optional + Annealing Epoch, by default 0.25 * config.training.max_epochs + annealing_strategy : str, optional + Annealing Strategy, by default 'cos' + device : str, optional + Device to use, by default None + """ + kwargs["swa_lrs"] = swa_lrs or config.training.swa.lr + kwargs["swa_epoch_start"] = swa_epoch_start or min( + int(0.75 * config.training.max_epochs), + config.training.max_epochs - 1, + ) + kwargs["annealing_epoch"] = annealing_epoch or max(int(0.25 * config.training.max_epochs), 1) + kwargs["annealing_strategy"] = annealing_strategy or "cos" + kwargs["device"] = device + + super().__init__(**kwargs) + self.config = config diff --git a/tests/diagnostics/test_callbacks.py b/tests/diagnostics/test_callbacks.py new file mode 100644 index 00000000..cf927742 --- /dev/null +++ b/tests/diagnostics/test_callbacks.py @@ -0,0 +1,81 @@ +# ruff: noqa: ANN001, ANN201 +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import omegaconf +import pytest +import yaml + +from anemoi.training.diagnostics.callbacks import get_callbacks + +default_config = """ +diagnostics: + callbacks: [] + + plot: + enabled: False + callbacks: [] + + debug: + # this will detect and trace back NaNs / Infs etc. but will slow down training + anomaly_detection: False + + profiler: False + + checkpoint: + enabled: False + + log: {} +""" + + +def test_no_extra_callbacks_set(): + # No extra callbacks set + config = omegaconf.OmegaConf.create(yaml.safe_load(default_config)) + callbacks = get_callbacks(config) + assert len(callbacks) == 1 # ParentUUIDCallback + + +def test_deprecation_warning(): + # Test deprecation warning + with pytest.warns(DeprecationWarning): + config = omegaconf.OmegaConf.create(default_config) + config.diagnostics.update({"eval": {"enabled": True, "rollout": 1, "frequency": 1}}) + callbacks = get_callbacks(config) + assert len(callbacks) == 2 + + +def test_add_config_enabled_callback(): + # Add logging callback + config = omegaconf.OmegaConf.create(default_config) + config.diagnostics.callbacks.append({"log": {"mlflow": {"enabled": True}}}) + callbacks = get_callbacks(config) + assert len(callbacks) == 2 + + +def test_add_callback(): + config = omegaconf.OmegaConf.create(default_config) + config.diagnostics.callbacks.append({"_target_": "anemoi.training.diagnostics.callbacks.id.ParentUUIDCallback"}) + callbacks = get_callbacks(config) + assert len(callbacks) == 2 + + +def test_add_plotting_callback(monkeypatch): + # Add plotting callback + import anemoi.training.diagnostics.callbacks.plot as plot + + class PlotLoss: + def __init__(self, config: omegaconf.DictConfig): + pass + + monkeypatch.setattr(plot, "PlotLoss", PlotLoss) + + config = omegaconf.OmegaConf.create(default_config) + config.diagnostics.plot.enabled = True + config.diagnostics.plot.callbacks = [{"_target_": "anemoi.training.diagnostics.callbacks.plot.PlotLoss"}] + callbacks = get_callbacks(config) + assert len(callbacks) == 2 From 52ea91f7a03a8816b7db6340916e9c9e2a4715da Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 4 Oct 2024 08:30:28 +0000 Subject: [PATCH 16/40] Update deprecation warning --- .../diagnostics/callbacks/__init__.py | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 2a0e80d2..d81867ca 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -69,11 +69,16 @@ def nestedget(conf: DictConfig, key, default): ( "diagnostics.eval.enabled", lambda config: RolloutEval( - config, rollout=config.diagnostics.eval.rollout, frequency=config.diagnostics.eval.frequency + config, + rollout=config.diagnostics.eval.rollout, + frequency=config.diagnostics.eval.frequency, ), ), ("diagnostics.plot.learned_features", plot.GraphTrainableFeaturesPlot), - (["diagnostics.plot.enabled", "diagnostics.plot.longrollout.enabled"], plot.LongRolloutPlots), + ( + ["diagnostics.plot.enabled", "diagnostics.plot.longrollout.enabled"], + plot.LongRolloutPlots, + ), ] @@ -103,10 +108,18 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non frequency = timedelta(minutes=frequency_dict["save_frequency"]) else: target = key - ckpt_frequency_save_dict[target] = (config.hardware.files.checkpoint[key], frequency, n_saved) + ckpt_frequency_save_dict[target] = ( + config.hardware.files.checkpoint[key], + frequency, + n_saved, + ) if not config.diagnostics.profiler: - for save_key, (name, save_frequency, save_n_models) in ckpt_frequency_save_dict.items(): + for save_key, ( + name, + save_frequency, + save_n_models, + ) in ckpt_frequency_save_dict.items(): if save_frequency is not None: LOGGER.debug("Checkpoint callback at %s = %s ...", save_key, save_frequency) return ( @@ -154,8 +167,10 @@ def check_key(config, key: str | Iterable[str] | Callable[[DictConfig], bool]): for deprecated_key, callback_list in DEPRECATED_CONFIGS: if check_key(config, deprecated_key): + suggested_change = f""" - _target_: {callback_list.__module__}.{callback_list.__name__}""" warnings.warn( - f"Deprecated config {deprecated_key} found. Please update your config file to use the new callback initialisation method.", + f"Deprecated config {deprecated_key} found. Please update your config file to use the new callback initialisation method." + + f"This will be removed in a future release.\n Add the following to the `callbacks` list:\n{suggested_change}", DeprecationWarning, ) callbacks.append(callback_list(config)) From bb8b9bbf09a12f1ed843965eb0642b5268ec7beb Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Thu, 10 Oct 2024 15:02:22 +0000 Subject: [PATCH 17/40] Refactor: Remove backwards compatability, - Split plots - Rename, lr to optimiser - Refactor plotting callbacks to be more init config --- .../diagnostics/callbacks/pretraining.yaml | 4 + .../config/diagnostics/callbacks/rollout.yaml | 4 + .../config/diagnostics/eval_rollout.yaml | 67 +-- .../config/diagnostics/plot/detailed.yaml | 65 ++ .../config/diagnostics/plot/none.yaml | 1 + .../config/diagnostics/plot/simple.yaml | 38 ++ .../diagnostics/callbacks/__init__.py | 54 +- .../diagnostics/callbacks/learning_rate.py | 26 - .../callbacks/{swa.py => optimiser.py} | 22 +- .../training/diagnostics/callbacks/plot.py | 564 +++++++++++++----- .../callbacks/{id.py => provenance.py} | 0 11 files changed, 547 insertions(+), 298 deletions(-) create mode 100644 src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml create mode 100644 src/anemoi/training/config/diagnostics/callbacks/rollout.yaml create mode 100644 src/anemoi/training/config/diagnostics/plot/detailed.yaml create mode 100644 src/anemoi/training/config/diagnostics/plot/none.yaml create mode 100644 src/anemoi/training/config/diagnostics/plot/simple.yaml delete mode 100644 src/anemoi/training/diagnostics/callbacks/learning_rate.py rename src/anemoi/training/diagnostics/callbacks/{swa.py => optimiser.py} (78%) rename src/anemoi/training/diagnostics/callbacks/{id.py => provenance.py} (100%) diff --git a/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml b/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml new file mode 100644 index 00000000..f96b9bed --- /dev/null +++ b/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml @@ -0,0 +1,4 @@ +# Add callbacks here +# - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval +# rollout: ${diagnostics.eval.rollout} +# frequency: ${diagnostics.eval.frequency} diff --git a/src/anemoi/training/config/diagnostics/callbacks/rollout.yaml b/src/anemoi/training/config/diagnostics/callbacks/rollout.yaml new file mode 100644 index 00000000..7a949b53 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/callbacks/rollout.yaml @@ -0,0 +1,4 @@ +# Add callbacks here +- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval + rollout: ${diagnostics.eval.rollout} + frequency: ${diagnostics.eval.frequency} diff --git a/src/anemoi/training/config/diagnostics/eval_rollout.yaml b/src/anemoi/training/config/diagnostics/eval_rollout.yaml index 78b12fb4..804df9d7 100644 --- a/src/anemoi/training/config/diagnostics/eval_rollout.yaml +++ b/src/anemoi/training/config/diagnostics/eval_rollout.yaml @@ -1,62 +1,13 @@ --- +defaults: + - plot: detailed + - callbacks: pretraining -callbacks: - # Add callbacks here - - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval - rollout: 12 - frequency: 20 - -plot: - enabled: True - - callbacks: - # Add extra plot callbacks here - - _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot - - _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot - - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss - - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample - - asynchronous: True - frequency: 750 - sample_idx: 0 - per_sample: 6 - parameters: - - z_500 - - t_850 - - u_850 - - v_850 - - 2t - - 10u - - 10v - - sp - - tp - - cp - #Defining the accumulation levels for precipitation related fields and the colormap - accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm - cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] - precip_and_related_fields: [tp, cp] - # Histogram and Spectrum plots - parameters_histogram: - - z_500 - - tp - - 2t - - 10u - - 10v - parameters_spectrum: - - z_500 - - tp - - 2t - - 10u - - 10v - # group parameters by categories when visualizing contributions to the loss - # one-parameter groups are possible to highlight individual parameters - parameter_groups: - moisture: [tp, cp, tcw] - sfc_wind: [10u, 10v] - longrollout: - enabled: False - rollout: [60] - frequency: 20 # every X epochs +eval: + enabled: False + # use this to evaluate the model over longer rollouts, every so many validation batches + rollout: 12 + frequency: 20 debug: # this will detect and trace back NaNs / Infs etc. but will slow down training @@ -66,8 +17,8 @@ debug: # remember to also activate the tensorboard logger (below) profiler: False +enable_checkpointing: True checkpoint: - enabled: True every_n_minutes: save_frequency: 30 # Approximate, as this is checked at the end of training steps num_models_saved: 3 # If set to k, saves the 'last' k model weights in the training. diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml new file mode 100644 index 00000000..c36c8e2a --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -0,0 +1,65 @@ +enabled: True # Enable plotting callbacks +asynchronous: True # Whether to plot asynchronously +frequency: # Frequency of the plotting + batch: 750 + epoch: 5 + +# Parameters to plot +parameters: +- z_500 +- t_850 +- u_850 +- v_850 +- 2t +- 10u +- 10v +- sp +- tp +- cp + +# Precipitation and related fields +precip_and_related_fields: [tp, cp] + +callbacks: + # Add plot callbacks here + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot + epoch_frequency: 5 + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: 0 + per_sample : 6 + parameters: ${diagnostics.plot.parameters} + #Defining the accumulation levels for precipitation related fields and the colormap + accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm + cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum + # batch_frequency: 100 # Override for batch frequency + sample_idx: 0 + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram + sample_idx: 0 + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v + # - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots + # rollout: [60] + # batch_frequency: 10 + # epoch_frequency: 20 diff --git a/src/anemoi/training/config/diagnostics/plot/none.yaml b/src/anemoi/training/config/diagnostics/plot/none.yaml new file mode 100644 index 00000000..bc114417 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/none.yaml @@ -0,0 +1 @@ +enabled: false diff --git a/src/anemoi/training/config/diagnostics/plot/simple.yaml b/src/anemoi/training/config/diagnostics/plot/simple.yaml new file mode 100644 index 00000000..a2bff503 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/simple.yaml @@ -0,0 +1,38 @@ +enabled: True # Enable plotting callbacks +asynchronous: True # Whether to plot asynchronously +frequency: # Frequency of the plotting + batch: 750 + epoch: 10 + +# Parameters to plot +parameters: +- z_500 +- t_850 +- u_850 +- v_850 +- 2t +- 10u +- 10v +- sp +- tp +- cp + +# Precipitation and related fields +precip_and_related_fields: [tp, cp] + +callbacks: + # Add plot callbacks here + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: 0 + per_sample : 6 + parameters: ${diagnostics.plot.parameters} + #Defining the accumulation levels for precipitation related fields and the colormap + accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm + cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 37e0e2dc..2079f047 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -20,9 +20,9 @@ from anemoi.training.diagnostics.callbacks import plot from anemoi.training.diagnostics.callbacks.checkpoint import AnemoiCheckpoint from anemoi.training.diagnostics.callbacks.evaluation import RolloutEval -from anemoi.training.diagnostics.callbacks.id import ParentUUIDCallback -from anemoi.training.diagnostics.callbacks.learning_rate import LearningRateMonitor -from anemoi.training.diagnostics.callbacks.swa import StochasticWeightAveraging +from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor +from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging +from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback if TYPE_CHECKING: from pytorch_lightning.callbacks import Callback @@ -54,43 +54,12 @@ def nestedget(conf: DictConfig, key, default): or nestedget(config, "diagnostics.log.mflow.enabled", False), LearningRateMonitor, ), - ( - lambda config: config.diagnostics.plot.enabled - and ( - nestedget(config, "diagnostics.plot.parameters_histogram", None) - or nestedget(config, "diagnostics.plot.parameters_spectrum", None) - ) - is not None, - plot.PlotAdditionalMetrics, - ), -] - -DEPRECATED_CONFIGS: list[tuple[list[str] | str, type[Callback]]] = [ - ( - "diagnostics.eval.enabled", - lambda config: RolloutEval( - config, - rollout=config.diagnostics.eval.rollout, - frequency=config.diagnostics.eval.frequency, - ), - ), - ( - "diagnostics.plot.learned_features", - [ - plot.GraphNodeTrainableFeaturesPlot, - plot.GraphEdgeTrainableFeaturesPlot, - ], - ), - ( - ["diagnostics.plot.enabled", "diagnostics.plot.longrollout.enabled"], - plot.LongRolloutPlots, - ), ] def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | None: """Get checkpointing callback""" - if not config.diagnostics.checkpoint.get("enabled", True): + if not config.diagnostics.get("enable_checkpointing", True): return [] checkpoint_settings = { @@ -106,6 +75,7 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint] | Non } ckpt_frequency_save_dict = {} + for key, frequency_dict in config.diagnostics.checkpoint.items(): frequency = frequency_dict["save_frequency"] n_saved = frequency_dict["num_models_saved"] @@ -171,16 +141,6 @@ def check_key(config, key: str | Iterable[str] | Callable[[DictConfig], bool]): return all(nestedget(config, k, False) for k in key) return nestedget(config, key, False) - for deprecated_key, callback_list in DEPRECATED_CONFIGS: - if check_key(config, deprecated_key): - suggested_change = f""" - _target_: {callback_list.__module__}.{callback_list.__name__}""" - warnings.warn( - f"Deprecated config {deprecated_key} found. Please update your config file to use the new callback initialisation method." - + f"This will be removed in a future release.\n Add the following to the `callbacks` list:\n{suggested_change}", - DeprecationWarning, - ) - callbacks.append(callback_list(config)) - for enable_key, callback_list in CONFIG_ENABLED_CALLBACKS: if check_key(config, enable_key): callbacks.append(callback_list(config)) @@ -231,13 +191,13 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901 trainer_callbacks.extend(checkpoint_callback) # Base callbacks - for callback in config.diagnostics.get("callbacks", []): + for callback in config.diagnostics.get("callbacks", None) or []: # Instantiate new callbacks trainer_callbacks.append(instantiate(callback, config)) # Plotting callbacks if config.diagnostics.plot.enabled: - for callback in config.diagnostics.plot.get("callbacks", []): + for callback in config.diagnostics.plot.get("callbacks", None) or []: # Instantiate new callbacks trainer_callbacks.append(instantiate(callback, config)) diff --git a/src/anemoi/training/diagnostics/callbacks/learning_rate.py b/src/anemoi/training/diagnostics/callbacks/learning_rate.py deleted file mode 100644 index 9839a120..00000000 --- a/src/anemoi/training/diagnostics/callbacks/learning_rate.py +++ /dev/null @@ -1,26 +0,0 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. - -from __future__ import annotations - -import logging -from typing import TYPE_CHECKING - -from pytorch_lightning.callbacks import LearningRateMonitor as pl_LearningRateMonitor - -LOGGER = logging.getLogger(__name__) - -if TYPE_CHECKING: - from omegaconf import DictConfig - - -class LearningRateMonitor(pl_LearningRateMonitor): - """Provide LearningRateMonitor from pytorch_lightning as a callback.""" - - def __init__(self, config: DictConfig, logging_interval: str = "step", log_momentum: bool = False) -> None: - super().__init__(logging_interval=logging_interval, log_momentum=log_momentum) - self.config = config diff --git a/src/anemoi/training/diagnostics/callbacks/swa.py b/src/anemoi/training/diagnostics/callbacks/optimiser.py similarity index 78% rename from src/anemoi/training/diagnostics/callbacks/swa.py rename to src/anemoi/training/diagnostics/callbacks/optimiser.py index 8e3b26f5..07de3078 100644 --- a/src/anemoi/training/diagnostics/callbacks/swa.py +++ b/src/anemoi/training/diagnostics/callbacks/optimiser.py @@ -10,13 +10,26 @@ import logging from typing import TYPE_CHECKING +from pytorch_lightning.callbacks import LearningRateMonitor as pl_LearningRateMonitor from pytorch_lightning.callbacks.stochastic_weight_avg import StochasticWeightAveraging as pl_StochasticWeightAveraging +LOGGER = logging.getLogger(__name__) + if TYPE_CHECKING: - from omegaconf import OmegaConf + from omegaconf import DictConfig -LOGGER = logging.getLogger(__name__) +class LearningRateMonitor(pl_LearningRateMonitor): + """Provide LearningRateMonitor from pytorch_lightning as a callback.""" + + def __init__( + self, + config: DictConfig, + logging_interval: str = "step", + log_momentum: bool = False, + ) -> None: + super().__init__(logging_interval=logging_interval, log_momentum=log_momentum) + self.config = config class StochasticWeightAveraging(pl_StochasticWeightAveraging): @@ -24,7 +37,7 @@ class StochasticWeightAveraging(pl_StochasticWeightAveraging): def __init__( self, - config: OmegaConf, + config: DictConfig, swa_lrs: int | None = None, swa_epoch_start: int | None = None, annealing_epoch: int | None = None, @@ -32,8 +45,7 @@ def __init__( device: str | None = None, **kwargs, ) -> None: - """ - Stochastic Weight Averaging Callback. + """Stochastic Weight Averaging Callback. Parameters ---------- diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index c086f9d8..39aa387b 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -82,7 +82,7 @@ def __init__(self, config: OmegaConf) -> None: super().__init__() self.config = config self.save_basedir = config.hardware.paths.plots - self.plot_frequency = config.diagnostics.plot.frequency + self.post_processors = None self.pre_processors = None self.latlons = None @@ -135,9 +135,13 @@ def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: st @abstractmethod @rank_zero_only def _plot( - *args: list, - **kwargs: dict, - ) -> None: ... + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + epoch: int, + **kwargs, + ) -> None: + """Plotting function to be implemented by subclasses.""" @rank_zero_only def _async_plot( @@ -161,37 +165,138 @@ def _async_plot( sys.exit(1) +class BasePerBatchPlotCallback(BasePlotCallback): + """Base Callback for plotting at the end of each batch.""" + + def __init__(self, config: OmegaConf, batch_frequency: int | None = None): + """Initialise the BasePerBatchPlotCallback. + + Parameters + ---------- + config : OmegaConf + Config object + batch_frequency : int, optional + Batch Frequency to plot at, by default None + If not given, uses default from config at `diagnostics.plot.frequency.batch` + + """ + super().__init__(config) + self.batch_frequency = batch_frequency or self.config.diagnostics.plot.frequency.batch + + @abstractmethod + @rank_zero_only + def _plot( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + outputs: list[torch.Tensor], + batch: torch.Tensor, + batch_idx: int, + epoch: int, + **kwargs, + ) -> None: + """Plotting function to be implemented by subclasses.""" + + @rank_zero_only + def on_validation_batch_end( + self, + trainer, + pl_module, + output, + batch: torch.Tensor, + batch_idx: int, + **kwargs, + ) -> None: + if batch_idx % self.batch_frequency == 0: + self.plot( + trainer, + pl_module, + output, + batch, + batch_idx, + epoch=trainer.current_epoch, + **kwargs, + ) + + +class BasePerEpochPlotCallback(BasePlotCallback): + """Base Callback for plotting at the end of each epoch.""" + + def __init__(self, config: OmegaConf, epoch_frequency: int | None = None): + """Initialise the BasePerEpochPlotCallback. + + Parameters + ---------- + config : OmegaConf + Config object + epoch_frequency : int, optional + Epoch frequency to plot at, by default None + If not given, uses default from config at `diagnostics.plot.frequency.epoch` + """ + super().__init__(config) + self.epoch_frequency = epoch_frequency or self.config.diagnostics.plot.frequency.epoch + + @rank_zero_only + def on_validation_epoch_end( + self, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + **kwargs, + ) -> None: + if trainer.current_epoch % self.epoch_frequency == 0: + self.plot(trainer, pl_module, epoch=trainer.current_epoch, **kwargs) + + class LongRolloutPlots(BasePlotCallback): """Evaluates the model performance over a (longer) rollout window.""" - def __init__(self, config) -> None: - """Initialize RolloutEval callback. + def __init__( + self, + config: OmegaConf, + rollout: list[int], + batch_frequency: int, + epoch_frequency: int = 1, + sample_idx: int = 0, + ) -> None: + """Initialise RolloutEval callback. Parameters ---------- - config : dict - Dictionary with configuration settings + config : OmegaConf + Config object + rollout : list[int] + Rollout steps to plot at + batch_frequency : int + Batch frequency to plot at + epoch_frequency : int, optional + Epoch frequency to plot at, by default 1 + sample_idx : int, optional + Sample to plot, by default 0 """ super().__init__(config) + self.epoch_frequency = epoch_frequency + self.batch_frequency = batch_frequency + LOGGER.debug( "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", - config.diagnostics.plot.longrollout.rollout, - config.diagnostics.plot.longrollout.frequency, + rollout, + epoch_frequency, ) - self.rollout = config.diagnostics.plot.longrollout.rollout - self.eval_frequency = config.diagnostics.plot.longrollout.frequency - self.sample_idx = self.config.diagnostics.plot.sample_idx + self.rollout = rollout + self.sample_idx = sample_idx @rank_zero_only def _plot( self, trainer, pl_module: pl.LightningModule, + output: list[torch.Tensor], batch: torch.Tensor, batch_idx, epoch, ) -> None: + _ = output start_time = time.time() @@ -274,7 +379,10 @@ def _plot( tag=f"gnn_pred_val_sample_rstep{rollout_step:03d}_batch{batch_idx:04d}_rank0", exp_log_tag=f"val_pred_sample_rstep{rollout_step:03d}_rank{local_rank:01d}", ) - LOGGER.info("Time taken to plot samples after longer rollout: %s seconds", int(time.time() - start_time)) + LOGGER.info( + "Time taken to plot samples after longer rollout: %s seconds", + int(time.time() - start_time), + ) @rank_zero_only def on_validation_batch_end( @@ -285,8 +393,7 @@ def on_validation_batch_end( batch: torch.Tensor, batch_idx: int, ) -> None: - _ = output - if (batch_idx) % self.plot_frequency == 0 and (trainer.current_epoch + 1) % self.eval_frequency == 0: + if (batch_idx) % self.batch_frequency == 0 and (trainer.current_epoch + 1) % self.epoch_frequency == 0: precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, @@ -296,88 +403,121 @@ def on_validation_batch_end( context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() with context: - self._plot(trainer, pl_module, batch, batch_idx, epoch=trainer.current_epoch) + self.plot(trainer, pl_module, output, batch, batch_idx) -class GraphNodeTrainableFeaturesPlot(BasePlotCallback): +class GraphNodeTrainableFeaturesPlot(BasePerEpochPlotCallback): """Visualize the node trainable features defined.""" - def __init__(self, config: OmegaConf) -> None: + def __init__(self, config: OmegaConf, epoch_frequency: int | None = None) -> None: """Initialise the GraphTrainableFeaturesPlot callback. Parameters ---------- config : OmegaConf Config object - + epoch_frequency: int | None, optional + Override for frequency to plot at, by default None """ - super().__init__(config) - self.epoch_freq = 5 + super().__init__(config, epoch_frequency=epoch_frequency) @rank_zero_only def _plot( self, trainer: pl.Trainer, - model: torch.nn.Module, - tag: str, - exp_log_tag: str, + pl_module: pl.LightningModule, + epoch: int, ) -> None: + _ = epoch + model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model + fig = plot_graph_node_features(model) - self._output_figure(trainer.logger, fig, epoch=trainer.current_epoch, tag=tag, exp_log_tag=exp_log_tag) - @rank_zero_only - def on_validation_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model + tag = "node_trainable_params" + exp_log_tag = "node_trainable_params" - self.plot(trainer, model, tag="node_trainable_params", exp_log_tag="node_trainable_params") + self._output_figure( + trainer.logger, + fig, + epoch=trainer.current_epoch, + tag=tag, + exp_log_tag=exp_log_tag, + ) -class GraphEdgeTrainableFeaturesPlot(BasePlotCallback): +class GraphEdgeTrainableFeaturesPlot(BasePerEpochPlotCallback): """Trainable edge features plot. Visualize the trainable features defined at the edges between meshes. """ - def __init__(self, config): - super().__init__(config) - self.epoch_freq = 5 + def __init__(self, config: OmegaConf, epoch_frequency: int | None = None) -> None: + """Plot trainable edge features. + + Parameters + ---------- + config : OmegaConf + Config object + epoch_frequency : int | None, optional + Override for frequency to plot at, by default None + """ + super().__init__(config, epoch_frequency=epoch_frequency) + @rank_zero_only def _plot( self, trainer: pl.Trainer, - model: torch.nn.Module, - tag: str, - exp_log_tag: str, + pl_module: pl.LightningModule, + epoch: int, ) -> None: + _ = epoch + + model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model fig = plot_graph_edge_features(model) - self._output_figure(trainer.logger, fig, epoch=trainer.current_epoch, tag=tag, exp_log_tag=exp_log_tag) - def on_validation_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None: - if trainer.current_epoch % self.epoch_freq == 0 and pl_module.global_rank == 0: - model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model - self._plot(trainer, model, tag="edge_trainable_params", exp_log_tag="edge_trainable_params") + tag = "edge_trainable_params" + exp_log_tag = "edge_trainable_params" + + self._output_figure( + trainer.logger, + fig, + epoch=trainer.current_epoch, + tag=tag, + exp_log_tag=exp_log_tag, + ) -class PlotLoss(BasePlotCallback): +class PlotLoss(BasePerBatchPlotCallback): """Plots the unsqueezed loss over rollouts.""" - def __init__(self, config: OmegaConf) -> None: + def __init__( + self, + config: OmegaConf, + parameter_groups: dict[dict[str, list[str]]], + batch_frequency: int | None = None, + ) -> None: """Initialise the PlotLoss callback. Parameters ---------- config : OmegaConf Object with configuration settings + parameter_groups : dict + Dictionary with parameter groups with parameter names as keys + batch_frequency : int, optional + Override for batch frequency, by default None """ - super().__init__(config) + super().__init__(config, batch_frequency=batch_frequency) self.parameter_names = None - self.parameter_groups = self.config.diagnostics.plot.parameter_groups + self.parameter_groups = parameter_groups if self.parameter_groups is None: self.parameter_groups = {} @cached_property - def sort_and_color_by_parameter_group(self) -> tuple[np.ndarray, np.ndarray, dict, list]: + def sort_and_color_by_parameter_group( + self, + ) -> tuple[np.ndarray, np.ndarray, dict, list]: """Sort parameters by group and prepare colors.""" def automatically_determine_group(name: str) -> str: @@ -414,7 +554,7 @@ def automatically_determine_group(name: str) -> str: # join parameter groups that appear only once and are not given in config-file unique_group_list = np.array( [ - unique_group_list[tn] if count > 1 or unique_group_list[tn] in self.parameter_groups else "other" + (unique_group_list[tn] if count > 1 or unique_group_list[tn] in self.parameter_groups else "other") for tn, count in enumerate(group_counts) ], ) @@ -461,7 +601,12 @@ def automatically_determine_group(name: str) -> str: string_length = 0 legend_patches.append(mpatches.Patch(color=bar_color_per_group[group_idx], label=text_label[:-2])) - return sort_by_parameter_group, bar_color_per_group[group_inverse], xticks, legend_patches + return ( + sort_by_parameter_group, + bar_color_per_group[group_inverse], + xticks, + legend_patches, + ) @rank_zero_only def _plot( @@ -503,40 +648,61 @@ def _plot( exp_log_tag=f"loss_sample_rstep{rollout_step:02d}_rank{pl_module.local_rank:01d}", ) - def on_validation_batch_end( - self, - trainer: pl.Trainer, - pl_module: pl.LightningModule, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) - -class PlotSample(BasePlotCallback): +class PlotSample(BasePerBatchPlotCallback): """Plots a post-processed sample: input, target and prediction.""" - def __init__(self, config: OmegaConf) -> None: + def __init__( + self, + config: OmegaConf, + sample_idx: int, + parameters: list[str], + accumulation_levels_plot: list[float], + cmap_accumulation: list[str], + precip_and_related_fields: list[str] | None = None, + per_sample: int = 6, + batch_frequency: int | None = None, + ) -> None: """Initialise the PlotSample callback. Parameters ---------- config : OmegaConf Config object - + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + accumulation_levels_plot : list[float] + Accumulation levels to plot + cmap_accumulation : list[str] + Colors of the accumulation levels + precip_and_related_fields : list[str] | None, optional + Precip variable names, by default None + per_sample : int, optional + Number of plots per sample, by default 6 + batch_frequency : int, optional + Batch frequency to plot at, by default None """ - super().__init__(config) - self.sample_idx = self.config.diagnostics.plot.sample_idx - self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields - LOGGER.info("Using defined accumulation colormap for fields: %s", self.precip_and_related_fields) + super().__init__(config, batch_frequency=batch_frequency) + self.sample_idx = sample_idx + self.parameters = parameters + + self.precip_and_related_fields = precip_and_related_fields + self.accumulation_levels_plot = accumulation_levels_plot + self.cmap_accumulation = cmap_accumulation + self.per_sample = per_sample + + LOGGER.info( + "Using defined accumulation colormap for fields: %s", + self.precip_and_related_fields, + ) @rank_zero_only def _plot( self, trainer: pl.Trainer, - pl_module: pl.Lightning_module, + pl_module: pl.LightningModule, outputs: list[torch.Tensor], batch: torch.Tensor, batch_idx: int, @@ -547,8 +713,11 @@ def _plot( # Build dictionary of indices and parameters to be plotted diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic plot_parameters_dict = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in diagnostics, + ) + for name in self.parameters } # When running in Async mode, it might happen that in the last epoch these tensors @@ -578,10 +747,10 @@ def _plot( for rollout_step in range(pl_module.rollout): fig = plot_predicted_multilevel_flat_sample( plot_parameters_dict, - self.config.diagnostics.plot.per_sample, + self.per_sample, self.latlons, - self.config.diagnostics.plot.accumulation_levels_plot, - self.config.diagnostics.plot.cmap_accumulation, + self.accumulation_levels_plot, + self.cmap_accumulation, data[0, ...].squeeze(), data[rollout_step + 1, ...].squeeze(), output_tensor[rollout_step, ...], @@ -596,40 +765,82 @@ def _plot( exp_log_tag=f"val_pred_sample_rstep{rollout_step:02d}_rank{local_rank:01d}", ) - def on_validation_batch_end( + +class BasePlotAdditionalMetrics(BasePerBatchPlotCallback): + """Base processing class for additional metrics.""" + + def process( self, - trainer: pl.Trainer, - pl_module: pl.Lightning_module, - outputs: list[torch.Tensor], + pl_module: pl.LightningModule, + outputs: list, batch: torch.Tensor, - batch_idx: int, - ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) + ) -> tuple[np.ndarray, np.ndarray]: + # When running in Async mode, it might happen that in the last epoch these tensors + # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA + # but internal ones would be on the cpu), The lines below allow to address this problem + if self.pre_processors is None: + # Copy to be used across all the training cycle + self.pre_processors = copy.deepcopy(pl_module.model.pre_processors).cpu() + if self.post_processors is None: + # Copy to be used across all the training cycle + self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() + if self.latlons is None: + self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) + + batch = pl_module.model.pre_processors(batch, in_place=False) + input_tensor = batch[ + self.sample_idx, + pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, + ..., + pl_module.data_indices.internal_data.output.full, + ].cpu() + data = self.post_processors(input_tensor).numpy() + output_tensor = self.post_processors( + torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), + in_place=False, + ).numpy() + return data, output_tensor -class PlotAdditionalMetrics(BasePlotCallback): +class PlotSpectrum(BasePlotAdditionalMetrics): """Plots TP related metric comparing target and prediction. The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. - Power Spectrum - - Histograms """ - def __init__(self, config: OmegaConf) -> None: - """Initialise the PlotAdditionalMetrics callback. + def __init__( + self, + config: OmegaConf, + sample_idx: int, + parameters: list[str], + precip_and_related_fields: list[str], + batch_frequency: int | None = None, + ) -> None: + """Initialise the PlotSpectrum callback. Parameters ---------- config : OmegaConf Config object - + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + precip_and_related_fields : list[str] | None, optional + Precip variable names, by default None + batch_frequency : int | None, optional + Override for batch frequency, by default None """ - super().__init__(config) - self.sample_idx = self.config.diagnostics.plot.sample_idx - self.precip_and_related_fields = self.config.diagnostics.plot.precip_and_related_fields - LOGGER.info("Using precip histogram plotting method for fields: %s.", self.precip_and_related_fields) + super().__init__(config, batch_frequency=batch_frequency) + self.sample_idx = sample_idx + self.parameters = parameters + self.precip_and_related_fields = precip_and_related_fields + LOGGER.info( + "Using precip histogram plotting method for fields: %s.", + self.precip_and_related_fields, + ) @rank_zero_only def _plot( @@ -643,89 +854,118 @@ def _plot( ) -> None: logger = trainer.logger - # When running in Async mode, it might happen that in the last epoch these tensors - # have been moved to the cpu (and then the denormalising would fail as the 'input_tensor' would be on CUDA - # but internal ones would be on the cpu), The lines below allow to address this problem - if self.pre_processors is None: - # Copy to be used across all the training cycle - self.pre_processors = copy.deepcopy(pl_module.model.pre_processors).cpu() - if self.post_processors is None: - # Copy to be used across all the training cycle - self.post_processors = copy.deepcopy(pl_module.model.post_processors).cpu() - if self.latlons is None: - self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) local_rank = pl_module.local_rank - batch = pl_module.model.pre_processors(batch, in_place=False) - input_tensor = batch[ - self.sample_idx, - pl_module.multi_step - 1 : pl_module.multi_step + pl_module.rollout + 1, - ..., - pl_module.data_indices.internal_data.output.full, - ].cpu() - data = self.post_processors(input_tensor).numpy() - output_tensor = self.post_processors( - torch.cat(tuple(x[self.sample_idx : self.sample_idx + 1, ...].cpu() for x in outputs[1])), - in_place=False, - ).numpy() + data, output_tensor = self.process(pl_module, outputs, batch) for rollout_step in range(pl_module.rollout): - if self.config.diagnostics.plot.parameters_histogram is not None: - # Build dictionary of inidicies and parameters to be plotted - - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - plot_parameters_dict_histogram = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters_histogram - } - - fig = plot_histogram( - plot_parameters_dict_histogram, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - precip_and_related_fields=self.precip_and_related_fields, - ) + # Build dictionary of inidicies and parameters to be plotted - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + plot_parameters_dict_histogram = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in diagnostics, ) + for name in self.parameters + } - if self.config.diagnostics.plot.parameters_spectrum is not None: - # Build dictionary of inidicies and parameters to be plotted - diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - - plot_parameters_dict_spectrum = { - pl_module.data_indices.model.output.name_to_index[name]: (name, name not in diagnostics) - for name in self.config.diagnostics.plot.parameters_spectrum - } - - fig = plot_power_spectrum( - plot_parameters_dict_spectrum, - self.latlons, - data[0, ...].squeeze(), - data[rollout_step + 1, ...].squeeze(), - output_tensor[rollout_step, ...], - ) + fig = plot_histogram( + plot_parameters_dict_histogram, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + precip_and_related_fields=self.precip_and_related_fields, + ) - self._output_figure( - logger, - fig, - epoch=epoch, - tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", - ) + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", + ) - def on_validation_batch_end( + +class PlotHistogram(BasePlotAdditionalMetrics): + """Plots TP related metric comparing target and prediction. + + The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. + + - Histograms + """ + + def __init__( + self, + config: OmegaConf, + sample_idx: int, + parameters: list[str], + precip_and_related_fields: list[str] | None = None, + batch_frequency: int | None = None, + ) -> None: + """Initialise the PlotHistogram callback. + + Parameters + ---------- + config : OmegaConf + Config object + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + precip_and_related_fields : list[str] | None, optional + Precip variable names, by default None + batch_frequency : int | None, optional + Override for batch frequency, by default None + """ + super().__init__(config, batch_frequency=batch_frequency) + self.sample_idx = sample_idx + self.parameters = parameters + self.precip_and_related_fields = precip_and_related_fields + LOGGER.info( + "Using precip histogram plotting method for fields: %s.", + self.precip_and_related_fields, + ) + + @rank_zero_only + def _plot( self, trainer: pl.Trainer, pl_module: pl.LightningModule, - outputs: list[torch.Tensor], + outputs: list, batch: torch.Tensor, batch_idx: int, + epoch: int, ) -> None: - if batch_idx % self.plot_frequency == 0: - self.plot(trainer, pl_module, outputs, batch, batch_idx, epoch=trainer.current_epoch) + logger = trainer.logger + + local_rank = pl_module.local_rank + data, output_tensor = self.process(pl_module, outputs, batch) + + for rollout_step in range(pl_module.rollout): + + # Build dictionary of inidicies and parameters to be plotted + diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic + + plot_parameters_dict_spectrum = { + pl_module.data_indices.model.output.name_to_index[name]: ( + name, + name not in diagnostics, + ) + for name in self.parameters + } + + fig = plot_power_spectrum( + plot_parameters_dict_spectrum, + self.latlons, + data[0, ...].squeeze(), + data[rollout_step + 1, ...].squeeze(), + output_tensor[rollout_step, ...], + ) + + self._output_figure( + logger, + fig, + epoch=epoch, + tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", + ) diff --git a/src/anemoi/training/diagnostics/callbacks/id.py b/src/anemoi/training/diagnostics/callbacks/provenance.py similarity index 100% rename from src/anemoi/training/diagnostics/callbacks/id.py rename to src/anemoi/training/diagnostics/callbacks/provenance.py From 0349be2c25ffa1cbcb2c17847c9c7193ef0319ae Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Thu, 10 Oct 2024 16:07:30 +0000 Subject: [PATCH 18/40] Fix tests --- tests/diagnostics/test_callbacks.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/tests/diagnostics/test_callbacks.py b/tests/diagnostics/test_callbacks.py index cf927742..dd1d63b4 100644 --- a/tests/diagnostics/test_callbacks.py +++ b/tests/diagnostics/test_callbacks.py @@ -7,7 +7,6 @@ # nor does it submit to any jurisdiction. import omegaconf -import pytest import yaml from anemoi.training.diagnostics.callbacks import get_callbacks @@ -26,8 +25,8 @@ profiler: False + enable_checkpointing: False checkpoint: - enabled: False log: {} """ @@ -40,15 +39,6 @@ def test_no_extra_callbacks_set(): assert len(callbacks) == 1 # ParentUUIDCallback -def test_deprecation_warning(): - # Test deprecation warning - with pytest.warns(DeprecationWarning): - config = omegaconf.OmegaConf.create(default_config) - config.diagnostics.update({"eval": {"enabled": True, "rollout": 1, "frequency": 1}}) - callbacks = get_callbacks(config) - assert len(callbacks) == 2 - - def test_add_config_enabled_callback(): # Add logging callback config = omegaconf.OmegaConf.create(default_config) @@ -59,7 +49,9 @@ def test_add_config_enabled_callback(): def test_add_callback(): config = omegaconf.OmegaConf.create(default_config) - config.diagnostics.callbacks.append({"_target_": "anemoi.training.diagnostics.callbacks.id.ParentUUIDCallback"}) + config.diagnostics.callbacks.append( + {"_target_": "anemoi.training.diagnostics.callbacks.provenance.ParentUUIDCallback"}, + ) callbacks = get_callbacks(config) assert len(callbacks) == 2 From 1e97ff13170ddd4ec8fbf058f441ae25c54e4748 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Tue, 15 Oct 2024 08:38:09 +0000 Subject: [PATCH 19/40] PR Fixes - Remove enabled from plotting callbacks - Connect sample_idx in config --- .../training/config/diagnostics/plot/detailed.yaml | 10 ++++++---- src/anemoi/training/config/diagnostics/plot/none.yaml | 2 +- .../training/config/diagnostics/plot/simple.yaml | 6 ++++-- src/anemoi/training/diagnostics/callbacks/__init__.py | 7 +++---- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index c36c8e2a..d04e5130 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -1,4 +1,3 @@ -enabled: True # Enable plotting callbacks asynchronous: True # Whether to plot asynchronously frequency: # Frequency of the plotting batch: 750 @@ -17,6 +16,9 @@ parameters: - tp - cp +# Sample index +sample_idx: 0 + # Precipitation and related fields precip_and_related_fields: [tp, cp] @@ -32,7 +34,7 @@ callbacks: moisture: [tp, cp, tcw] sfc_wind: [10u, 10v] - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample - sample_idx: 0 + sample_idx: ${diagnostics.plot.sample_idx} per_sample : 6 parameters: ${diagnostics.plot.parameters} #Defining the accumulation levels for precipitation related fields and the colormap @@ -42,7 +44,7 @@ callbacks: - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum # batch_frequency: 100 # Override for batch frequency - sample_idx: 0 + sample_idx: ${diagnostics.plot.sample_idx} precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} parameters: - z_500 @@ -51,7 +53,7 @@ callbacks: - 10u - 10v - _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram - sample_idx: 0 + sample_idx: ${diagnostics.plot.sample_idx} precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} parameters: - z_500 diff --git a/src/anemoi/training/config/diagnostics/plot/none.yaml b/src/anemoi/training/config/diagnostics/plot/none.yaml index bc114417..3101f292 100644 --- a/src/anemoi/training/config/diagnostics/plot/none.yaml +++ b/src/anemoi/training/config/diagnostics/plot/none.yaml @@ -1 +1 @@ -enabled: false +callbacks: [] diff --git a/src/anemoi/training/config/diagnostics/plot/simple.yaml b/src/anemoi/training/config/diagnostics/plot/simple.yaml index a2bff503..2a987ccb 100644 --- a/src/anemoi/training/config/diagnostics/plot/simple.yaml +++ b/src/anemoi/training/config/diagnostics/plot/simple.yaml @@ -1,4 +1,3 @@ -enabled: True # Enable plotting callbacks asynchronous: True # Whether to plot asynchronously frequency: # Frequency of the plotting batch: 750 @@ -17,6 +16,9 @@ parameters: - tp - cp +# Sample index +sample_idx: 0 + # Precipitation and related fields precip_and_related_fields: [tp, cp] @@ -29,7 +31,7 @@ callbacks: moisture: [tp, cp, tcw] sfc_wind: [10u, 10v] - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample - sample_idx: 0 + sample_idx: ${diagnostics.plot.sample_idx} per_sample : 6 parameters: ${diagnostics.plot.parameters} #Defining the accumulation levels for precipitation related fields and the colormap diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index 2079f047..f6159f4e 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -196,10 +196,9 @@ def get_callbacks(config: DictConfig) -> list[Callback]: # noqa: C901 trainer_callbacks.append(instantiate(callback, config)) # Plotting callbacks - if config.diagnostics.plot.enabled: - for callback in config.diagnostics.plot.get("callbacks", None) or []: - # Instantiate new callbacks - trainer_callbacks.append(instantiate(callback, config)) + for callback in config.diagnostics.plot.get("callbacks", None) or []: + # Instantiate new callbacks + trainer_callbacks.append(instantiate(callback, config)) # Extend with config enabled callbacks trainer_callbacks.extend(_get_config_enabled_callbacks(config)) From 460c8ba0e45d555a73ff7da521f2893d6c2f23bc Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 18 Oct 2024 08:42:35 +0000 Subject: [PATCH 20/40] Update Changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5b3e745c..18c179da 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Keep it human-readable, your future self will thank you! ### Added ### Fixed +- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) ### Changed ## [0.2.0 - Feature release](https://github.com/ecmwf/anemoi-training/compare/0.1.0...0.2.0) - 2024-10-16 @@ -51,7 +52,6 @@ Keep it human-readable, your future self will thank you! ### Changed - Updated configuration examples in documentation and corrected links - [#46](https://github.com/ecmwf/anemoi-training/pull/46) -- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) - Remove credential prompt from mlflow login, replace with seed refresh token via web - [#78](https://github.com/ecmwf/anemoi-training/pull/78) - Update CODEOWNERS From 21c05dec83a33a40c4126c2aa40354c593b2e4ba Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 21 Oct 2024 14:39:35 +0100 Subject: [PATCH 21/40] Refactor rollout (#87) Refactor rollout logic --- CHANGELOG.md | 2 + docs/modules/diagnostics.rst | 13 ++- src/anemoi/training/config/config.yaml | 2 +- .../config/dataloader/native_grid.yaml | 2 + .../diagnostics/callbacks/pretraining.yaml | 3 - .../{rollout.yaml => rollout_eval.yaml} | 4 +- .../{eval_rollout.yaml => evaluation.yaml} | 5 -- .../config/diagnostics/plot/detailed.yaml | 9 ++- src/anemoi/training/data/datamodule.py | 6 +- .../diagnostics/callbacks/evaluation.py | 37 +++------ .../training/diagnostics/callbacks/plot.py | 20 ++--- src/anemoi/training/train/forecaster.py | 81 ++++++++++++++----- 12 files changed, 97 insertions(+), 87 deletions(-) rename src/anemoi/training/config/diagnostics/callbacks/{rollout.yaml => rollout_eval.yaml} (53%) rename src/anemoi/training/config/diagnostics/{eval_rollout.yaml => evaluation.yaml} (91%) diff --git a/CHANGELOG.md b/CHANGELOG.md index cd80e245..de4ff07c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ Keep it human-readable, your future self will thank you! ### Fixed - Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) +- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87) + - Enable longer validation rollout than training ### Changed ## [0.2.0 - Feature release](https://github.com/ecmwf/anemoi-training/compare/0.1.0...0.2.0) - 2024-10-16 diff --git a/docs/modules/diagnostics.rst b/docs/modules/diagnostics.rst index c46f201c..a6a2f744 100644 --- a/docs/modules/diagnostics.rst +++ b/docs/modules/diagnostics.rst @@ -21,18 +21,17 @@ functionality to use both Weights & Biases and Tensorboard. The callbacks can also be used to evaluate forecasts over longer rollouts beyond the forecast time that the model is trained on. The -number of rollout steps (or forecast iteration steps) is set using -``config.eval.rollout = *num_of_rollout_steps*``. +number of rollout steps for verification (or forecast iteration steps) +is set using ``config.dataloader.validation_rollout = +*num_of_rollout_steps*``. Note the user has the option to evaluate the callbacks asynchronously (using the following config option ``config.diagnostics.plot.asynchronous``, which means that the model training doesn't stop whilst the callbacks are being evaluated). -However, note that callbacks can still be slow, and therefore the -plotting callbacks can be switched off by setting -``config.diagnostics.plot.enabled`` to ``False`` or all the callbacks -can be completely switched off by setting -``config.diagnostics.eval.enabled`` to ``False``. +Callbacks are configured in the config file under the +``config.diagnostics.callbacks`` key, and plotting callbacks under the +``config.diagnostics.plot`` key. Below is the documentation for the default callbacks provided, but it is also possible for users to add callbacks using the same structure: diff --git a/src/anemoi/training/config/config.yaml b/src/anemoi/training/config/config.yaml index 2045da93..a379acfd 100644 --- a/src/anemoi/training/config/config.yaml +++ b/src/anemoi/training/config/config.yaml @@ -1,7 +1,7 @@ defaults: - data: zarr - dataloader: native_grid -- diagnostics: eval_rollout +- diagnostics: evaluation - hardware: example - graph: multi_scale - model: gnn diff --git a/src/anemoi/training/config/dataloader/native_grid.yaml b/src/anemoi/training/config/dataloader/native_grid.yaml index e6d50801..d7aa4f6d 100644 --- a/src/anemoi/training/config/dataloader/native_grid.yaml +++ b/src/anemoi/training/config/dataloader/native_grid.yaml @@ -45,6 +45,8 @@ training: frequency: ${data.frequency} drop: [] +validation_rollout: 1 # number of rollouts to use for validation, must be equal or greater than rollout expected by callbacks + validation: dataset: ${dataloader.dataset} start: 2021 diff --git a/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml b/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml index f96b9bed..1eb35f69 100644 --- a/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml +++ b/src/anemoi/training/config/diagnostics/callbacks/pretraining.yaml @@ -1,4 +1 @@ # Add callbacks here -# - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval -# rollout: ${diagnostics.eval.rollout} -# frequency: ${diagnostics.eval.frequency} diff --git a/src/anemoi/training/config/diagnostics/callbacks/rollout.yaml b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml similarity index 53% rename from src/anemoi/training/config/diagnostics/callbacks/rollout.yaml rename to src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml index 7a949b53..d7daf8d0 100644 --- a/src/anemoi/training/config/diagnostics/callbacks/rollout.yaml +++ b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml @@ -1,4 +1,4 @@ # Add callbacks here - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval - rollout: ${diagnostics.eval.rollout} - frequency: ${diagnostics.eval.frequency} + rollout: ${dataloader.validation_rollout} + frequency: 20 diff --git a/src/anemoi/training/config/diagnostics/eval_rollout.yaml b/src/anemoi/training/config/diagnostics/evaluation.yaml similarity index 91% rename from src/anemoi/training/config/diagnostics/eval_rollout.yaml rename to src/anemoi/training/config/diagnostics/evaluation.yaml index 804df9d7..9647fe1c 100644 --- a/src/anemoi/training/config/diagnostics/eval_rollout.yaml +++ b/src/anemoi/training/config/diagnostics/evaluation.yaml @@ -3,11 +3,6 @@ defaults: - plot: detailed - callbacks: pretraining -eval: - enabled: False - # use this to evaluate the model over longer rollouts, every so many validation batches - rollout: 12 - frequency: 20 debug: # this will detect and trace back NaNs / Infs etc. but will slow down training diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index d04e5130..3f6a862f 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -61,7 +61,8 @@ callbacks: - 2t - 10u - 10v - # - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots - # rollout: [60] - # batch_frequency: 10 - # epoch_frequency: 20 + - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots + rollout: + - ${dataloader.validation_rollout} + batch_frequency: 10 + epoch_frequency: 20 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index b714c965..b6c4335a 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -126,10 +126,8 @@ def ds_train(self) -> NativeGridDataset: @cached_property def ds_valid(self) -> NativeGridDataset: r = self.rollout - if self.config.diagnostics.eval.enabled: - r = max(r, self.config.diagnostics.eval.rollout) - if self.config.diagnostics.plot.get("longrollout") and self.config.diagnostics.plot.longrollout.enabled: - r = max(r, max(self.config.diagnostics.plot.longrollout.rollout)) + r = max(r, self.config.dataloader.get("validation_rollout", 1)) + assert self.config.dataloader.training.end < self.config.dataloader.validation.start, ( f"Training end date {self.config.dataloader.training.end} is not before" f"validation start date {self.config.dataloader.validation.start}" diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index a91de1f5..f3c4cb9a 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -32,6 +32,10 @@ def __init__(self, config: OmegaConf, rollout: int, frequency: int) -> None: ---------- config : dict Dictionary with configuration settings + rollout : int + Rollout length for evaluation + frequency : int + Frequency of evaluation, per batch """ super().__init__() @@ -53,33 +57,14 @@ def _eval( loss = torch.zeros(1, dtype=batch.dtype, device=pl_module.device, requires_grad=False) metrics = {} - # start rollout - batch = pl_module.model.pre_processors(batch, in_place=False) - x = batch[ - :, - 0 : pl_module.multi_step, - ..., - pl_module.data_indices.internal_data.input.full, - ] # (bs, multi_step, latlon, nvar) - assert ( - batch.shape[1] >= self.rollout + pl_module.multi_step - ), "Batch length not sufficient for requested rollout length!" + assert batch.shape[1] >= self.rollout + pl_module.multi_step, ( + "Batch length not sufficient for requested validation rollout length! " + f"Set `dataloader.validation_rollout` to at least {self.rollout + pl_module.multi_step}" + ) with torch.no_grad(): - for rollout_step in range(self.rollout): - y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) - y = batch[ - :, - pl_module.multi_step + rollout_step, - ..., - pl_module.data_indices.internal_data.output.full, - ] # target, shape = (bs, latlon, nvar) - # y includes the auxiliary variables, so we must leave those out when computing the loss - loss += pl_module.loss(y_pred, y) - - x = pl_module.advance_input(x, y_pred, batch, rollout_step) - - metrics_next, _ = pl_module.calculate_val_metrics(y_pred, y, rollout_step) + for loss_next, metrics_next, _ in pl_module.rollout_step(batch, rollout=self.rollout, validation_mode=True): + loss += loss_next metrics.update(metrics_next) # scale loss @@ -88,7 +73,7 @@ def _eval( def _log(self, pl_module: pl.LightningModule, loss: torch.Tensor, metrics: dict, bs: int) -> None: pl_module.log( - f"val_r{self.rollout}_wmse", + f"val_r{self.rollout}_{getattr(pl_module.loss, 'name', pl_module.loss.__class__.__name__.lower())}", loss, on_epoch=True, on_step=True, diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 817e5537..c80e42c1 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -324,17 +324,10 @@ def _plot( self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) local_rank = pl_module.local_rank - batch = pl_module.model.pre_processors(batch, in_place=False) - # prepare input tensor for rollout from preprocessed batch - x = batch[ - :, - 0 : pl_module.multi_step, - ..., - pl_module.data_indices.internal_data.input.full, - ] # (bs, multi_step, latlon, nvar) - assert ( - batch.shape[1] >= max(self.rollout) + pl_module.multi_step - ), "Batch length not sufficient for requested rollout length!" + assert batch.shape[1] >= self.rollout + pl_module.multi_step, ( + "Batch length not sufficient for requested validation rollout length! " + f"Set `dataloader.validation_rollout` to at least {self.rollout + pl_module.multi_step}" + ) # prepare input tensor for plotting input_tensor_0 = batch[ @@ -347,10 +340,7 @@ def _plot( # start rollout with torch.no_grad(): - for rollout_step in range(max(self.rollout)): - y_pred = pl_module(x) # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) - - x = pl_module.advance_input(x, y_pred, batch, rollout_step) + for rollout_step, (_, _, y_pred) in enumerate(pl_module.rollout_step(batch, rollout=max(self.rollout))): if (rollout_step + 1) in self.rollout: # prepare true output tensor for plotting diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 6738848a..bfec6bae 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -11,7 +11,9 @@ import math import os from collections import defaultdict +from collections.abc import Generator from collections.abc import Mapping +from typing import Optional import numpy as np import pytorch_lightning as pl @@ -222,17 +224,39 @@ def advance_input( ] return x - def _step( + def rollout_step( self, batch: torch.Tensor, - batch_idx: int, + rollout: Optional[int] = None, # noqa: FA100 validation_mode: bool = False, - ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: - del batch_idx - loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) + ) -> Generator[tuple[torch.Tensor, dict, list], None, None]: + """ + Rollout step for the forecaster. + + Will run pre_processors on batch, but not post_processors on predictions. + + Parameters + ---------- + batch : torch.Tensor + Batch to use for rollout + rollout : int | None, optional + Number of times to rollout for, by default None + validation_mode : bool, optional + Whether in validation mode, and to calculate validation metrics, by default False + If False, metrics will be empty + + Yields + ------ + Generator[tuple[torch.Tensor, dict, list], None, None] + Loss value, metrics, and predictions (per step) + + Returns + ------- + None + None + """ # for validation not normalized in-place because remappers cannot be applied in-place batch = self.model.pre_processors(batch, in_place=not validation_mode) - metrics = {} # start rollout of preprocessed batch x = batch[ @@ -241,29 +265,51 @@ def _step( ..., self.data_indices.internal_data.input.full, ] # (bs, multi_step, latlon, nvar) + msg = ( + "Batch length not sufficient for requested multi_step length!" + f", {batch.shape[1]} !>= {rollout + self.multi_step}" + ) + assert batch.shape[1] >= rollout + self.multi_step, msg - y_preds = [] - for rollout_step in range(self.rollout): + for rollout_step in range(rollout or self.rollout): # prediction at rollout step rollout_step, shape = (bs, latlon, nvar) y_pred = self(x) y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.internal_data.output.full] # y includes the auxiliary variables, so we must leave those out when computing the loss - loss += checkpoint(self.loss, y_pred, y, use_reentrant=False) + loss = checkpoint(self.loss, y_pred, y, use_reentrant=False) x = self.advance_input(x, y_pred, batch, rollout_step) + metrics_next = {} if validation_mode: - metrics_next, y_preds_next = self.calculate_val_metrics( + metrics_next = self.calculate_val_metrics( y_pred, y, rollout_step, - enable_plot=self.enable_plot, ) - metrics.update(metrics_next) - y_preds.extend(y_preds_next) + yield loss, metrics_next, y_pred + + def _step( + self, + batch: torch.Tensor, + batch_idx: int, + validation_mode: bool = False, + ) -> tuple[torch.Tensor, Mapping[str, torch.Tensor]]: + del batch_idx + loss = torch.zeros(1, dtype=batch.dtype, device=self.device, requires_grad=False) + metrics = {} + y_preds = [] + + for loss_next, metrics_next, y_preds_next in self.rollout_step( + batch, + rollout=self.rollout, + validation_mode=validation_mode, + ): + loss += loss_next + metrics.update(metrics_next) + y_preds.extend(y_preds_next) - # scale loss loss *= 1.0 / self.rollout return loss, metrics, y_preds @@ -272,10 +318,8 @@ def calculate_val_metrics( y_pred: torch.Tensor, y: torch.Tensor, rollout_step: int, - enable_plot: bool = False, ) -> tuple[dict, list]: metrics = {} - y_preds = [] y_postprocessed = self.model.post_processors(y, in_place=False) y_pred_postprocessed = self.model.post_processors(y_pred, in_place=False) for mkey, indices in self.metric_ranges_validation.items(): @@ -283,10 +327,7 @@ def calculate_val_metrics( y_pred_postprocessed[..., indices], y_postprocessed[..., indices], ) - - if enable_plot: - y_preds.append(y_pred) - return metrics, y_preds + return metrics def training_step(self, batch: torch.Tensor, batch_idx: int) -> torch.Tensor: train_loss, _, _ = self._step(batch, batch_idx) From 3c5e1448f7cd3379871912f5f2049fb28f2f39e1 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 21 Oct 2024 14:56:21 +0000 Subject: [PATCH 22/40] Remove batch frequency from LongRolloutPlots --- src/anemoi/training/config/diagnostics/plot/detailed.yaml | 1 - src/anemoi/training/diagnostics/callbacks/plot.py | 6 +----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index 3f6a862f..b27bee93 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -64,5 +64,4 @@ callbacks: - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots rollout: - ${dataloader.validation_rollout} - batch_frequency: 10 epoch_frequency: 20 diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index c80e42c1..c20a4cf3 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -260,7 +260,6 @@ def __init__( self, config: OmegaConf, rollout: list[int], - batch_frequency: int, epoch_frequency: int = 1, sample_idx: int = 0, ) -> None: @@ -272,8 +271,6 @@ def __init__( Config object rollout : list[int] Rollout steps to plot at - batch_frequency : int - Batch frequency to plot at epoch_frequency : int, optional Epoch frequency to plot at, by default 1 sample_idx : int, optional @@ -282,7 +279,6 @@ def __init__( super().__init__(config) self.epoch_frequency = epoch_frequency - self.batch_frequency = batch_frequency LOGGER.debug( "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", @@ -389,7 +385,7 @@ def on_validation_batch_end( batch: torch.Tensor, batch_idx: int, ) -> None: - if (batch_idx) % self.batch_frequency == 0 and (trainer.current_epoch + 1) % self.epoch_frequency == 0: + if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.epoch_frequency == 0: precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, From 382728c561ca53309636f1bb3783e680368fb32c Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Tue, 22 Oct 2024 16:12:30 +0000 Subject: [PATCH 23/40] Remove TP reference --- src/anemoi/training/diagnostics/callbacks/plot.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index c20a4cf3..15b22b7a 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -886,11 +886,9 @@ def _plot( class PlotHistogram(BasePlotAdditionalMetrics): - """Plots TP related metric comparing target and prediction. + """Plots histograms comparing target and prediction. The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones. - - - Histograms """ def __init__( From 6fa66ccb2b0097ce2392fefa4dc44463c0aee22d Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 23 Oct 2024 10:49:37 +0000 Subject: [PATCH 24/40] Remove missing config reference --- src/anemoi/training/train/forecaster.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index bfec6bae..6dc6a179 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -126,8 +126,6 @@ def __init__( LOGGER.debug("Rollout max : %d", self.rollout_max) LOGGER.debug("Multistep: %d", self.multi_step) - self.enable_plot = config.diagnostics.plot.enabled - self.model_comm_group_id = int(os.environ.get("SLURM_PROCID", "0")) // config.hardware.num_gpus_per_model self.model_comm_group_rank = int(os.environ.get("SLURM_PROCID", "0")) % config.hardware.num_gpus_per_model self.model_comm_num_groups = math.ceil( From 110fb640a0320022ed53a46924b7ce5eb6ecee11 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 23 Oct 2024 15:37:58 +0000 Subject: [PATCH 25/40] Swapped histogram and spectrum --- src/anemoi/training/diagnostics/callbacks/plot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 15b22b7a..dc44b3da 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -868,7 +868,7 @@ def _plot( for name in self.parameters } - fig = plot_histogram( + fig = plot_power_spectrum( plot_parameters_dict_histogram, data[0, ...].squeeze(), data[rollout_step + 1, ...].squeeze(), @@ -951,7 +951,7 @@ def _plot( for name in self.parameters } - fig = plot_power_spectrum( + fig = plot_histogram( plot_parameters_dict_spectrum, self.latlons, data[0, ...].squeeze(), From 23cc785d52d5523450ef50494025f3696031bb1f Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 23 Oct 2024 15:42:41 +0000 Subject: [PATCH 26/40] Update copyright notice --- src/anemoi/training/diagnostics/callbacks/__init__.py | 7 +++---- src/anemoi/training/diagnostics/callbacks/checkpoint.py | 4 +++- src/anemoi/training/diagnostics/callbacks/evaluation.py | 4 +++- src/anemoi/training/diagnostics/callbacks/optimiser.py | 4 +++- src/anemoi/training/diagnostics/callbacks/plot.py | 8 +++++--- src/anemoi/training/diagnostics/callbacks/provenance.py | 4 +++- 6 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/__init__.py b/src/anemoi/training/diagnostics/callbacks/__init__.py index f6159f4e..4b0921f1 100644 --- a/src/anemoi/training/diagnostics/callbacks/__init__.py +++ b/src/anemoi/training/diagnostics/callbacks/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. @@ -8,7 +10,6 @@ from __future__ import annotations import logging -import warnings from datetime import timedelta from typing import TYPE_CHECKING from typing import Callable @@ -17,9 +18,7 @@ from hydra.utils import instantiate from omegaconf import DictConfig -from anemoi.training.diagnostics.callbacks import plot from anemoi.training.diagnostics.callbacks.checkpoint import AnemoiCheckpoint -from anemoi.training.diagnostics.callbacks.evaluation import RolloutEval from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor from anemoi.training.diagnostics.callbacks.optimiser import StochasticWeightAveraging from anemoi.training.diagnostics.callbacks.provenance import ParentUUIDCallback diff --git a/src/anemoi/training/diagnostics/callbacks/checkpoint.py b/src/anemoi/training/diagnostics/callbacks/checkpoint.py index 6953d8fb..cb95f5a4 100644 --- a/src/anemoi/training/diagnostics/callbacks/checkpoint.py +++ b/src/anemoi/training/diagnostics/callbacks/checkpoint.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index f3c4cb9a..d09378b0 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/diagnostics/callbacks/optimiser.py b/src/anemoi/training/diagnostics/callbacks/optimiser.py index 07de3078..bff82fcb 100644 --- a/src/anemoi/training/diagnostics/callbacks/optimiser.py +++ b/src/anemoi/training/diagnostics/callbacks/optimiser.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index dc44b3da..500a50f3 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -1,12 +1,14 @@ -# ruff: noqa: ANN001 - -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +# ruff: noqa: ANN001 + from __future__ import annotations import copy diff --git a/src/anemoi/training/diagnostics/callbacks/provenance.py b/src/anemoi/training/diagnostics/callbacks/provenance.py index 3b5da7a6..414f0311 100644 --- a/src/anemoi/training/diagnostics/callbacks/provenance.py +++ b/src/anemoi/training/diagnostics/callbacks/provenance.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. From 51a455d9679a0431202390f3337e353bb9e184cb Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Thu, 24 Oct 2024 14:13:29 +0000 Subject: [PATCH 27/40] Fix issues with split of PlotAdditionalMetrics --- .../config/diagnostics/plot/detailed.yaml | 1 - .../training/diagnostics/callbacks/plot.py | 28 +++++++------------ 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index b27bee93..5a27e5a0 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -45,7 +45,6 @@ callbacks: - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum # batch_frequency: 100 # Override for batch frequency sample_idx: ${diagnostics.plot.sample_idx} - precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} parameters: - z_500 - tp diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 500a50f3..85d7e007 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -816,7 +816,6 @@ def __init__( config: OmegaConf, sample_idx: int, parameters: list[str], - precip_and_related_fields: list[str], batch_frequency: int | None = None, ) -> None: """Initialise the PlotSpectrum callback. @@ -829,19 +828,12 @@ def __init__( Sample to plot parameters : list[str] Parameters to plot - precip_and_related_fields : list[str] | None, optional - Precip variable names, by default None batch_frequency : int | None, optional Override for batch frequency, by default None """ super().__init__(config, batch_frequency=batch_frequency) self.sample_idx = sample_idx self.parameters = parameters - self.precip_and_related_fields = precip_and_related_fields - LOGGER.info( - "Using precip histogram plotting method for fields: %s.", - self.precip_and_related_fields, - ) @rank_zero_only def _plot( @@ -862,7 +854,7 @@ def _plot( # Build dictionary of inidicies and parameters to be plotted diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - plot_parameters_dict_histogram = { + plot_parameters_dict_spectrum = { pl_module.data_indices.model.output.name_to_index[name]: ( name, name not in diagnostics, @@ -871,19 +863,19 @@ def _plot( } fig = plot_power_spectrum( - plot_parameters_dict_histogram, + plot_parameters_dict_spectrum, + self.latlons, data[0, ...].squeeze(), data[rollout_step + 1, ...].squeeze(), output_tensor[rollout_step, ...], - precip_and_related_fields=self.precip_and_related_fields, ) self._output_figure( logger, fig, epoch=epoch, - tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", + tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", ) @@ -945,7 +937,7 @@ def _plot( # Build dictionary of inidicies and parameters to be plotted diagnostics = [] if self.config.data.diagnostic is None else self.config.data.diagnostic - plot_parameters_dict_spectrum = { + plot_parameters_dict_histogram = { pl_module.data_indices.model.output.name_to_index[name]: ( name, name not in diagnostics, @@ -954,17 +946,17 @@ def _plot( } fig = plot_histogram( - plot_parameters_dict_spectrum, - self.latlons, + plot_parameters_dict_histogram, data[0, ...].squeeze(), data[rollout_step + 1, ...].squeeze(), output_tensor[rollout_step, ...], + self.precip_and_related_fields, ) self._output_figure( logger, fig, epoch=epoch, - tag=f"gnn_pred_val_spec_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_spec_rstep_{rollout_step:02d}_rank{local_rank:01d}", + tag=f"gnn_pred_val_histo_rstep_{rollout_step:02d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_histo_rstep_{rollout_step:02d}_rank{local_rank:01d}", ) From 3c6e1af29d61066b3df9a51f149b1909b5e53075 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 25 Oct 2024 07:42:56 +0000 Subject: [PATCH 28/40] Fix CHANGELOG --- CHANGELOG.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9f900546..3e481ad8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.2.1...HEAD) +### Fixed +- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) +- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87) + - Enable longer validation rollout than training + ## [0.2.1 - Bugfix: resuming mlflow runs](https://github.com/ecmwf/anemoi-training/compare/0.2.0...0.2.1) - 2024-10-24 ### Added @@ -20,13 +25,9 @@ Keep it human-readable, your future self will thank you! - Feature: New `Boolean1DMask` class. Enables rollout training for limited area models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79) ### Fixed -- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60) -- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87) - - Enable longer validation rollout than training - Mlflow-sync to handle creation of new experiments in the remote server [#83] (https://github.com/ecmwf/anemoi-training/pull/83) - Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99] (https://github.com/ecmwf/anemoi-training/pull/99) - ci: fix pyshtools install error (#100) https://github.com/ecmwf/anemoi-training/pull/100 - - Mlflow-sync to handle creation of new experiments in the remote server [#83](https://github.com/ecmwf/anemoi-training/pull/83) - Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99](https://github.com/ecmwf/anemoi-training/pull/99) - ci: fix pyshtools install error [#100](https://github.com/ecmwf/anemoi-training/pull/100) From 86059a9ab3b22d4c9a0f5b9965674426bb8c54d7 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 25 Oct 2024 10:49:23 +0000 Subject: [PATCH 29/40] Fix documentation for callbacks --- docs/modules/diagnostics.rst | 68 ++++++++++++++++++++++++--- docs/user-guide/configuring.rst | 4 +- docs/user-guide/tracking.rst | 2 +- src/anemoi/training/config/debug.yaml | 2 +- 4 files changed, 65 insertions(+), 11 deletions(-) diff --git a/docs/modules/diagnostics.rst b/docs/modules/diagnostics.rst index a6a2f744..8686770a 100644 --- a/docs/modules/diagnostics.rst +++ b/docs/modules/diagnostics.rst @@ -17,7 +17,9 @@ during training. It is split into two parts: By default, anemoi-training uses MLFlow tracker, but it includes functionality to use both Weights & Biases and Tensorboard. -**Callbacks** +########### + Callbacks +########### The callbacks can also be used to evaluate forecasts over longer rollouts beyond the forecast time that the model is trained on. The @@ -25,13 +27,65 @@ number of rollout steps for verification (or forecast iteration steps) is set using ``config.dataloader.validation_rollout = *num_of_rollout_steps*``. -Note the user has the option to evaluate the callbacks asynchronously -(using the following config option -``config.diagnostics.plot.asynchronous``, which means that the model -training doesn't stop whilst the callbacks are being evaluated). Callbacks are configured in the config file under the -``config.diagnostics.callbacks`` key, and plotting callbacks under the -``config.diagnostics.plot`` key. +``config.diagnostics`` key. + +For regular callbacks, they can be provided as a list of dictionaries +underneath the ``config.diagnostics.callbacks`` key. Each dictionary +must have a ``_target`` key which is used by hydra to instantiate the +callback, any other kwarg is passed to the callback's constructor. + +.. code:: yaml + + callbacks: + - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval + rollout: ${dataloader.validation_rollout} + frequency: 20 + +Plotting callbacks are configured in a similar way, but they are +specified underneath the ``config.diagnostics.plot.callbacks`` key. + +This is done to ensure seperation and ease of configuration between +experiments. + +``config.diagnostics.plot`` is a broader config file specifying the +parameters to plot, as well as the plotting frequency, and +asynchronosity. + +Setting ``config.diagnostics.plot.asynchronous``, means that the model +training doesn't stop whilst the callbacks are being evaluated) + +.. code:: yaml + + plot: + asynchronous: True # Whether to plot asynchronously + frequency: # Frequency of the plotting + batch: 750 + epoch: 5 + + # Parameters to plot + parameters: + - z_500 + - t_850 + - u_850 + + # Sample index + sample_idx: 0 + + # Precipitation and related fields + precip_and_related_fields: [tp, cp] + + callbacks: + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: ${diagnostics.plot.sample_idx} + per_sample : 6 + parameters: ${diagnostics.plot.parameters} Below is the documentation for the default callbacks provided, but it is also possible for users to add callbacks using the same structure: diff --git a/docs/user-guide/configuring.rst b/docs/user-guide/configuring.rst index 35efec3f..307cebf0 100644 --- a/docs/user-guide/configuring.rst +++ b/docs/user-guide/configuring.rst @@ -21,7 +21,7 @@ settings at the top as follows: defaults: - data: zarr - dataloader: native_grid - - diagnostics: eval_rollout + - diagnostics: evaluation - hardware: example - graph: multi_scale - model: gnn @@ -100,7 +100,7 @@ match the dataset you provide. defaults: - data: zarr - dataloader: native_grid - - diagnostics: eval_rollout + - diagnostics: evaluation - hardware: example - graph: multi_scale - model: transformer # Change from default group diff --git a/docs/user-guide/tracking.rst b/docs/user-guide/tracking.rst index cab5e851..f97182d7 100644 --- a/docs/user-guide/tracking.rst +++ b/docs/user-guide/tracking.rst @@ -33,7 +33,7 @@ the same experiment. Within the MLflow experiments tab, it is possible to define different namespaces. To create a new namespace, the user just needs to pass an 'experiment_name' -(``config.diagnostics.eval_rollout.log.mlflow.experiment_name``) to the +(``config.diagnostics.evaluation.log.mlflow.experiment_name``) to the mlflow logger. **Parent-Child Runs** diff --git a/src/anemoi/training/config/debug.yaml b/src/anemoi/training/config/debug.yaml index 5be3e9f4..32d58153 100644 --- a/src/anemoi/training/config/debug.yaml +++ b/src/anemoi/training/config/debug.yaml @@ -1,7 +1,7 @@ defaults: - data: zarr - dataloader: native_grid -- diagnostics: eval_rollout +- diagnostics: evaluation - hardware: example - graph: multi_scale - model: gnn From 0bce490e2b40f25e249cc76f040cff19983c49da Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 25 Oct 2024 11:00:19 +0000 Subject: [PATCH 30/40] Add all callback submodules to docs --- docs/modules/diagnostics.rst | 26 ++++++++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/modules/diagnostics.rst b/docs/modules/diagnostics.rst index 8686770a..28eac7c7 100644 --- a/docs/modules/diagnostics.rst +++ b/docs/modules/diagnostics.rst @@ -17,9 +17,7 @@ during training. It is split into two parts: By default, anemoi-training uses MLFlow tracker, but it includes functionality to use both Weights & Biases and Tensorboard. -########### - Callbacks -########### +**Callbacks** The callbacks can also be used to evaluate forecasts over longer rollouts beyond the forecast time that the model is trained on. The @@ -90,7 +88,27 @@ training doesn't stop whilst the callbacks are being evaluated) Below is the documentation for the default callbacks provided, but it is also possible for users to add callbacks using the same structure: -.. automodule:: anemoi.training.diagnostics.callbacks +.. automodule:: anemoi.training.diagnostics.callbacks.checkpoint + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.evaluation + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.optimiser + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.plot + :members: + :no-undoc-members: + :show-inheritance: + +.. automodule:: anemoi.training.diagnostics.callbacks.provenance :members: :no-undoc-members: :show-inheritance: From d6e1d9c5741e750f59dacfad4a917855625647a3 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 25 Oct 2024 13:03:30 +0100 Subject: [PATCH 31/40] Apply suggestions from code review Co-authored-by: Sara Hahner <44293258+sahahner@users.noreply.github.com> --- src/anemoi/training/diagnostics/callbacks/plot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 85d7e007..db0b6b0e 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -322,9 +322,9 @@ def _plot( self.latlons = np.rad2deg(pl_module.latlons_data.clone().cpu().numpy()) local_rank = pl_module.local_rank - assert batch.shape[1] >= self.rollout + pl_module.multi_step, ( + assert batch.shape[1] >= max(self.rollout) + pl_module.multi_step, ( "Batch length not sufficient for requested validation rollout length! " - f"Set `dataloader.validation_rollout` to at least {self.rollout + pl_module.multi_step}" + f"Set `dataloader.validation_rollout` to at least {max(self.rollout) + pl_module.multi_step}" ) # prepare input tensor for plotting @@ -397,7 +397,7 @@ def on_validation_batch_end( context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() with context: - self.plot(trainer, pl_module, output, batch, batch_idx) + self.plot(trainer, pl_module, output, batch, batch_idx, trainer.current_epoch) class GraphNodeTrainableFeaturesPlot(BasePerEpochPlotCallback): From 6073d84f1ebcac571ce07432a108d754ff75d3c3 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 25 Oct 2024 12:48:43 +0000 Subject: [PATCH 32/40] Fix init args issue in RolloutPlots --- .../config/diagnostics/plot/detailed.yaml | 2 ++ .../training/diagnostics/callbacks/plot.py | 34 ++++++++++++++----- 2 files changed, 27 insertions(+), 9 deletions(-) diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index 5a27e5a0..7c01b575 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -64,3 +64,5 @@ callbacks: rollout: - ${dataloader.validation_rollout} epoch_frequency: 20 + sample_idx: ${diagnostics.plot.sample_idx} + parameters: ${diagnostics.plot.parameters} diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index db0b6b0e..59a2deb7 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -146,7 +146,7 @@ def _plot( self, trainer: pl.Trainer, pl_module: pl.LightningModule, - epoch: int, + *args, **kwargs, ) -> None: """Plotting function to be implemented by subclasses.""" @@ -262,10 +262,14 @@ def __init__( self, config: OmegaConf, rollout: list[int], + sample_idx: int, + parameters: list[str], + accumulation_levels_plot: list[float] | None = None, + cmap_accumulation: list[str] | None = None, + per_sample: int = 6, epoch_frequency: int = 1, - sample_idx: int = 0, ) -> None: - """Initialise RolloutEval callback. + """Initialise LongRolloutPlots callback. Parameters ---------- @@ -273,10 +277,18 @@ def __init__( Config object rollout : list[int] Rollout steps to plot at + sample_idx : int + Sample to plot + parameters : list[str] + Parameters to plot + accumulation_levels_plot : list[float] | None + Accumulation levels to plot, by default None + cmap_accumulation : list[str] | None + Colors of the accumulation levels, by default None + per_sample : int, optional + Number of plots per sample, by default 6 epoch_frequency : int, optional Epoch frequency to plot at, by default 1 - sample_idx : int, optional - Sample to plot, by default 0 """ super().__init__(config) @@ -289,6 +301,10 @@ def __init__( ) self.rollout = rollout self.sample_idx = sample_idx + self.accumulation_levels_plot = accumulation_levels_plot + self.cmap_accumulation = cmap_accumulation + self.per_sample = per_sample + self.parameters = parameters @rank_zero_only def _plot( @@ -312,7 +328,7 @@ def _plot( name, name not in self.config.data.get("diagnostic", []), ) - for name in self.config.diagnostics.plot.parameters + for name in self.parameters } if self.post_processors is None: @@ -357,10 +373,10 @@ def _plot( fig = plot_predicted_multilevel_flat_sample( plot_parameters_dict, - self.config.diagnostics.plot.per_sample, + self.per_sample, self.latlons, - self.config.diagnostics.plot.get("accumulation_levels_plot", None), - self.config.diagnostics.plot.get("cmap_accumulation", None), + self.accumulation_levels_plot, + self.cmap_accumulation, data_0.squeeze(), data_rollout_step.squeeze(), output_tensor[0, 0, :, :], # rolloutstep, first member From f1d883f538498cfc0d7b158c7300ff9e3b25fce9 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Fri, 25 Oct 2024 14:27:17 +0000 Subject: [PATCH 33/40] Add rollout_eval config --- .../config/diagnostics/plot/detailed.yaml | 6 -- .../config/diagnostics/plot/rollout_eval.yaml | 68 +++++++++++++++++++ 2 files changed, 68 insertions(+), 6 deletions(-) create mode 100644 src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index 7c01b575..6ff7875e 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -60,9 +60,3 @@ callbacks: - 2t - 10u - 10v - - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots - rollout: - - ${dataloader.validation_rollout} - epoch_frequency: 20 - sample_idx: ${diagnostics.plot.sample_idx} - parameters: ${diagnostics.plot.parameters} diff --git a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml new file mode 100644 index 00000000..7c01b575 --- /dev/null +++ b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml @@ -0,0 +1,68 @@ +asynchronous: True # Whether to plot asynchronously +frequency: # Frequency of the plotting + batch: 750 + epoch: 5 + +# Parameters to plot +parameters: +- z_500 +- t_850 +- u_850 +- v_850 +- 2t +- 10u +- 10v +- sp +- tp +- cp + +# Sample index +sample_idx: 0 + +# Precipitation and related fields +precip_and_related_fields: [tp, cp] + +callbacks: + # Add plot callbacks here + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot + - _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot + epoch_frequency: 5 + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss + # group parameters by categories when visualizing contributions to the loss + # one-parameter groups are possible to highlight individual parameters + parameter_groups: + moisture: [tp, cp, tcw] + sfc_wind: [10u, 10v] + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample + sample_idx: ${diagnostics.plot.sample_idx} + per_sample : 6 + parameters: ${diagnostics.plot.parameters} + #Defining the accumulation levels for precipitation related fields and the colormap + accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm + cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"] + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum + # batch_frequency: 100 # Override for batch frequency + sample_idx: ${diagnostics.plot.sample_idx} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v + - _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram + sample_idx: ${diagnostics.plot.sample_idx} + precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} + parameters: + - z_500 + - tp + - 2t + - 10u + - 10v + - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots + rollout: + - ${dataloader.validation_rollout} + epoch_frequency: 20 + sample_idx: ${diagnostics.plot.sample_idx} + parameters: ${diagnostics.plot.parameters} From 66bd3062f773b1846939f0861d15ad175a5a7f6d Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 28 Oct 2024 08:58:14 +0000 Subject: [PATCH 34/40] Add training mode to rollout step --- .../training/diagnostics/callbacks/evaluation.py | 7 ++++++- src/anemoi/training/diagnostics/callbacks/plot.py | 4 +++- src/anemoi/training/train/forecaster.py | 15 +++++++++++---- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index d09378b0..4619fedd 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -65,7 +65,12 @@ def _eval( ) with torch.no_grad(): - for loss_next, metrics_next, _ in pl_module.rollout_step(batch, rollout=self.rollout, validation_mode=True): + for loss_next, metrics_next, _ in pl_module.rollout_step( + batch, + rollout=self.rollout, + validation_mode=True, + training_mode=True, + ): loss += loss_next metrics.update(metrics_next) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 59a2deb7..5519d7da 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -354,7 +354,9 @@ def _plot( # start rollout with torch.no_grad(): - for rollout_step, (_, _, y_pred) in enumerate(pl_module.rollout_step(batch, rollout=max(self.rollout))): + for rollout_step, (_, _, y_pred) in enumerate( + pl_module.rollout_step(batch, rollout=max(self.rollout), validation_mode=False, training_mode=False), + ): if (rollout_step + 1) in self.rollout: # prepare true output tensor for plotting diff --git a/src/anemoi/training/train/forecaster.py b/src/anemoi/training/train/forecaster.py index 9c01b1dc..03ac2782 100644 --- a/src/anemoi/training/train/forecaster.py +++ b/src/anemoi/training/train/forecaster.py @@ -15,6 +15,7 @@ from collections.abc import Generator from collections.abc import Mapping from typing import Optional +from typing import Union import numpy as np import pytorch_lightning as pl @@ -227,8 +228,9 @@ def rollout_step( self, batch: torch.Tensor, rollout: Optional[int] = None, # noqa: FA100 + training_mode: bool = True, validation_mode: bool = False, - ) -> Generator[tuple[torch.Tensor, dict, list], None, None]: + ) -> Generator[tuple[Union[torch.Tensor, None], dict, list], None, None]: # noqa: FA100 """ Rollout step for the forecaster. @@ -238,15 +240,19 @@ def rollout_step( ---------- batch : torch.Tensor Batch to use for rollout - rollout : int | None, optional + rollout : Optional[int], optional Number of times to rollout for, by default None + If None, will use self.rollout + training_mode : bool, optional + Whether in training mode and to calculate the loss, by default True + If False, loss will be None validation_mode : bool, optional Whether in validation mode, and to calculate validation metrics, by default False If False, metrics will be empty Yields ------ - Generator[tuple[torch.Tensor, dict, list], None, None] + Generator[tuple[Union[torch.Tensor, None], dict, list], None, None] Loss value, metrics, and predictions (per step) Returns @@ -276,7 +282,7 @@ def rollout_step( y = batch[:, self.multi_step + rollout_step, ..., self.data_indices.internal_data.output.full] # y includes the auxiliary variables, so we must leave those out when computing the loss - loss = checkpoint(self.loss, y_pred, y, use_reentrant=False) + loss = checkpoint(self.loss, y_pred, y, use_reentrant=False) if training_mode else None x = self.advance_input(x, y_pred, batch, rollout_step) @@ -303,6 +309,7 @@ def _step( for loss_next, metrics_next, y_preds_next in self.rollout_step( batch, rollout=self.rollout, + training_mode=True, validation_mode=validation_mode, ): loss += loss_next From 8dfe25d70f523138141a714599df09177a070796 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 28 Oct 2024 08:58:36 +0000 Subject: [PATCH 35/40] Force LongRolloutPlots to plot in serial --- src/anemoi/training/diagnostics/callbacks/plot.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 5519d7da..a1725153 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -415,7 +415,8 @@ def on_validation_batch_end( context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() with context: - self.plot(trainer, pl_module, output, batch, batch_idx, trainer.current_epoch) + # Issue with running asyncronously, so call the plot function directly + self._plot(trainer, pl_module, output, batch, batch_idx, trainer.current_epoch) class GraphNodeTrainableFeaturesPlot(BasePerEpochPlotCallback): From 942e06f803e121bf72ae73b4e33a5cc4e9fbe6a6 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 28 Oct 2024 10:19:33 +0000 Subject: [PATCH 36/40] Add warning to LongRolloutPlots when async --- src/anemoi/training/diagnostics/callbacks/plot.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index a1725153..e82cf593 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -388,8 +388,8 @@ def _plot( logger, fig, epoch=epoch, - tag=f"gnn_pred_val_sample_rstep{rollout_step:03d}_batch{batch_idx:04d}_rank0", - exp_log_tag=f"val_pred_sample_rstep{rollout_step:03d}_rank{local_rank:01d}", + tag=f"gnn_pred_val_sample_rstep{rollout_step + 1:03d}_batch{batch_idx:04d}_rank0", + exp_log_tag=f"val_pred_sample_rstep{rollout_step + 1:03d}_rank{local_rank:01d}", ) LOGGER.info( "Time taken to plot samples after longer rollout: %s seconds", @@ -414,6 +414,9 @@ def on_validation_batch_end( dtype = precision_mapping.get(prec) context = torch.autocast(device_type=batch.device.type, dtype=dtype) if dtype is not None else nullcontext() + if self.config.diagnostics.plot.asynchronous: + LOGGER.warning("Asynchronous plotting not supported for long rollout plots.") + with context: # Issue with running asyncronously, so call the plot function directly self._plot(trainer, pl_module, output, batch, batch_idx, trainer.current_epoch) From 84072a6b14b33528961343bbc2048013c41d3621 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 28 Oct 2024 10:57:22 +0000 Subject: [PATCH 37/40] Fix asserrt calculation --- src/anemoi/training/diagnostics/callbacks/evaluation.py | 2 +- src/anemoi/training/diagnostics/callbacks/plot.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index 4619fedd..6873918a 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -61,7 +61,7 @@ def _eval( assert batch.shape[1] >= self.rollout + pl_module.multi_step, ( "Batch length not sufficient for requested validation rollout length! " - f"Set `dataloader.validation_rollout` to at least {self.rollout + pl_module.multi_step}" + f"Set `dataloader.validation_rollout` to at least {max(self.rollout)}" ) with torch.no_grad(): diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index e82cf593..5ed3cf61 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -340,7 +340,7 @@ def _plot( assert batch.shape[1] >= max(self.rollout) + pl_module.multi_step, ( "Batch length not sufficient for requested validation rollout length! " - f"Set `dataloader.validation_rollout` to at least {max(self.rollout) + pl_module.multi_step}" + f"Set `dataloader.validation_rollout` to at least {max(self.rollout)}" ) # prepare input tensor for plotting From 30dfd451b7769fd9c23059fb2f2a4926e0d87f0f Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 28 Oct 2024 13:39:45 +0000 Subject: [PATCH 38/40] Apply post_processors before plotting in LongRolloutPlots --- src/anemoi/training/diagnostics/callbacks/plot.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 5ed3cf61..dcef7c22 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -344,7 +344,7 @@ def _plot( ) # prepare input tensor for plotting - input_tensor_0 = batch[ + input_tensor_0 = pl_module.model.pre_processors(batch, in_place=False)[ self.sample_idx, pl_module.multi_step - 1, ..., @@ -355,7 +355,12 @@ def _plot( # start rollout with torch.no_grad(): for rollout_step, (_, _, y_pred) in enumerate( - pl_module.rollout_step(batch, rollout=max(self.rollout), validation_mode=False, training_mode=False), + pl_module.rollout_step( + batch, + rollout=max(self.rollout), + validation_mode=False, + training_mode=False, + ), ): if (rollout_step + 1) in self.rollout: From eebaf16225fa25a75a221a9729bcc5d2b55b8e14 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 28 Oct 2024 14:01:48 +0000 Subject: [PATCH 39/40] Fix reference to batch --- src/anemoi/training/diagnostics/callbacks/plot.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index dcef7c22..d55d9e79 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -344,7 +344,8 @@ def _plot( ) # prepare input tensor for plotting - input_tensor_0 = pl_module.model.pre_processors(batch, in_place=False)[ + input_batch = pl_module.model.pre_processors(batch, in_place=False) + input_tensor_0 = input_batch[ self.sample_idx, pl_module.multi_step - 1, ..., @@ -365,7 +366,7 @@ def _plot( if (rollout_step + 1) in self.rollout: # prepare true output tensor for plotting - input_tensor_rollout_step = batch[ + input_tensor_rollout_step = input_batch[ self.sample_idx, pl_module.multi_step + rollout_step, # (pl_module.multi_step - 1) + (rollout_step + 1) ..., From 8b2a30ed21ff0e70e21a70453050025efd892380 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Mon, 28 Oct 2024 16:36:35 +0000 Subject: [PATCH 40/40] Fix debug config --- src/anemoi/training/config/debug.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/training/config/debug.yaml b/src/anemoi/training/config/debug.yaml index 32d58153..a6143bb6 100644 --- a/src/anemoi/training/config/debug.yaml +++ b/src/anemoi/training/config/debug.yaml @@ -18,7 +18,7 @@ defaults: diagnostics: plot: - enabled: False + callbacks: [] hardware: files: graph: ???