Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NCCL_SOCKET_IFNAME has no effect during pytorch distributed training with multiple NICs #1580

Open
hanruijiang opened this issue Jan 18, 2025 · 10 comments

Comments

@hanruijiang
Copy link

I am trying to use pytorch for multi-node distributed parallel training on 2 Debian servers with 3 RTX 3090s installed.

Each server has 2 NICs. One 1GB port (eno1) is assigned to the 192.168.3.* network segment, and one 10GB port (enp37s0f0) is assigned to the 192.168.5.* network segment.

I want them to use the 10GB port to communicate during training. However, they use the 1GB port to send data and the 10GB port to receive data.

I tried setting NCCL_SOCKET_IFNAME=enp37s0f0 as an environment variable, writing it to /etc/nccl.conf, and adding it to the python file (use os.environ['NCCL_SOCKET_IFNAME'] = 'enp37s0f0'). None of them worked.

Now, I can only temporarily solve this problem by modifying the routing table.

ip route add 192.168.5.0/24 dev enp37s0f0 metric 100

However, I want to add more 10GB NICs later to improve the communication capacity during training, which will require NCCL_SOCKET_IFNAME to specify multiple network ports, but the routing table does not have this capability.

What is the reason why NCCL_SOCKET_IFNAME does not work? How should I solve it?

training script

ddp.py

import os
    
import torch
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader

import torch.profiler

import torchvision
import torchvision.transforms as transforms

import argparse
import random
import numpy as np
from PIL import Image

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

class RandomDataset(torch.utils.data.Dataset):
    def __init__(self, transform, image_size=224, num_samples=1000):
        self.num_samples = num_samples
        self.data = torch.randint(0, 256, (num_samples, 3, image_size, image_size), dtype=torch.uint8)  # generate random images
        self.targets = torch.randint(0, 10, (num_samples,), dtype=torch.long)  # generate random labels (10 classes)
        self.transform = transform

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        image = self.data[idx]
        label = self.targets[idx]
        image = Image.fromarray(image.permute(1, 2, 0).numpy())  # convert to PIL image
        image = self.transform(image)
        return image, label

