From 6e64c91d0a8b0876860888f2c32bf094dc928ed2 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 18 Oct 2024 08:54:23 +0000 Subject: [PATCH 1/9] feat: add NamedNodeAttributes --- src/anemoi/models/layers/graph.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index 71703d9..332a4a3 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -11,6 +11,7 @@ import torch from torch import Tensor from torch import nn +from torch_geometric.data import HeteroData class TrainableTensor(nn.Module): @@ -35,8 +36,36 @@ def __init__(self, tensor_size: int, trainable_size: int) -> None: def forward(self, x: Tensor, batch_size: int) -> Tensor: latent = [einops.repeat(x, "e f -> (repeat e) f", repeat=batch_size)] if self.trainable is not None: - latent.append(einops.repeat(self.trainable, "e f -> (repeat e) f", repeat=batch_size)) + latent.append(einops.repeat(self.trainable.to(x.device), "e f -> (repeat e) f", repeat=batch_size)) return torch.cat( latent, dim=-1, # feature dimension ) + + +class NamedNodesAttributes(torch.nn.Module): + """Named Node Attributes Module.""" + + def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None: + """Initialize NamedNodesAttributes.""" + self.num_trainable_params = num_trainable_params + self.nodes_names = list(graph_data.node_types) + + self.trainable_tensors = nn.ModuleDict() + for nodes_name in self.nodes_names: + self.register_coordinates(nodes_name, graph_data[nodes_name].x) + self.register_tensor(nodes_name, graph_data[nodes_name].num_nodes) + + def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: + """Register coordinates.""" + sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1) + self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) + + def register_tensor(self, name: str, tensor_size: int) -> None: + """Register a trainable tensor.""" + self.trainable_tensors[name] = TrainableTensor(tensor_size, self.num_trainable_params) + + def forward(self, name: str, batch_size: int) -> Tensor: + """Forward pass.""" + latlons = getattr(self, f"latlons_{name}") + return self.trainable_tensors[name](latlons, batch_size) From 5c83f5f6f283e94c19429bfeeb81bf7570cbcfaa Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 18 Oct 2024 09:42:56 +0000 Subject: [PATCH 2/9] feat: use NamedNodesAttributes in AnemoiModelEncProcDec --- src/anemoi/models/layers/graph.py | 17 +++-- .../models/encoder_processor_decoder.py | 65 ++++--------------- 2 files changed, 27 insertions(+), 55 deletions(-) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index 332a4a3..7a4e0d4 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -48,22 +48,31 @@ class NamedNodesAttributes(torch.nn.Module): def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None: """Initialize NamedNodesAttributes.""" + super().__init__() + self.num_trainable_params = num_trainable_params - self.nodes_names = list(graph_data.node_types) + self.register_fixed_attributes(graph_data) self.trainable_tensors = nn.ModuleDict() for nodes_name in self.nodes_names: self.register_coordinates(nodes_name, graph_data[nodes_name].x) - self.register_tensor(nodes_name, graph_data[nodes_name].num_nodes) + self.register_tensor(nodes_name) + + def register_fixed_attributes(self, graph_data: HeteroData) -> None: + """Register fixed attributes.""" + self.nodes_names = list(graph_data.node_types) + self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names} + self.coord_dims = {2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names} + self.attr_ndims = {self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names} def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: """Register coordinates.""" sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1) self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) - def register_tensor(self, name: str, tensor_size: int) -> None: + def register_tensor(self, name: str) -> None: """Register a trainable tensor.""" - self.trainable_tensors[name] = TrainableTensor(tensor_size, self.num_trainable_params) + self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], self.num_trainable_params) def forward(self, name: str, batch_size: int) -> Tensor: """Forward pass.""" diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index c77db6e..592d6d4 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -21,7 +21,7 @@ from torch_geometric.data import HeteroData from anemoi.models.distributed.shapes import get_shape_shards -from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.layers.graph import NamedNodesAttributes LOGGER = logging.getLogger(__name__) @@ -55,33 +55,24 @@ def __init__( self._calculate_shapes_and_indices(data_indices) self._assert_matching_indices(data_indices) - - self.multi_step = model_config.training.multistep_input - - self._define_tensor_sizes(model_config) - - # Create trainable tensors - self._create_trainable_attributes() - - # Register lat/lon of nodes - self._register_latlon("data", self._graph_name_data) - self._register_latlon("hidden", self._graph_name_hidden) - self.data_indices = data_indices + self.multi_step = model_config.training.multistep_input self.num_channels = model_config.model.num_channels - input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size + self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data) + + input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data] # Encoder data -> hidden self.encoder = instantiate( model_config.model.encoder, in_channels_src=input_dim, - in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size, + in_channels_dst=self.node_attributes.attr_ndims[self._graph_name_hidden], hidden_dim=self.num_channels, sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)], - src_grid_size=self._data_grid_size, - dst_grid_size=self._hidden_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_data], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) # Processor hidden -> hidden @@ -89,8 +80,8 @@ def __init__( model_config.model.processor, num_channels=self.num_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._hidden_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], ) # Decoder hidden -> data @@ -101,8 +92,8 @@ def __init__( hidden_dim=self.num_channels, out_channels_dst=self.num_output_channels, sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)], - src_grid_size=self._hidden_grid_size, - dst_grid_size=self._data_grid_size, + src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden], + dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data], ) # Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite) @@ -132,34 +123,6 @@ def _assert_matching_indices(self, data_indices: dict) -> None: self._internal_output_idx, ), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}" - def _define_tensor_sizes(self, config: DotDict) -> None: - self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes - self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes - - self.trainable_data_size = config.model.trainable_parameters.data - self.trainable_hidden_size = config.model.trainable_parameters.hidden - - def _register_latlon(self, name: str, nodes: str) -> None: - """Register lat/lon buffers. - - Parameters - ---------- - name : str - Name to store the lat-lon coordinates of the nodes. - nodes : str - Name of nodes to map - """ - coords = self._graph_data[nodes].x - sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) - self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) - - def _create_trainable_attributes(self) -> None: - """Create all trainable attributes.""" - self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size) - self.trainable_hidden = TrainableTensor( - trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size - ) - def _run_mapper( self, mapper: nn.Module, @@ -209,12 +172,12 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> x_data_latent = torch.cat( ( einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), - self.trainable_data(self.latlons_data, batch_size=batch_size), + self.node_attributes(self._graph_name_data, batch_size=batch_size), ), dim=-1, # feature dimension ) - x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size) + x_hidden_latent = self.node_attributes(self._graph_name_hidden, batch_size=batch_size) # get shard shapes shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group) From fdbf92f2e76189297d3953ec8eadf109444f50ce Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 18 Oct 2024 10:14:37 +0000 Subject: [PATCH 3/9] fix: update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 57078fd..963c62d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,7 @@ Keep it human-readable, your future self will thank you! - configurabilty of the dropout probability in the the MultiHeadSelfAttention module - Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13) - GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46) +- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64) ### Changed - Bugfixes for CI From 659652f38f31c8bf49a061c2658aed26397fce0d Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Fri, 18 Oct 2024 10:42:31 +0000 Subject: [PATCH 4/9] fix: typo --- src/anemoi/models/layers/graph.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index 7a4e0d4..c3608c5 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -62,8 +62,8 @@ def register_fixed_attributes(self, graph_data: HeteroData) -> None: """Register fixed attributes.""" self.nodes_names = list(graph_data.node_types) self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names} - self.coord_dims = {2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names} - self.attr_ndims = {self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names} + self.coord_dims = {nodes_name: 2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names} + self.attr_ndims = {nodes_name: self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names} def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: """Register coordinates.""" From f10810cd4fedfdac4dba069ee5316b8564def589 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 18 Oct 2024 10:42:57 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/anemoi/models/layers/graph.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index c3608c5..5d96e73 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -63,7 +63,9 @@ def register_fixed_attributes(self, graph_data: HeteroData) -> None: self.nodes_names = list(graph_data.node_types) self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names} self.coord_dims = {nodes_name: 2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names} - self.attr_ndims = {nodes_name: self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names} + self.attr_ndims = { + nodes_name: self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names + } def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: """Register coordinates.""" From a68b044001e077a12b4d2eef1cd4f67a93150219 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Thu, 24 Oct 2024 15:20:10 +0000 Subject: [PATCH 6/9] fix: add tests --- tests/layers/test_graph.py | 78 +++++++++++++++++++++++++++++++++++++- 1 file changed, 76 insertions(+), 2 deletions(-) diff --git a/tests/layers/test_graph.py b/tests/layers/test_graph.py index 58674bd..19844da 100644 --- a/tests/layers/test_graph.py +++ b/tests/layers/test_graph.py @@ -11,9 +11,11 @@ import pytest import torch from torch import nn +from torch_geometric.data import HeteroData +import numpy as np +import einops -from anemoi.models.layers.graph import TrainableTensor - +from anemoi.models.layers.graph import TrainableTensor, NamedNodesAttributes class TestTrainableTensor: @pytest.fixture @@ -62,3 +64,75 @@ def test_forward_no_trainable(self, init, x): batch_size = 5 output = trainable_tensor(x, batch_size) assert output.shape == (batch_size * x.shape[0], tensor_size + trainable_size) + + +class TestNamedNodesAttributes: + """Test suite for the NamedNodesAttributes class. + + This class contains test cases to verify the functionality of the NamedNodesAttributes class, + including initialization, attribute registration, and forward pass operations. + """ + nodes_names: list[str] = ["nodes1", "nodes2"] + ndim: int = 2 + num_trainable_params: int = 8 + + @pytest.fixture + def graph_data(self): + graph = HeteroData() + for i, nodes_name in enumerate(TestNamedNodesAttributes.nodes_names): + graph[nodes_name].x = TestNamedNodesAttributes.get_n_random_coords(10 + 5 ** (i+1)) + return graph + + @staticmethod + def get_n_random_coords(n: int) -> torch.Tensor: + coords = torch.rand(n, TestNamedNodesAttributes.ndim) + coords[:, 0] = np.pi * (coords[:, 0] - 1 / 2) + coords[:, 1] = 2 * np.pi * coords[:, 1] + return coords + + @pytest.fixture + def nodes_attributes(self, graph_data: HeteroData) -> NamedNodesAttributes: + return NamedNodesAttributes(TestNamedNodesAttributes.num_trainable_params, graph_data) + + def test_init(self, nodes_attributes): + assert isinstance(nodes_attributes, NamedNodesAttributes) + + for nodes_name in self.nodes_names: + assert isinstance(nodes_attributes.num_nodes[nodes_name], int) + assert nodes_attributes.coord_dims[nodes_name] == 2 * TestNamedNodesAttributes.ndim + assert nodes_attributes.attr_ndims[nodes_name] - 2 * TestNamedNodesAttributes.ndim == TestNamedNodesAttributes.num_trainable_params + assert isinstance(nodes_attributes.trainable_tensors[nodes_name], TrainableTensor) + + def test_forward(self, nodes_attributes, graph_data): + batch_size = 3 + for nodes_name in self.nodes_names: + output = nodes_attributes(nodes_name, batch_size) + + expected_shape = ( + batch_size * graph_data[nodes_name].num_nodes, + 2 * TestNamedNodesAttributes.ndim + TestNamedNodesAttributes.num_trainable_params + ) + assert output.shape == expected_shape + + # Check if the first part of the output matches the sin-cos transformed coordinates + latlons = getattr(nodes_attributes, f"latlons_{nodes_name}") + repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size) + assert torch.allclose(output[:, :2*TestNamedNodesAttributes.ndim], repeated_latlons) + + # Check if the last part of the output is trainable (requires grad) + assert output[:, 2*TestNamedNodesAttributes.ndim:].requires_grad + + def test_forward_no_trainable(self, graph_data): + no_trainable_attributes = NamedNodesAttributes(0, graph_data) + batch_size = 2 + + for nodes_name in self.nodes_names: + output = no_trainable_attributes(nodes_name, batch_size) + + expected_shape = batch_size * graph_data[nodes_name].num_nodes, 2 * TestNamedNodesAttributes.ndim + assert output.shape == expected_shape + + # Check if the output exactly matches the sin-cos transformed coordinates + latlons = getattr(no_trainable_attributes, f"latlons_{nodes_name}") + repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size) + assert torch.allclose(output, repeated_latlons) From 23ce1df94b0ee48742b34a7cff5c0c9256eb03d5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 24 Oct 2024 15:20:36 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/layers/test_graph.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/tests/layers/test_graph.py b/tests/layers/test_graph.py index 19844da..e69e7d7 100644 --- a/tests/layers/test_graph.py +++ b/tests/layers/test_graph.py @@ -8,14 +8,16 @@ # nor does it submit to any jurisdiction. +import einops +import numpy as np import pytest import torch from torch import nn from torch_geometric.data import HeteroData -import numpy as np -import einops -from anemoi.models.layers.graph import TrainableTensor, NamedNodesAttributes +from anemoi.models.layers.graph import NamedNodesAttributes +from anemoi.models.layers.graph import TrainableTensor + class TestTrainableTensor: @pytest.fixture @@ -72,6 +74,7 @@ class TestNamedNodesAttributes: This class contains test cases to verify the functionality of the NamedNodesAttributes class, including initialization, attribute registration, and forward pass operations. """ + nodes_names: list[str] = ["nodes1", "nodes2"] ndim: int = 2 num_trainable_params: int = 8 @@ -80,7 +83,7 @@ class TestNamedNodesAttributes: def graph_data(self): graph = HeteroData() for i, nodes_name in enumerate(TestNamedNodesAttributes.nodes_names): - graph[nodes_name].x = TestNamedNodesAttributes.get_n_random_coords(10 + 5 ** (i+1)) + graph[nodes_name].x = TestNamedNodesAttributes.get_n_random_coords(10 + 5 ** (i + 1)) return graph @staticmethod @@ -89,7 +92,7 @@ def get_n_random_coords(n: int) -> torch.Tensor: coords[:, 0] = np.pi * (coords[:, 0] - 1 / 2) coords[:, 1] = 2 * np.pi * coords[:, 1] return coords - + @pytest.fixture def nodes_attributes(self, graph_data: HeteroData) -> NamedNodesAttributes: return NamedNodesAttributes(TestNamedNodesAttributes.num_trainable_params, graph_data) @@ -100,35 +103,38 @@ def test_init(self, nodes_attributes): for nodes_name in self.nodes_names: assert isinstance(nodes_attributes.num_nodes[nodes_name], int) assert nodes_attributes.coord_dims[nodes_name] == 2 * TestNamedNodesAttributes.ndim - assert nodes_attributes.attr_ndims[nodes_name] - 2 * TestNamedNodesAttributes.ndim == TestNamedNodesAttributes.num_trainable_params + assert ( + nodes_attributes.attr_ndims[nodes_name] - 2 * TestNamedNodesAttributes.ndim + == TestNamedNodesAttributes.num_trainable_params + ) assert isinstance(nodes_attributes.trainable_tensors[nodes_name], TrainableTensor) def test_forward(self, nodes_attributes, graph_data): batch_size = 3 for nodes_name in self.nodes_names: output = nodes_attributes(nodes_name, batch_size) - + expected_shape = ( batch_size * graph_data[nodes_name].num_nodes, - 2 * TestNamedNodesAttributes.ndim + TestNamedNodesAttributes.num_trainable_params + 2 * TestNamedNodesAttributes.ndim + TestNamedNodesAttributes.num_trainable_params, ) assert output.shape == expected_shape # Check if the first part of the output matches the sin-cos transformed coordinates latlons = getattr(nodes_attributes, f"latlons_{nodes_name}") repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size) - assert torch.allclose(output[:, :2*TestNamedNodesAttributes.ndim], repeated_latlons) + assert torch.allclose(output[:, : 2 * TestNamedNodesAttributes.ndim], repeated_latlons) # Check if the last part of the output is trainable (requires grad) - assert output[:, 2*TestNamedNodesAttributes.ndim:].requires_grad + assert output[:, 2 * TestNamedNodesAttributes.ndim :].requires_grad def test_forward_no_trainable(self, graph_data): no_trainable_attributes = NamedNodesAttributes(0, graph_data) batch_size = 2 - + for nodes_name in self.nodes_names: output = no_trainable_attributes(nodes_name, batch_size) - + expected_shape = batch_size * graph_data[nodes_name].num_nodes, 2 * TestNamedNodesAttributes.ndim assert output.shape == expected_shape From 949730063a55cbc3141ca788b26690d29c3463c9 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 28 Oct 2024 11:35:38 +0000 Subject: [PATCH 8/9] feat: drop unused attrs + type hints --- src/anemoi/models/layers/graph.py | 44 ++++++++++++++++++++----------- tests/layers/test_graph.py | 1 - 2 files changed, 29 insertions(+), 16 deletions(-) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index 10a6044..0018e86 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -45,27 +45,41 @@ def forward(self, x: Tensor, batch_size: int) -> Tensor: class NamedNodesAttributes(torch.nn.Module): - """Named Node Attributes Module.""" + """Named Nodes Attributes information. + + Attributes + ---------- + nodes_names : list[str] + List of nodes names in the graph. + num_nodes : dict[str, int] + Number of nodes for each group of nodes. + attr_ndims : dict[str, int] + Total dimension of node attributes (non-trainable + trainable) for each group of nodes. + trainable_tensors : nn.ModuleDict + Dictionary of trainable tensors for each group of nodes. + """ + nodes_names: list[str] + num_nodes: dict[str, int] + attr_ndims: dict[str, int] + trainable_tensors: dict[str, TrainableTensor] def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None: """Initialize NamedNodesAttributes.""" super().__init__() - self.num_trainable_params = num_trainable_params - self.register_fixed_attributes(graph_data) + self.define_fixed_attributes(graph_data, num_trainable_params) self.trainable_tensors = nn.ModuleDict() - for nodes_name in self.nodes_names: - self.register_coordinates(nodes_name, graph_data[nodes_name].x) - self.register_tensor(nodes_name) - - def register_fixed_attributes(self, graph_data: HeteroData) -> None: - """Register fixed attributes.""" - self.nodes_names = list(graph_data.node_types) - self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names} - self.coord_dims = {nodes_name: 2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names} + for nodes_name, nodes in graph_data.node_items(): + self.register_coordinates(nodes_name, nodes.x) + self.register_tensor(nodes_name, num_trainable_params) + + def define_fixed_attributes(self, graph_data: HeteroData, num_trainable_params: int) -> None: + """Define fixed attributes.""" + nodes_names = list(graph_data.node_types) + self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in nodes_names} self.attr_ndims = { - nodes_name: self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names + nodes_name: 2 * graph_data[nodes_name].x.shape[1] + num_trainable_params for nodes_name in nodes_names } def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: @@ -73,9 +87,9 @@ def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None: sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1) self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True) - def register_tensor(self, name: str) -> None: + def register_tensor(self, name: str, num_trainable_params: int) -> None: """Register a trainable tensor.""" - self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], self.num_trainable_params) + self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], num_trainable_params) def forward(self, name: str, batch_size: int) -> Tensor: """Forward pass.""" diff --git a/tests/layers/test_graph.py b/tests/layers/test_graph.py index e69e7d7..66456d6 100644 --- a/tests/layers/test_graph.py +++ b/tests/layers/test_graph.py @@ -102,7 +102,6 @@ def test_init(self, nodes_attributes): for nodes_name in self.nodes_names: assert isinstance(nodes_attributes.num_nodes[nodes_name], int) - assert nodes_attributes.coord_dims[nodes_name] == 2 * TestNamedNodesAttributes.ndim assert ( nodes_attributes.attr_ndims[nodes_name] - 2 * TestNamedNodesAttributes.ndim == TestNamedNodesAttributes.num_trainable_params From c725f12a3a52010a28d9f4a1a6988d37437a0fe3 Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Mon, 28 Oct 2024 11:41:02 +0000 Subject: [PATCH 9/9] fix: style --- src/anemoi/models/layers/graph.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/anemoi/models/layers/graph.py b/src/anemoi/models/layers/graph.py index 0018e86..4f4efaa 100644 --- a/src/anemoi/models/layers/graph.py +++ b/src/anemoi/models/layers/graph.py @@ -46,7 +46,7 @@ def forward(self, x: Tensor, batch_size: int) -> Tensor: class NamedNodesAttributes(torch.nn.Module): """Named Nodes Attributes information. - + Attributes ---------- nodes_names : list[str] @@ -58,6 +58,7 @@ class NamedNodesAttributes(torch.nn.Module): trainable_tensors : nn.ModuleDict Dictionary of trainable tensors for each group of nodes. """ + nodes_names: list[str] num_nodes: dict[str, int] attr_ndims: dict[str, int]