diff --git a/src/anemoi/models/graph/__init__.py b/src/anemoi/models/graph/__init__.py index 8144a91..a2e0a5e 100644 --- a/src/anemoi/models/graph/__init__.py +++ b/src/anemoi/models/graph/__init__.py @@ -1,15 +1,21 @@ from anemoi.utils.config import DotDict + class AnemoiGraphSchema: def __init__(self, graph_data: dict, config: DotDict) -> None: self.hidden_name = config.graphs.hidden_mesh.name self.mesh_names = [name for name in graph_data if isinstance(name, str)] - self.input_meshes = [k[0] for k in graph_data if isinstance(k, tuple) and k[2] == self.hidden_name and k[2] != k[0]] - self.output_meshes = [k[2] for k in graph_data if isinstance(k, tuple) and k[0] == self.hidden_name and k[2] != k[0]] + self.input_meshes = [ + k[0] for k in graph_data if isinstance(k, tuple) and k[2] == self.hidden_name and k[2] != k[0] + ] + self.output_meshes = [ + k[2] for k in graph_data if isinstance(k, tuple) and k[0] == self.hidden_name and k[2] != k[0] + ] self.num_nodes = {name: graph_data[name]["coords"].shape[0] for name in self.mesh_names} self.num_node_features = {name: 2 * graph_data[name]["coords"].shape[1] for name in self.mesh_names} self.num_trainable_params = { - name: config.model.trainable_parameters["data" if name != "hidden" else name] for name in self.graph.mesh_names + name: config.model.trainable_parameters["data" if name != "hidden" else name] + for name in self.graph.mesh_names } def get_node_emb_size(self, name: str) -> int: diff --git a/src/anemoi/models/models/encoder_processor_decoder.py b/src/anemoi/models/models/encoder_processor_decoder.py index a4491c9..127c425 100644 --- a/src/anemoi/models/models/encoder_processor_decoder.py +++ b/src/anemoi/models/models/encoder_processor_decoder.py @@ -21,8 +21,8 @@ from anemoi.models.data_indices.collection import IndexCollection from anemoi.models.distributed.shapes import get_shape_shards -from anemoi.models.layers.graph import TrainableTensor from anemoi.models.graph import AnemoiGraphSchema +from anemoi.models.layers.graph import TrainableTensor LOGGER = logging.getLogger(__name__) @@ -196,7 +196,9 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) -> dim=-1, # feature dimension ) - x_hidden_latent = self.trainable_tensors[self.graph.hidden_name](getattr(self, f"latlons_{self.graph.hidden_name}"), batch_size=batch_size) + x_hidden_latent = self.trainable_tensors[self.graph.hidden_name]( + getattr(self, f"latlons_{self.graph.hidden_name}"), batch_size=batch_size + ) # get shard shapes shard_shapes_data = {name: get_shape_shards(data, 0, model_comm_group) for name, data in x_data_latent.items()}