def main():

    num_epochs_default = 50
    image_size_default = 224 # 512 # 
    batch_size_default = 256 # 1024
    learning_rate_default = 0.1
    random_seed_default = 0

    # Each process runs on 1 GPU device specified by the local_rank argument.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--local-rank", type=int, help="Local rank. Necessary for using the torch.distributed.launch utility.")
    parser.add_argument("--num_epochs", type=int, help="Number of training epochs.", default=num_epochs_default)
    parser.add_argument("--image_size", type=int, help="Image size in the dataset.", default=image_size_default)
    parser.add_argument("--batch_size", type=int, help="Training batch size for one process.", default=batch_size_default)
    parser.add_argument("--learning_rate", type=float, help="Learning rate.", default=learning_rate_default)
    parser.add_argument("--random_seed", type=int, help="Random seed.", default=random_seed_default)
    argv = parser.parse_args()

    local_rank = argv.local_rank
    num_epochs = argv.num_epochs
    image_size = argv.image_size
    batch_size = argv.batch_size
    learning_rate = argv.learning_rate
    random_seed = argv.random_seed

    # We need to use seeds to make sure that the models initialized in different processes are the same
    set_random_seeds(random_seed=random_seed)

    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend="nccl")

    # Encapsulate the model on the GPU assigned to the current process
    model = torchvision.models.resnet18(pretrained=False)

    device = torch.device("cuda:{}".format(local_rank))
    model = model.to(device)
    ddp_model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank)

    # Prepare dataset and dataloader
    transform = transforms.Compose([
        # transforms.RandomCrop(image_size, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

    # Data should be prefetched
    # Download should be set to be False, because it is not multiprocess safe
    train_set = RandomDataset(image_size=image_size, transform=transform)

    # Restricts data loading to a subset of the dataset exclusive to the current process
    train_sampler = DistributedSampler(dataset=train_set)

    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, sampler=train_sampler, num_workers=8)

    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(ddp_model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)
    
    # Loop over the dataset multiple times
    for epoch in range(num_epochs):

        ddp_model.train()

        for data in train_loader:
            inputs, labels = data[0].to(device), data[1].to(device)
            optimizer.zero_grad()
            outputs = ddp_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

if __name__ == "__main__":

    main()

launch on master node

python -m torch.distributed.launch \
--nproc_per_node=3 --nnodes=2 --node_rank=0 \
--master_addr=192.168.5.12 --master_port=1234 \
ddp.py

launch on worker node

python -m torch.distributed.launch \
--nproc_per_node=3 --nnodes=2 --node_rank=1 \
--master_addr=192.168.5.12 --master_port=1234 \
ddp.py

environment

conda create -n  huggingface python=3.11 conda-forge::mamba -y
conda activate huggingface
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia -y
mamba install conda-forge::deepspeed xformers::xformers -y
mamba install -c conda-forge transformers diffusers accelerate -y

system info

$ uname -a
Linux cx-12 6.1.0-29-amd64 #1 SMP PREEMPT_DYNAMIC Debian 6.1.123-1 (2025-01-02) x86_64 GNU/Linux

$ nvidia-smi
Sat Jan 18 23:36:20 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.142                Driver Version: 550.142        CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        On  |   00000000:06:00.0 Off |                  N/A |
|  0%   30C    P8             28W /  250W |       1MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA GeForce RTX 3090        On  |   00000000:21:00.0 Off |                  N/A |
|  0%   31C    P8             22W /  250W |       1MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   2  NVIDIA GeForce RTX 3090        On  |   00000000:41:00.0 Off |                  N/A |
|  0%   32C    P8             27W /  250W |       1MiB /  24576MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                                                         
+-----------------------------------------------------------------------------------------+
| Processes:                                                                              |
|  GPU   GI   CI        PID   Type   Process name                              GPU Memory |
|        ID   ID                                                               Usage      |
|=========================================================================================|
|  No running processes found                                                             |
+-----------------------------------------------------------------------------------------+

network info

on master

$ ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
    inet6 ::1/128 scope host noprefixroute 
       valid_lft forever preferred_lft forever
2: eno1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc mq state UP group default qlen 1000
    link/ether ac:1f:6b:e7:45:12 brd ff:ff:ff:ff:ff:ff
    altname enp4s0
    inet 192.168.3.12/20 brd 192.168.15.255 scope global dynamic eno1
       valid_lft 3599sec preferred_lft 3599sec
    inet6 fe80::ae1f:6bff:fee7:4512/64 scope link 
       valid_lft forever preferred_lft forever
3: eno2: <NO-CARRIER,BROADCAST,MULTICAST,UP> mtu 1500 qdisc mq state DOWN group default qlen 1000
    link/ether ac:1f:6b:e7:45:13 brd ff:ff:ff:ff:ff:ff
    altname enp5s0
4: enp37s0f0: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc mq state UP group default qlen 1000
    link/ether a0:36:90:2a:ef:08 brd ff:ff:ff:ff:ff:ff
    inet 192.168.5.12/20 brd 192.168.15.255 scope global dynamic enp37s0f0
       valid_lft 2416sec preferred_lft 2416sec
    inet6 fe80::a236:90ff:fe2a:ef08/64 scope link 
       valid_lft forever preferred_lft forever
5: enp37s0f1: <NO-CARRIER,BROADCAST,MULTICAST,UP> mtu 1500 qdisc mq state DOWN group default qlen 1000
    link/ether a0:36:90:2a:ef:09 brd ff:ff:ff:ff:ff:ff

$ ip route show
default via 192.168.1.1 dev eno1 
172.17.0.0/16 dev docker0 proto kernel scope link src 172.17.0.1 linkdown 
192.168.0.0/20 dev eno1 proto kernel scope link src 192.168.3.12 
192.168.0.0/20 dev enp37s0f0 proto kernel scope link src 192.168.5.12

on worker

$ ip a
1: lo: <LOOPBACK,UP,LOWER_UP> mtu 65536 qdisc noqueue state UNKNOWN group default qlen 1000
    link/loopback 00:00:00:00:00:00 brd 00:00:00:00:00:00
    inet 127.0.0.1/8 scope host lo
       valid_lft forever preferred_lft forever
    inet6 ::1/128 scope host noprefixroute 
       valid_lft forever preferred_lft forever
2: eno1: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc mq state UP group default qlen 1000
    link/ether ac:1f:6b:e7:21:0e brd ff:ff:ff:ff:ff:ff
    altname enp4s0
    inet 192.168.3.15/20 brd 192.168.15.255 scope global dynamic eno1
       valid_lft 3362sec preferred_lft 3362sec
    inet6 fe80::ae1f:6bff:fee7:210e/64 scope link 
       valid_lft forever preferred_lft forever
3: eno2: <NO-CARRIER,BROADCAST,MULTICAST,UP> mtu 1500 qdisc mq state DOWN group default qlen 1000
    link/ether ac:1f:6b:e7:21:0f brd ff:ff:ff:ff:ff:ff
    altname enp5s0
4: enp37s0f0: <BROADCAST,MULTICAST,UP,LOWER_UP> mtu 1500 qdisc mq state UP group default qlen 1000
    link/ether a0:36:90:2a:76:88 brd ff:ff:ff:ff:ff:ff
    inet 192.168.5.15/20 brd 192.168.15.255 scope global dynamic enp37s0f0
       valid_lft 2386sec preferred_lft 2386sec
    inet6 fe80::a236:90ff:fe2a:7688/64 scope link 
       valid_lft forever preferred_lft forever
5: enp37s0f1: <NO-CARRIER,BROADCAST,MULTICAST,UP> mtu 1500 qdisc mq state DOWN group default qlen 1000
    link/ether a0:36:90:2a:76:89 brd ff:ff:ff:ff:ff:ff
6: docker0: <NO-CARRIER,BROADCAST,MULTICAST,UP> mtu 1500 qdisc noqueue state DOWN group default

$ ip route show
default via 192.168.1.1 dev eno1 
172.17.0.0/16 dev docker0 proto kernel scope link src 172.17.0.1 linkdown 
192.168.0.0/20 dev eno1 proto kernel scope link src 192.168.3.15 
192.168.0.0/20 dev enp37s0f0 proto kernel scope link src 192.168.5.15
@haltingstate
Copy link

haltingstate commented Jan 18, 2025

I am also getting the same error. Deep Speed and NCCL appears to be bugged on Debian 12.

Deep Speed is based upon nccl. Nccl is defaulting to 1 gigabit adapter, even when its configured to only use the 10 gigabit adapter.

NCCL_SOCKET_IFNAME does not appear to work properly on Debian systems.

@AddyLaddy
Copy link
Collaborator

I'm surprised that NCCL didn't choose the fastest network adapter. If you could share the NCCL_DEBUG=INFO log file we may be able to help determine the cause.

@kiskra-nvidia
Copy link
Member

I'm guessing this is because both eno1 and enp37s0f0 are on the same subnet (192.168.0.0/20) and the Linux kernel ends up choosing the former (because it's listed first in the routing table?).

