-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feature/22 feature add tendency training #23
base: develop
Are you sure you want to change the base?
Changes from 8 commits
792c032
fcb2a1e
03b9603
7d973a4
f94e46c
7d49c61
a40ec02
5ae42df
0e19870
2aca330
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,10 +8,12 @@ | |
# | ||
|
||
import uuid | ||
from typing import Optional | ||
|
||
import torch | ||
from anemoi.utils.config import DotDict | ||
from hydra.utils import instantiate | ||
from torch.distributed.distributed_c10d import ProcessGroup | ||
from torch_geometric.data import HeteroData | ||
|
||
from anemoi.models.preprocessing import Processors | ||
|
@@ -39,38 +41,55 @@ class AnemoiModelInterface(torch.nn.Module): | |
Metadata for the model. | ||
data_indices : dict | ||
Indices for the data. | ||
pre_processors : Processors | ||
Pre-processing steps to apply to the data before passing it to the model. | ||
post_processors : Processors | ||
Post-processing steps to apply to the model's output. | ||
model : AnemoiModelEncProcDec | ||
The underlying Anemoi model. | ||
""" | ||
|
||
def __init__( | ||
self, *, config: DotDict, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict | ||
self, | ||
*, | ||
config: DotDict, | ||
graph_data: HeteroData, | ||
statistics: dict, | ||
data_indices: dict, | ||
metadata: dict, | ||
statistics_tendencies: Optional[dict] = None, | ||
) -> None: | ||
super().__init__() | ||
self.config = config | ||
self.id = str(uuid.uuid4()) | ||
self.multi_step = self.config.training.multistep_input | ||
self.prediction_strategy = self.config.training.prediction_strategy | ||
self.graph_data = graph_data | ||
self.statistics = statistics | ||
self.statistics_tendencies = statistics_tendencies | ||
self.metadata = metadata | ||
self.data_indices = data_indices | ||
self._build_model() | ||
|
||
def _build_model(self) -> None: | ||
"""Builds the model and pre- and post-processors.""" | ||
# Instantiate processors | ||
processors = [ | ||
[name, instantiate(processor, data_indices=self.data_indices, statistics=self.statistics)] | ||
for name, processor in self.config.data.processors.items() | ||
# Instantiate processors for state | ||
processors_state = [ | ||
[name, instantiate(processor, statistics=self.statistics, data_indices=self.data_indices)] | ||
for name, processor in self.config.data.processors.state.items() | ||
] | ||
|
||
# Assign the processor list pre- and post-processors | ||
self.pre_processors = Processors(processors) | ||
self.post_processors = Processors(processors, inverse=True) | ||
self.pre_processors_state = Processors(processors_state) | ||
self.post_processors_state = Processors(processors_state, inverse=True) | ||
|
||
# Instantiate processors for tendency | ||
self.pre_processors_tendency = None | ||
self.post_processors_tendency = None | ||
if self.prediction_strategy == "tendency": | ||
processors_tendency = [ | ||
[name, instantiate(processor, statistics=self.statistics_tendencies, data_indices=self.data_indices)] | ||
for name, processor in self.config.data.processors.tendency.items() | ||
] | ||
|
||
self.pre_processors_tendency = Processors(processors_tendency) | ||
self.post_processors_tendency = Processors(processors_tendency, inverse=True) | ||
|
||
# Instantiate the model | ||
self.model = instantiate( | ||
|
@@ -81,8 +100,19 @@ def _build_model(self) -> None: | |
_recursive_=False, # Disables recursive instantiation by Hydra | ||
) | ||
|
||
# Use the forward method of the model directly | ||
self.forward = self.model.forward | ||
def forward(self, x: torch.Tensor, model_comm_group: Optional[ProcessGroup] = None) -> torch.Tensor: | ||
if self.prediction_strategy == "residual": | ||
# Predict state by adding residual connection (just for the prognostic variables) | ||
x_pred = self.model.forward(x, model_comm_group) | ||
x_pred[..., self.model._internal_output_idx] += x[:, -1, :, :, self.model._internal_input_idx] | ||
else: | ||
x_pred = self.model.forward(x, model_comm_group) | ||
|
||
for bounding in self.model.boundings: | ||
# bounding performed in the order specified in the config file | ||
x_pred = bounding(x_pred) | ||
|
||
return x_pred | ||
|
||
def predict_step(self, batch: torch.Tensor) -> torch.Tensor: | ||
"""Prediction step for the model. | ||
|
@@ -97,17 +127,54 @@ def predict_step(self, batch: torch.Tensor) -> torch.Tensor: | |
torch.Tensor | ||
Predicted data. | ||
""" | ||
batch = self.pre_processors(batch, in_place=False) | ||
|
||
with torch.no_grad(): | ||
|
||
assert ( | ||
len(batch.shape) == 4 | ||
), f"The input tensor has an incorrect shape: expected a 4-dimensional tensor, got {batch.shape}!" | ||
x = self.pre_processors_state(batch[:, 0 : self.multi_step, ...], in_place=False) | ||
|
||
# Dimensions are | ||
# batch, timesteps, horizonal space, variables | ||
x = batch[:, 0 : self.multi_step, None, ...] # add dummy ensemble dimension as 3rd index | ||
# batch, timesteps, horizontal space, variables | ||
x = x[..., None, :, :] # add dummy ensemble dimension as 3rd index | ||
if self.prediction_strategy == "tendency": | ||
tendency_hat = self(x) | ||
y_hat = self.add_tendency_to_state(x[:, -1, ...], tendency_hat) | ||
else: | ||
y_hat = self(x) | ||
y_hat = self.post_processors_state(y_hat, in_place=False) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These multiple If you want an alternative, this is what you could do: Since the current AnemoiModelInterface does not have the behaviour you want, you can instead to create another class that behave as you like (copy-paste the whole AnemoiModelInterface in to MyTendencyTrainingAnemoiModelInterface) and use this one instead. Just after copy-pasting, you realise that you have duplicated code between AnemoiModelInterface and MyTendencyTrainingAnemoiModelInterface. To avoid this, you can put the common code into a mother class BaseAnemoiModelInterface. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see that @JesperDramsch had the same opinion on this #23 (comment) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sounds good! I’ve already worked on additional interfaces for a different purpose on a local branch, so this should be straightforward to implement. However, if we split the interface, we’ll likely need to divide the current "AnemoiLightningModule" in anemoi-training into two separate modules: one for ForecastingState (ForecastingStateLightningModule) and another for ForecastingTendency (ForecastingTendencyLightningModule). At the moment, the existing AnemoiLightningModule (which is tied to the current interface) contains logic for handling both state and tendency steps forward. |
||
|
||
return y_hat | ||
|
||
y_hat = self(x) | ||
def add_tendency_to_state(self, state_inp: torch.Tensor, tendency: torch.Tensor) -> torch.Tensor: | ||
"""Add the tendency to the state. | ||
|
||
Parameters | ||
---------- | ||
state_inp : torch.Tensor | ||
The input state tensor with full input variables and unprocessed. | ||
tendency : torch.Tensor | ||
The tendency tensor output from model. | ||
|
||
Returns | ||
------- | ||
torch.Tensor | ||
Predicted data. | ||
""" | ||
|
||
state_outp = self.post_processors_tendency( | ||
tendency, in_place=False, data_index=self.data_indices.data.output.full | ||
) | ||
|
||
state_outp[..., self.data_indices.model.output.diagnostic] = self.post_processors_state( | ||
tendency[..., self.data_indices.model.output.diagnostic], | ||
in_place=False, | ||
data_index=self.data_indices.data.output.diagnostic, | ||
) | ||
|
||
state_outp[..., self.data_indices.model.output.prognostic] += state_inp[ | ||
..., self.data_indices.model.input.prognostic | ||
] | ||
|
||
return self.post_processors(y_hat, in_place=False) | ||
return state_outp |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it make sense to have this better visible, e.g. in the name here? Even if we haven't the alternativ implemented at the moment (but which makes sense to have for anemoi, e.g. with obs the residual is typically not possible).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, I think instead of
self.tendency_mode
there should be aself.prediction_mode
which can take the values 'state', 'residual' or 'tendency'. I will try to change that when we open the PR with anemoi-training.