From fcc0677da00a60a5be95007972c98ed6c26a46ca Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 30 Oct 2024 13:31:03 +0000 Subject: [PATCH 1/5] Rename frequency to batch_frequency in RolloutEval --- .../config/diagnostics/callbacks/rollout_eval.yaml | 2 +- .../training/diagnostics/callbacks/evaluation.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml index d7daf8d0..1552e02c 100644 --- a/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml +++ b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml @@ -1,4 +1,4 @@ # Add callbacks here - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval rollout: ${dataloader.validation_rollout} - frequency: 20 + batch_frequency: 100 diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index 6873918a..e5940350 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -27,7 +27,7 @@ class RolloutEval(Callback): """Evaluates the model performance over a (longer) rollout window.""" - def __init__(self, config: OmegaConf, rollout: int, frequency: int) -> None: + def __init__(self, config: OmegaConf, rollout: int, batch_frequency: int) -> None: """Initialize RolloutEval callback. Parameters @@ -36,20 +36,20 @@ def __init__(self, config: OmegaConf, rollout: int, frequency: int) -> None: Dictionary with configuration settings rollout : int Rollout length for evaluation - frequency : int - Frequency of evaluation, per batch + batch_frequency : int + batch_frequency of evaluation """ super().__init__() self.config = config LOGGER.debug( - "Setting up RolloutEval callback with rollout = %d, frequency = %d ...", + "Setting up RolloutEval callback with rollout = %d, batch_frequency = %d ...", rollout, - frequency, + batch_frequency, ) self.rollout = rollout - self.frequency = frequency + self.batch_frequency = batch_frequency def _eval( self, @@ -113,7 +113,7 @@ def on_validation_batch_end( batch_idx: int, ) -> None: del outputs # outputs are not used - if batch_idx % self.frequency == 0: + if batch_idx % self.batch_frequency == 0: precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, From 57c912d483cf39e7ba8f2d6da8703259493cfe53 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Thu, 31 Oct 2024 08:36:47 +0000 Subject: [PATCH 2/5] Revert value --- .../training/config/diagnostics/callbacks/rollout_eval.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml index 1552e02c..d08c2659 100644 --- a/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml +++ b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml @@ -1,4 +1,4 @@ # Add callbacks here - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval rollout: ${dataloader.validation_rollout} - batch_frequency: 100 + batch_frequency: 20 From e970eac4353f24aea73c28fcab366b167b96c599 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Tue, 5 Nov 2024 10:50:23 +0000 Subject: [PATCH 3/5] Update docstring --- src/anemoi/training/diagnostics/callbacks/evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index e5940350..fd6302ea 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -37,7 +37,7 @@ def __init__(self, config: OmegaConf, rollout: int, batch_frequency: int) -> Non rollout : int Rollout length for evaluation batch_frequency : int - batch_frequency of evaluation + Frequency of rollout evaluation, runs every `n` validation batches """ super().__init__() From 578940466e71d1cbb850f5e315a282aa156811b7 Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 6 Nov 2024 12:37:44 +0000 Subject: [PATCH 4/5] Rename frequency to form of every_n_ - Improves readability --- .../diagnostics/callbacks/rollout_eval.yaml | 2 +- .../config/diagnostics/plot/detailed.yaml | 4 +- .../config/diagnostics/plot/rollout_eval.yaml | 6 +- .../diagnostics/callbacks/evaluation.py | 12 ++-- .../training/diagnostics/callbacks/plot.py | 62 +++++++++---------- 5 files changed, 43 insertions(+), 43 deletions(-) diff --git a/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml index d08c2659..6afa04dc 100644 --- a/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml +++ b/src/anemoi/training/config/diagnostics/callbacks/rollout_eval.yaml @@ -1,4 +1,4 @@ # Add callbacks here - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval rollout: ${dataloader.validation_rollout} - batch_frequency: 20 + every_n_batches: 20 diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index 6ff7875e..6ed15fa4 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -26,7 +26,7 @@ callbacks: # Add plot callbacks here - _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot - _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot - epoch_frequency: 5 + every_n_epochs: 5 - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss # group parameters by categories when visualizing contributions to the loss # one-parameter groups are possible to highlight individual parameters @@ -43,7 +43,7 @@ callbacks: precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum - # batch_frequency: 100 # Override for batch frequency + # every_n_batches: 100 # Override for batch frequency sample_idx: ${diagnostics.plot.sample_idx} parameters: - z_500 diff --git a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml index 7c01b575..4c440e24 100644 --- a/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml +++ b/src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml @@ -26,7 +26,7 @@ callbacks: # Add plot callbacks here - _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot - _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot - epoch_frequency: 5 + every_n_epochs: 5 - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss # group parameters by categories when visualizing contributions to the loss # one-parameter groups are possible to highlight individual parameters @@ -43,7 +43,7 @@ callbacks: precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields} - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum - # batch_frequency: 100 # Override for batch frequency + # every_n_batches: 100 # Override for batch frequency sample_idx: ${diagnostics.plot.sample_idx} parameters: - z_500 @@ -63,6 +63,6 @@ callbacks: - _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots rollout: - ${dataloader.validation_rollout} - epoch_frequency: 20 + every_n_epochs: 20 sample_idx: ${diagnostics.plot.sample_idx} parameters: ${diagnostics.plot.parameters} diff --git a/src/anemoi/training/diagnostics/callbacks/evaluation.py b/src/anemoi/training/diagnostics/callbacks/evaluation.py index fd6302ea..fc812121 100644 --- a/src/anemoi/training/diagnostics/callbacks/evaluation.py +++ b/src/anemoi/training/diagnostics/callbacks/evaluation.py @@ -27,7 +27,7 @@ class RolloutEval(Callback): """Evaluates the model performance over a (longer) rollout window.""" - def __init__(self, config: OmegaConf, rollout: int, batch_frequency: int) -> None: + def __init__(self, config: OmegaConf, rollout: int, every_n_batches: int) -> None: """Initialize RolloutEval callback. Parameters @@ -36,7 +36,7 @@ def __init__(self, config: OmegaConf, rollout: int, batch_frequency: int) -> Non Dictionary with configuration settings rollout : int Rollout length for evaluation - batch_frequency : int + every_n_batches : int Frequency of rollout evaluation, runs every `n` validation batches """ @@ -44,12 +44,12 @@ def __init__(self, config: OmegaConf, rollout: int, batch_frequency: int) -> Non self.config = config LOGGER.debug( - "Setting up RolloutEval callback with rollout = %d, batch_frequency = %d ...", + "Setting up RolloutEval callback with rollout = %d, every_n_batches = %d ...", rollout, - batch_frequency, + every_n_batches, ) self.rollout = rollout - self.batch_frequency = batch_frequency + self.every_n_batches = every_n_batches def _eval( self, @@ -113,7 +113,7 @@ def on_validation_batch_end( batch_idx: int, ) -> None: del outputs # outputs are not used - if batch_idx % self.batch_frequency == 0: + if batch_idx % self.every_n_batches == 0: precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 98d16dc3..93f9aa17 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -177,20 +177,20 @@ def _async_plot( class BasePerBatchPlotCallback(BasePlotCallback): """Base Callback for plotting at the end of each batch.""" - def __init__(self, config: OmegaConf, batch_frequency: int | None = None): + def __init__(self, config: OmegaConf, every_n_batches: int | None = None): """Initialise the BasePerBatchPlotCallback. Parameters ---------- config : OmegaConf Config object - batch_frequency : int, optional + every_n_batches : int, optional Batch Frequency to plot at, by default None If not given, uses default from config at `diagnostics.plot.frequency.batch` """ super().__init__(config) - self.batch_frequency = batch_frequency or self.config.diagnostics.plot.frequency.batch + self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch @abstractmethod @rank_zero_only @@ -216,7 +216,7 @@ def on_validation_batch_end( batch_idx: int, **kwargs, ) -> None: - if batch_idx % self.batch_frequency == 0: + if batch_idx % self.every_n_batches == 0: self.plot( trainer, pl_module, @@ -231,19 +231,19 @@ def on_validation_batch_end( class BasePerEpochPlotCallback(BasePlotCallback): """Base Callback for plotting at the end of each epoch.""" - def __init__(self, config: OmegaConf, epoch_frequency: int | None = None): + def __init__(self, config: OmegaConf, every_n_epochs: int | None = None): """Initialise the BasePerEpochPlotCallback. Parameters ---------- config : OmegaConf Config object - epoch_frequency : int, optional + every_n_epochs : int, optional Epoch frequency to plot at, by default None If not given, uses default from config at `diagnostics.plot.frequency.epoch` """ super().__init__(config) - self.epoch_frequency = epoch_frequency or self.config.diagnostics.plot.frequency.epoch + self.every_n_epochs = every_n_epochs or self.config.diagnostics.plot.frequency.epoch @rank_zero_only def on_validation_epoch_end( @@ -252,7 +252,7 @@ def on_validation_epoch_end( pl_module: pl.LightningModule, **kwargs, ) -> None: - if trainer.current_epoch % self.epoch_frequency == 0: + if trainer.current_epoch % self.every_n_epochs == 0: self.plot(trainer, pl_module, epoch=trainer.current_epoch, **kwargs) @@ -268,7 +268,7 @@ def __init__( accumulation_levels_plot: list[float] | None = None, cmap_accumulation: list[str] | None = None, per_sample: int = 6, - epoch_frequency: int = 1, + every_n_epochs: int = 1, ) -> None: """Initialise LongRolloutPlots callback. @@ -288,17 +288,17 @@ def __init__( Colors of the accumulation levels, by default None per_sample : int, optional Number of plots per sample, by default 6 - epoch_frequency : int, optional + every_n_epochs : int, optional Epoch frequency to plot at, by default 1 """ super().__init__(config) - self.epoch_frequency = epoch_frequency + self.every_n_epochs = every_n_epochs LOGGER.debug( "Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...", rollout, - epoch_frequency, + every_n_epochs, ) self.rollout = rollout self.sample_idx = sample_idx @@ -412,7 +412,7 @@ def on_validation_batch_end( batch: torch.Tensor, batch_idx: int, ) -> None: - if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.epoch_frequency == 0: + if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.every_n_epochs == 0: precision_mapping = { "16-mixed": torch.float16, "bf16-mixed": torch.bfloat16, @@ -432,17 +432,17 @@ def on_validation_batch_end( class GraphNodeTrainableFeaturesPlot(BasePerEpochPlotCallback): """Visualize the node trainable features defined.""" - def __init__(self, config: OmegaConf, epoch_frequency: int | None = None) -> None: + def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None: """Initialise the GraphTrainableFeaturesPlot callback. Parameters ---------- config : OmegaConf Config object - epoch_frequency: int | None, optional + every_n_epochs: int | None, optional Override for frequency to plot at, by default None """ - super().__init__(config, epoch_frequency=epoch_frequency) + super().__init__(config, every_n_epochs=every_n_epochs) @rank_zero_only def _plot( @@ -474,17 +474,17 @@ class GraphEdgeTrainableFeaturesPlot(BasePerEpochPlotCallback): Visualize the trainable features defined at the edges between meshes. """ - def __init__(self, config: OmegaConf, epoch_frequency: int | None = None) -> None: + def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None: """Plot trainable edge features. Parameters ---------- config : OmegaConf Config object - epoch_frequency : int | None, optional + every_n_epochs : int | None, optional Override for frequency to plot at, by default None """ - super().__init__(config, epoch_frequency=epoch_frequency) + super().__init__(config, every_n_epochs=every_n_epochs) @rank_zero_only def _plot( @@ -517,7 +517,7 @@ def __init__( self, config: OmegaConf, parameter_groups: dict[dict[str, list[str]]], - batch_frequency: int | None = None, + every_n_batches: int | None = None, ) -> None: """Initialise the PlotLoss callback. @@ -527,11 +527,11 @@ def __init__( Object with configuration settings parameter_groups : dict Dictionary with parameter groups with parameter names as keys - batch_frequency : int, optional + every_n_batches : int, optional Override for batch frequency, by default None """ - super().__init__(config, batch_frequency=batch_frequency) + super().__init__(config, every_n_batches=every_n_batches) self.parameter_names = None self.parameter_groups = parameter_groups if self.parameter_groups is None: @@ -689,7 +689,7 @@ def __init__( cmap_accumulation: list[str], precip_and_related_fields: list[str] | None = None, per_sample: int = 6, - batch_frequency: int | None = None, + every_n_batches: int | None = None, ) -> None: """Initialise the PlotSample callback. @@ -709,10 +709,10 @@ def __init__( Precip variable names, by default None per_sample : int, optional Number of plots per sample, by default 6 - batch_frequency : int, optional + every_n_batches : int, optional Batch frequency to plot at, by default None """ - super().__init__(config, batch_frequency=batch_frequency) + super().__init__(config, every_n_batches=every_n_batches) self.sample_idx = sample_idx self.parameters = parameters @@ -850,7 +850,7 @@ def __init__( config: OmegaConf, sample_idx: int, parameters: list[str], - batch_frequency: int | None = None, + every_n_batches: int | None = None, ) -> None: """Initialise the PlotSpectrum callback. @@ -862,10 +862,10 @@ def __init__( Sample to plot parameters : list[str] Parameters to plot - batch_frequency : int | None, optional + every_n_batches : int | None, optional Override for batch frequency, by default None """ - super().__init__(config, batch_frequency=batch_frequency) + super().__init__(config, every_n_batches=every_n_batches) self.sample_idx = sample_idx self.parameters = parameters @@ -925,7 +925,7 @@ def __init__( sample_idx: int, parameters: list[str], precip_and_related_fields: list[str] | None = None, - batch_frequency: int | None = None, + every_n_batches: int | None = None, ) -> None: """Initialise the PlotHistogram callback. @@ -939,10 +939,10 @@ def __init__( Parameters to plot precip_and_related_fields : list[str] | None, optional Precip variable names, by default None - batch_frequency : int | None, optional + every_n_batches : int | None, optional Override for batch frequency, by default None """ - super().__init__(config, batch_frequency=batch_frequency) + super().__init__(config, every_n_batches=every_n_batches) self.sample_idx = sample_idx self.parameters = parameters self.precip_and_related_fields = precip_and_related_fields From 4fd1977a5da3b17fb48d06f302fdfcb41df754cd Mon Sep 17 00:00:00 2001 From: Harrison Cook Date: Wed, 6 Nov 2024 12:42:49 +0000 Subject: [PATCH 5/5] Update changelog. --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 124ccf50..ecb520f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,6 +23,7 @@ Keep it human-readable, your future self will thank you! - Add synchronisation workflow [#92](https://github.com/ecmwf/anemoi-training/pull/92) ### Changed +- Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118) - Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67) ## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28