I'm curious: during NCCL communication, what are the local and remote IP addresses of the socket connections used by NCCL -- could you check with something like ss -t?

Is there a reason why you don't use separate subnets for different NICs? That would be the classic solution to such problems...

I'm guessing you can tweak the routing table by flipping the order of entries or increasing the metric of enp37s0f0 and then it will "just work", at least for two NICs. But, as you point out, it won't scale -- if you add more NICs, it will still use only one for sending...

Your ip route add 192.168.5.0/24 dev enp37s0f0 metric 100 trick is a round-about way of doing what I suggested above -- using different subnets. I'm wondering if metric 100 was actually required for you to get it to work? This should actually scale to more NICs, provided that they are on different /24 subnets.

@hanruijiang
Copy link
Author

I'm surprised that NCCL didn't choose the fastest network adapter. If you could share the NCCL_DEBUG=INFO log file we may be able to help determine the cause.

i collected this log file on the master and worker nodes.

nccl_master.log
nccl_worker.log

i also used ifstat to record the network usage during training.

on master node

ifstat -i enp37s0f0,eno1 1 > ifstat_master.log

on workernode

ifstat -i enp37s0f0,eno1 1 > ifstat_worker.log

ifstat_master.log
ifstat_worker.log

@haltingstate
Copy link

@kiskra-nvidia

I'm guessing this is because both eno1 and enp37s0f0 are on the same subnet (192.168.0.0/20) and the Linux kernel ends up choosing the former (because it's listed first in the routing table?).

