Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/develop' into fix/refactor_callb…
Browse files Browse the repository at this point in the history
…acks
  • Loading branch information
HCookie committed Oct 29, 2024
2 parents 8b2a30e + 8b6c7f4 commit 7fe2c05
Show file tree
Hide file tree
Showing 16 changed files with 1,619 additions and 88 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ Keep it human-readable, your future self will thank you!
- Refactored callbacks. [#60](https://github.com/ecmwf/anemoi-training/pulls/60)
- Refactored rollout [#87](https://github.com/ecmwf/anemoi-training/pulls/87)
- Enable longer validation rollout than training
### Added
- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)

### Changed

Expand Down
141 changes: 136 additions & 5 deletions docs/modules/losses.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,153 @@
########

This module is used to define the loss function used to train the model.

Anemoi-training exposes a couple of loss functions by default to be
used, all of which are subclassed from ``BaseWeightedLoss``. This class
enables scalar multiplication, and graph node weighting.

.. automodule:: anemoi.training.losses.weightedloss
:members:
:no-undoc-members:
:show-inheritance:

************************
Default Loss Functions
************************

By default anemoi-training trains the model using a latitude-weighted
mean-squared-error, which is defined in the ``WeightedMSELoss`` class in
``aifs/losses/mse.py``.
``anemoi/training/losses/mse.py``. The loss function can be configured
in the config file at ``config.training.training_loss``, and
``config.training.validation_metrics``.

The following loss functions are available by default:

- ``WeightedMSELoss``: Latitude-weighted mean-squared-error.
- ``WeightedMAELoss``: Latitude-weighted mean-absolute-error.
- ``WeightedHuberLoss``: Latitude-weighted Huber loss.
- ``WeightedLogCoshLoss``: Latitude-weighted log-cosh loss.
- ``WeightedRMSELoss``: Latitude-weighted root-mean-squared-error.
- ``CombinedLoss``: Combined component weighted loss.

These are available in the ``anemoi.training.losses`` module, at
``anemoi.training.losses.{short_name}.{class_name}``.

So for example, to use the ``WeightedMSELoss`` class, you would
reference it in the config as follows:

.. code:: yaml
# loss function for the model
training_loss:
# loss class to initialise
_target_: anemoi.training.losses.mse.WeightedMSELoss
# loss function kwargs here
*********
Scalars
*********

In addition to node scaling, the loss function can also be scaled by a
scalar. These are provided by the ``Forecaster`` class, and a user can
define whether to include them in the loss function by setting
``scalars`` in the loss config dictionary.

.. code:: yaml
# loss function for the model
training_loss:
# loss class to initialise
_target_: anemoi.training.losses.mse.WeightedMSELoss
scalars: ['scalar1', 'scalar2']
Currently, the following scalars are available for use:

- ``variable``: Scale by the feature/variable weights as defined in the
config ``config.training.loss_scaling``.

The user can define their own loss function using the same structure as
the ``WeightedMSELoss`` class.
********************
Validation Metrics
********************

.. automodule:: anemoi.training.losses.mse
Validation metrics as defined in the config file at
``config.training.validation_metrics`` follow the same initialise
behaviour as the loss function, but can be a list. In this case all
losses are calculated and logged as a dictionary with the corresponding
name

***********************
Custom Loss Functions
***********************

Additionally, you can define your own loss function by subclassing
``BaseWeightedLoss`` and implementing the ``forward`` method, or by
subclassing ``FunctionalWeightedLoss`` and implementing the
``calculate_difference`` function. The latter abstracts the scaling, and
node weighting, and allows you to just specify the difference
calculation.

.. code:: python
from anemoi.training.losses.weightedloss import FunctionalWeightedLoss
class MyLossFunction(FunctionalWeightedLoss):
def calculate_difference(self, pred, target):
return (pred - target) ** 2
Then in the config, set ``_target_`` to the class name, and any
additional kwargs to the loss function.

*****************
Combined Losses
*****************

Building on the simple single loss functions, a user can define a
combined loss, one that weights and combines multiple loss functions.

This can be done by referencing the ``CombinedLoss`` class in the config
file, and setting the ``losses`` key to a list of loss functions to
combine. Each of those losses is then initalised just like the other
losses above.

.. code:: yaml
training_loss:
__target__: anemoi.training.losses.combined.CombinedLoss
losses:
- __target__: anemoi.training.losses.mse.WeightedMSELoss
- __target__: anemoi.training.losses.mae.WeightedMAELoss
scalars: ['variable']
loss_weights: [1.0,0.5]
All kwargs passed to ``CombinedLoss`` are passed to each of the loss
functions, and the loss weights are used to scale the individual losses
before combining them.

.. automodule:: anemoi.training.losses.combined
:members:
:no-undoc-members:
:show-inheritance:

*******************
Utility Functions
*******************

There is also generic functions that are useful for losses in
``aifs/losses/utils.py``.
``anemoi/training/losses/utils.py``.

``grad_scaler`` is used to automatically scale the loss gradients in the
loss function using the formula in https://arxiv.org/pdf/2306.06079.pdf,
section 4.3.2. This can be switched on in the config by setting the
option ``config.training.loss_gradient_scaling=True``.

``ScaleTensor`` is a class that can record and apply arbitrary scaling
factors to tensors. It supports relative indexing, combining multiple
scalars over the same dimensions, and is only constructed at
broadcasting time, so the shape can be resolved to match the tensor
exactly.

.. automodule:: anemoi.training.losses.utils
:members:
:no-undoc-members:
:show-inheritance:
24 changes: 24 additions & 0 deletions src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,35 @@ swa:
# use ZeroRedundancyOptimizer ; saves memory for larger models
zero_optimizer: False

# loss functions

# dynamic rescaling of the loss gradient
# see https://arxiv.org/pdf/2306.06079.pdf, section 4.3.2
# don't enable this by default until it's been tested and proven beneficial

# loss function for the model
training_loss:
# loss class to initialise
_target_: anemoi.training.losses.mse.WeightedMSELoss
# Scalars to include in loss calculation
# Available scalars include, 'variable'
scalars: ['variable']
ignore_nans: False

loss_gradient_scaling: False

# Validation metrics calculation,
# This may be a list, in which case all metrics will be calculated
# and logged according to their name
validation_metrics:
# loss class to initialise
- _target_: anemoi.training.losses.mse.WeightedMSELoss
# Scalars to include in loss calculation
# Available scalars include, 'variable'
scalars: []
# other kwargs
ignore_nans: True

# length of the "rollout" window (see Keisler's paper)
rollout:
start: 1
Expand Down
6 changes: 6 additions & 0 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from anemoi.training.diagnostics.plots import plot_loss
from anemoi.training.diagnostics.plots import plot_power_spectrum
from anemoi.training.diagnostics.plots import plot_predicted_multilevel_flat_sample
from anemoi.training.losses.weightedloss import BaseWeightedLoss

if TYPE_CHECKING:
import pytorch_lightning as pl
Expand Down Expand Up @@ -647,6 +648,11 @@ def _plot(
parameter_positions = list(pl_module.data_indices.internal_model.output.name_to_index.values())
# reorder parameter_names by position
self.parameter_names = [parameter_names[i] for i in np.argsort(parameter_positions)]
if not isinstance(pl_module.loss, BaseWeightedLoss):
logging.warning(
"Loss function must be a subclass of BaseWeightedLoss, or provide `squash`.",
RuntimeWarning,
)

batch = pl_module.model.pre_processors(batch, in_place=False)
for rollout_step in range(pl_module.rollout):
Expand Down
135 changes: 135 additions & 0 deletions src/anemoi/training/losses/combined.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

from __future__ import annotations

import functools
from typing import Any
from typing import Callable

import torch

from anemoi.training.train.forecaster import GraphForecaster


class CombinedLoss(torch.nn.Module):
"""Combined Loss function."""

def __init__(
self,
*extra_losses: dict[str, Any] | Callable,
losses: tuple[dict[str, Any] | Callable] | None = None,
loss_weights: tuple[int, ...],
**kwargs,
):
"""Combined loss function.
Allows multiple losses to be combined into a single loss function,
and the components weighted.
If a sub loss function requires additional weightings or code created tensors,
that must be `included_` for this function, and then controlled by the underlying
loss function configuration.
Parameters
----------
losses: tuple[dict[str, Any]| Callable]
Tuple of losses to initialise with `GraphForecaster.get_loss_function`.
Allows for kwargs to be passed, and weighings controlled.
*extra_losses: dict[str, Any] | Callable
Additional arg form of losses to include in the combined loss.
loss_weights : tuple[int, ...]
Weights of each loss function in the combined loss.
kwargs: Any
Additional arguments to pass to the loss functions
Examples
--------
>>> CombinedLoss(
{"__target__": "anemoi.training.losses.mse.WeightedMSELoss"},
loss_weights=(1.0,),
node_weights=node_weights
)
--------
>>> CombinedLoss(
losses = [anemoi.training.losses.mse.WeightedMSELoss],
loss_weights=(1.0,),
node_weights=node_weights
)
Or from the config,
```
training_loss:
__target__: anemoi.training.losses.combined.CombinedLoss
losses:
- __target__: anemoi.training.losses.mse.WeightedMSELoss
- __target__: anemoi.training.losses.mae.WeightedMAELoss
scalars: ['variable']
loss_weights: [1.0,0.5]
```
"""
super().__init__()

losses = (*(losses or []), *extra_losses)

assert len(losses) == len(loss_weights), "Number of losses and weights must match"
assert len(losses) > 0, "At least one loss must be provided"

self.losses = [
GraphForecaster.get_loss_function(loss, **kwargs) if isinstance(loss, dict) else loss(**kwargs)
for loss in losses
]
self.loss_weights = loss_weights

def forward(
self,
pred: torch.Tensor,
target: torch.Tensor,
**kwargs,
) -> torch.Tensor:
"""Calculates the combined loss.
Parameters
----------
pred : torch.Tensor
Prediction tensor, shape (bs, ensemble, lat*lon, n_outputs)
target : torch.Tensor
Target tensor, shape (bs, ensemble, lat*lon, n_outputs)
kwargs: Any
Additional arguments to pass to the loss functions
Will be passed to all loss functions
Returns
-------
torch.Tensor
Combined loss
"""
loss = None
for i, loss_fn in enumerate(self.losses):
if loss is not None:
loss += self.loss_weights[i] * loss_fn(pred, target, **kwargs)
else:
loss = self.loss_weights[i] * loss_fn(pred, target, **kwargs)
return loss

@property
def name(self) -> str:
return "combined_" + "_".join(getattr(loss, "name", loss.__class__.__name__.lower()) for loss in self.losses)

def __getattr__(self, name: str) -> Callable:
"""Allow access to underlying attributes of the loss functions."""
if not all(hasattr(loss, name) for loss in self.losses):
error_msg = f"Attribute {name} not found in all loss functions"
raise AttributeError(error_msg)

@functools.wraps(getattr(self.losses[0], name))
def hidden_func(*args, **kwargs) -> list[Any]:
return [getattr(loss, name)(*args, **kwargs) for loss in self.losses]

return hidden_func
Loading

0 comments on commit 7fe2c05

Please sign in to comment.