Skip to content

Commit

Permalink
Use more stable ordering of nodes during visualization (iree-org#104)
Browse files Browse the repository at this point in the history
Previously, nodes were identified by their id(). This is unstable and
would result in different files being generated for the same input. We
fix this by using our manually specified ordering for the nodes.

Signed-off-by: Harsh Menon <[email protected]>
  • Loading branch information
harsh-nod authored Aug 27, 2024
1 parent 03eccd9 commit fa3fa61
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions shark_turbine/kernel/wave/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
except:
graphviz_disabled = True
from torch import fx
import warnings


def number_nodes(graph: fx.Graph) -> dict[int, int]:
return {id(node): i for i, node in enumerate(graph.nodes)}


def visualize_graph(graph: fx.Graph, file_name: str):
if graphviz_disabled:
raise ImportError("pygraphviz not installed, cannot visualize graph")
node_numbering = number_nodes(graph)
G = pgv.AGraph(directed=True)
for node in graph.nodes:
G.add_node(id(node), label=node.name)
G.add_node(node_numbering[id(node)], label=node.name)
for node in graph.nodes:
for user in node.users.keys():
G.add_edge(id(node), id(user))
G.add_edge(node_numbering[id(node)], node_numbering[id(user)])
G.layout(prog="dot")
G.draw(file_name)

0 comments on commit fa3fa61

Please sign in to comment.