Skip to content

Commit

Permalink
fix: run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jun 7, 2024
1 parent 3b988f5 commit c352167
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
12 changes: 9 additions & 3 deletions src/anemoi/models/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,21 @@
from anemoi.utils.config import DotDict


class AnemoiGraphSchema:
def __init__(self, graph_data: dict, config: DotDict) -> None:
self.hidden_name = config.graphs.hidden_mesh.name
self.mesh_names = [name for name in graph_data if isinstance(name, str)]
self.input_meshes = [k[0] for k in graph_data if isinstance(k, tuple) and k[2] == self.hidden_name and k[2] != k[0]]
self.output_meshes = [k[2] for k in graph_data if isinstance(k, tuple) and k[0] == self.hidden_name and k[2] != k[0]]
self.input_meshes = [
k[0] for k in graph_data if isinstance(k, tuple) and k[2] == self.hidden_name and k[2] != k[0]
]
self.output_meshes = [
k[2] for k in graph_data if isinstance(k, tuple) and k[0] == self.hidden_name and k[2] != k[0]
]
self.num_nodes = {name: graph_data[name]["coords"].shape[0] for name in self.mesh_names}
self.num_node_features = {name: 2 * graph_data[name]["coords"].shape[1] for name in self.mesh_names}
self.num_trainable_params = {
name: config.model.trainable_parameters["data" if name != "hidden" else name] for name in self.graph.mesh_names
name: config.model.trainable_parameters["data" if name != "hidden" else name]
for name in self.graph.mesh_names
}

def get_node_emb_size(self, name: str) -> int:
Expand Down
6 changes: 4 additions & 2 deletions src/anemoi/models/models/encoder_processor_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.distributed.shapes import get_shape_shards
from anemoi.models.layers.graph import TrainableTensor
from anemoi.models.graph import AnemoiGraphSchema
from anemoi.models.layers.graph import TrainableTensor

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -196,7 +196,9 @@ def forward(self, x: Tensor, model_comm_group: Optional[ProcessGroup] = None) ->
dim=-1, # feature dimension
)

x_hidden_latent = self.trainable_tensors[self.graph.hidden_name](getattr(self, f"latlons_{self.graph.hidden_name}"), batch_size=batch_size)
x_hidden_latent = self.trainable_tensors[self.graph.hidden_name](
getattr(self, f"latlons_{self.graph.hidden_name}"), batch_size=batch_size
)

# get shard shapes
shard_shapes_data = {name: get_shape_shards(data, 0, model_comm_group) for name, data in x_data_latent.items()}
Expand Down

0 comments on commit c352167

Please sign in to comment.