Skip to content

Commit

Permalink
feat: detailed description
Browse files Browse the repository at this point in the history
  • Loading branch information
JPXKQX committed Jul 3, 2024
1 parent 9acfcc4 commit ee1d7a1
Show file tree
Hide file tree
Showing 6 changed files with 170 additions and 30 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,5 @@ _version.py
*.code-workspace

/config*
outputs/
*.pt
7 changes: 6 additions & 1 deletion src/anemoi/graphs/commands/inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@ class Inspect(Command):
timestamp = True

def add_arguments(self, command_parser):
command_parser.add_argument("graph", help="Path to the graph (a .PT file).")
command_parser.add_argument(
"--show_attribute_distributions",
action="store_false",
help="Show distribution of edge/node attributes.",
)
command_parser.add_argument("path", help="Path to the graph (a .PT file).")
command_parser.add_argument("output_path", help="Path to store the inspection results.")

def run(self, args):
Expand Down
160 changes: 143 additions & 17 deletions src/anemoi/graphs/inspector.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,38 @@
import logging
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Optional
from typing import Union

import torch
from torch_geometric.data import HeteroData
from anemoi.utils.humanize import bytes
from anemoi.utils.humanize import number
from anemoi.utils.text import table

from anemoi.graphs.plotting.displots import plot_dist_edge_attributes
from anemoi.graphs.plotting.displots import plot_dist_node_attributes
from anemoi.graphs.plotting.interactive_html import plot_interactive_nodes
from anemoi.graphs.plotting.interactive_html import plot_interactive_subgraph
from anemoi.graphs.plotting.interactive_html import plot_nodes
from anemoi.graphs.plotting.interactive_html import plot_orphan_nodes

logger = logging.getLogger(__name__)


@dataclass
class GraphInspectorTool:
"""Inspect the graph."""

graph: Union[HeteroData, str]
output_path: Path

def __post_init__(self):
if not isinstance(self.graph, HeteroData):
self.graph = torch.load(self.graph)
def __init__(
self,
path: Union[str, Path],
output_path: Path,
show_attribute_distributions: Optional[bool] = True,
**kwargs,
):
self.path = path
self.graph = torch.load(self.path)
self.output_path = output_path
self.show_attribute_distributions = show_attribute_distributions

if isinstance(self.output_path, str):
self.output_path = Path(self.output_path)
Expand All @@ -34,21 +41,140 @@ def __post_init__(self):
assert self.output_path.is_dir(), f"Path {self.output_path} is not a directory."
assert os.access(self.output_path, os.W_OK), f"Path {self.output_path} is not writable."

@property
def total_size(self):
"""Total size of the tensors in the graph (in bytes)."""
total_size = 0

for node_store in self.graph.node_stores:
for value in node_store.values():
if isinstance(value, torch.Tensor):
total_size += value.numel() * value.element_size()

for edge_store in self.graph.edge_stores:
for value in edge_store.values():
if isinstance(value, torch.Tensor):
total_size += value.numel() * value.element_size()

return total_size

def get_node_summary(self) -> list[list]:
"""Summary of the nodes in the graph.
Returns
-------
list[list]
Returns a list for each subgraph with the following information:
- Node name.
- Number of nodes.
- List of attribute names.
- Total dimension of the attributes.
- Min. latitude.
- Max. latitude.
- Min. longitude.
- Max. longitude.
"""
node_summary = []
for name, nodes in self.graph.node_items():
attributes = nodes.node_attrs()
attributes.remove("x")

node_summary.append(
[
name,
number(nodes.num_nodes),
", ".join(attributes),
sum(nodes[attr].shape[1] for attr in attributes),
number(nodes.x[:, 0].min().item() / 2 / math.pi * 360),
number(nodes.x[:, 0].max().item() / 2 / math.pi * 360),
number(nodes.x[:, 1].min().item() / 2 / math.pi * 360),
number(nodes.x[:, 1].max().item() / 2 / math.pi * 360),
]
)
return node_summary

