forked from mryab/efficient-dl-systems
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathddp_cifar100.py
111 lines (86 loc) · 3.23 KB
/
ddp_cifar100.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torchvision.datasets import CIFAR100
torch.set_num_threads(1)
def init_process(local_rank, fn, backend="nccl"):
"""Initialize the distributed environment."""
dist.init_process_group(backend, rank=local_rank)
size = dist.get_world_size()
fn(local_rank, size)
class Net(nn.Module):
"""
A very simple model with minimal changes from the tutorial, used for the sake of simplicity.
Feel free to replace it with EffNetV2-XL once you get comfortable injecting SyncBN into models programmatically.
"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 32, 3, 1)
self.conv2 = nn.Conv2d(32, 32, 3, 1)
self.dropout1 = nn.Dropout(0.25)
self.dropout2 = nn.Dropout(0.5)
self.fc1 = nn.Linear(6272, 128)
self.fc2 = nn.Linear(128, 100)
self.bn1 = nn.BatchNorm1d(128, affine=False) # to be replaced with SyncBatchNorm
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.dropout2(x)
output = self.fc2(x)
return output
def average_gradients(model):
size = float(dist.get_world_size())
for param in model.parameters():
dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
param.grad.data /= size
def run_training(rank, size):
torch.manual_seed(0)
dataset = CIFAR100(
"./cifar",
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
]
),
download=True,
)
# where's the validation dataset?
loader = DataLoader(dataset, sampler=DistributedSampler(dataset, size, rank), batch_size=64)
model = Net()
device = torch.device("cpu") # replace with "cuda" afterwards
model.to(device)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
num_batches = len(loader)
for _ in range(10):
epoch_loss = torch.zeros((1,), device=device)
for data, target in loader:
data = data.to(device)
target = target.to(device)
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.cross_entropy(output, target)
epoch_loss += loss.detach()
loss.backward()
average_gradients(model)
optimizer.step()
acc = (output.argmax(dim=1) == target).float().mean()
print(f"Rank {dist.get_rank()}, loss: {epoch_loss / num_batches}, acc: {acc}")
epoch_loss = 0
# where's the validation loop?
if __name__ == "__main__":
local_rank = int(os.environ["LOCAL_RANK"])
init_process(local_rank, fn=run_training, backend="gloo") # replace with "nccl" when testing on GPUs