From bc4747596582c0bf0f9e94ecdf575f1187037eeb Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 18 Apr 2024 10:54:20 -0700 Subject: [PATCH 1/6] set to none --- src/benchmarks/fsdp/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/benchmarks/fsdp/train.py b/src/benchmarks/fsdp/train.py index 9a082370..d80d7d6f 100644 --- a/src/benchmarks/fsdp/train.py +++ b/src/benchmarks/fsdp/train.py @@ -86,7 +86,7 @@ def on_trace_ready(p): batch_start = time.monotonic() # Zero-gradients. - optim.zero_grad() + optim.zero_grad(set_to_none=True) # Run forward pass. with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision): From dbcd7f1405d359eb1a4f2bca425ac740adf8889e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 18 Apr 2024 11:02:12 -0700 Subject: [PATCH 2/6] avoid allocating another sharded tensor --- src/olmo_core/distributed/fsdp/flat_param_handle.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/distributed/fsdp/flat_param_handle.py b/src/olmo_core/distributed/fsdp/flat_param_handle.py index 6002b975..fbff9016 100644 --- a/src/olmo_core/distributed/fsdp/flat_param_handle.py +++ b/src/olmo_core/distributed/fsdp/flat_param_handle.py @@ -254,7 +254,8 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F # Cast sharded ``params_data`` to ``dtype``. if dtype is not None: - self.params_sharded_data_lp = self.params_data.sharded_data.to(dtype) + self.params_sharded_data_lp = self.params_data.sharded_chunk(all_params_unsharded_data) + self.params_sharded_data_lp.copy_(self.params_data.sharded_data) # Initialize unsharded, padded gradient. if set_grads and self.params_unsharded_grad is None: From 96a34d4fe97d32236baa9ac53e016d337e5d4e9f Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 18 Apr 2024 11:18:21 -0700 Subject: [PATCH 3/6] try allocating one fewer --- .../distributed/fsdp/flat_param_handle.py | 32 +++++++------------ 1 file changed, 11 insertions(+), 21 deletions(-) diff --git a/src/olmo_core/distributed/fsdp/flat_param_handle.py b/src/olmo_core/distributed/fsdp/flat_param_handle.py index fbff9016..71ec518e 100644 --- a/src/olmo_core/distributed/fsdp/flat_param_handle.py +++ b/src/olmo_core/distributed/fsdp/flat_param_handle.py @@ -61,11 +61,6 @@ class FlatParamHandle: the same shape as the sharded version of ``params_data``. """ - params_sharded_grad_tmp: Optional[torch.Tensor] = None - """ - Temporary storage for the local consolidated sharded grads during the reduce-scatter. - """ - process_group: Optional[dist.ProcessGroup] = None device: Optional[torch.device] = None @@ -365,10 +360,6 @@ def pre_reduce_scatter_grads_( Stream.current(self.device).record_for(self.params_unsharded_grad) self.params_unsharded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype) - self.params_sharded_grad_tmp = torch.empty( - self.params_data.sharded_shape, dtype=self.params_unsharded_grad.dtype, device=self.device - ) - def reduce_scatter_grads_( self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None ): @@ -381,11 +372,8 @@ def reduce_scatter_grads_( if not self._ran_pre_reduce_scatter_grads: self.pre_reduce_scatter_grads_(grad_dtype=grad_dtype, grad_reduce_dtype=grad_reduce_dtype) - assert self.params_sharded_grad_tmp is not None else: - assert self.params_sharded_grad_tmp is not None Stream.current(self.device).record_for(self.params_unsharded_grad) - Stream.current(self.device).record_for(self.params_sharded_grad_tmp) self._ran_pre_reduce_scatter_grads = False @@ -395,13 +383,22 @@ def reduce_scatter_grads_( # Reduce the unsharded padded grad for all params. # NOTE: Only NCCL supports reduce-scatter. So with other backends we use all-reduce. + new_sharded_grad: torch.Tensor if dist.get_backend() == dist.Backend.NCCL: # Get chunks corresponding to each rank. grad_chunks = self.params_data.chunk_unsharded(self.params_unsharded_grad) - dist.reduce_scatter(self.params_sharded_grad_tmp, grad_chunks, group=self.process_group) + new_sharded_grad = grad_chunks[get_rank(group=self.process_group)] + dist.reduce_scatter(new_sharded_grad, grad_chunks, group=self.process_group) else: dist.all_reduce(self.params_unsharded_grad, group=self.process_group) - self.params_sharded_grad_tmp.copy_(self.params_data.sharded_chunk(self.params_unsharded_grad)) + new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad) + + # Cast the reduce-scatter target to the right dtype, potentially accumulating it into + # the existing gradient. + if self.params_sharded_grad is None: + self.params_sharded_grad = new_sharded_grad.to(grad_dtype) + else: + self.params_sharded_grad.add_(new_sharded_grad) # Deallocate the unsharded padded grad. # NOTE: Since we're potentially using a separate stream for this reduce-scatter, we need to make @@ -409,13 +406,6 @@ def reduce_scatter_grads_( Stream.current(self.device).record_for(self.params_unsharded_grad) self.params_unsharded_grad = None - # Cast the reduce-scatter target to the right dtype, potentially accumulating it into - # the existing gradient. - if self.params_sharded_grad is None: - self.params_sharded_grad = self.params_sharded_grad_tmp.to(grad_dtype) - else: - self.params_sharded_grad.add_(self.params_sharded_grad_tmp) - # At this point each param will be sharded again, and we set the grad for each param as a view # into the sharded grad. offset = 0 From 8c7f3a929b8bbe9a71b96b3d45e2a2718b3f416e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 18 Apr 2024 11:21:34 -0700 Subject: [PATCH 4/6] clone when needed --- src/olmo_core/distributed/fsdp/flat_param_handle.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/olmo_core/distributed/fsdp/flat_param_handle.py b/src/olmo_core/distributed/fsdp/flat_param_handle.py index 71ec518e..9d3a415d 100644 --- a/src/olmo_core/distributed/fsdp/flat_param_handle.py +++ b/src/olmo_core/distributed/fsdp/flat_param_handle.py @@ -396,7 +396,10 @@ def reduce_scatter_grads_( # Cast the reduce-scatter target to the right dtype, potentially accumulating it into # the existing gradient. if self.params_sharded_grad is None: - self.params_sharded_grad = new_sharded_grad.to(grad_dtype) + if new_sharded_grad.dtype == grad_dtype: + self.params_sharded_grad = new_sharded_grad.clone() + else: + self.params_sharded_grad = new_sharded_grad.to(grad_dtype) else: self.params_sharded_grad.add_(new_sharded_grad) @@ -405,6 +408,7 @@ def reduce_scatter_grads_( # sure `params_unsharded_grad` is not deallocated before the reduce-scatter finishes. Stream.current(self.device).record_for(self.params_unsharded_grad) self.params_unsharded_grad = None + del new_sharded_grad # At this point each param will be sharded again, and we set the grad for each param as a view # into the sharded grad. From ddabdb22af5fabc72eb97db249c978cdb7a2a20e Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 18 Apr 2024 11:29:43 -0700 Subject: [PATCH 5/6] post reduce-scatter --- .../distributed/fsdp/flat_param_handle.py | 27 +++++++++++++------ src/olmo_core/distributed/fsdp/fsdp.py | 4 +++ 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/src/olmo_core/distributed/fsdp/flat_param_handle.py b/src/olmo_core/distributed/fsdp/flat_param_handle.py index 9d3a415d..34f6d1e5 100644 --- a/src/olmo_core/distributed/fsdp/flat_param_handle.py +++ b/src/olmo_core/distributed/fsdp/flat_param_handle.py @@ -383,18 +383,29 @@ def reduce_scatter_grads_( # Reduce the unsharded padded grad for all params. # NOTE: Only NCCL supports reduce-scatter. So with other backends we use all-reduce. - new_sharded_grad: torch.Tensor if dist.get_backend() == dist.Backend.NCCL: # Get chunks corresponding to each rank. grad_chunks = self.params_data.chunk_unsharded(self.params_unsharded_grad) - new_sharded_grad = grad_chunks[get_rank(group=self.process_group)] - dist.reduce_scatter(new_sharded_grad, grad_chunks, group=self.process_group) + dist.reduce_scatter( + grad_chunks[get_rank(group=self.process_group)], grad_chunks, group=self.process_group + ) else: dist.all_reduce(self.params_unsharded_grad, group=self.process_group) - new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad) - # Cast the reduce-scatter target to the right dtype, potentially accumulating it into - # the existing gradient. + def post_reduce_scatter_grads_( + self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None + ): + """ + Finalize sharded gradients after the reduce-scatter. + """ + grad_dtype = grad_dtype or self.params_data.dtype + grad_reduce_dtype = grad_reduce_dtype or grad_dtype + + assert self.params_unsharded_grad is not None + new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad) + + # Cast the new sharded gradient to the right dtype, potentially accumulating it into + # the existing sharded gradient. if self.params_sharded_grad is None: if new_sharded_grad.dtype == grad_dtype: self.params_sharded_grad = new_sharded_grad.clone() @@ -404,8 +415,8 @@ def reduce_scatter_grads_( self.params_sharded_grad.add_(new_sharded_grad) # Deallocate the unsharded padded grad. - # NOTE: Since we're potentially using a separate stream for this reduce-scatter, we need to make - # sure `params_unsharded_grad` is not deallocated before the reduce-scatter finishes. + # NOTE: Since we're potentially using a separate stream here, we need to make + # sure `params_unsharded_grad` is not deallocated before this finishes. Stream.current(self.device).record_for(self.params_unsharded_grad) self.params_unsharded_grad = None del new_sharded_grad diff --git a/src/olmo_core/distributed/fsdp/fsdp.py b/src/olmo_core/distributed/fsdp/fsdp.py index 57de6ff2..ae191658 100644 --- a/src/olmo_core/distributed/fsdp/fsdp.py +++ b/src/olmo_core/distributed/fsdp/fsdp.py @@ -657,6 +657,10 @@ def _reduce_scatter_grads(self): for handle in self.state.flat_param_handles: handle.reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype) + with self.state.post_backward_stream(wait_stream=self.state.reduce_stream): + for handle in self.state.flat_param_handles: + handle.post_reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype) + def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None]: count = 0 while prefetch_queue and count < self.max_prefetch_count: From ed25dca834df41c61cdb0ec5967e14395223a787 Mon Sep 17 00:00:00 2001 From: epwalsh Date: Thu, 18 Apr 2024 11:46:16 -0700 Subject: [PATCH 6/6] show memory usage when profiling --- src/benchmarks/fsdp/train.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/benchmarks/fsdp/train.py b/src/benchmarks/fsdp/train.py index d80d7d6f..245ba9d2 100644 --- a/src/benchmarks/fsdp/train.py +++ b/src/benchmarks/fsdp/train.py @@ -78,6 +78,8 @@ def on_trace_ready(p): on_trace_ready=on_trace_ready, ) + print_rank0(torch.cuda.memory_summary()) + print_rank0("Starting training...") batch_times: deque[float] = deque([], 50) with profiler as p: @@ -110,6 +112,9 @@ def on_trace_ready(p): f" throughput/seconds_per_batch={batch_time:.3f}", ) + if i == 2: + print_rank0(torch.cuda.memory_summary()) + if p is not None: p.step()