Skip to content

Commit

Permalink
make grad clipping optional
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Apr 30, 2024
1 parent 494be77 commit a9f8194
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/benchmarks/fsdp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def main(
mixed_precision: bool = True,
profile: bool = False,
trace_output: str = "/tmp/traces/olmo_core.chrome_trace.json.gz",
max_grad_norm: Optional[float] = None,
**kwargs,
):
model, optim, dataloader = build_components(
Expand Down Expand Up @@ -99,21 +100,25 @@ def on_trace_ready(p):
loss.backward()

# Clip gradient norms.
if hasattr(model, "clip_grad_norm_"):
model.clip_grad_norm_(1.0)
else:
clip_grad_norm_(model.parameters(), 1.0)
norm: Optional[torch.Tensor] = None
if max_grad_norm is not None:
if hasattr(model, "clip_grad_norm_"):
norm = model.clip_grad_norm_(1.0)
else:
norm = clip_grad_norm_(model.parameters(), 1.0)

# Take optimizer step.
optim.step()

batch_time = time.monotonic() - batch_start
if i > 0:
batch_times.append(batch_time)
norm_str = f"{norm.item():.3f}" if norm is not None else "n/a"
print_rank0(
f"Batch [{i+1}/{num_batches}]:\n"
f" loss={loss.item():.3f}\n"
f" throughput/seconds_per_batch={batch_time:.3f}",
f" throughput/seconds_per_batch={batch_time:.3f}\n",
f" grad/total_norm={norm_str}",
)

if profile and i == 2:
Expand Down Expand Up @@ -194,6 +199,10 @@ def on_trace_ready(p):
type=int,
default=1,
)
parser.add_argument(
"--max-grad-norm",
type=float,
)
parser.add_argument(
"--lr",
type=float,
Expand Down Expand Up @@ -241,5 +250,6 @@ def on_trace_ready(p):
mixed_precision=mixed_precision,
max_prefetch_count=args.max_prefetch_count,
learning_rate=args.lr,
max_grad_norm=args.max_grad_norm,
seed=args.seed,
)

0 comments on commit a9f8194

Please sign in to comment.