From 97ef4556ec6f654a7e80f4969c3d0a347e044d86 Mon Sep 17 00:00:00 2001 From: Joongun Park <8554137+JoongunPark@users.noreply.github.com> Date: Tue, 8 Oct 2024 17:48:08 -0400 Subject: [PATCH] Fix mishandling All-to-All communication --- src/converter/pytorch_converter.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index ea383a51..01498b14 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -350,6 +350,11 @@ def get_protobuf_node_type_from_json_node( return COMM_SEND_NODE if "recv" in keyword: return COMM_RECV_NODE + # In NCCL, all-to-all communication is implemented using point-to-point + # communications. More details can be found here: + # https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/p2p.html + if "nccl:all_to_all" in keyword: + return COMM_COLL_NODE if "ncclKernel" in json_node.name or "ncclDevKernel" in json_node.name: return COMM_COLL_NODE return COMP_NODE @@ -379,6 +384,10 @@ def get_collective_comm_type(self, name: str) -> int: for key in comm_type_mapping: if key in normalized_name: return comm_type_mapping[key] + # If both COMM_COLL_NAME and ncclDevKernel_SendRecv are present, this is nccl:all_to_all. + if "ncclDevKernel_SendRecv" in name: + return comm_type_mapping["alltoall"] + raise ValueError( f"The name '{name}' does not correspond to a recognized collective communication type. " "The converter determines collective communication types based on the node name of a GPU operator. "