Skip to content

Commit

Permalink
GraphSchema interface create to include graph features
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jun 7, 2024
1 parent dc0dc7d commit 3b988f5
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 46 deletions.
16 changes: 16 additions & 0 deletions src/anemoi/models/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from anemoi.utils.config import DotDict

class AnemoiGraphSchema:
def __init__(self, graph_data: dict, config: DotDict) -> None:
self.hidden_name = config.graphs.hidden_mesh.name
self.mesh_names = [name for name in graph_data if isinstance(name, str)]
self.input_meshes = [k[0] for k in graph_data if isinstance(k, tuple) and k[2] == self.hidden_name and k[2] != k[0]]
self.output_meshes = [k[2] for k in graph_data if isinstance(k, tuple) and k[0] == self.hidden_name and k[2] != k[0]]
self.num_nodes = {name: graph_data[name]["coords"].shape[0] for name in self.mesh_names}
self.num_node_features = {name: 2 * graph_data[name]["coords"].shape[1] for name in self.mesh_names}
self.num_trainable_params = {
name: config.model.trainable_parameters["data" if name != "hidden" else name] for name in self.graph.mesh_names
}

def get_node_emb_size(self, name: str) -> int:
return self.num_node_features[name] + self.num_trainable_params[name]
71 changes: 25 additions & 46 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.graph import AnemoiGraphSchema

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -49,67 +50,54 @@ def __init__(
"""
super().__init__()

self._graph_name_hidden = config.graphs.hidden_mesh.name
self._graph_mesh_names = [name for name in graph_data if isinstance(name, str)]
self._graph_input_meshes = [
k[0] for k in graph_data if isinstance(k, tuple) and k[2] == self._graph_name_hidden and k[2] != k[0]
]
self._graph_output_meshes = [
k[2] for k in graph_data if isinstance(k, tuple) and k[0] == self._graph_name_hidden and k[2] != k[0]
]
self.graph = AnemoiGraphSchema(graph_data, config)
self.num_channels = config.model.num_channels
self.multi_step = config.training.multistep_input

self._calculate_shapes_and_indices(data_indices)
self._assert_matching_indices(data_indices)

self.multi_step = config.training.multistep_input

self._define_tensor_sizes(config, graph_data)

self._create_trainable_attributes()

# Register lat/lon
for mesh_key in self._graph_mesh_names:
for mesh_key in self.graph.mesh_names:
self._register_latlon(mesh_key, graph_data[mesh_key]["coords"])

self.num_channels = config.model.num_channels

input_dim = self.multi_step * self.num_input_channels

# Encoder data -> hidden
self.encoders = nn.ModuleDict()
for in_mesh in self._graph_input_meshes:
for in_mesh in self.graph.input_meshes:
self.encoders[in_mesh] = instantiate(
config.model.encoder,
in_channels_src=input_dim + self.num_node_features[in_mesh] + self.num_trainable_params[in_mesh],
in_channels_dst=self.num_node_features[self._graph_name_hidden]
+ self.num_trainable_params[self._graph_name_hidden],
in_channels_src=input_dim + self.graph.get_node_emb_size(in_mesh),
in_channels_dst=self.graph.get_node_emb_size(self.graph.hidden_name),
hidden_dim=self.num_channels,
sub_graph=graph_data[(in_mesh, "to", self._graph_name_hidden)],
src_grid_size=self.num_nodes[in_mesh],
dst_grid_size=self.num_nodes[self._graph_name_hidden],
sub_graph=graph_data[(in_mesh, "to", self.graph.hidden_name)],
src_grid_size=self.graph.num_nodes[in_mesh],
dst_grid_size=self.graph.num_nodes[self.graph.hidden_name],
)

# Processor hidden -> hidden
self.processor = instantiate(
config.model.processor,
num_channels=self.num_channels,
sub_graph=graph_data.get((self._graph_name_hidden, "to", self._graph_name_hidden), None),
src_grid_size=self.num_nodes[self._graph_name_hidden],
dst_grid_size=self.num_nodes[self._graph_name_hidden],
sub_graph=graph_data.get((self.graph.hidden_name, "to", self.graph.hidden_name), None),
src_grid_size=self.num_nodes[self.graph.hidden_name],
dst_grid_size=self.num_nodes[self.graph.hidden_name],
)

# Decoder hidden -> data
self.decoders = nn.ModuleDict()
for out_mesh in self._graph_output_meshes:
for out_mesh in self.graph.output_meshes:
self.decoders[out_mesh] = instantiate(
config.model.decoder,
in_channels_src=self.num_channels,
in_channels_dst=input_dim + self.num_node_features[out_mesh] + self.num_trainable_params[out_mesh],
in_channels_dst=input_dim + self.graph.get_node_emb_size(out_mesh),
hidden_dim=self.num_channels,
out_channels_dst=self.num_output_channels,
sub_graph=graph_data[(self._graph_name_hidden, "to", out_mesh)],
src_grid_size=self.num_nodes[self._graph_name_hidden],
dst_grid_size=self.num_nodes[out_mesh],
sub_graph=graph_data[(self.graph.hidden_name, "to", out_mesh)],
src_grid_size=self.graph.num_nodes[self.graph.hidden_name],
dst_grid_size=self.graph.num_nodes[out_mesh],
)

def _calculate_shapes_and_indices(self, data_indices: dict) -> None:
Expand All @@ -130,21 +118,12 @@ def _assert_matching_indices(self, data_indices: dict) -> None:
self._internal_output_idx,
), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}"

def _define_tensor_sizes(self, config: DotDict, graph_data: dict) -> None:
# Define Sizes of different tensors
self.num_nodes = {name: graph_data[name]["coords"].shape[0] for name in self._graph_mesh_names}
self.num_node_features = {name: 2 * graph_data[name]["coords"].shape[1] for name in self._graph_mesh_names}
self.num_trainable_params = {
name: config.model.trainable_parameters["data" if name != "hidden" else name]
for name in self._graph_mesh_names
}

def _create_trainable_attributes(self) -> None:
"""Create all trainable attributes."""
self.trainable_tensors = nn.ModuleDict()
for mesh in self._graph_mesh_names:
for mesh in self.graph.mesh_names:
self.trainable_tensors[mesh] = TrainableTensor(
trainable_size=self.num_trainable_params[mesh], tensor_size=self.num_nodes[mesh]
trainable_size=self.graph.num_trainable_params[mesh], tensor_size=self.graph.num_nodes[mesh]
)

def _register_latlon(self, name: str, coords: torch.Tensor) -> None:
Expand Down Expand Up @@ -208,7 +187,7 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->

# add data positional info (lat/lon)
x_data_latent = {}
for in_mesh in self._graph_input_meshes:
for in_mesh in self.graph.input_meshes:
x_data_latent[in_mesh] = torch.cat(
(
einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"),
Expand All @@ -217,7 +196,7 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
dim=-1, # feature dimension
)

x_hidden_latent = self.trainable_tensors[self._graph_name_hidden](self.latlons_hidden, batch_size=batch_size)
x_hidden_latent = self.trainable_tensors[self.graph.hidden_name](getattr(self, f"latlons_{self.graph.hidden_name}"), batch_size=batch_size)

# get shard shapes
shard_shapes_data = {name: get_shape_shards(data, 0, model_comm_group) for name, data in x_data_latent.items()}
Expand Down Expand Up @@ -270,8 +249,8 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
.clone()
)

if out_data_name in self._graph_input_meshes: # check if the mesh is in the input meshes
if out_data_name in self.graph.input_meshes: # check if the mesh is in the input meshes
# residual connection (just for the prognostic variables)
x_out[out_data_name][..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx]

return x_out[self._graph_output_meshes[0]]
return x_out[self.graph.output_meshes[0]]

0 comments on commit 3b988f5

Please sign in to comment.