This project is intended to explore PyTorch's Distributed ML training capabilities, specifically, the Distributed Data Parallel strategy (DDP).
Disclaimer: You will not be able to run this code directly as it depends on a large dataset that is not in this repo (a 20 class slice of https://www.kaggle.com/datasets/ambityga/imagenet100). I was able to upload 4 trained models using git lfs, so that model_results.ipynb
can be executed locally, but the actual training took place on EC2 instances.
file/dir | description |
---|---|
model_results.ipynb | used for testing trained models and calculating results |
dataloader_visualization.ipynb | visualizing data |
ddp_trainer.py | trainer class definition for DDP training |
multi_node_trainer.py | driver code for single and multinode training |
scripts/ | used to easily execute training code on EC2 VMs with pytorch's torchrun |
bash scripts/one_node_train.sh <TRAIN_TIME> <SAVE_NAME> <BATCH_SIZE>
Example:
bash scripts/one_node_train.sh 3.0 single_node 32
bash scripts/one_node_train.sh <TRAIN_TIME> <WORKER_NUM> <WORLD_SIZE> <SAVE_NAME> <BATCH_SIZE>
Example: (executed on each participating node)
bash scripts/multi_node_train.sh 3.0 0 4 multi_node 32
bash scripts/multi_node_train.sh 3.0 1 4 multi_node 32
bash scripts/multi_node_train.sh 3.0 2 4 multi_node 32
bash scripts/multi_node_train.sh 3.0 3 4 multi_node 32
See useful_server_commands.txt for more examples. Note, need to configure master server IP in scripts/multi_node_train.sh
TRAIN_TIME = Total training time in hours
WORKER_NUM = worker number
WORLD_SIZE = Total worker count
SAVE_NAME = Save name for saving snapshots (For fault tolerance) and training metrics
BATCH_SIZE = Batch size per device
Import:
from torch.distributed import init_process_group, destroy_process_group
Description:
Launches (and destroys) 1 subprocess per GPU, per system.
Code:
from multi_node_trainer.py
Description:
Newer form of python -m torch.distributed.launch --use_env train_script.py
. torchrun is include in pytorch >= 1.11
Bash Script example for running torchrun:
from scripts\multi_node_train.sh
Important Import:
from torch.utils.data.distributed import DistributedSampler
Description:
Distributed Sampler that coordinates batch sampling across the cluster. It ensures that each GPU gets different batchs of data.
Code:
from multi_node_trainer.py
PyTorch Import:
from torch.nn.parallel import DistributedDataParallel as DDP
Description:
DDP Wrapper for torch.nn.Module
models
Code:
from ddp_trainer.py
Trainer Import:
from ddp_trainer import Trainer
Description:
Trainer class to assist with DDP training. It provides built-in snapshotting that allows for fault tolerant training sessions. This simply means that if one worker fails during training, the training can be restarted and the latest snapshot will be loaded in on each worker to continue training. All metrics that are related to the state of training, such as run time, epoch number, loss histories, and the model state with the best validation loss (and more) are automatically loaded back in after a training failure has occurred.
Snapshotting Code:
checkpoint saves
checkpoint loads
checkpoint load condition
saving final training metrics
For testing, I used a modified VGG19 model (replaced classifier to output 20 classes instead of 1000), and trained it on a 20 class slice of imageNet. I started the encoder with pretrained weights, but left them unfrozen to make sure that all gradients need to be calculated.
Exact model creation code given below:
For PyTorch DDP testing I trained the same model with the same data on 1, 2, 4, and 8 AWS EC2 instances.
Instance Type | g4dn.2xlarge - 8 vCPUs - 1 Nvidia T4 GPU |
---|---|
Network Speed | 25 GB/s |
In each experiment, I trained for 2 hours.
AMI:
All EC2 nodes:
Note: Each Node has 1 GPU
1 Node | 2 Nodes | 4 Nodes | 8 Nodes | |
---|---|---|---|---|
Learning Rate (e-4) | 1.00 | 1.41 | 2.00 | 2.83 |
Global Batch Size | 32 | 64 | 128 | 256 |
Steps Per Epoch | 650 | 375 | 163 | 82 |
Avg Epoch Time(min) | 8.76 | 5.40 | 3.35 | 2.37 |
Epoch Num @ 2 hr | 14 | 23 | 36 | 54 |
Scaling Efficiency | 1.00 | 0.81 | 0.65 | 0.46 |
Best Validation Loss | 0.3795 | 0.3646 | 0.3558 | 0.3522 |
Test Loss | 0.4485 | 0.4351 | 0.4286 | 0.4287 |
Top1 Test Accuracy | 86.90% | 87.30% | 87.50% | 87.50% |
Scale up before scaling out!
As we can see from the above two graphs, the rate of training was not much improved (if at all) as the number of nodes increased. Theoretically, training times should be less if the right balance of learning rates, and batch sizes are found.
Depending on network speeds, it is very likely that the communication overhead caused by gradient synchronization between nodes will decrease scaling efficiency. I recommend connections that exceed 100GB/s, but even those kinds of speeds do not compare to linked GPU transfer speeds. Because of this, I recommend to perform training on multi GPU systems before trying to scale to multinodes. It's clear that 1 system with 8 GPUs is better than 8 systems with 1 GPU (not certain, but I think this is likely to be the case, unless network transfer speeds are very high). This may seem obvious, but it wasn't so obvious to me until this experiment (mostly because 8 nodes are likely to have 8x the number of cpu cores, and it's hard to know just how much those extra CPU cores factor in). More testing will be conducted in the future!