diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index aa29b531..49dc7752 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -10,6 +10,8 @@ ALL_TO_ALL, BROADCAST, COMM_COLL_NODE, + COMM_RECV_NODE, + COMM_SEND_NODE, COMP_NODE, REDUCE_SCATTER, GlobalMetadata, @@ -283,6 +285,13 @@ def convert_nodes(self, pytorch_nodes: Dict[int, PyTorchNode], chakra_nodes: Dic ] ) + elif chakra_node.type in {COMM_SEND_NODE, COMM_RECV_NODE}: + chakra_gpu_node.attr.extend( + [ + ChakraAttr(name="comm_size", int64_val=pytorch_gpu_node.comm_size), + ] + ) + chakra_nodes[chakra_gpu_node.id] = chakra_gpu_node def convert_to_chakra_node(self, chakra_nodes: Dict[int, ChakraNode], pytorch_node: PyTorchNode) -> ChakraNode: @@ -334,6 +343,12 @@ def get_chakra_node_type_from_pytorch_node(self, pytorch_node: PyTorchNode) -> i Returns: int: The corresponding Chakra node type. """ + if "sendrecv" in pytorch_node.name.lower(): + return COMM_SEND_NODE + if "send" in pytorch_node.name.lower(): + return COMM_SEND_NODE + if "recv" in pytorch_node.name.lower(): + return COMM_RECV_NODE if ( pytorch_node.is_gpu_op() and ("ncclKernel" in pytorch_node.name or "ncclDevKernel" in pytorch_node.name)