diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index 292d2a44..009155c0 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -256,9 +256,11 @@ def training_step(self, batch: TensorDict): micro_batches = batch.split(self.config.data.micro_batch_size) n_micro_batches = len(micro_batches) + step_loss = 0 for micro_batch in micro_batches: loss = self._compute_loss(batch=micro_batch) / n_micro_batches loss.backward() + step_loss += loss.item() self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) @@ -275,8 +277,9 @@ def training_step(self, batch: TensorDict): log_gpu_memory_usage('After offload weights', logger=logger) - # TODO: all reduce to get accurate loss - return {'train/loss': loss.detach().item(), 'train/lr(1e-3)': lr * 1e3} + step_loss = torch.tensor(step_loss).cuda() + torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG) + return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3} def validation_step(self, batch: TensorDict): self.fsdp_model.eval()