Is the linux kernel choosing the network interface to use? Or is NCCL choosing the network interface?

If I

  • have two network interfaces I want to use, on different rails (own switches)
  • two ETH interfaces I do not want to use
  • and I add to NCCL_SOCKET_IFNAME the two network switches I want to use
    THEN
  • should NCCL use all 4 interfaces, based upon what the kernel wants?
  • or should NCCL use the 2 network interfaces I specified and nothing else?

Is there a way of stopping NCCL from using a network interface, if I dont want it to use that interface?

  1. How do I say "use interface A and B" but "do not use interface "Docker"

@haltingstate
Copy link

@kiskra-nvidia

This system appears to have

  • one ethernet 1 gb, eno1
  • one 10 gigabit
  • one docker interface

In this instance

  1. Is it possible its using the "Docker" interface for some reason? Because it was not excluded?

  2. How do you say "use interface A and B, but exclude C"

NCCL_SOCKET_IFNAME=ino1,10gigbit1,^docker ?

Can you use ^ after the "=" or have to use ^=

@hanruijiang
Copy link
Author

I'm curious: during NCCL communication, what are the local and remote IP addresses of the socket connections used by NCCL -- could you check with something like ss -t?

i run ss -t during training on both master and worker node

ss-t_worker.log
ss-t_master.log

Is there a reason why you don't use separate subnets for different NICs? That would be the classic solution to such problems...

I modified the ip route table like this, and it works

default via 192.168.1.1 dev eno1 
172.17.0.0/16 dev docker0 proto kernel scope link src 172.17.0.1 linkdown 
192.168.3.0/24 dev eno1 scope link 
192.168.5.0/24 dev enp37s0f0 scope link 

@haltingstate
Copy link

haltingstate commented Jan 19, 2025

@hanruijiang

ESTAB           0                0                                    192.168.3.12:ssh                                192.168.5.13:52834  
ESTAB           0                0                                    192.168.3.12:ssh                                192.168.5.13:42416
ESTAB           0                0                                    192.168.3.12:59970                              192.168.5.15:55573 
---
ESTAB           0                305756                               192.168.3.12:37178                              192.168.5.15:43603

This is connection between the 192.168.5.15 10 gigabit interface and the 192.168.3.12 eno1 interface

There is no connection to 192.168.3.12 shown in the logs for either the client or server.


@kiskra-nvidia

Search for "192.168.3." in the worker and master log.

Here:

https://github.com/user-attachments/files/18467806/nccl_master.log

https://github.com/user-attachments/files/18467841/nccl_worker.log

There is no logging of

  1. The port/ip/interface being used for listening (either "eno", or "192.168.3")

  2. No logging of the connection to 192.168.3 being opened

Clearly some logging message is missing.

  • as the log shows only 192.168.5.* IPs so connection should be FROM 192.168.5.* TO 192.168.5.*
  • there is no indication or reason why it would be listening on 192.168.3.*
  • there is no indication or reason why it would be connecting to 192.168.3.*

Could this be out of band data or because out of band interface was not set? And the out of band listening/sending connection is not being logged?

Maybe because NCCL_OOB_NET_IFNAME is not set?

Is NCCL_OOB_NET_ENABLE, enabled by default now?

Is NCCL_OOB_NET_IFNAME missing logging prints for connections?

@kiskra-nvidia
Copy link
Member

@hanruijiang

i collected this log file on the master and worker nodes.

nccl_master.log
nccl_worker.log

i run ss -t during training on both master and worker node

ss-t_worker.log
ss-t_master.log

Thank you! These logs confirms that NCCL connects to the correct destination IP addresses (192.168.5.12 and 192.168.5.15) but, given the routing tables, the Linux kernel chooses to get there over the eno1 interface, which is why the source addresses are 192.168.3.15 or 192.168.3.12, respectively.

I modified the ip route table like this, and it works

default via 192.168.1.1 dev eno1 
172.17.0.0/16 dev docker0 proto kernel scope link src 172.17.0.1 linkdown 
192.168.3.0/24 dev eno1 scope link 
192.168.5.0/24 dev enp37s0f0 scope link 

Yes, exactly, that should work -- thank you for confirming.

@kiskra-nvidia
Copy link
Member

