Skip to content

Commit

Permalink
[Hoxfix] Fix CUDA_DEVICE_MAX_CONNECTIONS for comm overlap
Browse files Browse the repository at this point in the history
Co-authored-by: Edenzzzz <[email protected]>
  • Loading branch information
Edenzzzz and Edenzzzz authored Jul 5, 2024
1 parent 3420921 commit 8ec24b6
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 6 deletions.
6 changes: 6 additions & 0 deletions colossalai/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@

import os

# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when overlapping communication and computation,
# the order of of kernel launches on GPUs are the same as on the CPU so that comm is launched first.
# see https://github.com/NVIDIA/Megatron-LM/issues/533
# https://forums.developer.nvidia.com/t/how-many-streams-maximum-number-of-streams/6571/16
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"

import torch.distributed as dist

from colossalai.accelerator import get_accelerator
Expand Down
1 change: 0 additions & 1 deletion colossalai/legacy/nn/layer/parallel_1d/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def backward(ctx, grad_output):
handle = dist.all_reduce(grad_input, group=gpc.get_group(ctx.parallel_mode), async_op=True)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_ = torch.empty(1, device=grad_output.device) + 1

grad_weight = grad_output.t().matmul(total_input)
grad_bias = grad_output.sum(dim=0) if use_bias else None
Expand Down
4 changes: 0 additions & 4 deletions colossalai/shardformer/shard/shardformer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import os
from typing import Dict, List, Tuple

import torch.distributed as dist
Expand All @@ -11,9 +10,6 @@
from .shard_config import ShardConfig
from .sharder import ModelSharder

# set CUDA_DEVICE_MAX_CONNECTIONS=1 to ensure that when communication and computation overlap, the order of core scheduling is correct
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"


class ShardFormer:
"""
Expand Down
2 changes: 1 addition & 1 deletion examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def empty_init():
with get_profile_context(
args.profile,
args.ignore_steps,
len(dataloader) - 1,
1, # avoid creating massive log files
save_dir=f"profile/{time.strftime('%H:%M', time.localtime())}-{args.plugin}-llama-{args.config}",
) as prof:
if isinstance(plugin, HybridParallelPlugin) and args.pp > 1:
Expand Down

0 comments on commit 8ec24b6

Please sign in to comment.