Skip to content

Commit

Permalink
Identify process group init nodes as METADATA nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jun 25, 2024
1 parent e6c22a4 commit 912ef70
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
6 changes: 4 additions & 2 deletions src/converter/pytorch_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,10 @@ def convert_nodes(self, pytorch_nodes: Dict[int, PyTorchNode], chakra_nodes: Dic
cases for GPU nodes and collective communication types.
"""
for _, pytorch_node in pytorch_nodes.items():
if (pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP) or (
pytorch_node.get_op_type() == PyTorchNodeType.LABEL
if (
(pytorch_node.get_op_type() == PyTorchNodeType.CPU_OP)
or (pytorch_node.get_op_type() == PyTorchNodeType.LABEL)
or (pytorch_node.get_op_type() == PyTorchNodeType.METADATA)
):
chakra_node = self.convert_to_chakra_node(chakra_nodes, pytorch_node)
chakra_nodes[chakra_node.id] = chakra_node
Expand Down
5 changes: 4 additions & 1 deletion src/converter/pytorch_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class PyTorchNodeType(Enum):
CPU_OP = 1
GPU_OP = 2
LABEL = 3 # Non-operator nodes
METADATA = 4


class PyTorchNode:
Expand Down Expand Up @@ -114,7 +115,9 @@ def get_op_type(self) -> PyTorchNodeType:
Returns
PyTorchNodeType: The type of the PyTorch operation.
"""
if self.is_gpu_op():
if "process_group:init" in self.name:
return PyTorchNodeType.METADATA
elif self.is_gpu_op():
return PyTorchNodeType.GPU_OP
elif hasattr(self, "op_schema") or hasattr(self, "outputs"):
return PyTorchNodeType.CPU_OP
Expand Down

0 comments on commit 912ef70

Please sign in to comment.