def get_edge_summary(self) -> list[list]:
"""Summary of the edges in the graph.
Returns
-------
list[list]
Returns a list for each subgraph with the following information:
- Source node name.
- Destination node name.
- Number of edges.
- Total dimension of the attributes.
- List of attribute names.
"""
edge_summary = []
for (src_nodes, _, dst_nodes), edges in self.graph.edge_items():
attributes = edges.edge_attrs()
attributes.remove("edge_index")

edge_summary.append(
[
src_nodes,
dst_nodes,
number(edges.num_edges),
sum(edges[attr].shape[1] for attr in attributes),
", ".join(attributes),
]
)
return edge_summary

def describe(self) -> None:
"""Describe the graph."""
print()
print(f"📦 Path : {self.path}")
print(f"💽 Size : {bytes(self.total_size)} ({number(self.total_size)})")
print()
print(
table(
self.get_node_summary(),
header=[
"Nodes name",
"Num. nodes",
"Attributes",
"Attribute dim",
"Min. latitude",
"Max. latitude",
"Min. longitude",
"Max. longitude",
],
align=["<", ">", ">", ">", ">", ">", ">", ">"],
margin=3,
)
)
print()
print()
print(
table(
self.get_edge_summary(),
header=["Source", "Destination", "Num. edges", "Attribute dim", "Attributes"],
align=["<", "<", ">", ">", ">"],
margin=3,
)
)
print("🔋 Graph ready.")
print()

def run_all(self):
"""Run all the inspector methods."""
plot_dist_edge_attributes(self.graph, self.output_path / "distribution_edge_attributes.png")
plot_dist_node_attributes(self.graph, self.output_path / "distribution_node_attributes.png")
self.describe()

if self.show_attribute_distributions:
plot_dist_edge_attributes(self.graph, self.output_path / "distribution_edge_attributes.png")
plot_dist_node_attributes(self.graph, self.output_path / "distribution_node_attributes.png")

plot_orphan_nodes(self.graph, self.output_path / "orphan_nodes.html")

logger.info("Saving interactive plots of nodes ...")
for nodes_name, nodes_store in self.graph.node_items():
ofile = self.output_path / f"{nodes_name}_nodes.html"
title = f"Map of {nodes_name} nodes"
plot_nodes(title, nodes_store.x[:, 0].numpy(), nodes_store.x[:, 1].numpy(), out_file=ofile)
plot_interactive_nodes(title, nodes_store.x[:, 0].numpy(), nodes_store.x[:, 1].numpy(), out_file=ofile)

logger.info("Saving interactive plots of subgraphs ...")
for src_nodes, _, dst_nodes in self.graph.edge_types:
ofile = self.output_path / f"{src_nodes}_to_{dst_nodes}.html"
plot_interactive_subgraph(self.graph, (src_nodes, dst_nodes), out_file=ofile)


