diff --git a/src/benchmarks/fsdp/train.py b/src/benchmarks/fsdp/train.py index 41219036..85a7d7ab 100644 --- a/src/benchmarks/fsdp/train.py +++ b/src/benchmarks/fsdp/train.py @@ -117,8 +117,8 @@ def on_trace_ready(p): print_rank0( f"Batch [{i+1}/{num_batches}]:\n" f" loss={loss.item():.3f}\n" - f" throughput/seconds_per_batch={batch_time:.3f}\n", - f" grad/total_norm={norm_str}", + f" throughput/seconds_per_batch={batch_time:.3f}\n" + f" grad/total_norm={norm_str}" ) if profile and i == 2: