diff --git a/src/converter/pytorch_converter.py b/src/converter/pytorch_converter.py index e7c49d9b..2371dccb 100644 --- a/src/converter/pytorch_converter.py +++ b/src/converter/pytorch_converter.py @@ -362,7 +362,7 @@ def get_collective_comm_type(self, name: str) -> int: comm_type_mapping = { "allreduce": ALL_REDUCE, "alltoall": ALL_TO_ALL, - "gather": ALL_GATHER, + "allgather": ALL_GATHER, "reducescatter": REDUCE_SCATTER, "broadcast": BROADCAST, # Additional cases can be added here