if __name__ == "__main__":
GraphInspectorTool("my_graph.pt", "output").run_all()
4 changes: 1 addition & 3 deletions src/anemoi/graphs/normalizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@ def normalize(self, values: np.ndarray) -> np.ndarray:
return values / np.linalg.norm(values)
if self.norm == "unit-max":
return values / np.amax(values)
if self.norm == "unit-sum":
return values / np.sum(values)
if self.norm == "unit-std":
std = np.std(values)
if std == 0:
logger.warning(f"Std. dev. of the {self.__class__.__name__} is 0. Cannot normalize.")
return values
return values / std
raise ValueError(
f"Weight normalization \"{values}\" is not valid. Options are: 'l1', 'l2', 'unit-max' 'unit-sum' or 'unit-std'."
f"Weight normalization \"{values}\" is not valid. Options are: 'l1', 'l2', 'unit-max' or 'unit-std'."
)
20 changes: 14 additions & 6 deletions src/anemoi/graphs/plotting/displots.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def plot_dist_node_attributes(graph: HeteroData, out_file: Optional[Union[str, P
dim_attrs = sum(attr_dims.values())

# Define the layout
_, axs = plt.subplots(num_nodes, dim_attrs, figsize=(10 * len(graph.node_types), 10), sharex=True, sharey=False)
_, axs = plt.subplots(num_nodes, dim_attrs, figsize=(10 * len(graph.node_types), 10))
if axs.ndim == 1:
axs = axs.reshape(num_nodes, dim_attrs)

Expand All @@ -31,8 +31,12 @@ def plot_dist_node_attributes(graph: HeteroData, out_file: Optional[Union[str, P
for dim in range(attr_values):
if attr_name in nodes_store:
axs[i, j + dim].hist(nodes_store[attr_name][:, dim], bins=50)
axs[i, j + dim].set_ylabel(nodes_name)
axs[i, j + dim].set_xlabel(attr_name if attr_values == 1 else f"{attr_name}_{dim}")
if j + dim == 0:
axs[i, j + dim].set_ylabel(nodes_name)
if i == 0:
axs[i, j + dim].set_title(attr_name if attr_values == 1 else f"{attr_name}_{dim}")
elif i == num_nodes - 1:
axs[i, j + dim].set_xlabel(attr_name if attr_values == 1 else f"{attr_name}_{dim}")
else:
axs[i, j + dim].set_axis_off()

Expand All @@ -50,7 +54,7 @@ def plot_dist_edge_attributes(graph: HeteroData, out_file: Optional[Union[str, P
dim_attrs = sum(attr_dims.values())

# Define the layout
_, axs = plt.subplots(num_edges, dim_attrs, figsize=(10 * len(graph.edge_types), 10), sharex=True, sharey=False)
_, axs = plt.subplots(num_edges, dim_attrs, figsize=(10 * len(graph.edge_types), 10))
if axs.ndim == 1:
axs = axs.reshape(num_edges, dim_attrs)

Expand All @@ -59,8 +63,12 @@ def plot_dist_edge_attributes(graph: HeteroData, out_file: Optional[Union[str, P
for dim in range(attr_values):
if attr_name in edge_store:
axs[i, j + dim].hist(edge_store[attr_name][:, dim], bins=50)
axs[i, j + dim].set_ylabel("".join(edge_name).replace("to", " --> "))
axs[i, j + dim].set_xlabel(attr_name if attr_values == 1 else f"{attr_name}_{dim}")
if j + dim == 0:
axs[i, j + dim].set_ylabel("".join(edge_name).replace("to", " --> "))
if i == 0:
axs[i, j + dim].set_title(attr_name if attr_values == 1 else f"{attr_name}_{dim}")
elif i == num_edges - 1:
axs[i, j + dim].set_xlabel(attr_name if attr_values == 1 else f"{attr_name}_{dim}")
else:
axs[i, j + dim].set_axis_off()

Expand Down
8 changes: 5 additions & 3 deletions src/anemoi/graphs/plotting/interactive_html.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from pathlib import Path
from typing import Optional
from typing import Union
Expand All @@ -14,6 +15,8 @@
annotations_style = {"text": "", "showarrow": False, "xref": "paper", "yref": "paper", "x": 0.005, "y": -0.002}
plotly_axis_config = {"showgrid": False, "zeroline": False, "showticklabels": False}

logger = logging.getLogger(__name__)


def plot_interactive_subgraph(
graph: HeteroData,
Expand Down Expand Up @@ -126,7 +129,7 @@ def plot_orphan_nodes(graph: HeteroData, out_file: Optional[Union[str, Path]] =
orphans[f"{dst_nodes} orphans ({src_nodes} -->)"] = node_list(graph, dst_nodes, mask=list(tail_orphans))

if len(orphans) == 0:
print("No orphan nodes found.")
logger.info("No orphan nodes found.")
return

colorbar = plt.cm.rainbow(np.linspace(0, 1, len(orphans)))
Expand Down Expand Up @@ -159,12 +162,11 @@ def plot_orphan_nodes(graph: HeteroData, out_file: Optional[Union[str, Path]] =

if out_file is not None:
fig.write_html(out_file)
print(f"Orphan nodes plot saved to {out_file}.")
else:
fig.show()


def plot_nodes(
def plot_interactive_nodes(
title: str, lats: np.ndarray, lons: np.ndarray, mask: np.ndarray = None, out_file: Optional[str] = None
) -> None:
"""Plot nodes.
Expand Down

0 comments on commit ee1d7a1

Please sign in to comment.