Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the displayed loss in the sft trainer for gradient accumulation > 1 #102

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions verl/trainer/fsdp_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,11 @@ def training_step(self, batch: TensorDict):

micro_batches = batch.split(self.config.data.micro_batch_size)
n_micro_batches = len(micro_batches)
step_loss = 0
for micro_batch in micro_batches:
loss = self._compute_loss(batch=micro_batch) / n_micro_batches
loss.backward()
step_loss += loss.item()

self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad)

Expand All @@ -275,8 +277,9 @@ def training_step(self, batch: TensorDict):

log_gpu_memory_usage('After offload weights', logger=logger)

# TODO: all reduce to get accurate loss
return {'train/loss': loss.detach().item(), 'train/lr(1e-3)': lr * 1e3}
step_loss = torch.tensor(step_loss).cuda()
torch.distributed.all_reduce(step_loss, op=torch.distributed.ReduceOp.AVG)
return {'train/loss': step_loss.detach().item(), 'train/lr(1e-3)': lr * 1e3}

def validation_step(self, batch: TensorDict):
self.fsdp_model.eval()
Expand Down
Loading