Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/22 feature add tendency training #23

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
107 changes: 89 additions & 18 deletions src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
#

import uuid
from typing import Optional

import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from torch.distributed.distributed_c10d import ProcessGroup
from torch_geometric.data import HeteroData

from anemoi.models.preprocessing import Processors
Expand Down Expand Up @@ -39,38 +41,55 @@ class AnemoiModelInterface(torch.nn.Module):
Metadata for the model.
data_indices : dict
Indices for the data.
pre_processors : Processors
Pre-processing steps to apply to the data before passing it to the model.
post_processors : Processors
Post-processing steps to apply to the model's output.
model : AnemoiModelEncProcDec
The underlying Anemoi model.
"""

def __init__(
self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict
self,
*,
config: DotDict,
graph_data: HeteroData,
statistics: dict,
data_indices: dict,
metadata: dict,
statistics_tendencies: Optional[dict] = None,
) -> None:
super().__init__()
self.config = config
self.id = str(uuid.uuid4())
self.multi_step = self.config.training.multistep_input
self.prediction_strategy = self.config.training.prediction_strategy
self.graph_data = graph_data
self.statistics = statistics
self.statistics_tendencies = statistics_tendencies
self.metadata = metadata
self.data_indices = data_indices
self._build_model()

def _build_model(self) -> None:
"""Builds the model and pre- and post-processors."""
# Instantiate processors
processors = [
[name, instantiate(processor, data_indices=self.data_indices, statistics=self.statistics)]
for name, processor in self.config.data.processors.items()
# Instantiate processors for state
processors_state = [
[name, instantiate(processor, statistics=self.statistics, data_indices=self.data_indices)]
for name, processor in self.config.data.processors.state.items()
]

# Assign the processor list pre- and post-processors
self.pre_processors = Processors(processors)
self.post_processors = Processors(processors, inverse=True)
self.pre_processors_state = Processors(processors_state)
self.post_processors_state = Processors(processors_state, inverse=True)

# Instantiate processors for tendency
self.pre_processors_tendency = None
self.post_processors_tendency = None
if self.prediction_strategy == "tendency":
processors_tendency = [
[name, instantiate(processor, statistics=self.statistics_tendencies, data_indices=self.data_indices)]
for name, processor in self.config.data.processors.tendency.items()
]

self.pre_processors_tendency = Processors(processors_tendency)
self.post_processors_tendency = Processors(processors_tendency, inverse=True)

# Instantiate the model
self.model = instantiate(
Expand All @@ -81,8 +100,19 @@ def _build_model(self) -> None:
_recursive_=False, # Disables recursive instantiation by Hydra
)

# Use the forward method of the model directly
self.forward = self.model.forward
def forward(self, x: torch.Tensor, model_comm_group: Optional[ProcessGroup] = None) -> torch.Tensor:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make sense to have this better visible, e.g. in the name here? Even if we haven't the alternativ implemented at the moment (but which makes sense to have for anemoi, e.g. with obs the residual is typically not possible).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I think instead of self.tendency_mode there should be a self.prediction_mode which can take the values 'state', 'residual' or 'tendency'. I will try to change that when we open the PR with anemoi-training.

if self.prediction_strategy == "residual":
# Predict state by adding residual connection (just for the prognostic variables)
x_pred = self.model.forward(x, model_comm_group)
x_pred[..., self.model._internal_output_idx] += x[:, -1, :, :, self.model._internal_input_idx]
else:
x_pred = self.model.forward(x, model_comm_group)

for bounding in self.model.boundings:
# bounding performed in the order specified in the config file
x_pred = bounding(x_pred)

return x_pred

def predict_step(self, batch: torch.Tensor) -> torch.Tensor:
"""Prediction step for the model.
Expand All @@ -97,17 +127,58 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor:
torch.Tensor
Predicted data.
"""
batch = self.pre_processors(batch, in_place=False)

with torch.no_grad():

assert (
len(batch.shape) == 4
), f"The input tensor has an incorrect shape: expected a 4-dimensional tensor, got {batch.shape}!"
x = self.pre_processors_state(
batch[:, 0 : self.multi_step, ...], in_place=False, data_index=self.data_indices.data.input.full
)

# Dimensions are
# batch, timesteps, horizonal space, variables
x = batch[:, 0 : self.multi_step, None, ...] # add dummy ensemble dimension as 3rd index
# batch, timesteps, horizontal space, variables
x = x[:, :, None, ...] # add dummy ensemble dimension as 3rd index
if self.prediction_strategy == "tendency":
tendency_hat = self(x)
y_hat = self.add_tendency_to_state(x[:, -1, ...], tendency_hat)
else:
y_hat = self(x)
y_hat = self.post_processors_state(y_hat, in_place=False, data_index=self.data_indices.data.output.full)

return y_hat

y_hat = self(x)
def add_tendency_to_state(self, state_inp: torch.Tensor, tendency: torch.Tensor) -> torch.Tensor:
"""Add the tendency to the state.

Parameters
----------
state_inp : torch.Tensor
The normalized input state tensor with full input variables.
tendency : torch.Tensor
The normalized tendency tensor output from model.

