Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor node attributes #64

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Keep it human-readable, your future self will thank you!
- configurabilty of the dropout probability in the the MultiHeadSelfAttention module
- Variable Bounding as configurable model layers [#13](https://github.com/ecmwf/anemoi-models/issues/13)
- GraphTransformerMapperBlock chunking to reduce memory usage during inference [#46](https://github.com/ecmwf/anemoi-models/pull/46)
- New `NamedNodesAttributes` class to handle node attributes in a more flexible way [#64](https://github.com/ecmwf/anemoi-models/pull/64)

### Changed
- Bugfixes for CI
Expand Down
42 changes: 41 additions & 1 deletion src/anemoi/models/layers/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import torch
from torch import Tensor
from torch import nn
from torch_geometric.data import HeteroData


class TrainableTensor(nn.Module):
Expand All @@ -36,8 +37,47 @@ def __init__(self, tensor_size: int, trainable_size: int) -> None:
def forward(self, x: Tensor, batch_size: int) -> Tensor:
latent = [einops.repeat(x, "e f -> (repeat e) f", repeat=batch_size)]
if self.trainable is not None:
latent.append(einops.repeat(self.trainable, "e f -> (repeat e) f", repeat=batch_size))
latent.append(einops.repeat(self.trainable.to(x.device), "e f -> (repeat e) f", repeat=batch_size))
return torch.cat(
latent,
dim=-1, # feature dimension
)


class NamedNodesAttributes(torch.nn.Module):
"""Named Node Attributes Module."""

def __init__(self, num_trainable_params: int, graph_data: HeteroData) -> None:
JPXKQX marked this conversation as resolved.
Show resolved Hide resolved
"""Initialize NamedNodesAttributes."""
super().__init__()

self.num_trainable_params = num_trainable_params
self.register_fixed_attributes(graph_data)

self.trainable_tensors = nn.ModuleDict()
for nodes_name in self.nodes_names:
self.register_coordinates(nodes_name, graph_data[nodes_name].x)
self.register_tensor(nodes_name)

def register_fixed_attributes(self, graph_data: HeteroData) -> None:
"""Register fixed attributes."""
self.nodes_names = list(graph_data.node_types)
self.num_nodes = {nodes_name: graph_data[nodes_name].num_nodes for nodes_name in self.nodes_names}
self.coord_dims = {nodes_name: 2 * graph_data[nodes_name].x.shape[1] for nodes_name in self.nodes_names}
self.attr_ndims = {
nodes_name: self.coord_dims[nodes_name] + self.num_trainable_params for nodes_name in self.nodes_names
}

def register_coordinates(self, name: str, node_coords: torch.Tensor) -> None:
"""Register coordinates."""
sin_cos_coords = torch.cat([torch.sin(node_coords), torch.cos(node_coords)], dim=-1)
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)

def register_tensor(self, name: str) -> None:
"""Register a trainable tensor."""
self.trainable_tensors[name] = TrainableTensor(self.num_nodes[name], self.num_trainable_params)

def forward(self, name: str, batch_size: int) -> Tensor:
"""Forward pass."""
latlons = getattr(self, f"latlons_{name}")
return self.trainable_tensors[name](latlons, batch_size)
65 changes: 14 additions & 51 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch_geometric.data import HeteroData

from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.layers.graph import NamedNodesAttributes

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -56,42 +56,33 @@ def __init__(

self._calculate_shapes_and_indices(data_indices)
self._assert_matching_indices(data_indices)

self.multi_step = model_config.training.multistep_input

self._define_tensor_sizes(model_config)

# Create trainable tensors
self._create_trainable_attributes()

# Register lat/lon of nodes
self._register_latlon("data", self._graph_name_data)
self._register_latlon("hidden", self._graph_name_hidden)

self.data_indices = data_indices

self.multi_step = model_config.training.multistep_input
self.num_channels = model_config.model.num_channels

input_dim = self.multi_step * self.num_input_channels + self.latlons_data.shape[1] + self.trainable_data_size
self.node_attributes = NamedNodesAttributes(model_config.model.trainable_parameters.hidden, self._graph_data)

input_dim = self.multi_step * self.num_input_channels + self.node_attributes.attr_ndims[self._graph_name_data]

# Encoder data -> hidden
self.encoder = instantiate(
model_config.model.encoder,
in_channels_src=input_dim,
in_channels_dst=self.latlons_hidden.shape[1] + self.trainable_hidden_size,
in_channels_dst=self.node_attributes.attr_ndims[self._graph_name_hidden],
hidden_dim=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_data, "to", self._graph_name_hidden)],
src_grid_size=self._data_grid_size,
dst_grid_size=self._hidden_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

# Processor hidden -> hidden
self.processor = instantiate(
model_config.model.processor,
num_channels=self.num_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_hidden)],
src_grid_size=self._hidden_grid_size,
dst_grid_size=self._hidden_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
)

# Decoder hidden -> data
Expand All @@ -102,8 +93,8 @@ def __init__(
hidden_dim=self.num_channels,
out_channels_dst=self.num_output_channels,
sub_graph=self._graph_data[(self._graph_name_hidden, "to", self._graph_name_data)],
src_grid_size=self._hidden_grid_size,
dst_grid_size=self._data_grid_size,
src_grid_size=self.node_attributes.num_nodes[self._graph_name_hidden],
dst_grid_size=self.node_attributes.num_nodes[self._graph_name_data],
)

# Instantiation of model output bounding functions (e.g., to ensure outputs like TP are positive definite)
Expand Down Expand Up @@ -133,34 +124,6 @@ def _assert_matching_indices(self, data_indices: dict) -> None:
self._internal_output_idx,
), f"Internal model indices must match {self._internal_input_idx} != {self._internal_output_idx}"

def _define_tensor_sizes(self, config: DotDict) -> None:
self._data_grid_size = self._graph_data[self._graph_name_data].num_nodes
self._hidden_grid_size = self._graph_data[self._graph_name_hidden].num_nodes

self.trainable_data_size = config.model.trainable_parameters.data
self.trainable_hidden_size = config.model.trainable_parameters.hidden

def _register_latlon(self, name: str, nodes: str) -> None:
"""Register lat/lon buffers.

Parameters
----------
name : str
Name to store the lat-lon coordinates of the nodes.
nodes : str
Name of nodes to map
"""
coords = self._graph_data[nodes].x
sin_cos_coords = torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1)
self.register_buffer(f"latlons_{name}", sin_cos_coords, persistent=True)

def _create_trainable_attributes(self) -> None:
"""Create all trainable attributes."""
self.trainable_data = TrainableTensor(trainable_size=self.trainable_data_size, tensor_size=self._data_grid_size)
self.trainable_hidden = TrainableTensor(
trainable_size=self.trainable_hidden_size, tensor_size=self._hidden_grid_size
)

def _run_mapper(
self,
mapper: nn.Module,
Expand Down Expand Up @@ -210,12 +173,12 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
x_data_latent = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
self.trainable_data(self.latlons_data, batch_size=batch_size),
self.node_attributes(self._graph_name_data, batch_size=batch_size),
),
dim=-1, # feature dimension
)

x_hidden_latent = self.trainable_hidden(self.latlons_hidden, batch_size=batch_size)
x_hidden_latent = self.node_attributes(self._graph_name_hidden, batch_size=batch_size)

# get shard shapes
shard_shapes_data = get_shape_shards(x_data_latent, 0, model_comm_group)
Expand Down
80 changes: 80 additions & 0 deletions tests/layers/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,14 @@
# nor does it submit to any jurisdiction.


import einops
import numpy as np
import pytest
import torch
from torch import nn
from torch_geometric.data import HeteroData

from anemoi.models.layers.graph import NamedNodesAttributes
from anemoi.models.layers.graph import TrainableTensor


Expand Down Expand Up @@ -62,3 +66,79 @@ def test_forward_no_trainable(self, init, x):
batch_size = 5
output = trainable_tensor(x, batch_size)
assert output.shape == (batch_size * x.shape[0], tensor_size + trainable_size)


class TestNamedNodesAttributes:
"""Test suite for the NamedNodesAttributes class.

This class contains test cases to verify the functionality of the NamedNodesAttributes class,
including initialization, attribute registration, and forward pass operations.
"""

nodes_names: list[str] = ["nodes1", "nodes2"]
ndim: int = 2
num_trainable_params: int = 8

@pytest.fixture
def graph_data(self):
graph = HeteroData()
for i, nodes_name in enumerate(TestNamedNodesAttributes.nodes_names):
graph[nodes_name].x = TestNamedNodesAttributes.get_n_random_coords(10 + 5 ** (i + 1))
return graph

@staticmethod
def get_n_random_coords(n: int) -> torch.Tensor:
coords = torch.rand(n, TestNamedNodesAttributes.ndim)
coords[:, 0] = np.pi * (coords[:, 0] - 1 / 2)
coords[:, 1] = 2 * np.pi * coords[:, 1]
return coords

@pytest.fixture
def nodes_attributes(self, graph_data: HeteroData) -> NamedNodesAttributes:
return NamedNodesAttributes(TestNamedNodesAttributes.num_trainable_params, graph_data)

def test_init(self, nodes_attributes):
assert isinstance(nodes_attributes, NamedNodesAttributes)

for nodes_name in self.nodes_names:
assert isinstance(nodes_attributes.num_nodes[nodes_name], int)
assert nodes_attributes.coord_dims[nodes_name] == 2 * TestNamedNodesAttributes.ndim
assert (
nodes_attributes.attr_ndims[nodes_name] - 2 * TestNamedNodesAttributes.ndim
== TestNamedNodesAttributes.num_trainable_params
)
assert isinstance(nodes_attributes.trainable_tensors[nodes_name], TrainableTensor)

def test_forward(self, nodes_attributes, graph_data):
batch_size = 3
for nodes_name in self.nodes_names:
output = nodes_attributes(nodes_name, batch_size)

expected_shape = (
batch_size * graph_data[nodes_name].num_nodes,
2 * TestNamedNodesAttributes.ndim + TestNamedNodesAttributes.num_trainable_params,
)
assert output.shape == expected_shape

# Check if the first part of the output matches the sin-cos transformed coordinates
latlons = getattr(nodes_attributes, f"latlons_{nodes_name}")
repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size)
assert torch.allclose(output[:, : 2 * TestNamedNodesAttributes.ndim], repeated_latlons)

# Check if the last part of the output is trainable (requires grad)
assert output[:, 2 * TestNamedNodesAttributes.ndim :].requires_grad

def test_forward_no_trainable(self, graph_data):
no_trainable_attributes = NamedNodesAttributes(0, graph_data)
batch_size = 2

for nodes_name in self.nodes_names:
output = no_trainable_attributes(nodes_name, batch_size)

expected_shape = batch_size * graph_data[nodes_name].num_nodes, 2 * TestNamedNodesAttributes.ndim
assert output.shape == expected_shape

# Check if the output exactly matches the sin-cos transformed coordinates
latlons = getattr(no_trainable_attributes, f"latlons_{nodes_name}")
repeated_latlons = einops.repeat(latlons, "n f -> (b n) f", b=batch_size)
assert torch.allclose(output, repeated_latlons)
Loading