Skip to content

Commit

Permalink
args.pipeline_mode=pipe to use torch.distributed.pipeline.sync.Pipe
Browse files Browse the repository at this point in the history
  • Loading branch information
pbelevich committed Feb 24, 2021
1 parent ea14d26 commit 77b3259
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 23 deletions.
64 changes: 48 additions & 16 deletions examples/BERT/cross_lingual_mlm_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch.nn as nn
import torch.optim as optim
from torch.distributed.optim import DistributedOptimizer
from torch.distributed.rpc import RRef
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

Expand All @@ -21,6 +22,7 @@
from torchtext.experimental.transforms import sentencepiece_tokenizer
from transforms import PretrainedSPVocab
from torchtext.experimental.models.utils import count_model_param
from torch.distributed.pipeline.sync import Pipe


def collate_batch(batch_data, args, mask_id, pad_id, text_transform):
Expand Down Expand Up @@ -58,7 +60,7 @@ def evaluate(data_source, model, mask_id, pad_id, ntokens, criterion, args, devi
return total_loss / (len(data_source) - 1) # Set batch # to 1 for inference


def local_step(model, data, targets, criterion, optimizer, ntokens):
def local_step(model, data, targets, criterion, optimizer, ntokens, args):
optimizer.zero_grad()
output = model(data)
loss = criterion(output.view(-1, ntokens), targets.view(-1))
Expand All @@ -69,7 +71,18 @@ def local_step(model, data, targets, criterion, optimizer, ntokens):
return res


def dist_step(model, data, targets, criterion, optimizer, ntokens):
def pipe_step(model, data, targets, criterion, optimizer, ntokens, args):
optimizer.zero_grad()
output = model(data).local_value() # Because torch.distributed.pipeline.sync.Pipe.forward returns RRef
loss = criterion(output.view(-1, ntokens), targets.view(-1))
loss.backward()
res = loss.item()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
optimizer.step()
return res


def rpc_step(model, data, targets, criterion, optimizer, ntokens, args):
with dist_autograd.context() as context_id:
output = model(data)
loss = criterion(output.view(-1, ntokens), targets.view(-1))
Expand All @@ -91,7 +104,7 @@ def train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
for batch, (data, targets) in enumerate(dataloader):
data = data.to(devices[0])
targets = targets.to(devices[-1])
loss = step_impl(model, data, targets, criterion, optimizer, ntokens)
loss = step_impl(model, data, targets, criterion, optimizer, ntokens, args)

total_loss += loss
if batch % args.log_interval == 0 and batch > 0:
Expand Down Expand Up @@ -171,12 +184,19 @@ def text_transform(x: str) -> List:
print("Allocating memory")
if args.pipeline_mode == 'sp':
model = SingleProcessPipeline(shards, devices)

optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75)
else:
elif args.pipeline_mode == 'pipe':
model = Pipe(SingleProcessPipeline(shards, devices, to_device=False), chunks=args.batch_size // args.split_size)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.75)
elif args.pipeline_mode == 'cpu' or args.pipeline_mode == 'cuda':
workers = [f"worker{i+1}" for i in range(len(devices))]
model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=(RemoteBaseCUDARPC if args.pipeline_mode == 'cuda' else RemoteBaseCPURPC))
if args.pipeline_mode == 'cpu':
impl = RemoteBaseCPURPC
elif args.pipeline_mode == 'cuda':
impl = RemoteBaseCUDARPC
model = RPCPipeline(shards, devices, workers, split_size=args.split_size, remote_base_class=impl)
optimizer = DistributedOptimizer(
optim.Adam,
model.parameter_rrefs(),
Expand All @@ -199,16 +219,25 @@ def text_transform(x: str) -> List:

epoch_start_time = time.time()
last_lr = scheduler.get_last_lr()[0] if scheduler is not None else args.lr

if args.pipeline_mode == 'sp':
step = local_step
elif args.pipeline_mode == 'pipe':
step = pipe_step
else:
step = rpc_step

if args.pipeline_mode == 'cpu':
train_devices = ["cpu"] # Because "TensorPipe RPC backend only supports CPU tensors by default, please move your tensors to CPU before sending them over RPC"
else:
train_devices = devices

train(model, mask_id, pad_id, train_loss_log, train_data, text_transform,
optimizer, criterion, ntokens, epoch, last_lr, args,
devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"],
local_step if args.pipeline_mode == 'sp' else dist_step)
optimizer, criterion, ntokens, epoch, last_lr, args, train_devices, step)

# Turn on evaluation mode which disables dropout.
model.eval()
val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args,
devices if args.pipeline_mode == 'sp' or args.pipeline_mode == 'cuda' else ["cpu"],
text_transform)
val_loss = evaluate(val_data, model, mask_id, pad_id, ntokens, criterion, args, train_devices, text_transform)
val_loss_log.append(val_loss)
print('-' * 89)
print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
Expand Down Expand Up @@ -253,7 +282,7 @@ def _forward(x):
print('-' * 89)


def run_worker(rank, args):
def run_worker(rank, world_size, args):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '29500'
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=256)
Expand All @@ -265,7 +294,7 @@ def run_worker(rank, args):
rpc.init_rpc(
"master",
rank=rank,
world_size=args.gpus+1,
world_size=world_size,
rpc_backend_options=options
)
run_main(args)
Expand All @@ -278,7 +307,7 @@ def run_worker(rank, args):
rpc.init_rpc(
f"worker{rank}",
rank=rank,
world_size=args.gpus+1,
world_size=world_size,
rpc_backend_options=options
)
pass
Expand Down Expand Up @@ -337,5 +366,8 @@ def run_worker(rank, args):

if args.pipeline_mode == 'sp':
run_main(args)
elif args.pipeline_mode == 'pipe':
# Because torch.distributed.pipeline.sync.Pipe.forward returns RRef and requires RPC
mp.spawn(run_worker, args=(1, args), nprocs=1, join=True)
else:
mp.spawn(run_worker, args=(args,), nprocs=args.gpus+1, join=True)
mp.spawn(run_worker, args=(args.gpus+1, args), nprocs=args.gpus+1, join=True)
12 changes: 5 additions & 7 deletions examples/BERT/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import threading
import concurrent.futures


class ToDevice(nn.Module):
def __init__(self, device):
super().__init__()
Expand All @@ -15,19 +16,17 @@ def forward(self, x):


class SingleProcessPipeline(nn.Sequential):
def __init__(self, shards, devices):
def __init__(self, shards, devices, to_device=True):
super().__init__()
assert len(shards) == len(devices)
self.devices = devices
self.seq = nn.Sequential()

with concurrent.futures.ThreadPoolExecutor() as executor:
concurrent.futures.wait([executor.submit(lambda s, d: s.to(d), shards[i], devices[i]) for i in range(len(shards))])

for i, shard in enumerate(shards):
self.seq.add_module(f'Shard({devices[i]})', shard)
if i != len(shards)-1:
self.seq.add_module(f'ToDevice({devices[i+1]})', ToDevice(devices[i+1]))
self.add_module(f'Shard({devices[i]})', shard)
if to_device and i != len(shards)-1:
self.add_module(f'ToDevice({devices[i+1]})', ToDevice(devices[i+1]))


class RemoteBaseCPURPC(nn.Module):
Expand Down Expand Up @@ -87,4 +86,3 @@ def parameter_rrefs(self):
for shard in self.shards:
remote_params.extend(shard.remote().parameter_rrefs().to_here())
return remote_params

0 comments on commit 77b3259

Please sign in to comment.