From 8ec24b6a4d0e0dbec7da39e43c3c1b2cfcb0395d Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Fri, 5 Jul 2024 20:02:36 +0800 Subject: [PATCH] [Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap Co-authored-by: Edenzzzz --- colossalai/initialize.py | 6 ++++++ colossalai/legacy/nn/layer/parallel_1d/_operation.py | 1 - colossalai/shardformer/shard/shardformer.py | 4 ---- examples/language/llama/benchmark.py | 2 +- 4 files changed, 7 insertions(+), 6 deletions(-) diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 71d42312ee7d..4e2eff7ce352 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -3,6 +3,12 @@ import os +# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation, +# the order of of kernel launches on GPUs are the same as on the CPU so that comm is launched first. +# see https://github.com/NVIDIA/Megatron-LM/issues/533 +# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16 +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + import torch.distributed as dist from colossalai.accelerator import get_accelerator diff --git a/colossalai/legacy/nn/layer/parallel_1d/_operation.py b/colossalai/legacy/nn/layer/parallel_1d/_operation.py index f01da97ba39a..8b8f04ccf456 100644 --- a/colossalai/legacy/nn/layer/parallel_1d/_operation.py +++ b/colossalai/legacy/nn/layer/parallel_1d/_operation.py @@ -81,7 +81,6 @@ def backward(ctx, grad_output): handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True) # Delay the start of weight gradient computation shortly (3us) to have # all-reduce scheduled first and have GPU resources allocated - _ = torch.empty(1, device=grad_output.device) + 1 grad_weight = grad_output.t().matmul(total_input) grad_bias = grad_output.sum(dim=0) if use_bias else None diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index b54c5827316e..db03eec414c2 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,4 +1,3 @@ -import os from typing import Dict, List, Tuple import torch.distributed as dist @@ -11,9 +10,6 @@ from .shard_config import ShardConfig from .sharder import ModelSharder -# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when communication and computation overlap, the order of core scheduling is correct -os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" - class ShardFormer: """ diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 8a35db1f7038..2b7bd50b8766 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -292,7 +292,7 @@ def empty_init(): with get_profile_context( args.profile, args.ignore_steps, - len(dataloader) - 1, + 1, # avoid creating massive log files save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}", ) as prof: if isinstance(plugin, HybridParallelPlugin) and args.pp > 1: