diff --git a/src/anemoi/models/graph/__init__.py b/src/anemoi/models/graph/__init__.py new file mode 100644 index 0000000..8144a91 --- /dev/null +++ b/src/anemoi/models/graph/__init__.py @@ -0,0 +1,16 @@ +from anemoi.utils.config import DotDict + +class AnemoiGraphSchema: + def __init__(self, graph_data: dict, config: DotDict) -> None: + self.hidden_name = config.graphs.hidden_mesh.name + self.mesh_names = [name for name in graph_data if isinstance(name, str)] + self.input_meshes = [k[0] for k in graph_data if isinstance(k, tuple) and k[2] == self.hidden_name and k[2] != k[0]] + self.output_meshes = [k[2] for k in graph_data if isinstance(k, tuple) and k[0] == self.hidden_name and k[2] != k[0]] + self.num_nodes = {name: graph_data[name]["coords"].shape[0] for name in self.mesh_names} + self.num_node_features = {name: 2 * graph_data[name]["coords"].shape[1] for name in self.mesh_names} + self.num_trainable_params = { + name: config.model.trainable_parameters["data" if name != "hidden" else name] for name in self.graph.mesh_names + } + + def get_node_emb_size(self, name: str) -> int: + return self.num_node_features[name] + self.num_trainable_params[name] diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index 8203bd8..a4491c9 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -22,6 +22,7 @@ from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.distributed.shapes import get_shape_shards from anemoi.models.layers.graph import TrainableTensor +from anemoi.models.graph import AnemoiGraphSchema LOGGER = logging.getLogger(__name__) @@ -49,67 +50,54 @@ def __init__( """ super().__init__() - self._graph_name_hidden = config.graphs.hidden_mesh.name - self._graph_mesh_names = [name for name in graph_data if isinstance(name, str)] - self._graph_input_meshes = [ - k[0] for k in graph_data if isinstance(k, tuple) and k[2] == self._graph_name_hidden and k[2] != k[0] - ] - self._graph_output_meshes = [ - k[2] for k in graph_data if isinstance(k, tuple) and k[0] == self._graph_name_hidden and k[2] != k[0] - ] + self.graph = AnemoiGraphSchema(graph_data, config) + self.num_channels = config.model.num_channels + self.multi_step = config.training.multistep_input self._calculate_shapes_and_indices(data_indices) self._assert_matching_indices(data_indices) - - self.multi_step = config.training.multistep_input - - self._define_tensor_sizes(config, graph_data) - self._create_trainable_attributes() # Register lat/lon - for mesh_key in self._graph_mesh_names: + for mesh_key in self.graph.mesh_names: self._register_latlon(mesh_key, graph_data[mesh_key]["coords"]) - self.num_channels = config.model.num_channels - input_dim = self.multi_step * self.num_input_channels # Encoder data -> hidden self.encoders = nn.ModuleDict() - for in_mesh in self._graph_input_meshes: + for in_mesh in self.graph.input_meshes: self.encoders[in_mesh] = instantiate( config.model.encoder, - in_channels_src=input_dim + self.num_node_features[in_mesh] + self.num_trainable_params[in_mesh], - in_channels_dst=self.num_node_features[self._graph_name_hidden] - + self.num_trainable_params[self._graph_name_hidden], + in_channels_src=input_dim + self.graph.get_node_emb_size(in_mesh), + in_channels_dst=self.graph.get_node_emb_size(self.graph.hidden_name), hidden_dim=self.num_channels, - sub_graph=graph_data[(in_mesh, "to", self._graph_name_hidden)], - src_grid_size=self.num_nodes[in_mesh], - dst_grid_size=self.num_nodes[self._graph_name_hidden], + sub_graph=graph_data[(in_mesh, "to", self.graph.hidden_name)], + src_grid_size=self.graph.num_nodes[in_mesh], + dst_grid_size=self.graph.num_nodes[self.graph.hidden_name], ) # Processor hidden -> hidden self.processor = instantiate( config.model.processor, num_channels=self.num_channels, - sub_graph=graph_data.get((self._graph_name_hidden, "to", self._graph_name_hidden), None), - src_grid_size=self.num_nodes[self._graph_name_hidden], - dst_grid_size=self.num_nodes[self._graph_name_hidden], + sub_graph=graph_data.get((self.graph.hidden_name, "to", self.graph.hidden_name), None), + src_grid_size=self.num_nodes[self.graph.hidden_name], + dst_grid_size=self.num_nodes[self.graph.hidden_name], ) # Decoder hidden -> data self.decoders = nn.ModuleDict() - for out_mesh in self._graph_output_meshes: + for out_mesh in self.graph.output_meshes: self.decoders[out_mesh] = instantiate( config.model.decoder, in_channels_src=self.num_channels, - in_channels_dst=input_dim + self.num_node_features[out_mesh] + self.num_trainable_params[out_mesh], + in_channels_dst=input_dim + self.graph.get_node_emb_size(out_mesh), hidden_dim=self.num_channels, out_channels_dst=self.num_output_channels, - sub_graph=graph_data[(self._graph_name_hidden, "to", out_mesh)], - src_grid_size=self.num_nodes[self._graph_name_hidden], - dst_grid_size=self.num_nodes[out_mesh], + sub_graph=graph_data[(self.graph.hidden_name, "to", out_mesh)], + src_grid_size=self.graph.num_nodes[self.graph.hidden_name], + dst_grid_size=self.graph.num_nodes[out_mesh], ) def _calculate_shapes_and_indices(self, data_indices: dict) -> None: @@ -130,21 +118,12 @@ def _assert_matching_indices(self, data_indices: dict) -> None: self._internal_output_idx, ), f"Model indices must match {self._internal_input_idx} != {self._internal_output_idx}" - def _define_tensor_sizes(self, config: DotDict, graph_data: dict) -> None: - # Define Sizes of different tensors - self.num_nodes = {name: graph_data[name]["coords"].shape[0] for name in self._graph_mesh_names} - self.num_node_features = {name: 2 * graph_data[name]["coords"].shape[1] for name in self._graph_mesh_names} - self.num_trainable_params = { - name: config.model.trainable_parameters["data" if name != "hidden" else name] - for name in self._graph_mesh_names - } - def _create_trainable_attributes(self) -> None: """Create all trainable attributes.""" self.trainable_tensors = nn.ModuleDict() - for mesh in self._graph_mesh_names: + for mesh in self.graph.mesh_names: self.trainable_tensors[mesh] = TrainableTensor( - trainable_size=self.num_trainable_params[mesh], tensor_size=self.num_nodes[mesh] + trainable_size=self.graph.num_trainable_params[mesh], tensor_size=self.graph.num_nodes[mesh] ) def _register_latlon(self, name: str, coords: torch.Tensor) -> None: @@ -208,7 +187,7 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> # add data positional info (lat/lon) x_data_latent = {} - for in_mesh in self._graph_input_meshes: + for in_mesh in self.graph.input_meshes: x_data_latent[in_mesh] = torch.cat( ( einops.rearrange(x, "batch time ensemble grid vars -> (batch ensemble grid) (time vars)"), @@ -217,7 +196,7 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> dim=-1, # feature dimension ) - x_hidden_latent = self.trainable_tensors[self._graph_name_hidden](self.latlons_hidden, batch_size=batch_size) + x_hidden_latent = self.trainable_tensors[self.graph.hidden_name](getattr(self, f"latlons_{self.graph.hidden_name}"), batch_size=batch_size) # get shard shapes shard_shapes_data = {name: get_shape_shards(data, 0, model_comm_group) for name, data in x_data_latent.items()} @@ -270,8 +249,8 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> .clone() ) - if out_data_name in self._graph_input_meshes: # check if the mesh is in the input meshes + if out_data_name in self.graph.input_meshes: # check if the mesh is in the input meshes # residual connection (just for the prognostic variables) x_out[out_data_name][..., self._internal_output_idx] += x[:, -1, :, :, self._internal_input_idx] - return x_out[self._graph_output_meshes[0]] + return x_out[self.graph.output_meshes[0]]