You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm trying to understand how to enable user buffer registration for NCCL with NVLink SHARP on a node with H100 GPUs using PyTorch.
I believe that NVLS works when NCCL uses its own internal staging buffers. If I force NCCL_ALGO=NVLS,NVLSTree and run a simple program, I seem to get a successful result without any error messages.
Code (launched with `torchrun --standalone --nproc_per_node 8`)
import torch
import torch.distributed
import torch.profiler
def main():
options = torch.distributed.ProcessGroupNCCL.Options(is_high_priority_stream=True)
torch.distributed.init_process_group(backend="nccl", pg_options=options)
rank = torch.distributed.get_rank()
torch.cuda.set_device(rank)
a = torch.randn((16_384, 16_384), dtype=torch.float32, device="cuda")
torch.distributed.all_reduce(a)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
torch.distributed.all_reduce(a)
prof.export_chrome_trace(f"nvls_{rank}.json")
if __name__ == "__main__":
main()
In order to enable user buffer registration, I figured the easiest option would be to use CUDA graphs. When I do so, however, I start seeing some warning messages in the logs, and the code then seems to behave as before (e.g., same number of SMs used by NCCL's kernels).
Code (launched with `torchrun --standalone --nproc_per_node 8`)
import torch
import torch.distributed
import torch.profiler
def main():
options = torch.distributed.ProcessGroupNCCL.Options(is_high_priority_stream=True)
torch.distributed.init_process_group(backend="nccl", pg_options=options)
rank = torch.distributed.get_rank()
torch.cuda.set_device(rank)
a = torch.randn((16_384, 16_384), dtype=torch.float32, device="cuda")
torch.distributed.all_reduce(a)
g = torch.cuda.CUDAGraph()
with torch.cuda.graph(g):
torch.distributed.all_reduce(a)
with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof:
g.replay()
prof.export_chrome_trace(f"nvls_{rank}.json")
if __name__ == "__main__":
main()
Full logs
NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
NCCL INFO Bootstrap: Using eth0:10.7.147.73<0>
NCCL INFO cudaDriverVersion 12040
NCCL INFO NCCL version 2.24.3+cuda12.2
NCCL INFO Comm config Blocking set to 1
NCCL INFO NET/Plugin: Failed to find ncclNetPlugin_v9 symbol.
NCCL INFO NET/Plugin: Loaded net plugin NCCL RDMA Plugin v8 (v8)
NCCL INFO NET/Plugin: Failed to find ncclCollNetPlugin_v9 symbol.
NCCL INFO NET/Plugin: Loaded collnet plugin SHARP (v8)
NCCL INFO Plugin Path : /opt/hpcx/nccl_rdma_sharp_plugin/lib/libnccl-net.so
NCCL INFO P2P plugin IBext_v8
NCCL INFO NCCL_SOCKET_IFNAME set by environment to eth0
NCCL INFO NCCL_IB_PCI_RELAXED_ORDERING set by environment to 1.
NCCL INFO NET/IB : Using [0]ibp0:1/IB/SHARP [1]ibp1:1/IB/SHARP [2]ibp2:1/IB/SHARP [3]ibp3:1/IB/SHARP [4]ibp4:1/IB/SHARP [5]ibp5:1/IB/SHARP [6]ibp6:1/IB/SHARP [7]ibp7:1/IB/SHARP [RO]; OOB eth0:10.7.147.73<0>
NCCL INFO PROFILER/Plugin: Could not find: libnccl-profiler.so.
NCCL INFO Using network IBext_v8
NCCL INFO ncclCommInitRankConfig comm 0x57a7c7b82cf0 rank 0 nranks 8 cudaDev 0 nvmlDev 0 busId 1a000 commId 0xdf5506ed2c65f234 - Init START
NCCL INFO RAS client listening socket at ::1<28028>
NCCL INFO Bootstrap timings total 0.619019 (create 0.000032, send 0.000090, recv 0.105137, ring 0.000181, delay 0.000000)
NCCL INFO MNNVL busId 0x1a000 fabric UUID 0.0 cliqueId 0x0 state 3 healthMask 0x0
NCCL INFO Setting affinity for GPU 0 to 55555555,55555555,55555555,55555555
NCCL INFO NVLS multicast support is available on dev 0
NCCL INFO comm 0x57a7c7b82cf0 rank 0 nRanks 8 nNodes 1 localRanks 8 localRank 0 MNNVL 0
NCCL INFO Channel 00/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 01/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 02/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 03/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 04/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 05/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 06/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 07/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 08/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 09/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 10/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 11/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 12/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 13/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 14/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 15/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 16/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 17/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 18/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 19/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 20/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 21/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 22/24 : 0 1 2 3 4 5 6 7
NCCL INFO Channel 23/24 : 0 1 2 3 4 5 6 7
NCCL INFO Trees [0] 1/-1/-1->0->-1 [1] 1/-1/-1->0->-1 [2] 1/-1/-1->0->-1 [3] 1/-1/-1->0->-1 [4] 1/-1/-1->0->-1 [5] 1/-1/-1->0->-1 [6] 1/-1/-1->0->-1 [7] 1/-1/-1->0->-1 [8] 1/-1/-1->0->-1 [9] 1/-1/-1->0->-1 [10] 1/-1/-1->0->-1 [11] 1/-1/-1->0->-1 [12] 1/-1/-1->0->-1 [13] 1/-1/-1->0->-1 [14] 1/-1/-1->0->-1 [15] 1/-1/-1->0->-1 [16] 1/-1/-1->0->-1 [17] 1/-1/-1->0->-1 [18] 1/-1/-1->0->-1 [19] 1/-1/-1->0->-1 [20] 1/-1/-1->0->-1 [21] 1/-1/-1->0->-1 [22] 1/-1/-1->0->-1 [23] 1/-1/-1->0->-1
NCCL INFO P2P Chunksize set to 524288
NCCL INFO Check P2P Type intraNodeP2pSupport 1 directMode 0
NCCL INFO [Proxy Service] Device 0 CPU core 46
NCCL INFO [Proxy Service UDS] Device 0 CPU core 48
NCCL INFO NVLS Creating Multicast group nranks 8 size 2097152 on rank 0
NCCL INFO NVLS Created Multicast group 78c7f38a7860 nranks 8 size 2097152 on rank 0
NCCL INFO NVLS rank 0 (dev 0) alloc done, ucptr 0x78c85bc00000 ucgran 2097152 mcptr 0x78c85be00000 mcgran 2097152 size 2097152 (24576)
NCCL INFO NCCL_ALGO set by environment to NVLS,NVLSTree
NCCL INFO Enabled NCCL Func/Proto/Algo Matrix:
Function | LL LL128 Simple | Tree Ring CollNetDirect CollNetChain NVLS NVLSTree PAT
Broadcast | 1 2 1 | 0 0 0 0 1 1 0
Reduce | 1 2 1 | 0 0 0 0 1 1 0
AllGather | 1 2 1 | 0 0 0 0 1 1 0
ReduceScatter | 1 2 1 | 0 0 0 0 1 1 0
AllReduce | 1 2 1 | 0 0 0 0 1 1 0
NCCL INFO threadThresholds 8/8/64 | 64/8/64 | 512 | 512
NCCL INFO 24 coll channels, 24 collnet channels, 16 nvls channels, 32 p2p channels, 32 p2p channels per peer
NCCL INFO CC Off, workFifoBytes 1048576
NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v4 symbol.
NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v3 symbol.
NCCL INFO TUNER/Plugin: Failed to find ncclTunerPlugin_v2 symbol, using internal tuner instead.
NCCL INFO ncclCommInitRankConfig comm 0x57a7c7b82cf0 rank 0 nranks 8 cudaDev 0 nvmlDev 0 busId 1a000 commId 0xdf5506ed2c65f234 - Init COMPLETE
NCCL INFO Init timings - ncclCommInitRankConfig: rank 0 nranks 8 total 3.57 (kernels 0.29, alloc 1.32, bootstrap 0.62, allgathers 0.01, topo 1.07, graphs 0.01, connections 0.23, rest 0.02)
NCCL INFO NVLS comm 0x57a7c7b82cf0 headRank 0 nHeads 8 buffSize 1048576 nvlsPerRankSize 33554432 nvlsTotalSize 268435456
NCCL INFO NVLS Creating Multicast group nranks 8 size 536870912 on rank 0
NCCL INFO NVLS Created Multicast group 78c7f319dfc0 nranks 8 size 536870912 on rank 0
NCCL INFO NVLS rank 0 (dev 0) alloc done, ucptr 0x31c0000000 ucgran 2097152 mcptr 0x31e0000000 mcgran 536870912 size 536870912 (268435456)
NCCL INFO NVLS Creating Multicast group nranks 8 size 1073741824 on rank 0
NCCL INFO NVLS Created Multicast group 57a7c8741cc0 nranks 8 size 1073741824 on rank 0
transport/nvls.cc:580 NCCL WARN Cuda failure 1 'invalid argument'
transport/nvls.cc:703 NCCL WARN rank 0 failed to NVLS register sendbuff 0x78c800000000 sendbuffSize 1073741824 recvbuff 0x78c800000000 recvbuffSize 1073741824
NCCL INFO misc/socket.cc:880 -> 3
NCCL INFO misc/socket.cc:880 -> 3
NCCL INFO misc/socket.cc:880 -> 3
NCCL INFO misc/socket.cc:880 -> 3
NCCL INFO misc/socket.cc:880 -> 3
NCCL INFO NVLS Unbind MC handle 78c7f38a7860 size 2097152 dev 0
NCCL INFO NVLS Unmap mem UC handle 0x78c7f38a7e10(0x78c85bc00000) MC handle 0x78c7f38a7860(0x78c85be00000)
NCCL INFO NVLS Unbind MC handle 78c7f319dfc0 size 536870912 dev 0
NCCL INFO NVLS Unmap mem UC handle 0x78c7f30cb130(0x31c0000000) MC handle 0x78c7f319dfc0(0x31e0000000)
NCCL INFO comm 0x57a7c7b82cf0 rank 0 nranks 8 cudaDev 0 busId 1a000 - Abort COMPLETE
Could you help me understand what is causing that error? Is it an issue in my code? In my cluster's setup? In PyTorch's code? I need help narrowing down the problem to seek specific help from relevant teams.
After reading up on NCCL's docs about user buffer registration, I also tried a few other approaches:
NCCL_LEGACY_CUDA_REGISTER=1, to "force" NCCL to attempt registration even if PyTorch's caching allocator uses cudaMalloc
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True, to switch PyTorcg's allocator to use cuMem-based calls
TORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=1, to make PyTorch invoke ncclCommRegister on each buffer it allocates
None of the combinations I tried seems to make those error go away. Let me know if you want me to try some specific combination and report back the results.
I am using the latest stable PyTorch version (v2.5.1), built for CUDA 12.4. I am using the latest stable NCCL version (v2.24.3), installed from PyPI, and injected into my application using LD_PRELOAD.
Thanks!
The text was updated successfully, but these errors were encountered:
Hi Luca,
If you want to enable NVLS UB, you have to follow https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/bufferreg.html#nvlink-sharp-buffer-registration
AFAIK, pytorch has an issue supporting the NVLS UB. Currently, pytorch is using expandable_segments which causes the VM segment to overlap multiple physical mem segments. This would cause all UB registration to fail in NCCL since NCCL has no view of user's memory layout and cannot support multi-segment registrations.
The right way for pytorch to register buffers is:
Use cuMem* API which follows NCCL's requirement in the link above
I'm trying to understand how to enable user buffer registration for NCCL with NVLink SHARP on a node with H100 GPUs using PyTorch.
I believe that NVLS works when NCCL uses its own internal staging buffers. If I force
NCCL_ALGO=NVLS,NVLSTree
and run a simple program, I seem to get a successful result without any error messages.Code (launched with `torchrun --standalone --nproc_per_node 8`)
Full logs (`DEBUG=INFO` + `SUBSYS=+NVML`)
In order to enable user buffer registration, I figured the easiest option would be to use CUDA graphs. When I do so, however, I start seeing some warning messages in the logs, and the code then seems to behave as before (e.g., same number of SMs used by NCCL's kernels).
The error seems to come from the
cuMulticastBindAddr
call:nccl/src/transport/nvls.cc
Line 580 in 6aae379
Code (launched with `torchrun --standalone --nproc_per_node 8`)
Full logs
Could you help me understand what is causing that error? Is it an issue in my code? In my cluster's setup? In PyTorch's code? I need help narrowing down the problem to seek specific help from relevant teams.
After reading up on NCCL's docs about user buffer registration, I also tried a few other approaches:
NCCL_LEGACY_CUDA_REGISTER=1
, to "force" NCCL to attempt registration even if PyTorch's caching allocator usescudaMalloc
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
, to switch PyTorcg's allocator to usecuMem
-based callsTORCH_NCCL_USE_TENSOR_REGISTER_ALLOCATOR_HOOK=1
, to make PyTorch invokencclCommRegister
on each buffer it allocatesNone of the combinations I tried seems to make those error go away. Let me know if you want me to try some specific combination and report back the results.
I am using the latest stable PyTorch version (v2.5.1), built for CUDA 12.4. I am using the latest stable NCCL version (v2.24.3), installed from PyPI, and injected into my application using
LD_PRELOAD
.Thanks!
The text was updated successfully, but these errors were encountered: