diff --git a/src/anemoi/graphs/create.py b/src/anemoi/graphs/create.py index 7a1d2a9..53861a9 100644 --- a/src/anemoi/graphs/create.py +++ b/src/anemoi/graphs/create.py @@ -9,31 +9,9 @@ logger = logging.getLogger(__name__) -def generate_graph(graph_config: DotDict) -> HeteroData: - """Generate a graph from a configuration. - - Parameters - ---------- - graph_config : DotDict - Configuration for the nodes and edges (and its attributes). - - Returns - ------- - HeteroData - Graph. - """ - graph = HeteroData() - - for name, nodes_cfg in graph_config.nodes.items(): - graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {})) - - for edges_cfg in graph_config.edges: - graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {})) - - return graph - - class GraphCreator: + """Graph creator.""" + def __init__( self, path, @@ -55,9 +33,18 @@ def init(self): if self._path_readable() and not self.overwrite: raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.") - def load(self) -> HeteroData: + def generate_graph(self) -> HeteroData: config = DotDict.from_file(self.config) - graph = generate_graph(config) + + graph = HeteroData() + for name, nodes_cfg in config.nodes.items(): + graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {})) + + for edges_cfg in config.edges: + graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform( + graph, edges_cfg.get("attributes", {}) + ) + return graph def save(self, graph: HeteroData) -> None: @@ -67,7 +54,7 @@ def save(self, graph: HeteroData) -> None: def create(self): self.init() - graph = self.load() + graph = self.generate_graph() self.save(graph) def _path_readable(self) -> bool: diff --git a/src/anemoi/graphs/edges/__init__.py b/src/anemoi/graphs/edges/__init__.py index 29875d0..edd07db 100644 --- a/src/anemoi/graphs/edges/__init__.py +++ b/src/anemoi/graphs/edges/__init__.py @@ -1,4 +1,4 @@ -from .connections import CutOffEdgeBuilder -from .connections import KNNEdgeBuilder +from .builder import CutOffEdgeBuilder +from .builder import KNNEdgeBuilder __all__ = ["KNNEdgeBuilder", "CutOffEdgeBuilder"] diff --git a/src/anemoi/graphs/edges/attributes.py b/src/anemoi/graphs/edges/attributes.py index 9e7509f..47787b3 100644 --- a/src/anemoi/graphs/edges/attributes.py +++ b/src/anemoi/graphs/edges/attributes.py @@ -6,8 +6,6 @@ import numpy as np import torch -from scipy.sparse import coo_matrix -from sklearn.preprocessing import normalize from torch_geometric.data import HeteroData from anemoi.graphs.edges.directional import directional_edge_features @@ -51,13 +49,13 @@ def compute(self, graph: HeteroData, src_name: str, dst_name: str) -> torch.Tens @dataclass -class HaversineDistance(BaseEdgeAttribute): +class EdgeLength(BaseEdgeAttribute): """Edge length feature.""" norm: str = "l1" invert: bool = True - def compute(self, graph: HeteroData, src_name: str, dst_name: str): + def compute(self, graph: HeteroData, src_name: str, dst_name: str) -> np.ndarray: """Compute haversine distance (in kilometers) between nodes connected by edges.""" assert src_name in graph.node_types, f"Node {src_name} not found in graph." assert dst_name in graph.node_types, f"Node {dst_name} not found in graph." @@ -65,15 +63,7 @@ def compute(self, graph: HeteroData, src_name: str, dst_name: str): src_coords = graph[src_name].x.numpy()[edge_index[0]] dst_coords = graph[dst_name].x.numpy()[edge_index[1]] edge_lengths = haversine_distance(src_coords, dst_coords) - return coo_matrix((edge_lengths, (edge_index[1], edge_index[0]))) - - def normalize(self, values) -> np.ndarray: - """Normalize the edge length. - - This method scales the edge lengths to a unit norm, computing the norms - for each source node (axis=1). - """ - return normalize(values, norm="l1", axis=1).data + return edge_lengths def post_process(self, values: np.ndarray) -> torch.Tensor: if self.invert: diff --git a/src/anemoi/graphs/edges/connections.py b/src/anemoi/graphs/edges/builder.py similarity index 100% rename from src/anemoi/graphs/edges/connections.py rename to src/anemoi/graphs/edges/builder.py diff --git a/src/anemoi/graphs/nodes/__init__.py b/src/anemoi/graphs/nodes/__init__.py index 5458495..beecc98 100644 --- a/src/anemoi/graphs/nodes/__init__.py +++ b/src/anemoi/graphs/nodes/__init__.py @@ -1,4 +1,4 @@ -from .nodes import NPZNodes -from .nodes import ZarrNodes +from .builder import NPZNodes +from .builder import ZarrNodes __all__ = ["ZarrNodes", "NPZNodes"] diff --git a/src/anemoi/graphs/nodes/nodes.py b/src/anemoi/graphs/nodes/builder.py similarity index 100% rename from src/anemoi/graphs/nodes/nodes.py rename to src/anemoi/graphs/nodes/builder.py diff --git a/tests/conftest.py b/tests/conftest.py index 80ebfaa..b6b8ba0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ import numpy as np import pytest import torch +import yaml from torch_geometric.data import HeteroData lats = [-0.15, 0, 0.15] @@ -50,3 +51,42 @@ def graph_nodes_and_edges() -> HeteroData: graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords) graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]]) return graph + + +@pytest.fixture +def config_file(tmp_path) -> tuple[str, str]: + """Mock grid_definition_path with files for 3 resolutions.""" + cfg = { + "nodes": { + "test_nodes": { + "node_builder": { + "_target_": "anemoi.graphs.nodes.NPZNodes", + "grid_definition_path": str(tmp_path), + "resolution": "o16", + }, + } + }, + "edges": [ + { + "nodes": {"src_name": "test_nodes", "dst_name": "test_nodes"}, + "edge_builder": { + "_target_": "anemoi.graphs.edges.KNNEdgeBuilder", + "num_nearest_neighbours": 3, + }, + "attributes": { + "dist_norm": { + "_target_": "anemoi.graphs.edges.attributes.EdgeLength", + "norm": "l1", + "invert": True, + }, + "directional_features": {"_target_": "anemoi.graphs.edges.attributes.DirectionalFeatures"}, + }, + }, + ], + } + file_name = "config.yaml" + + with (tmp_path / file_name).open("w") as file: + yaml.dump(cfg, file) + + return tmp_path, file_name diff --git a/tests/nodes/test_npz.py b/tests/nodes/test_npz.py index 8642e39..ebe88d9 100644 --- a/tests/nodes/test_npz.py +++ b/tests/nodes/test_npz.py @@ -2,7 +2,7 @@ import torch from torch_geometric.data import HeteroData -from anemoi.graphs.nodes.nodes import NPZNodes +from anemoi.graphs.nodes.builder import NPZNodes from anemoi.graphs.nodes.weights import AreaWeights from anemoi.graphs.nodes.weights import UniformWeights diff --git a/tests/nodes/test_zarr.py b/tests/nodes/test_zarr.py index e9a5234..ddf804f 100644 --- a/tests/nodes/test_zarr.py +++ b/tests/nodes/test_zarr.py @@ -3,29 +3,29 @@ import zarr from torch_geometric.data import HeteroData -from anemoi.graphs.nodes import nodes +from anemoi.graphs.nodes import builder from anemoi.graphs.nodes.weights import AreaWeights from anemoi.graphs.nodes.weights import UniformWeights def test_init(mocker, mock_zarr_dataset): """Test ZarrNodes initialization.""" - mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset) - node_builder = nodes.ZarrNodes("dataset.zarr") - assert isinstance(node_builder, nodes.BaseNodeBuilder) - assert isinstance(node_builder, nodes.ZarrNodes) + mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) + node_builder = builder.ZarrNodes("dataset.zarr") + assert isinstance(node_builder, builder.BaseNodeBuilder) + assert isinstance(node_builder, builder.ZarrNodes) def test_fail_init(): """Test ZarrNodes initialization with invalid resolution.""" with pytest.raises(zarr.errors.PathNotFoundError): - nodes.ZarrNodes("invalid_path.zarr") + builder.ZarrNodes("invalid_path.zarr") def test_register_nodes(mocker, mock_zarr_dataset): """Test ZarrNodes register correctly the nodes.""" - mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset) - node_builder = nodes.ZarrNodes("dataset.zarr") + mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset) + node_builder = builder.ZarrNodes("dataset.zarr") graph = HeteroData() graph = node_builder.register_nodes(graph, "test_nodes") @@ -39,8 +39,8 @@ def test_register_nodes(mocker, mock_zarr_dataset): @pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights]) def test_register_weights(mocker, graph_with_nodes: HeteroData, attr_class): """Test ZarrNodes register correctly the weights.""" - mocker.patch.object(nodes, "open_dataset", return_value=None) - node_builder = nodes.ZarrNodes("dataset.zarr") + mocker.patch.object(builder, "open_dataset", return_value=None) + node_builder = builder.ZarrNodes("dataset.zarr") config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}} graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config) diff --git a/tests/test_graphs.py b/tests/test_graphs.py index 846ee89..0ceb171 100644 --- a/tests/test_graphs.py +++ b/tests/test_graphs.py @@ -6,9 +6,23 @@ # nor does it submit to any jurisdiction. -def test_graphs(): - pass +from pathlib import Path +import torch +from torch_geometric.data import HeteroData -if __name__ == "__main__": - test_graphs() +from anemoi.graphs import create + + +def test_graphs(config_file: tuple[Path, str], mock_grids_path: tuple[str, int]): + """Test GraphCreator workflow.""" + tmp_path, config_name = config_file + graph_path = tmp_path / "graph.pt" + config_path = tmp_path / config_name + + create.GraphCreator(graph_path, config_path).create() + + graph = torch.load(graph_path) + assert isinstance(graph, HeteroData) + assert "test_nodes" in graph.node_types + assert ("test_nodes", "to", "test_nodes") in graph.edge_types