Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/22 feature add tendency training #23

Draft
wants to merge 10 commits into
base: develop
Choose a base branch
from
12 changes: 10 additions & 2 deletions src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from torch.distributed.distributed_c10d import ProcessGroup
from torch_geometric.data import HeteroData

from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec
Expand Down Expand Up @@ -96,8 +97,15 @@ def _build_model(self) -> None:
config=self.config, data_indices=self.data_indices, graph_data=self.graph_data
)

# Use the forward method of the model directly
self.forward = self.model.forward
def forward(self, x: torch.Tensor, model_comm_group: Optional[ProcessGroup] = None) -> torch.Tensor:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it make sense to have this better visible, e.g. in the name here? Even if we haven't the alternativ implemented at the moment (but which makes sense to have for anemoi, e.g. with obs the residual is typically not possible).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree, I think instead of self.tendency_mode there should be a self.prediction_mode which can take the values 'state', 'residual' or 'tendency'. I will try to change that when we open the PR with anemoi-training.

if self.tendency_mode:
# Predict tendency
x_pred = self.model.forward(x, model_comm_group)
else:
# Predict state by adding residual connection (just for the prognostic variables)
x_pred = self.model.forward(x, model_comm_group)
x_pred[..., self.model._internal_output_idx] += x[:, -1, :, :, self.model._internal_input_idx]
return x_pred

def predict_step(self, batch: torch.Tensor) -> torch.Tensor:
"""Prediction step for the model.
Expand Down
2 changes: 0 additions & 2 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,4 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
.clone()
)

# residual connection (just for the prognostic variables)
x_out[..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]
return x_out
Loading