Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support anemoi.graphs #2

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ dependencies = [
"torch-geometric==2.4",
"einops==0.6.1",
"hydra-core==1.3",
"anemoi-datasets==0.2.1",
"anemoi-utils==0.1.9",
"anemoi-utils>=0.1.9",
]

[project.optional-dependencies]
Expand Down
6 changes: 3 additions & 3 deletions src/anemoi/models/interface/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import uuid

import torch
from anemoi.utils.config import DotConfig
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from torch_geometric.data import HeteroData

from anemoi.models.data_indices.collection import IndexCollection
from anemoi.models.models.encoder_processor_decoder import AnemoiModelEncProcDec
from anemoi.models.preprocessing import Processors

Expand All @@ -22,7 +22,7 @@ class AnemoiModelInterface(torch.nn.Module):
"""Anemoi model on torch level."""

def __init__(
self, *, config: DotConfig, graph_data: HeteroData, statistics: dict, data_indices: dict, metadata: dict
self, *, config: DotDict, graph_data: dict, statistics: dict, data_indices: IndexCollection, metadata: dict
) -> None:
super().__init__()
self.config = config
Expand Down
23 changes: 11 additions & 12 deletions src/anemoi/models/layers/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from torch import nn
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
from torch.distributed.distributed_c10d import ProcessGroup
from torch_geometric.data import HeteroData
from torch_geometric.typing import Adj
from torch_geometric.typing import PairTensor

Expand Down Expand Up @@ -113,12 +112,12 @@ def pre_process(self, x, shard_shapes, model_comm_group=None):


class GraphEdgeMixin:
def _register_edges(self, sub_graph: HeteroData, src_size: int, dst_size: int, trainable_size: int) -> None:
def _register_edges(self, sub_graph: dict, src_size: int, dst_size: int, trainable_size: int) -> None:
"""Register edge dim, attr, index_base, and increment.

Parameters
----------
sub_graph : HeteroData
sub_graph : dict
Sub graph of the full structure
src_size : int
Source size
Expand All @@ -127,9 +126,9 @@ def _register_edges(self, sub_graph: HeteroData, src_size: int, dst_size: int, t
trainable_size : int
Trainable tensor size
"""
self.edge_dim = sub_graph.edge_attr.shape[1] + trainable_size
self.register_buffer("edge_attr", sub_graph.edge_attr, persistent=False)
self.register_buffer("edge_index_base", sub_graph.edge_index, persistent=False)
self.edge_dim = sub_graph["edge_attr"].shape[1] + trainable_size
self.register_buffer("edge_attr", sub_graph["edge_attr"], persistent=False)
self.register_buffer("edge_index_base", sub_graph["edge_index"], persistent=False)
self.register_buffer(
"edge_inc", torch.from_numpy(np.asarray([[src_size], [dst_size]], dtype=np.int64)), persistent=True
)
Expand Down Expand Up @@ -173,7 +172,7 @@ def __init__(
activation: str = "GELU",
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -273,7 +272,7 @@ def __init__(
activation: str = "GELU",
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -344,7 +343,7 @@ def __init__(
activation: str = "GELU",
num_heads: int = 16,
mlp_hidden_ratio: int = 4,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -414,7 +413,7 @@ def __init__(
cpu_offload: bool = False,
activation: str = "SiLU",
mlp_extra_layers: int = 0,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -517,7 +516,7 @@ def __init__(
cpu_offload: bool = False,
activation: str = "SiLU",
mlp_extra_layers: int = 0,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down Expand Up @@ -601,7 +600,7 @@ def __init__(
cpu_offload: bool = False,
activation: str = "SiLU",
mlp_extra_layers: int = 0,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
) -> None:
Expand Down
5 changes: 2 additions & 3 deletions src/anemoi/models/layers/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
from torch.distributed.distributed_c10d import ProcessGroup
from torch.utils.checkpoint import checkpoint
from torch_geometric.data import HeteroData

from anemoi.models.distributed.graph import shard_tensor
from anemoi.models.distributed.khop_edges import sort_edges_1hop
Expand Down Expand Up @@ -170,7 +169,7 @@ def __init__(
mlp_extra_layers: int = 0,
activation: str = "SiLU",
cpu_offload: bool = False,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
**kwargs,
Expand Down Expand Up @@ -257,7 +256,7 @@ def __init__(
mlp_hidden_ratio: int = 4,
activation: str = "GELU",
cpu_offload: bool = False,
sub_graph: Optional[HeteroData] = None,
sub_graph: Optional[dict] = None,
src_grid_size: int = 0,
dst_grid_size: int = 0,
**kwargs,
Expand Down
Loading
Loading