diff --git a/bmtrain/pipe/comm.py b/bmtrain/pipe/comm.py index 32099a0..d65f0f1 100644 --- a/bmtrain/pipe/comm.py +++ b/bmtrain/pipe/comm.py @@ -107,7 +107,7 @@ def send_backward_recv_forward(self, backward_grad, need_data=False): data = self.get_data() else: forward_state = [None] - data = None + data = [None] return forward_state, data diff --git a/bmtrain/pipe/schedule.py b/bmtrain/pipe/schedule.py index abd898d..29299fb 100644 --- a/bmtrain/pipe/schedule.py +++ b/bmtrain/pipe/schedule.py @@ -26,6 +26,8 @@ def backward_func(inp, backward_step, output, grad_output, optim_manager=None): if not isinstance(grad_output, Iterable): grad_output = [grad_output] backward_step(output[0], grad_output[0]) + current_stream = torch.cuda.current_stream() + current_stream.wait_stream(bmt.config['load_stream']) input_grad = [None] if inp is not None: input_grad = [] @@ -164,7 +166,5 @@ def pipeline_forward_backward(model, data_iterator, forward_step, backward_step, commander.send_prev(input_grad) blocklist = model.get_blocklist() # blocklist.reduce_tied_module() - - bmt.synchronize() diff --git a/example/pipe_train.py b/example/pipe_train.py index c3c757f..64204a0 100644 --- a/example/pipe_train.py +++ b/example/pipe_train.py @@ -6,6 +6,7 @@ from bmtrain.global_var import config from bmtrain import inspect from bmtrain.pipe import pipeline_forward_backward +from typing import Iterable def main(): bmt.init_distributed( @@ -51,7 +52,7 @@ def data_loader(): torch.full_like(targets, -100, dtype=torch.long) ) pos = torch.arange(enc_input.size(1)).long().cuda().repeat(enc_input.size(0), 1) - yield enc_input, pos, pos 1: loss_func = bmt.loss.FusedCrossEntropy(ignore_index=-100, parallel=True) @@ -67,7 +68,6 @@ def data_loader(): bmt.synchronize() avg_time_recorder = bmt.utils.AverageRecorder() avg_loss_recorder = bmt.utils.AverageRecorder() - global_loss_items = [] def forward_step(model, input, data): enc_input, pos, mask, targets = data @@ -89,20 +89,20 @@ def backward_step(output, grad_output): output = optim_manager.scale_loss(output) output = output / bmt.config['micros'] torch.autograd.backward(output, grad_tensors=grad_output) - current_stream = torch.cuda.current_stream() - current_stream.wait_stream(bmt.config['load_stream']) for iteration in range(10): # load data + global_loss_items = [] st = time.time() rank = bmt.config["topology"].pipe_rank # global_loss, grad_norm = pipeline_forward_backward(model, data_loader(), micro , num_micros, optim_manager) + optim_manager.zero_grad() pipeline_forward_backward(model, data_loader(), forward_step, backward_step, micro , num_micros) grad_norm = optim_manager.clip_grad_norm(optim_manager.optimizers[0].param_groups, 1.0, norm_type=2) optim_manager.step() - optim_manager.zero_grad() + bmt.synchronize() # record time and loss iteration_time = time.time() - st