Skip to content

Commit

Permalink
FSDP memory usage improvements (#12)
Browse files Browse the repository at this point in the history
* set to none

* avoid allocating another sharded tensor

* try allocating one fewer

* clone when needed

* post reduce-scatter

* show memory usage when profiling
  • Loading branch information
epwalsh authored Apr 18, 2024
1 parent 8c75d23 commit a4e0ccf
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 25 deletions.
7 changes: 6 additions & 1 deletion src/benchmarks/fsdp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def on_trace_ready(p):
on_trace_ready=on_trace_ready,
)

print_rank0(torch.cuda.memory_summary())

print_rank0("Starting training...")
batch_times: deque[float] = deque([], 50)
with profiler as p:
Expand All @@ -86,7 +88,7 @@ def on_trace_ready(p):
batch_start = time.monotonic()

# Zero-gradients.
optim.zero_grad()
optim.zero_grad(set_to_none=True)

# Run forward pass.
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision):
Expand All @@ -110,6 +112,9 @@ def on_trace_ready(p):
f" throughput/seconds_per_batch={batch_time:.3f}",
)

if i == 2:
print_rank0(torch.cuda.memory_summary())

if p is not None:
p.step()

Expand Down
54 changes: 30 additions & 24 deletions src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ class FlatParamHandle:
the same shape as the sharded version of ``params_data``.
"""

params_sharded_grad_tmp: Optional[torch.Tensor] = None
"""
Temporary storage for the local consolidated sharded grads during the reduce-scatter.
"""

process_group: Optional[dist.ProcessGroup] = None

device: Optional[torch.device] = None
Expand Down Expand Up @@ -254,7 +249,8 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F

# Cast sharded ``params_data`` to ``dtype``.
if dtype is not None:
self.params_sharded_data_lp = self.params_data.sharded_data.to(dtype)
self.params_sharded_data_lp = self.params_data.sharded_chunk(all_params_unsharded_data)
self.params_sharded_data_lp.copy_(self.params_data.sharded_data)

# Initialize unsharded, padded gradient.
if set_grads and self.params_unsharded_grad is None:
Expand Down Expand Up @@ -364,10 +360,6 @@ def pre_reduce_scatter_grads_(
Stream.current(self.device).record_for(self.params_unsharded_grad)
self.params_unsharded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype)

self.params_sharded_grad_tmp = torch.empty(
self.params_data.sharded_shape, dtype=self.params_unsharded_grad.dtype, device=self.device
)

def reduce_scatter_grads_(
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
):
Expand All @@ -380,11 +372,8 @@ def reduce_scatter_grads_(

if not self._ran_pre_reduce_scatter_grads:
self.pre_reduce_scatter_grads_(grad_dtype=grad_dtype, grad_reduce_dtype=grad_reduce_dtype)
assert self.params_sharded_grad_tmp is not None
else:
assert self.params_sharded_grad_tmp is not None
Stream.current(self.device).record_for(self.params_unsharded_grad)
Stream.current(self.device).record_for(self.params_sharded_grad_tmp)

self._ran_pre_reduce_scatter_grads = False

Expand All @@ -397,23 +386,40 @@ def reduce_scatter_grads_(
if dist.get_backend() == dist.Backend.NCCL:
# Get chunks corresponding to each rank.
grad_chunks = self.params_data.chunk_unsharded(self.params_unsharded_grad)
dist.reduce_scatter(self.params_sharded_grad_tmp, grad_chunks, group=self.process_group)
dist.reduce_scatter(
grad_chunks[get_rank(group=self.process_group)], grad_chunks, group=self.process_group
)
else:
dist.all_reduce(self.params_unsharded_grad, group=self.process_group)
self.params_sharded_grad_tmp.copy_(self.params_data.sharded_chunk(self.params_unsharded_grad))

# Deallocate the unsharded padded grad.
# NOTE: Since we're potentially using a separate stream for this reduce-scatter, we need to make
# sure `params_unsharded_grad` is not deallocated before the reduce-scatter finishes.
Stream.current(self.device).record_for(self.params_unsharded_grad)
self.params_unsharded_grad = None
def post_reduce_scatter_grads_(
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
):
"""
Finalize sharded gradients after the reduce-scatter.
"""
grad_dtype = grad_dtype or self.params_data.dtype
grad_reduce_dtype = grad_reduce_dtype or grad_dtype

assert self.params_unsharded_grad is not None
new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad)

# Cast the reduce-scatter target to the right dtype, potentially accumulating it into
# the existing gradient.
# Cast the new sharded gradient to the right dtype, potentially accumulating it into
# the existing sharded gradient.
if self.params_sharded_grad is None:
self.params_sharded_grad = self.params_sharded_grad_tmp.to(grad_dtype)
if new_sharded_grad.dtype == grad_dtype:
self.params_sharded_grad = new_sharded_grad.clone()
else:
self.params_sharded_grad = new_sharded_grad.to(grad_dtype)
else:
self.params_sharded_grad.add_(self.params_sharded_grad_tmp)
self.params_sharded_grad.add_(new_sharded_grad)

# Deallocate the unsharded padded grad.
# NOTE: Since we're potentially using a separate stream here, we need to make
# sure `params_unsharded_grad` is not deallocated before this finishes.
Stream.current(self.device).record_for(self.params_unsharded_grad)
self.params_unsharded_grad = None
del new_sharded_grad

# At this point each param will be sharded again, and we set the grad for each param as a view
# into the sharded grad.
Expand Down
4 changes: 4 additions & 0 deletions src/olmo_core/distributed/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,10 @@ def _reduce_scatter_grads(self):
for handle in self.state.flat_param_handles:
handle.reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)

with self.state.post_backward_stream(wait_stream=self.state.reduce_stream):
for handle in self.state.flat_param_handles:
handle.post_reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)

def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None]:
count = 0
while prefetch_queue and count < self.max_prefetch_count:
Expand Down

0 comments on commit a4e0ccf

Please sign in to comment.