Skip to content

Commit

Permalink
Fix ValueError for missing send/recv nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
TaekyungHeo committed Jun 11, 2024
1 parent 22fc788 commit 9be4248
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/converter/pytorch_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
ALL_TO_ALL,
BROADCAST,
COMM_COLL_NODE,
COMM_RECV_NODE,
COMM_SEND_NODE,
COMP_NODE,
REDUCE_SCATTER,
GlobalMetadata,
Expand Down Expand Up @@ -283,6 +285,13 @@ def convert_nodes(self, pytorch_nodes: Dict[int, PyTorchNode], chakra_nodes: Dic
]
)

elif chakra_node.type in {COMM_SEND_NODE, COMM_RECV_NODE}:
chakra_gpu_node.attr.extend(
[
ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size),
]
)

chakra_nodes[chakra_gpu_node.id] = chakra_gpu_node

def convert_to_chakra_node(self, chakra_nodes: Dict[int, ChakraNode], pytorch_node: PyTorchNode) -> ChakraNode:
Expand Down Expand Up @@ -334,6 +343,12 @@ def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> i
Returns:
int: The corresponding Chakra node type.
"""
if "sendrecv" in pytorch_node.name.lower():
return COMM_SEND_NODE
if "send" in pytorch_node.name.lower():
return COMM_SEND_NODE
if "recv" in pytorch_node.name.lower():
return COMM_RECV_NODE
if (
pytorch_node.is_gpu_op()
and ("ncclKernel" in pytorch_node.name or "ncclDevKernel" in pytorch_node.name)
Expand Down

0 comments on commit 9be4248

Please sign in to comment.