Returns
-------
torch.Tensor
Predicted data.
"""

state_outp = self.post_processors_tendency(
tendency, in_place=False, data_index=self.data_indices.data.output.full
)

state_outp[..., self.data_indices.model.output.diagnostic] = self.post_processors_state(
tendency[..., self.data_indices.model.output.diagnostic],
in_place=False,
data_index=self.data_indices.data.output.diagnostic,
)

state_outp[..., self.data_indices.model.output.prognostic] += self.post_processors_state(
state_inp[..., self.data_indices.model.input.prognostic],
in_place=False,
data_index=self.data_indices.data.input.prognostic,
)

return self.post_processors(y_hat, in_place=False)
return state_outp
7 changes: 0 additions & 7 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,11 +259,4 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
.clone()
)

# residual connection (just for the prognostic variables)
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]

for bounding in self.boundings:
# bounding performed in the order specified in the config file
x_out = bounding(x_out)

return x_out
16 changes: 11 additions & 5 deletions src/anemoi/models/preprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def _invert_key_value_list(self, method_config: dict[str, list[str]]) -> dict[st
for variable in variables
}

def forward(self, x, in_place: bool = True, inverse: bool = False) -> Tensor:
def forward(
self, x, in_place: bool = True, inverse: bool = False, data_index: Optional[torch.Tensor] = None
) -> Tensor:
"""Process the input tensor.

Parameters
Expand All @@ -106,15 +108,17 @@ def forward(self, x, in_place: bool = True, inverse: bool = False) -> Tensor:
Whether to process the tensor in place
inverse : bool
Whether to inverse transform the input
data_index : torch.Tensor, optional
Normalize only the specified indices, by default.

Returns
-------
torch.Tensor
Processed tensor
"""
if inverse:
return self.inverse_transform(x, in_place=in_place)
return self.transform(x, in_place=in_place)
return self.inverse_transform(x, in_place=in_place, data_index=data_index)
return self.transform(x, in_place=in_place, data_index=data_index)

def transform(self, x, in_place: bool = True) -> Tensor:
"""Process the input tensor."""
Expand Down Expand Up @@ -155,7 +159,7 @@ def __init__(self, processors: list, inverse: bool = False) -> None:
def __repr__(self) -> str:
return f"{self.__class__.__name__} [{'inverse' if self.inverse else 'forward'}]({self.processors})"

def forward(self, x, in_place: bool = True) -> Tensor:
def forward(self, x, in_place: bool = True, data_index: Optional[torch.Tensor] = None) -> Tensor:
"""Process the input tensor.

Parameters
Expand All @@ -164,14 +168,16 @@ def forward(self, x, in_place: bool = True) -> Tensor:
Input tensor
in_place : bool
Whether to process the tensor in place
data_index : Optional[torch.Tensor], optional
Normalize only the specified indices, by default None

Returns
-------
torch.Tensor
Processed tensor
"""
for processor in self.processors.values():
x = processor(x, in_place=in_place, inverse=self.inverse)
x = processor(x, in_place=in_place, inverse=self.inverse, data_index=data_index)

if self.first_run:
self.first_run = False
Expand Down
17 changes: 13 additions & 4 deletions src/anemoi/models/preprocessing/imputer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,9 @@ def _expand_subset_mask(self, x: torch.Tensor, idx_src: int) -> torch.Tensor:
"""Expand the subset of the mask to the correct shape."""
return self.nan_locations[:, idx_src].expand(*x.shape[:-2], -1)

def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
def transform(
self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Impute missing values in the input tensor."""
if not in_place:
x = x.clone()
Expand All @@ -115,7 +117,9 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
self.nan_locations = torch.isnan(x[idx].squeeze())

# Choose correct index based on number of variables
if x.shape[-1] == self.num_training_input_vars:
if data_index is not None:
index = data_index
elif x.shape[-1] == self.num_training_input_vars:
index = self.index_training_input
elif x.shape[-1] == self.num_inference_input_vars:
index = self.index_inference_input
Expand All @@ -131,13 +135,18 @@ def transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
x[..., idx_dst][self._expand_subset_mask(x, idx_src)] = value
return x

def inverse_transform(self, x: torch.Tensor, in_place: bool = True) -> torch.Tensor:
def inverse_transform(
self, x: torch.Tensor, in_place: bool = True, data_index: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""Impute missing values in the input tensor."""
if not in_place:
x = x.clone()

# Replace original nans with nan again
if x.shape[-1] == self.num_training_output_vars:
# Choose correct index based on number of variables
if data_index is not None:
index = data_index
elif x.shape[-1] == self.num_training_output_vars:
index = self.index_training_output
elif x.shape[-1] == self.num_inference_output_vars:
index = self.index_inference_output
Expand Down
6 changes: 4 additions & 2 deletions tests/preprocessing/test_preprocessor_normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def input_normalizer():
{
"diagnostics": {"log": {"code": {"level": "DEBUG"}}},
"data": {
"normalizer": {"default": "mean-std", "min-max": ["x"], "max": ["y"], "none": ["z"], "mean-std": ["q"]},
"normalizers": {
"state": {"default": "mean-std", "min-max": ["x"], "max": ["y"], "none": ["z"], "mean-std": ["q"]}
},
"forcing": ["z", "q"],
"diagnostic": ["other"],
"remapped": {},
Expand Down Expand Up @@ -68,7 +70,7 @@ def remap_normalizer():
}
name_to_index = {"x": 0, "y": 1, "z": 2, "q": 3, "other": 4}
data_indices = IndexCollection(config=config, name_to_index=name_to_index)
return InputNormalizer(config=config.data.normalizer, data_indices=data_indices, statistics=statistics)
return InputNormalizer(config=config.data.normalizers.state, statistics=statistics, data_indices=data_indices)


def test_normalizer_not_inplace(input_normalizer) -> None:
Expand Down
Loading