diff --git a/hiddenlayer/pytorch_builder.py b/hiddenlayer/pytorch_builder.py index 702c167..af35117 100644 --- a/hiddenlayer/pytorch_builder.py +++ b/hiddenlayer/pytorch_builder.py @@ -26,6 +26,13 @@ ht.Rename(op=r"BatchNormalization", to="BatchNorm"), ] +# https://github.com/pytorch/pytorch/blob/2efe4d809fdc94501fc38bf429e9a8d4205b51b6/torch/utils/tensorboard/_pytorch_graph.py#L384 +def _node_get(node: torch._C.Node, key: str): + """Gets attributes of a node which is polymorphic over return type.""" + sel = node.kindOf(key) + return getattr(node, sel)(key) + +torch._C.Node.__getitem__ = _node_get def dump_pytorch_graph(graph): """List all the nodes in a PyTorch graph."""