From 87ef0deba56474be03dbf26ff71168817aa778f2 Mon Sep 17 00:00:00 2001 From: Arunabh Date: Wed, 24 Jan 2024 16:14:43 +0100 Subject: [PATCH 1/2] Fix attribute retrieval in import_graph --- hiddenlayer/pytorch_builder.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/hiddenlayer/pytorch_builder.py b/hiddenlayer/pytorch_builder.py index 702c167..7641426 100644 --- a/hiddenlayer/pytorch_builder.py +++ b/hiddenlayer/pytorch_builder.py @@ -79,7 +79,13 @@ def import_graph(hl_graph, model, args, input_names=None, verbose=False): # Op op = torch_node.kind() # Parameters - params = {k: torch_node[k] for k in torch_node.attributeNames()} + params = {} + for k in torch_node.attributeNames(): + try: + params[k] = getattr(torch_node, k) + except AttributeError: + # Handle the case where the attribute is not present + params[k] = None # Inputs/outputs # TODO: inputs = [i.unique() for i in node.inputs()] outputs = [o.unique() for o in torch_node.outputs()] From ea322bba283c57aed1e8e5b91751946a89cb2b53 Mon Sep 17 00:00:00 2001 From: Arunabh Date: Sun, 11 Feb 2024 13:41:48 +0100 Subject: [PATCH 2/2] Update pytorch_builder.py --- hiddenlayer/pytorch_builder.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/hiddenlayer/pytorch_builder.py b/hiddenlayer/pytorch_builder.py index 7641426..af35117 100644 --- a/hiddenlayer/pytorch_builder.py +++ b/hiddenlayer/pytorch_builder.py @@ -26,6 +26,13 @@ ht.Rename(op=r"BatchNormalization", to="BatchNorm"), ] +# https://github.com/pytorch/pytorch/blob/2efe4d809fdc94501fc38bf429e9a8d4205b51b6/torch/utils/tensorboard/_pytorch_graph.py#L384 +def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type.""" + sel = node.kindOf(key) + return getattr(node, sel)(key) + +torch._C.Node.__getitem__ = _node_get def dump_pytorch_graph(graph): """List all the nodes in a PyTorch graph.""" @@ -79,13 +86,7 @@ def import_graph(hl_graph, model, args, input_names=None, verbose=False): # Op op = torch_node.kind() # Parameters - params = {} - for k in torch_node.attributeNames(): - try: - params[k] = getattr(torch_node, k) - except AttributeError: - # Handle the case where the attribute is not present - params[k] = None + params = {k: torch_node[k] for k in torch_node.attributeNames()} # Inputs/outputs # TODO: inputs = [i.unique() for i in node.inputs()] outputs = [o.unique() for o in torch_node.outputs()]