@haltingstate

I understand your disappointment that the outcome ends up being different than what you requested. But NCCL's ability to follow your request is subject to the assumptions made in the code regarding how a network should be configured, as well as to the limitations of the available programming interfaces exposed by the underlying Linux kernel.

Is the linux kernel choosing the network interface to use? Or is NCCL choosing the network interface?

Short answer: the Linux kernel.

Longer answer: NCCL is choosing the destination IP address, which is subject to the NCCL_SOCKET_IFNAME filtering. The assumption in the code is that, if the node has multiple network interfaces, they will be on separate subnets, so the choice of the destination IP address indirectly selects the network interface to use.

That assumption breaks if multiple network interfaces are on the same subnet, which is what we are dealing with here. NCCL could possibly be made to work in this scenario by utilizing the bind-before-connect technique to request a specific source IP address (the address of the local interface we want to use), but we currently don't do it. Who knows what unforeseen scenarios that currently happen to work fine would break if we decided to force such a change.

Strictly speaking, selecting the source IP address is not equivalent to selecting a particular interface either -- what if two local interfaces have the same IP address? The Linux kernel does in fact have an API that allows to bind a socket to a device by interface name, but it's a privileged operation guarded by CAP_NET_RAW (oops!).

What I'm trying to say is that everything has limitations; every option has pros and cons.

If I

  • have two network interfaces I want to use, on different rails (own switches)
  • two ETH interfaces I do not want to use
  • and I add to NCCL_SOCKET_IFNAME the two network switches I want to use
    THEN
  • should NCCL use all 4 interfaces, based upon what the kernel wants?
  • or should NCCL use the 2 network interfaces I specified and nothing else?

Is there a way of stopping NCCL from using a network interface, if I dont want it to use that interface?

NCCL_SOCKET_IFNAME instructs NCCL to limit what interfaces it's willing to consider. Whether the Linux kernel will obey the same limits when it comes time to send data around, is subject to the assumptions listed earlier.

  1. How do I say "use interface A and B" but "do not use interface "Docker"

NCCL_SOCKET_IFNAME==A,B. Everything not listed explicitly is filtered out.

Is it possible its using the "Docker" interface for some reason? Because it was not excluded?

In practice -- not, because the Docker interface is on a different subnet.

How do you say "use interface A and B, but exclude C"

See above. C does not need to be specified.

NCCL_SOCKET_IFNAME=ino1,10gigbit1,^docker ?
Can you use ^ after the "=" or have to use ^=

No, that syntax doesn't work. = (exact match) and ^ (invert) must be at the beginning of the filter string (if both are to be used, ^ must come first).

There is no logging of

  1. The port/ip/interface being used for listening (either "eno", or "192.168.3")
  2. No logging of the connection to 192.168.3 being opened

Correct. You won't find these in the log, because the first one is not being done, and as to the second one, the address is not chosen explicitly by NCCL, but rather the Linux kernel chooses it implicitly.

as the log shows only 192.168.5.* IPs so connection should be FROM 192.168.5.* TO 192.168.5.*

In fact, the connection is FROM "anywhere" (unspecified) TO 192.168.5.*. Because NCCL does not specify the FROM address, the Linux kernel chooses it on its own, based on the TO address and the routing table. That's how FROM ends up being 192.168.3.*.

there is no indication or reason why it would be listening on 192.168.3.*
there is no indication or reason why it would be connecting to 192.168.3.*

Both of the above statements are true. NCCL never listens on 192.168.3.* and also never connects to 192.168.3.*. As I said, the source, not the destination, ends up being 192.168.3.*.

Could this be out of band data or because out of band interface was not set? And the out of band listening/sending connection is not being logged?
Maybe because NCCL_OOB_NET_IFNAME is not set?

No. OOB is used during bootstrap (early initialization) and is separate from the code used later for user data exchange, although it is subject to the same assumptions/limitations as listed above (but in a way it "doesn't matter", since OOB data exchanges are relatively low-volume).

Is NCCL_OOB_NET_ENABLE, enabled by default now?

No, it's still opt-in.

Is NCCL_OOB_NET_IFNAME missing logging prints for connections?

That's possible, given that many bootstrap connections are extremely short-lived. We try to limit the logging to what we think adds value.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants