Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Callbacks #60

Merged
merged 58 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from 46 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
b12fac8
Refactor Callbacks
HCookie Sep 24, 2024
29a8477
Update changelog
HCookie Sep 24, 2024
15824be
Fix TypeError
HCookie Sep 24, 2024
4077bf4
Move to hydra.instantiate
HCookie Sep 25, 2024
494d39d
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Sep 25, 2024
fe37c02
Add __all__
HCookie Sep 25, 2024
2d8275c
Add to base config
HCookie Sep 25, 2024
230eb0e
Fix nested list
HCookie Sep 25, 2024
5547b20
Fix nested get issue
HCookie Sep 26, 2024
1d80cfb
Fix type checking
HCookie Sep 27, 2024
e79dfc7
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 1, 2024
96ab74c
feat: edge plot in callbacks
JPXKQX Oct 1, 2024
4aeb1a5
feat: set default extra callbacks
JPXKQX Oct 1, 2024
816b3af
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 2, 2024
644038f
fix: typing & refactoring
JPXKQX Oct 2, 2024
8356cd4
fix: remove list comprehension
JPXKQX Oct 2, 2024
930e4d2
Refactor according to PR
HCookie Oct 2, 2024
52ea91f
Update deprecation warning
HCookie Oct 4, 2024
0dd81b7
Merge branch 'fxi/refactor_callbacks' into feature/graph-features-cal…
JPXKQX Oct 4, 2024
332f746
Merge pull request #71 from ecmwf/feature/graph-features-callback
HCookie Oct 4, 2024
bb8b9bb
Refactor: Remove backwards compatability,
HCookie Oct 10, 2024
0349be2
Fix tests
HCookie Oct 10, 2024
1e97ff1
PR Fixes
HCookie Oct 15, 2024
d7f713e
Merge branch 'develop' into fix/refactor_callbacks
HCookie Oct 18, 2024
ebfaf90
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 18, 2024
460c8ba
Update Changelog
HCookie Oct 18, 2024
5671c7e
Merge branch 'develop' into fix/refactor_callbacks
HCookie Oct 21, 2024
21c05de
Refactor rollout (#87)
HCookie Oct 21, 2024
3c5e144
Remove batch frequency from LongRolloutPlots
HCookie Oct 21, 2024
5742754
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 21, 2024
8671543
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 22, 2024
382728c
Remove TP reference
HCookie Oct 22, 2024
6fa66cc
Remove missing config reference
HCookie Oct 23, 2024
110fb64
Swapped histogram and spectrum
HCookie Oct 23, 2024
23cc785
Update copyright notice
HCookie Oct 23, 2024
bfe76f3
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 23, 2024
5a6880e
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 24, 2024
51a455d
Fix issues with split of PlotAdditionalMetrics
HCookie Oct 24, 2024
3318675
Merge branch 'fxi/refactor_callbacks' of github.com:ecmwf/anemoi-trai…
HCookie Oct 24, 2024
77bd65d
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 24, 2024
3c6e1af
Fix CHANGELOG
HCookie Oct 25, 2024
86059a9
Fix documentation for callbacks
HCookie Oct 25, 2024
0bce490
Add all callback submodules to docs
HCookie Oct 25, 2024
f5057c6
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 25, 2024
d6e1d9c
Apply suggestions from code review
HCookie Oct 25, 2024
6073d84
Fix init args issue in RolloutPlots
HCookie Oct 25, 2024
f1d883f
Add rollout_eval config
HCookie Oct 25, 2024
66bd306
Add training mode to rollout step
HCookie Oct 28, 2024
8dfe25d
Force LongRolloutPlots to plot in serial
HCookie Oct 28, 2024
942e06f
Add warning to LongRolloutPlots when async
HCookie Oct 28, 2024
8e6ab30
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 28, 2024
84072a6
Fix asserrt calculation
HCookie Oct 28, 2024
42b59e5
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 28, 2024
30dfd45
Apply post_processors before plotting in LongRolloutPlots
HCookie Oct 28, 2024
eebaf16
Fix reference to batch
HCookie Oct 28, 2024
b31da0e
Merge branch 'develop' into fxi/refactor_callbacks
HCookie Oct 28, 2024
8b2a30e
Fix debug config
HCookie Oct 28, 2024
7fe2c05
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
HCookie Oct 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-training/compare/0.2.1...HEAD)

### Fixed
- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60)
- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87)
- Enable longer validation rollout than training

## [0.2.1 - Bugfix: resuming mlflow runs](https://github.com/ecmwf/anemoi-training/compare/0.2.0...0.2.1) - 2024-10-24

### Added
Expand All @@ -20,7 +25,9 @@ Keep it human-readable, your future self will thank you!
- Feature: New `Boolean1DMask` class. Enables rollout training for limited area models. [#79](https://github.com/ecmwf/anemoi-training/pulls/79)

### Fixed

- Mlflow-sync to handle creation of new experiments in the remote server [#83] (https://github.com/ecmwf/anemoi-training/pull/83)
- Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99] (https://github.com/ecmwf/anemoi-training/pull/99)
- ci: fix pyshtools install error (#100) https://github.com/ecmwf/anemoi-training/pull/100
- Mlflow-sync to handle creation of new experiments in the remote server [#83](https://github.com/ecmwf/anemoi-training/pull/83)
- Fix for multi-gpu when using mlflow due to refactoring of _get_mlflow_run_params function [#99](https://github.com/ecmwf/anemoi-training/pull/99)
- ci: fix pyshtools install error [#100](https://github.com/ecmwf/anemoi-training/pull/100)
Expand Down
97 changes: 84 additions & 13 deletions docs/modules/diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,94 @@ functionality to use both Weights & Biases and Tensorboard.

The callbacks can also be used to evaluate forecasts over longer
rollouts beyond the forecast time that the model is trained on. The
number of rollout steps (or forecast iteration steps) is set using
``config.eval.rollout = *num_of_rollout_steps*``.

Note the user has the option to evaluate the callbacks asynchronously
(using the following config option
``config.diagnostics.plot.asynchronous``, which means that the model
training doesn't stop whilst the callbacks are being evaluated).
However, note that callbacks can still be slow, and therefore the
plotting callbacks can be switched off by setting
``config.diagnostics.plot.enabled`` to ``False`` or all the callbacks
can be completely switched off by setting
``config.diagnostics.eval.enabled`` to ``False``.
number of rollout steps for verification (or forecast iteration steps)
is set using ``config.dataloader.validation_rollout =
*num_of_rollout_steps*``.

Callbacks are configured in the config file under the
``config.diagnostics`` key.

For regular callbacks, they can be provided as a list of dictionaries
underneath the ``config.diagnostics.callbacks`` key. Each dictionary
must have a ``_target`` key which is used by hydra to instantiate the
callback, any other kwarg is passed to the callback's constructor.

.. code:: yaml

callbacks:
- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval
rollout: ${dataloader.validation_rollout}
frequency: 20

Plotting callbacks are configured in a similar way, but they are
specified underneath the ``config.diagnostics.plot.callbacks`` key.

This is done to ensure seperation and ease of configuration between
experiments.

``config.diagnostics.plot`` is a broader config file specifying the
parameters to plot, as well as the plotting frequency, and
asynchronosity.

Setting ``config.diagnostics.plot.asynchronous``, means that the model
training doesn't stop whilst the callbacks are being evaluated)

.. code:: yaml

plot:
asynchronous: True # Whether to plot asynchronously
frequency: # Frequency of the plotting
batch: 750
epoch: 5

# Parameters to plot
parameters:
- z_500
- t_850
- u_850

# Sample index
sample_idx: 0

# Precipitation and related fields
precip_and_related_fields: [tp, cp]

callbacks:
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
sample_idx: ${diagnostics.plot.sample_idx}
per_sample : 6
parameters: ${diagnostics.plot.parameters}

Below is the documentation for the default callbacks provided, but it is
also possible for users to add callbacks using the same structure:

.. automodule:: anemoi.training.diagnostics.callbacks
.. automodule:: anemoi.training.diagnostics.callbacks.checkpoint
:members:
:no-undoc-members:
:show-inheritance:

.. automodule:: anemoi.training.diagnostics.callbacks.evaluation
:members:
:no-undoc-members:
:show-inheritance:

.. automodule:: anemoi.training.diagnostics.callbacks.optimiser
:members:
:no-undoc-members:
:show-inheritance:

.. automodule:: anemoi.training.diagnostics.callbacks.plot
:members:
:no-undoc-members:
:show-inheritance:

.. automodule:: anemoi.training.diagnostics.callbacks.provenance
:members:
:no-undoc-members:
:show-inheritance:
4 changes: 2 additions & 2 deletions docs/user-guide/configuring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ settings at the top as follows:
defaults:
- data: zarr
- dataloader: native_grid
- diagnostics: eval_rollout
- diagnostics: evaluation
- hardware: example
- graph: multi_scale
- model: gnn
Expand Down Expand Up @@ -100,7 +100,7 @@ match the dataset you provide.
defaults:
- data: zarr
- dataloader: native_grid
- diagnostics: eval_rollout
- diagnostics: evaluation
- hardware: example
- graph: multi_scale
- model: transformer # Change from default group
Expand Down
2 changes: 1 addition & 1 deletion docs/user-guide/tracking.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ the same experiment.
Within the MLflow experiments tab, it is possible to define different
namespaces. To create a new namespace, the user just needs to pass an
'experiment_name'
(``config.diagnostics.eval_rollout.log.mlflow.experiment_name``) to the
(``config.diagnostics.evaluation.log.mlflow.experiment_name``) to the
mlflow logger.

**Parent-Child Runs**
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/training/config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- data: zarr
- dataloader: native_grid
- diagnostics: eval_rollout
- diagnostics: evaluation
- hardware: example
- graph: multi_scale
- model: gnn
Expand Down
2 changes: 2 additions & 0 deletions src/anemoi/training/config/dataloader/native_grid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ training:
frequency: ${data.frequency}
drop: []

validation_rollout: 1 # number of rollouts to use for validation, must be equal or greater than rollout expected by callbacks

validation:
dataset: ${dataloader.dataset}
start: 2021
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/training/config/debug.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
defaults:
- data: zarr
- dataloader: native_grid
- diagnostics: eval_rollout
- diagnostics: evaluation
- hardware: example
- graph: multi_scale
- model: gnn
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Add callbacks here
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Add callbacks here
- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval
rollout: ${dataloader.validation_rollout}
frequency: 20
Original file line number Diff line number Diff line change
@@ -1,53 +1,8 @@
---
eval:
enabled: False
# use this to evaluate the model over longer rollouts, every so many validation batches
rollout: 12
frequency: 20
plot:
enabled: True
asynchronous: True
frequency: 750
sample_idx: 0
per_sample: 6
parameters:
- z_500
- t_850
- u_850
- v_850
- 2t
- 10u
- 10v
- sp
- tp
- cp
#Defining the accumulation levels for precipitation related fields and the colormap
accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"]
precip_and_related_fields: [tp, cp]
# Histogram and Spectrum plots
parameters_histogram:
- z_500
- tp
- 2t
- 10u
- 10v
parameters_spectrum:
- z_500
- tp
- 2t
- 10u
- 10v
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
learned_features: False
longrollout:
enabled: False
rollout: [60]
frequency: 20 # every X epochs
defaults:
mc4117 marked this conversation as resolved.
Show resolved Hide resolved
- plot: detailed
- callbacks: pretraining


debug:
# this will detect and trace back NaNs / Infs etc. but will slow down training
Expand All @@ -57,6 +12,7 @@ debug:
# remember to also activate the tensorboard logger (below)
profiler: False

enable_checkpointing: True
checkpoint:
every_n_minutes:
save_frequency: 30 # Approximate, as this is checked at the end of training steps
Expand Down
68 changes: 68 additions & 0 deletions src/anemoi/training/config/diagnostics/plot/detailed.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
asynchronous: True # Whether to plot asynchronously
frequency: # Frequency of the plotting
batch: 750
epoch: 5

# Parameters to plot
parameters:
- z_500
- t_850
- u_850
- v_850
- 2t
- 10u
- 10v
- sp
- tp
- cp

# Sample index
sample_idx: 0

# Precipitation and related fields
precip_and_related_fields: [tp, cp]

callbacks:
# Add plot callbacks here
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphNodeTrainableFeaturesPlot
- _target_: anemoi.training.diagnostics.callbacks.plot.GraphEdgeTrainableFeaturesPlot
epoch_frequency: 5
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
sample_idx: ${diagnostics.plot.sample_idx}
per_sample : 6
HCookie marked this conversation as resolved.
Show resolved Hide resolved
parameters: ${diagnostics.plot.parameters}
#Defining the accumulation levels for precipitation related fields and the colormap
accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"]
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}

- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSpectrum
# batch_frequency: 100 # Override for batch frequency
sample_idx: ${diagnostics.plot.sample_idx}
parameters:
- z_500
- tp
- 2t
- 10u
- 10v
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotHistogram
sample_idx: ${diagnostics.plot.sample_idx}
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}
parameters:
- z_500
- tp
- 2t
- 10u
- 10v
- _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots
HCookie marked this conversation as resolved.
Show resolved Hide resolved
rollout:
- ${dataloader.validation_rollout}
epoch_frequency: 20
sample_idx: ${diagnostics.plot.sample_idx}
parameters: ${diagnostics.plot.parameters}
1 change: 1 addition & 0 deletions src/anemoi/training/config/diagnostics/plot/none.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
callbacks: []
40 changes: 40 additions & 0 deletions src/anemoi/training/config/diagnostics/plot/simple.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
asynchronous: True # Whether to plot asynchronously
frequency: # Frequency of the plotting
batch: 750
epoch: 10

# Parameters to plot
parameters:
- z_500
- t_850
- u_850
- v_850
- 2t
- 10u
- 10v
- sp
- tp
- cp

# Sample index
sample_idx: 0

# Precipitation and related fields
precip_and_related_fields: [tp, cp]

callbacks:
# Add plot callbacks here
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
sample_idx: ${diagnostics.plot.sample_idx}
per_sample : 6
parameters: ${diagnostics.plot.parameters}
#Defining the accumulation levels for precipitation related fields and the colormap
accumulation_levels_plot: [0, 0.05, 0.1, 0.25, 0.5, 1, 1.5, 2, 3, 4, 5, 6, 7, 100] # in mm
cmap_accumulation: ["#ffffff", "#04e9e7", "#019ff4", "#0300f4", "#02fd02", "#01c501", "#008e00", "#fdf802", "#e5bc00", "#fd9500", "#fd0000", "#d40000", "#bc0000", "#f800fd"]
precip_and_related_fields: ${diagnostics.plot.precip_and_related_fields}
6 changes: 2 additions & 4 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,10 +129,8 @@ def ds_train(self) -> NativeGridDataset:
@cached_property
def ds_valid(self) -> NativeGridDataset:
r = self.rollout
if self.config.diagnostics.eval.enabled:
r = max(r, self.config.diagnostics.eval.rollout)
if self.config.diagnostics.plot.get("longrollout") and self.config.diagnostics.plot.longrollout.enabled:
r = max(r, max(self.config.diagnostics.plot.longrollout.rollout))
r = max(r, self.config.dataloader.get("validation_rollout", 1))

assert self.config.dataloader.training.end < self.config.dataloader.validation.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
f"validation start date {self.config.dataloader.validation.start}"
Expand Down
Loading
Loading