Skip to content

Commit

Permalink
Rename frequency to batch_frequency in RolloutEval (#118)
Browse files Browse the repository at this point in the history
* Rename frequency to batch_frequency in RolloutEval

* Revert value

* Update docstring

* Rename frequency to form of every_n_
- Improves readability

* Update changelog.
  • Loading branch information
HCookie authored Nov 6, 2024
1 parent 0e639e4 commit 7ec2e38
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ Keep it human-readable, your future self will thank you!


### Changed
- Renamed frequency keys in callbacks configuration. [#118](https://github.com/ecmwf/anemoi-training/pull/118)
- Modified training configuration to support max_steps and tied lr iterations to max_steps by default [#67](https://github.com/ecmwf/anemoi-training/pull/67)

## [0.2.2 - Maintenance: pin python <3.13](https://github.com/ecmwf/anemoi-training/compare/0.2.1...0.2.2) - 2024-10-28
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Add callbacks here
- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval
rollout: ${dataloader.validation_rollout}
frequency: 20
every_n_batches: 20
4 changes: 2 additions & 2 deletions src/anemoi/training/config/diagnostics/plot/detailed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ callbacks:
# Add plot callbacks here
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot
epoch_frequency: 5
every_n_epochs: 5
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
Expand All @@ -43,7 +43,7 @@ callbacks:
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}

- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum
# batch_frequency: 100 # Override for batch frequency
# every_n_batches: 100 # Override for batch frequency
sample_idx: ${diagnostics.plot.sample_idx}
parameters:
- z_500
Expand Down
6 changes: 3 additions & 3 deletions src/anemoi/training/config/diagnostics/plot/rollout_eval.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ callbacks:
# Add plot callbacks here
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot
epoch_frequency: 5
every_n_epochs: 5
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
Expand All @@ -43,7 +43,7 @@ callbacks:
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}

- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum
# batch_frequency: 100 # Override for batch frequency
# every_n_batches: 100 # Override for batch frequency
sample_idx: ${diagnostics.plot.sample_idx}
parameters:
- z_500
Expand All @@ -63,6 +63,6 @@ callbacks:
- _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots
rollout:
- ${dataloader.validation_rollout}
epoch_frequency: 20
every_n_epochs: 20
sample_idx: ${diagnostics.plot.sample_idx}
parameters: ${diagnostics.plot.parameters}
14 changes: 7 additions & 7 deletions src/anemoi/training/diagnostics/callbacks/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
class RolloutEval(Callback):
"""Evaluates the model performance over a (longer) rollout window."""

def __init__(self, config: OmegaConf, rollout: int, frequency: int) -> None:
def __init__(self, config: OmegaConf, rollout: int, every_n_batches: int) -> None:
"""Initialize RolloutEval callback.
Parameters
Expand All @@ -36,20 +36,20 @@ def __init__(self, config: OmegaConf, rollout: int, frequency: int) -> None:
Dictionary with configuration settings
rollout : int
Rollout length for evaluation
frequency : int
Frequency of evaluation, per batch
every_n_batches : int
Frequency of rollout evaluation, runs every `n` validation batches
"""
super().__init__()
self.config = config

LOGGER.debug(
"Setting up RolloutEval callback with rollout = %d, frequency = %d ...",
"Setting up RolloutEval callback with rollout = %d, every_n_batches = %d ...",
rollout,
frequency,
every_n_batches,
)
self.rollout = rollout
self.frequency = frequency
self.every_n_batches = every_n_batches

def _eval(
self,
Expand Down Expand Up @@ -113,7 +113,7 @@ def on_validation_batch_end(
batch_idx: int,
) -> None:
del outputs # outputs are not used
if batch_idx % self.frequency == 0:
if batch_idx % self.every_n_batches == 0:
precision_mapping = {
"16-mixed": torch.float16,
"bf16-mixed": torch.bfloat16,
Expand Down
62 changes: 31 additions & 31 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,20 @@ def _async_plot(
class BasePerBatchPlotCallback(BasePlotCallback):
"""Base Callback for plotting at the end of each batch."""

def __init__(self, config: OmegaConf, batch_frequency: int | None = None):
def __init__(self, config: OmegaConf, every_n_batches: int | None = None):
"""Initialise the BasePerBatchPlotCallback.
Parameters
----------
config : OmegaConf
Config object
batch_frequency : int, optional
every_n_batches : int, optional
Batch Frequency to plot at, by default None
If not given, uses default from config at `diagnostics.plot.frequency.batch`
"""
super().__init__(config)
self.batch_frequency = batch_frequency or self.config.diagnostics.plot.frequency.batch
self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch

@abstractmethod
@rank_zero_only
Expand All @@ -216,7 +216,7 @@ def on_validation_batch_end(
batch_idx: int,
**kwargs,
) -> None:
if batch_idx % self.batch_frequency == 0:
if batch_idx % self.every_n_batches == 0:
self.plot(
trainer,
pl_module,
Expand All @@ -231,19 +231,19 @@ def on_validation_batch_end(
class BasePerEpochPlotCallback(BasePlotCallback):
"""Base Callback for plotting at the end of each epoch."""

def __init__(self, config: OmegaConf, epoch_frequency: int | None = None):
def __init__(self, config: OmegaConf, every_n_epochs: int | None = None):
"""Initialise the BasePerEpochPlotCallback.
Parameters
----------
config : OmegaConf
Config object
epoch_frequency : int, optional
every_n_epochs : int, optional
Epoch frequency to plot at, by default None
If not given, uses default from config at `diagnostics.plot.frequency.epoch`
"""
super().__init__(config)
self.epoch_frequency = epoch_frequency or self.config.diagnostics.plot.frequency.epoch
self.every_n_epochs = every_n_epochs or self.config.diagnostics.plot.frequency.epoch

@rank_zero_only
def on_validation_epoch_end(
Expand All @@ -252,7 +252,7 @@ def on_validation_epoch_end(
pl_module: pl.LightningModule,
**kwargs,
) -> None:
if trainer.current_epoch % self.epoch_frequency == 0:
if trainer.current_epoch % self.every_n_epochs == 0:
self.plot(trainer, pl_module, epoch=trainer.current_epoch, **kwargs)


Expand All @@ -268,7 +268,7 @@ def __init__(
accumulation_levels_plot: list[float] | None = None,
cmap_accumulation: list[str] | None = None,
per_sample: int = 6,
epoch_frequency: int = 1,
every_n_epochs: int = 1,
) -> None:
"""Initialise LongRolloutPlots callback.
Expand All @@ -288,17 +288,17 @@ def __init__(
Colors of the accumulation levels, by default None
per_sample : int, optional
Number of plots per sample, by default 6
epoch_frequency : int, optional
every_n_epochs : int, optional
Epoch frequency to plot at, by default 1
"""
super().__init__(config)

self.epoch_frequency = epoch_frequency
self.every_n_epochs = every_n_epochs

LOGGER.debug(
"Setting up callback for plots with long rollout: rollout = %d, frequency = every %d epoch ...",
rollout,
epoch_frequency,
every_n_epochs,
)
self.rollout = rollout
self.sample_idx = sample_idx
Expand Down Expand Up @@ -412,7 +412,7 @@ def on_validation_batch_end(
batch: torch.Tensor,
batch_idx: int,
) -> None:
if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.epoch_frequency == 0:
if (batch_idx) == 0 and (trainer.current_epoch + 1) % self.every_n_epochs == 0:
precision_mapping = {
"16-mixed": torch.float16,
"bf16-mixed": torch.bfloat16,
Expand All @@ -432,17 +432,17 @@ def on_validation_batch_end(
class GraphNodeTrainableFeaturesPlot(BasePerEpochPlotCallback):
"""Visualize the node trainable features defined."""

def __init__(self, config: OmegaConf, epoch_frequency: int | None = None) -> None:
def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None:
"""Initialise the GraphTrainableFeaturesPlot callback.
Parameters
----------
config : OmegaConf
Config object
epoch_frequency: int | None, optional
every_n_epochs: int | None, optional
Override for frequency to plot at, by default None
"""
super().__init__(config, epoch_frequency=epoch_frequency)
super().__init__(config, every_n_epochs=every_n_epochs)

@rank_zero_only
def _plot(
Expand Down Expand Up @@ -474,17 +474,17 @@ class GraphEdgeTrainableFeaturesPlot(BasePerEpochPlotCallback):
Visualize the trainable features defined at the edges between meshes.
"""

def __init__(self, config: OmegaConf, epoch_frequency: int | None = None) -> None:
def __init__(self, config: OmegaConf, every_n_epochs: int | None = None) -> None:
"""Plot trainable edge features.
Parameters
----------
config : OmegaConf
Config object
epoch_frequency : int | None, optional
every_n_epochs : int | None, optional
Override for frequency to plot at, by default None
"""
super().__init__(config, epoch_frequency=epoch_frequency)
super().__init__(config, every_n_epochs=every_n_epochs)

@rank_zero_only
def _plot(
Expand Down Expand Up @@ -517,7 +517,7 @@ def __init__(
self,
config: OmegaConf,
parameter_groups: dict[dict[str, list[str]]],
batch_frequency: int | None = None,
every_n_batches: int | None = None,
) -> None:
"""Initialise the PlotLoss callback.
Expand All @@ -527,11 +527,11 @@ def __init__(
Object with configuration settings
parameter_groups : dict
Dictionary with parameter groups with parameter names as keys
batch_frequency : int, optional
every_n_batches : int, optional
Override for batch frequency, by default None
"""
super().__init__(config, batch_frequency=batch_frequency)
super().__init__(config, every_n_batches=every_n_batches)
self.parameter_names = None
self.parameter_groups = parameter_groups
if self.parameter_groups is None:
Expand Down Expand Up @@ -689,7 +689,7 @@ def __init__(
cmap_accumulation: list[str],
precip_and_related_fields: list[str] | None = None,
per_sample: int = 6,
batch_frequency: int | None = None,
every_n_batches: int | None = None,
) -> None:
"""Initialise the PlotSample callback.
Expand All @@ -709,10 +709,10 @@ def __init__(
Precip variable names, by default None
per_sample : int, optional
Number of plots per sample, by default 6
batch_frequency : int, optional
every_n_batches : int, optional
Batch frequency to plot at, by default None
"""
super().__init__(config, batch_frequency=batch_frequency)
super().__init__(config, every_n_batches=every_n_batches)
self.sample_idx = sample_idx
self.parameters = parameters

Expand Down Expand Up @@ -850,7 +850,7 @@ def __init__(
config: OmegaConf,
sample_idx: int,
parameters: list[str],
batch_frequency: int | None = None,
every_n_batches: int | None = None,
) -> None:
"""Initialise the PlotSpectrum callback.
Expand All @@ -862,10 +862,10 @@ def __init__(
Sample to plot
parameters : list[str]
Parameters to plot
batch_frequency : int | None, optional
every_n_batches : int | None, optional
Override for batch frequency, by default None
"""
super().__init__(config, batch_frequency=batch_frequency)
super().__init__(config, every_n_batches=every_n_batches)
self.sample_idx = sample_idx
self.parameters = parameters

Expand Down Expand Up @@ -925,7 +925,7 @@ def __init__(
sample_idx: int,
parameters: list[str],
precip_and_related_fields: list[str] | None = None,
batch_frequency: int | None = None,
every_n_batches: int | None = None,
) -> None:
"""Initialise the PlotHistogram callback.
Expand All @@ -939,10 +939,10 @@ def __init__(
Parameters to plot
precip_and_related_fields : list[str] | None, optional
Precip variable names, by default None
batch_frequency : int | None, optional
every_n_batches : int | None, optional
Override for batch frequency, by default None
"""
super().__init__(config, batch_frequency=batch_frequency)
super().__init__(config, every_n_batches=every_n_batches)
self.sample_idx = sample_idx
self.parameters = parameters
self.precip_and_related_fields = precip_and_related_fields
Expand Down

0 comments on commit 7ec2e38

Please sign in to comment.