Skip to content

Commit

Permalink
bugfix (encoder edge lengths) + refector
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jun 27, 2024
1 parent 0f82ea7 commit a9c5ada
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 59 deletions.
41 changes: 14 additions & 27 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,9 @@
logger = logging.getLogger(__name__)


def generate_graph(graph_config: DotDict) -> HeteroData:
"""Generate a graph from a configuration.
Parameters
----------
graph_config : DotDict
Configuration for the nodes and edges (and its attributes).
Returns
-------
HeteroData
Graph.
"""
graph = HeteroData()

for name, nodes_cfg in graph_config.nodes.items():
graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {}))

for edges_cfg in graph_config.edges:
graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform(graph, edges_cfg.get("attributes", {}))

return graph


class GraphCreator:
"""Graph creator."""

def __init__(
self,
path,
Expand All @@ -55,9 +33,18 @@ def init(self):
if self._path_readable() and not self.overwrite:
raise Exception(f"{self.path} already exists. Use overwrite=True to overwrite.")

def load(self) -> HeteroData:
def generate_graph(self) -> HeteroData:
config = DotDict.from_file(self.config)
graph = generate_graph(config)

graph = HeteroData()
for name, nodes_cfg in config.nodes.items():
graph = instantiate(nodes_cfg.node_builder).transform(graph, name, nodes_cfg.get("attributes", {}))

for edges_cfg in config.edges:
graph = instantiate(edges_cfg.edge_builder, **edges_cfg.nodes).transform(
graph, edges_cfg.get("attributes", {})
)

return graph

def save(self, graph: HeteroData) -> None:
Expand All @@ -67,7 +54,7 @@ def save(self, graph: HeteroData) -> None:

def create(self):
self.init()
graph = self.load()
graph = self.generate_graph()
self.save(graph)

def _path_readable(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions src/anemoi/graphs/edges/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .connections import CutOffEdgeBuilder
from .connections import KNNEdgeBuilder
from .builder import CutOffEdgeBuilder
from .builder import KNNEdgeBuilder

__all__ = ["KNNEdgeBuilder", "CutOffEdgeBuilder"]
16 changes: 3 additions & 13 deletions src/anemoi/graphs/edges/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@

import numpy as np
import torch
from scipy.sparse import coo_matrix
from sklearn.preprocessing import normalize
from torch_geometric.data import HeteroData

from anemoi.graphs.edges.directional import directional_edge_features
Expand Down Expand Up @@ -51,29 +49,21 @@ def compute(self, graph: HeteroData, src_name: str, dst_name: str) -> torch.Tens


@dataclass
class HaversineDistance(BaseEdgeAttribute):
class EdgeLength(BaseEdgeAttribute):
"""Edge length feature."""

norm: str = "l1"
invert: bool = True

def compute(self, graph: HeteroData, src_name: str, dst_name: str):
def compute(self, graph: HeteroData, src_name: str, dst_name: str) -> np.ndarray:
"""Compute haversine distance (in kilometers) between nodes connected by edges."""
assert src_name in graph.node_types, f"Node {src_name} not found in graph."
assert dst_name in graph.node_types, f"Node {dst_name} not found in graph."
edge_index = graph[(src_name, "to", dst_name)].edge_index
src_coords = graph[src_name].x.numpy()[edge_index[0]]
dst_coords = graph[dst_name].x.numpy()[edge_index[1]]
edge_lengths = haversine_distance(src_coords, dst_coords)
return coo_matrix((edge_lengths, (edge_index[1], edge_index[0])))

def normalize(self, values) -> np.ndarray:
"""Normalize the edge length.
This method scales the edge lengths to a unit norm, computing the norms
for each source node (axis=1).
"""
return normalize(values, norm="l1", axis=1).data
return edge_lengths

def post_process(self, values: np.ndarray) -> torch.Tensor:
if self.invert:
Expand Down
File renamed without changes.
4 changes: 2 additions & 2 deletions src/anemoi/graphs/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .nodes import NPZNodes
from .nodes import ZarrNodes
from .builder import NPZNodes
from .builder import ZarrNodes

__all__ = ["ZarrNodes", "NPZNodes"]
File renamed without changes.
40 changes: 40 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
import torch
import yaml
from torch_geometric.data import HeteroData

lats = [-0.15, 0, 0.15]
Expand Down Expand Up @@ -50,3 +51,42 @@ def graph_nodes_and_edges() -> HeteroData:
graph["test_nodes"].x = 2 * torch.pi * torch.tensor(coords)
graph[("test_nodes", "to", "test_nodes")].edge_index = torch.tensor([[0, 1], [1, 2], [2, 3], [3, 0]])
return graph


@pytest.fixture
def config_file(tmp_path) -> tuple[str, str]:
"""Mock grid_definition_path with files for 3 resolutions."""
cfg = {
"nodes": {
"test_nodes": {
"node_builder": {
"_target_": "anemoi.graphs.nodes.NPZNodes",
"grid_definition_path": str(tmp_path),
"resolution": "o16",
},
}
},
"edges": [
{
"nodes": {"src_name": "test_nodes", "dst_name": "test_nodes"},
"edge_builder": {
"_target_": "anemoi.graphs.edges.KNNEdgeBuilder",
"num_nearest_neighbours": 3,
},
"attributes": {
"dist_norm": {
"_target_": "anemoi.graphs.edges.attributes.EdgeLength",
"norm": "l1",
"invert": True,
},
"directional_features": {"_target_": "anemoi.graphs.edges.attributes.DirectionalFeatures"},
},
},
],
}
file_name = "config.yaml"

with (tmp_path / file_name).open("w") as file:
yaml.dump(cfg, file)

return tmp_path, file_name
2 changes: 1 addition & 1 deletion tests/nodes/test_npz.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import torch
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes.nodes import NPZNodes
from anemoi.graphs.nodes.builder import NPZNodes
from anemoi.graphs.nodes.weights import AreaWeights
from anemoi.graphs.nodes.weights import UniformWeights

Expand Down
20 changes: 10 additions & 10 deletions tests/nodes/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,29 @@
import zarr
from torch_geometric.data import HeteroData

from anemoi.graphs.nodes import nodes
from anemoi.graphs.nodes import builder
from anemoi.graphs.nodes.weights import AreaWeights
from anemoi.graphs.nodes.weights import UniformWeights


def test_init(mocker, mock_zarr_dataset):
"""Test ZarrNodes initialization."""
mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset)
node_builder = nodes.ZarrNodes("dataset.zarr")
assert isinstance(node_builder, nodes.BaseNodeBuilder)
assert isinstance(node_builder, nodes.ZarrNodes)
mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset)
node_builder = builder.ZarrNodes("dataset.zarr")
assert isinstance(node_builder, builder.BaseNodeBuilder)
assert isinstance(node_builder, builder.ZarrNodes)


