Skip to content

Commit

Permalink
Pipeline example code refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
MayDomine committed May 6, 2024
1 parent 290c1e3 commit a742092
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion bmtrain/pipe/comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def send_backward_recv_forward(self, backward_grad, need_data=False):
data = self.get_data()
else:
forward_state = [None]
data = None
data = [None]
return forward_state, data


Expand Down
4 changes: 2 additions & 2 deletions bmtrain/pipe/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ def backward_func(inp, backward_step, output, grad_output, optim_manager=None):
if not isinstance(grad_output, Iterable):
grad_output = [grad_output]
backward_step(output[0], grad_output[0])
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(bmt.config['load_stream'])
input_grad = [None]
if inp is not None:
input_grad = []
Expand Down Expand Up @@ -164,7 +166,5 @@ def pipeline_forward_backward(model, data_iterator, forward_step, backward_step,
commander.send_prev(input_grad)
blocklist = model.get_blocklist()
# blocklist.reduce_tied_module()

bmt.synchronize()


10 changes: 5 additions & 5 deletions example/pipe_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from bmtrain.global_var import config
from bmtrain import inspect
from bmtrain.pipe import pipeline_forward_backward
from typing import Iterable

def main():
bmt.init_distributed(
Expand Down Expand Up @@ -51,7 +52,7 @@ def data_loader():
torch.full_like(targets, -100, dtype=torch.long)
)
pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1)
yield enc_input, pos, pos<enc_length[:, None], targets
yield enc_input, pos, mask, targets

if config['tp_size'] > 1:
loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True)
Expand All @@ -67,7 +68,6 @@ def data_loader():
bmt.synchronize()
avg_time_recorder = bmt.utils.AverageRecorder()
avg_loss_recorder = bmt.utils.AverageRecorder()
global_loss_items = []

def forward_step(model, input, data):
enc_input, pos, mask, targets = data
Expand All @@ -89,20 +89,20 @@ def backward_step(output, grad_output):
output = optim_manager.scale_loss(output)
output = output / bmt.config['micros']
torch.autograd.backward(output, grad_tensors=grad_output)
current_stream = torch.cuda.current_stream()
current_stream.wait_stream(bmt.config['load_stream'])



for iteration in range(10):
# load data
global_loss_items = []
st = time.time()
rank = bmt.config["topology"].pipe_rank
# global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager)
optim_manager.zero_grad()
pipeline_forward_backward(model, data_loader(), forward_step, backward_step, micro , num_micros)
grad_norm = optim_manager.clip_grad_norm(optim_manager.optimizers[0].param_groups, 1.0, norm_type=2)
optim_manager.step()
optim_manager.zero_grad()
bmt.synchronize()
# record time and loss
iteration_time = time.time() - st

Expand Down

0 comments on commit a742092

Please sign in to comment.