diff --git a/CHANGELOG.md b/CHANGELOG.md index f68f20d..470cfaf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Keep it human-readable, your future self will thank you! - configurabilty of the dropout probability in the the MultiHeadSelfAttention module - Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) - GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46) +- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64) - Contributors file [#69](https://github.com/ecmwf/anemoi-models/pull/69) ### Changed diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index c7dbefc..4f4efaa 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -12,6 +12,7 @@ import torch from torch import Tensor from torch import nn +from torch_geometric.data import HeteroData class TrainableTensor(nn.Module): @@ -36,8 +37,62 @@ def __init__(self, tensor_size: int, trainable_size: int) -> None: def forward(self, x: Tensor, batch_size: int) -> Tensor: latent = [einops.repeat(x, "e f -> (repeat e) f", repeat=batch_size)] if self.trainable is not None: - latent.append(einops.repeat(self.trainable, "e f -> (repeat e) f", repeat=batch_size)) + latent.append(einops.repeat(self.trainable.to(x.device), "e f -> (repeat e) f", repeat=batch_size)) return torch.cat( latent, dim=-1, # feature dimension ) + + +class NamedNodesAttributes(torch.nn.Module): + """Named Nodes Attributes information. + + Attributes + ---------- + nodes_names : list[str] + List of nodes names in the graph. + num_nodes : dict[str, int] + Number of nodes for each group of nodes. + attr_ndims : dict[str, int] + Total dimension of node attributes (non-trainable + trainable) for each group of nodes. + trainable_tensors : nn.ModuleDict + Dictionary of trainable tensors for each group of nodes. + """ + + nodes_names: list[str] + num_nodes: dict[str, int] + attr_ndims: dict[str, int] + trainable_tensors: dict[str, TrainableTensor] + + def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None: + """Initialize NamedNodesAttributes.""" + super().__init__() + + self.define_fixed_attributes(graph_data, num_trainable_params) + + self.trainable_tensors = nn.ModuleDict() + for nodes_name, nodes in graph_data.node_items(): + self.register_coordinates(nodes_name, nodes.x) + self.register_tensor(nodes_name, num_trainable_params) + + def define_fixed_attributes(self, graph_data: HeteroData, num_trainable_params: int) -> None: + """Define fixed attributes.""" + nodes_names = list(graph_data.node_types) + self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in nodes_names} + self.attr_ndims = { + nodes_name: 2 * graph_data[nodes_name].x.shape[1] + num_trainable_params for nodes_name in nodes_names + } + + def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: + """Register coordinates.""" + sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1) + self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) + + def register_tensor(self, name: str, num_trainable_params: int) -> None: + """Register a trainable tensor.""" + self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], num_trainable_params) + + def forward(self, name: str, batch_size: int) -> Tensor: + """Forward pass.""" + latlons = getattr(self, f"latlons_{name}") + return self.trainable_tensors[name](latlons, batch_size) diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index bdb6260..c67c8c0 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -22,7 +22,7 @@ from torch_geometric.data import HeteroData from anemoi.models.distributed.shapes import get_shape_shards -from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.layers.graph import NamedNodesAttributes LOGGER = logging.getLogger(__name__) @@ -56,33 +56,24 @@ def __init__( self._calculate_shapes_and_indices(data_indices) self._assert_matching_indices(data_indices) - - self.multi_step = model_config.training.multistep_input - - self._define_tensor_sizes(model_config) - - # Create trainable tensors - self._create_trainable_attributes() - - # Register lat/lon of nodes - self._register_latlon("data", self._graph_name_data) - self._register_latlon("hidden", self._graph_name_hidden) - self.data_indices = data_indices + self.multi_step = model_config.training.multistep_input self.num_channels = model_config.model.num_channels - input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size + self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) + + input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] # Encoder data -> hidden self.encoder = instantiate( model_config.model.encoder, in_channels_src=input_dim, - in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size, + in_channels_dst=self.node_attributes.attr_ndims[self._graph_name_hidden], hidden_dim=self.num_channels, sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], - src_grid_size=self._data_grid_size, - dst_grid_size=self._hidden_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) # Processor hidden -> hidden @@ -90,8 +81,8 @@ def __init__( model_config.model.processor, num_channels=self.num_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._hidden_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) # Decoder hidden -> data @@ -102,8 +93,8 @@ def __init__( hidden_dim=self.num_channels, out_channels_dst=self.num_output_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._data_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], ) # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) @@ -133,34 +124,6 @@ def _assert_matching_indices(self, data_indices: dict) -> None: self._internal_output_idx, ), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}" - def _define_tensor_sizes(self, config: DotDict) -> None: - self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes - self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes - - self.trainable_data_size = config.model.trainable_parameters.data - self.trainable_hidden_size = config.model.trainable_parameters.hidden - - def _register_latlon(self, name: str, nodes: str) -> None: - """Register lat/lon buffers. - - Parameters - ---------- - name : str - Name to store the lat-lon coordinates of the nodes. - nodes : str - Name of nodes to map - """ - coords = self._graph_data[nodes].x - sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) - self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) - - def _create_trainable_attributes(self) -> None: - """Create all trainable attributes.""" - self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) - self.trainable_hidden = TrainableTensor( - trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size - ) - def _run_mapper( self, mapper: nn.Module, @@ -210,12 +173,12 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> x_data_latent = torch.cat( ( einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), - self.trainable_data(self.latlons_data, batch_size=batch_size), + self.node_attributes(self._graph_name_data, batch_size=batch_size), ), dim=-1, # feature dimension ) - x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size) + x_hidden_latent = self.node_attributes(self._graph_name_hidden, batch_size=batch_size) # get shard shapes shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group) diff --git a/tests/layers/test_graph.py b/tests/layers/test_graph.py index 58674bd..66456d6 100644 --- a/tests/layers/test_graph.py +++ b/tests/layers/test_graph.py @@ -8,10 +8,14 @@ # nor does it submit to any jurisdiction. +import einops +import numpy as np import pytest import torch from torch import nn +from torch_geometric.data import HeteroData +from anemoi.models.layers.graph import NamedNodesAttributes from anemoi.models.layers.graph import TrainableTensor @@ -62,3 +66,78 @@ def test_forward_no_trainable(self, init, x): batch_size = 5 output = trainable_tensor(x, batch_size) assert output.shape == (batch_size * x.shape[0], tensor_size + trainable_size) + + +class TestNamedNodesAttributes: + """Test suite for the NamedNodesAttributes class. + + This class contains test cases to verify the functionality of the NamedNodesAttributes class, + including initialization, attribute registration, and forward pass operations. + """ + + nodes_names: list[str] = ["nodes1", "nodes2"] + ndim: int = 2 + num_trainable_params: int = 8 + + @pytest.fixture + def graph_data(self): + graph = HeteroData() + for i, nodes_name in enumerate(TestNamedNodesAttributes.nodes_names): + graph[nodes_name].x = TestNamedNodesAttributes.get_n_random_coords(10 + 5 ** (i + 1)) + return graph + + @staticmethod + def get_n_random_coords(n: int) -> torch.Tensor: + coords = torch.rand(n, TestNamedNodesAttributes.ndim) + coords[:, 0] = np.pi * (coords[:, 0] - 1 / 2) + coords[:, 1] = 2 * np.pi * coords[:, 1] + return coords + + @pytest.fixture + def nodes_attributes(self, graph_data: HeteroData) -> NamedNodesAttributes: + return NamedNodesAttributes(TestNamedNodesAttributes.num_trainable_params, graph_data) + + def test_init(self, nodes_attributes): + assert isinstance(nodes_attributes, NamedNodesAttributes) + + for nodes_name in self.nodes_names: + assert isinstance(nodes_attributes.num_nodes[nodes_name], int) + assert ( + nodes_attributes.attr_ndims[nodes_name] - 2 * TestNamedNodesAttributes.ndim + == TestNamedNodesAttributes.num_trainable_params + ) + assert isinstance(nodes_attributes.trainable_tensors[nodes_name], TrainableTensor) + + def test_forward(self, nodes_attributes, graph_data): + batch_size = 3 + for nodes_name in self.nodes_names: + output = nodes_attributes(nodes_name, batch_size) + + expected_shape = ( + batch_size * graph_data[nodes_name].num_nodes, + 2 * TestNamedNodesAttributes.ndim + TestNamedNodesAttributes.num_trainable_params, + ) + assert output.shape == expected_shape + + # Check if the first part of the output matches the sin-cos transformed coordinates + latlons = getattr(nodes_attributes, f"latlons_{nodes_name}") + repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size) + assert torch.allclose(output[:, : 2 * TestNamedNodesAttributes.ndim], repeated_latlons) + + # Check if the last part of the output is trainable (requires grad) + assert output[:, 2 * TestNamedNodesAttributes.ndim :].requires_grad + + def test_forward_no_trainable(self, graph_data): + no_trainable_attributes = NamedNodesAttributes(0, graph_data) + batch_size = 2 + + for nodes_name in self.nodes_names: + output = no_trainable_attributes(nodes_name, batch_size) + + expected_shape = batch_size * graph_data[nodes_name].num_nodes, 2 * TestNamedNodesAttributes.ndim + assert output.shape == expected_shape + + # Check if the output exactly matches the sin-cos transformed coordinates + latlons = getattr(no_trainable_attributes, f"latlons_{nodes_name}") + repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size) + assert torch.allclose(output, repeated_latlons)