diff --git a/src/benchmarks/fsdp/train.py b/src/benchmarks/fsdp/train.py index 9a082370..245ba9d2 100644 --- a/src/benchmarks/fsdp/train.py +++ b/src/benchmarks/fsdp/train.py @@ -78,6 +78,8 @@ def on_trace_ready(p): on_trace_ready=on_trace_ready, ) + print_rank0(torch.cuda.memory_summary()) + print_rank0("Starting training...") batch_times: deque[float] = deque([], 50) with profiler as p: @@ -86,7 +88,7 @@ def on_trace_ready(p): batch_start = time.monotonic() # Zero-gradients. - optim.zero_grad() + optim.zero_grad(set_to_none=True) # Run forward pass. with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision): @@ -110,6 +112,9 @@ def on_trace_ready(p): f" throughput/seconds_per_batch={batch_time:.3f}", ) + if i == 2: + print_rank0(torch.cuda.memory_summary()) + if p is not None: p.step() diff --git a/src/olmo_core/distributed/fsdp/flat_param_handle.py b/src/olmo_core/distributed/fsdp/flat_param_handle.py index 6002b975..34f6d1e5 100644 --- a/src/olmo_core/distributed/fsdp/flat_param_handle.py +++ b/src/olmo_core/distributed/fsdp/flat_param_handle.py @@ -61,11 +61,6 @@ class FlatParamHandle: the same shape as the sharded version of ``params_data``. """ - params_sharded_grad_tmp: Optional[torch.Tensor] = None - """ - Temporary storage for the local consolidated sharded grads during the reduce-scatter. - """ - process_group: Optional[dist.ProcessGroup] = None device: Optional[torch.device] = None @@ -254,7 +249,8 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F # Cast sharded ``params_data`` to ``dtype``. if dtype is not None: - self.params_sharded_data_lp = self.params_data.sharded_data.to(dtype) + self.params_sharded_data_lp = self.params_data.sharded_chunk(all_params_unsharded_data) + self.params_sharded_data_lp.copy_(self.params_data.sharded_data) # Initialize unsharded, padded gradient. if set_grads and self.params_unsharded_grad is None: @@ -364,10 +360,6 @@ def pre_reduce_scatter_grads_( Stream.current(self.device).record_for(self.params_unsharded_grad) self.params_unsharded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype) - self.params_sharded_grad_tmp = torch.empty( - self.params_data.sharded_shape, dtype=self.params_unsharded_grad.dtype, device=self.device - ) - def reduce_scatter_grads_( self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None ): @@ -380,11 +372,8 @@ def reduce_scatter_grads_( if not self._ran_pre_reduce_scatter_grads: self.pre_reduce_scatter_grads_(grad_dtype=grad_dtype, grad_reduce_dtype=grad_reduce_dtype) - assert self.params_sharded_grad_tmp is not None else: - assert self.params_sharded_grad_tmp is not None Stream.current(self.device).record_for(self.params_unsharded_grad) - Stream.current(self.device).record_for(self.params_sharded_grad_tmp) self._ran_pre_reduce_scatter_grads = False @@ -397,23 +386,40 @@ def reduce_scatter_grads_( if dist.get_backend() == dist.Backend.NCCL: # Get chunks corresponding to each rank. grad_chunks = self.params_data.chunk_unsharded(self.params_unsharded_grad) - dist.reduce_scatter(self.params_sharded_grad_tmp, grad_chunks, group=self.process_group) + dist.reduce_scatter( + grad_chunks[get_rank(group=self.process_group)], grad_chunks, group=self.process_group + ) else: dist.all_reduce(self.params_unsharded_grad, group=self.process_group) - self.params_sharded_grad_tmp.copy_(self.params_data.sharded_chunk(self.params_unsharded_grad)) - # Deallocate the unsharded padded grad. - # NOTE: Since we're potentially using a separate stream for this reduce-scatter, we need to make - # sure `params_unsharded_grad` is not deallocated before the reduce-scatter finishes. - Stream.current(self.device).record_for(self.params_unsharded_grad) - self.params_unsharded_grad = None + def post_reduce_scatter_grads_( + self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None + ): + """ + Finalize sharded gradients after the reduce-scatter. + """ + grad_dtype = grad_dtype or self.params_data.dtype + grad_reduce_dtype = grad_reduce_dtype or grad_dtype + + assert self.params_unsharded_grad is not None + new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad) - # Cast the reduce-scatter target to the right dtype, potentially accumulating it into - # the existing gradient. + # Cast the new sharded gradient to the right dtype, potentially accumulating it into + # the existing sharded gradient. if self.params_sharded_grad is None: - self.params_sharded_grad = self.params_sharded_grad_tmp.to(grad_dtype) + if new_sharded_grad.dtype == grad_dtype: + self.params_sharded_grad = new_sharded_grad.clone() + else: + self.params_sharded_grad = new_sharded_grad.to(grad_dtype) else: - self.params_sharded_grad.add_(self.params_sharded_grad_tmp) + self.params_sharded_grad.add_(new_sharded_grad) + + # Deallocate the unsharded padded grad. + # NOTE: Since we're potentially using a separate stream here, we need to make + # sure `params_unsharded_grad` is not deallocated before this finishes. + Stream.current(self.device).record_for(self.params_unsharded_grad) + self.params_unsharded_grad = None + del new_sharded_grad # At this point each param will be sharded again, and we set the grad for each param as a view # into the sharded grad. diff --git a/src/olmo_core/distributed/fsdp/fsdp.py b/src/olmo_core/distributed/fsdp/fsdp.py index 57de6ff2..ae191658 100644 --- a/src/olmo_core/distributed/fsdp/fsdp.py +++ b/src/olmo_core/distributed/fsdp/fsdp.py @@ -657,6 +657,10 @@ def _reduce_scatter_grads(self): for handle in self.state.flat_param_handles: handle.reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype) + with self.state.post_backward_stream(wait_stream=self.state.reduce_stream): + for handle in self.state.flat_param_handles: + handle.post_reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype) + def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None]: count = 0 while prefetch_queue and count < self.max_prefetch_count: