From ee1d7a1e2dc82d46d91f7cdc8140dd346fa2737d Mon Sep 17 00:00:00 2001 From: Mario Santa Cruz Date: Wed, 3 Jul 2024 10:06:53 +0000 Subject: [PATCH] feat: detailed description --- .gitignore | 1 + src/anemoi/graphs/commands/inspect.py | 7 +- src/anemoi/graphs/inspector.py | 160 ++++++++++++++++-- src/anemoi/graphs/normalizer.py | 4 +- src/anemoi/graphs/plotting/displots.py | 20 ++- .../graphs/plotting/interactive_html.py | 8 +- 6 files changed, 170 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 1b49006..438a474 100644 --- a/.gitignore +++ b/.gitignore @@ -188,4 +188,5 @@ _version.py *.code-workspace /config* +outputs/ *.pt diff --git a/src/anemoi/graphs/commands/inspect.py b/src/anemoi/graphs/commands/inspect.py index 8be9b6d..d3ea30e 100644 --- a/src/anemoi/graphs/commands/inspect.py +++ b/src/anemoi/graphs/commands/inspect.py @@ -10,7 +10,12 @@ class Inspect(Command): timestamp = True def add_arguments(self, command_parser): - command_parser.add_argument("graph", help="Path to the graph (a .PT file).") + command_parser.add_argument( + "--show_attribute_distributions", + action="store_false", + help="Show distribution of edge/node attributes.", + ) + command_parser.add_argument("path", help="Path to the graph (a .PT file).") command_parser.add_argument("output_path", help="Path to store the inspection results.") def run(self, args): diff --git a/src/anemoi/graphs/inspector.py b/src/anemoi/graphs/inspector.py index f9b1b55..0350c65 100644 --- a/src/anemoi/graphs/inspector.py +++ b/src/anemoi/graphs/inspector.py @@ -1,31 +1,38 @@ import logging +import math import os -from dataclasses import dataclass from pathlib import Path +from typing import Optional from typing import Union import torch -from torch_geometric.data import HeteroData +from anemoi.utils.humanize import bytes +from anemoi.utils.humanize import number +from anemoi.utils.text import table from anemoi.graphs.plotting.displots import plot_dist_edge_attributes from anemoi.graphs.plotting.displots import plot_dist_node_attributes +from anemoi.graphs.plotting.interactive_html import plot_interactive_nodes from anemoi.graphs.plotting.interactive_html import plot_interactive_subgraph -from anemoi.graphs.plotting.interactive_html import plot_nodes from anemoi.graphs.plotting.interactive_html import plot_orphan_nodes logger = logging.getLogger(__name__) -@dataclass class GraphInspectorTool: """Inspect the graph.""" - graph: Union[HeteroData, str] - output_path: Path - - def __post_init__(self): - if not isinstance(self.graph, HeteroData): - self.graph = torch.load(self.graph) + def __init__( + self, + path: Union[str, Path], + output_path: Path, + show_attribute_distributions: Optional[bool] = True, + **kwargs, + ): + self.path = path + self.graph = torch.load(self.path) + self.output_path = output_path + self.show_attribute_distributions = show_attribute_distributions if isinstance(self.output_path, str): self.output_path = Path(self.output_path) @@ -34,21 +41,140 @@ def __post_init__(self): assert self.output_path.is_dir(), f"Path {self.output_path} is not a directory." assert os.access(self.output_path, os.W_OK), f"Path {self.output_path} is not writable." + @property + def total_size(self): + """Total size of the tensors in the graph (in bytes).""" + total_size = 0 + + for node_store in self.graph.node_stores: + for value in node_store.values(): + if isinstance(value, torch.Tensor): + total_size += value.numel() * value.element_size() + + for edge_store in self.graph.edge_stores: + for value in edge_store.values(): + if isinstance(value, torch.Tensor): + total_size += value.numel() * value.element_size() + + return total_size + + def get_node_summary(self) -> list[list]: + """Summary of the nodes in the graph. + + Returns + ------- + list[list] + Returns a list for each subgraph with the following information: + - Node name. + - Number of nodes. + - List of attribute names. + - Total dimension of the attributes. + - Min. latitude. + - Max. latitude. + - Min. longitude. + - Max. longitude. + """ + node_summary = [] + for name, nodes in self.graph.node_items(): + attributes = nodes.node_attrs() + attributes.remove("x") + + node_summary.append( + [ + name, + number(nodes.num_nodes), + ", ".join(attributes), + sum(nodes[attr].shape[1] for attr in attributes), + number(nodes.x[:, 0].min().item() / 2 / math.pi * 360), + number(nodes.x[:, 0].max().item() / 2 / math.pi * 360), + number(nodes.x[:, 1].min().item() / 2 / math.pi * 360), + number(nodes.x[:, 1].max().item() / 2 / math.pi * 360), + ] + ) + return node_summary + + def get_edge_summary(self) -> list[list]: + """Summary of the edges in the graph. + + Returns + ------- + list[list] + Returns a list for each subgraph with the following information: + - Source node name. + - Destination node name. + - Number of edges. + - Total dimension of the attributes. + - List of attribute names. + """ + edge_summary = [] + for (src_nodes, _, dst_nodes), edges in self.graph.edge_items(): + attributes = edges.edge_attrs() + attributes.remove("edge_index") + + edge_summary.append( + [ + src_nodes, + dst_nodes, + number(edges.num_edges), + sum(edges[attr].shape[1] for attr in attributes), + ", ".join(attributes), + ] + ) + return edge_summary + + def describe(self) -> None: + """Describe the graph.""" + print() + print(f"📦 Path : {self.path}") + print(f"💽 Size : {bytes(self.total_size)} ({number(self.total_size)})") + print() + print( + table( + self.get_node_summary(), + header=[ + "Nodes name", + "Num. nodes", + "Attributes", + "Attribute dim", + "Min. latitude", + "Max. latitude", + "Min. longitude", + "Max. longitude", + ], + align=["<", ">", ">", ">", ">", ">", ">", ">"], + margin=3, + ) + ) + print() + print() + print( + table( + self.get_edge_summary(), + header=["Source", "Destination", "Num. edges", "Attribute dim", "Attributes"], + align=["<", "<", ">", ">", ">"], + margin=3, + ) + ) + print("🔋 Graph ready.") + print() + def run_all(self): """Run all the inspector methods.""" - plot_dist_edge_attributes(self.graph, self.output_path / "distribution_edge_attributes.png") - plot_dist_node_attributes(self.graph, self.output_path / "distribution_node_attributes.png") + self.describe() + + if self.show_attribute_distributions: + plot_dist_edge_attributes(self.graph, self.output_path / "distribution_edge_attributes.png") + plot_dist_node_attributes(self.graph, self.output_path / "distribution_node_attributes.png") + plot_orphan_nodes(self.graph, self.output_path / "orphan_nodes.html") + logger.info("Saving interactive plots of nodes ...") for nodes_name, nodes_store in self.graph.node_items(): ofile = self.output_path / f"{nodes_name}_nodes.html" title = f"Map of {nodes_name} nodes" - plot_nodes(title, nodes_store.x[:, 0].numpy(), nodes_store.x[:, 1].numpy(), out_file=ofile) + plot_interactive_nodes(title, nodes_store.x[:, 0].numpy(), nodes_store.x[:, 1].numpy(), out_file=ofile) + logger.info("Saving interactive plots of subgraphs ...") for src_nodes, _, dst_nodes in self.graph.edge_types: ofile = self.output_path / f"{src_nodes}_to_{dst_nodes}.html" plot_interactive_subgraph(self.graph, (src_nodes, dst_nodes), out_file=ofile) - - -if __name__ == "__main__": - GraphInspectorTool("my_graph.pt", "output").run_all() diff --git a/src/anemoi/graphs/normalizer.py b/src/anemoi/graphs/normalizer.py index 5b3edcd..85613b3 100644 --- a/src/anemoi/graphs/normalizer.py +++ b/src/anemoi/graphs/normalizer.py @@ -18,8 +18,6 @@ def normalize(self, values: np.ndarray) -> np.ndarray: return values / np.linalg.norm(values) if self.norm == "unit-max": return values / np.amax(values) - if self.norm == "unit-sum": - return values / np.sum(values) if self.norm == "unit-std": std = np.std(values) if std == 0: @@ -27,5 +25,5 @@ def normalize(self, values: np.ndarray) -> np.ndarray: return values return values / std raise ValueError( - f"Weight normalization \"{values}\" is not valid. Options are: 'l1', 'l2', 'unit-max' 'unit-sum' or 'unit-std'." + f"Weight normalization \"{values}\" is not valid. Options are: 'l1', 'l2', 'unit-max' or 'unit-std'." ) diff --git a/src/anemoi/graphs/plotting/displots.py b/src/anemoi/graphs/plotting/displots.py index 58d68e2..2773394 100644 --- a/src/anemoi/graphs/plotting/displots.py +++ b/src/anemoi/graphs/plotting/displots.py @@ -22,7 +22,7 @@ def plot_dist_node_attributes(graph: HeteroData, out_file: Optional[Union[str, P dim_attrs = sum(attr_dims.values()) # Define the layout - _, axs = plt.subplots(num_nodes, dim_attrs, figsize=(10 * len(graph.node_types), 10), sharex=True, sharey=False) + _, axs = plt.subplots(num_nodes, dim_attrs, figsize=(10 * len(graph.node_types), 10)) if axs.ndim == 1: axs = axs.reshape(num_nodes, dim_attrs) @@ -31,8 +31,12 @@ def plot_dist_node_attributes(graph: HeteroData, out_file: Optional[Union[str, P for dim in range(attr_values): if attr_name in nodes_store: axs[i, j + dim].hist(nodes_store[attr_name][:, dim], bins=50) - axs[i, j + dim].set_ylabel(nodes_name) - axs[i, j + dim].set_xlabel(attr_name if attr_values == 1 else f"{attr_name}_{dim}") + if j + dim == 0: + axs[i, j + dim].set_ylabel(nodes_name) + if i == 0: + axs[i, j + dim].set_title(attr_name if attr_values == 1 else f"{attr_name}_{dim}") + elif i == num_nodes - 1: + axs[i, j + dim].set_xlabel(attr_name if attr_values == 1 else f"{attr_name}_{dim}") else: axs[i, j + dim].set_axis_off() @@ -50,7 +54,7 @@ def plot_dist_edge_attributes(graph: HeteroData, out_file: Optional[Union[str, P dim_attrs = sum(attr_dims.values()) # Define the layout - _, axs = plt.subplots(num_edges, dim_attrs, figsize=(10 * len(graph.edge_types), 10), sharex=True, sharey=False) + _, axs = plt.subplots(num_edges, dim_attrs, figsize=(10 * len(graph.edge_types), 10)) if axs.ndim == 1: axs = axs.reshape(num_edges, dim_attrs) @@ -59,8 +63,12 @@ def plot_dist_edge_attributes(graph: HeteroData, out_file: Optional[Union[str, P for dim in range(attr_values): if attr_name in edge_store: axs[i, j + dim].hist(edge_store[attr_name][:, dim], bins=50) - axs[i, j + dim].set_ylabel("".join(edge_name).replace("to", " --> ")) - axs[i, j + dim].set_xlabel(attr_name if attr_values == 1 else f"{attr_name}_{dim}") + if j + dim == 0: + axs[i, j + dim].set_ylabel("".join(edge_name).replace("to", " --> ")) + if i == 0: + axs[i, j + dim].set_title(attr_name if attr_values == 1 else f"{attr_name}_{dim}") + elif i == num_edges - 1: + axs[i, j + dim].set_xlabel(attr_name if attr_values == 1 else f"{attr_name}_{dim}") else: axs[i, j + dim].set_axis_off() diff --git a/src/anemoi/graphs/plotting/interactive_html.py b/src/anemoi/graphs/plotting/interactive_html.py index 286a0b5..93f0c47 100644 --- a/src/anemoi/graphs/plotting/interactive_html.py +++ b/src/anemoi/graphs/plotting/interactive_html.py @@ -1,3 +1,4 @@ +import logging from pathlib import Path from typing import Optional from typing import Union @@ -14,6 +15,8 @@ annotations_style = {"text": "", "showarrow": False, "xref": "paper", "yref": "paper", "x": 0.005, "y": -0.002} plotly_axis_config = {"showgrid": False, "zeroline": False, "showticklabels": False} +logger = logging.getLogger(__name__) + def plot_interactive_subgraph( graph: HeteroData, @@ -126,7 +129,7 @@ def plot_orphan_nodes(graph: HeteroData, out_file: Optional[Union[str, Path]] = orphans[f"{dst_nodes} orphans ({src_nodes} -->)"] = node_list(graph, dst_nodes, mask=list(tail_orphans)) if len(orphans) == 0: - print("No orphan nodes found.") + logger.info("No orphan nodes found.") return colorbar = plt.cm.rainbow(np.linspace(0, 1, len(orphans))) @@ -159,12 +162,11 @@ def plot_orphan_nodes(graph: HeteroData, out_file: Optional[Union[str, Path]] = if out_file is not None: fig.write_html(out_file) - print(f"Orphan nodes plot saved to {out_file}.") else: fig.show() -def plot_nodes( +def plot_interactive_nodes( title: str, lats: np.ndarray, lons: np.ndarray, mask: np.ndarray = None, out_file: Optional[str] = None ) -> None: """Plot nodes.