def test_fail_init():
"""Test ZarrNodes initialization with invalid resolution."""
with pytest.raises(zarr.errors.PathNotFoundError):
nodes.ZarrNodes("invalid_path.zarr")
builder.ZarrNodes("invalid_path.zarr")


def test_register_nodes(mocker, mock_zarr_dataset):
"""Test ZarrNodes register correctly the nodes."""
mocker.patch.object(nodes, "open_dataset", return_value=mock_zarr_dataset)
node_builder = nodes.ZarrNodes("dataset.zarr")
mocker.patch.object(builder, "open_dataset", return_value=mock_zarr_dataset)
node_builder = builder.ZarrNodes("dataset.zarr")
graph = HeteroData()

graph = node_builder.register_nodes(graph, "test_nodes")
Expand All @@ -39,8 +39,8 @@ def test_register_nodes(mocker, mock_zarr_dataset):
@pytest.mark.parametrize("attr_class", [UniformWeights, AreaWeights])
def test_register_weights(mocker, graph_with_nodes: HeteroData, attr_class):
"""Test ZarrNodes register correctly the weights."""
mocker.patch.object(nodes, "open_dataset", return_value=None)
node_builder = nodes.ZarrNodes("dataset.zarr")
mocker.patch.object(builder, "open_dataset", return_value=None)
node_builder = builder.ZarrNodes("dataset.zarr")
config = {"test_attr": {"_target_": f"anemoi.graphs.nodes.weights.{attr_class.__name__}"}}

graph = node_builder.register_attributes(graph_with_nodes, "test_nodes", config)
Expand Down
22 changes: 18 additions & 4 deletions tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,23 @@
# nor does it submit to any jurisdiction.


def test_graphs():
pass
from pathlib import Path

import torch
from torch_geometric.data import HeteroData

if __name__ == "__main__":
test_graphs()
from anemoi.graphs import create


def test_graphs(config_file: tuple[Path, str], mock_grids_path: tuple[str, int]):
"""Test GraphCreator workflow."""
tmp_path, config_name = config_file
graph_path = tmp_path / "graph.pt"
config_path = tmp_path / config_name

create.GraphCreator(graph_path, config_path).create()

graph = torch.load(graph_path)
assert isinstance(graph, HeteroData)
assert "test_nodes" in graph.node_types
assert ("test_nodes", "to", "test_nodes") in graph.edge_types

0 comments on commit a9c5ada

Please sign in to comment.