Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FSDP memory usage improvements #12

Merged
merged 6 commits into from
Apr 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/benchmarks/fsdp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def on_trace_ready(p):
on_trace_ready=on_trace_ready,
)

print_rank0(torch.cuda.memory_summary())

print_rank0("Starting training...")
batch_times: deque[float] = deque([], 50)
with profiler as p:
Expand All @@ -86,7 +88,7 @@ def on_trace_ready(p):
batch_start = time.monotonic()

# Zero-gradients.
optim.zero_grad()
optim.zero_grad(set_to_none=True)

# Run forward pass.
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision):
Expand All @@ -110,6 +112,9 @@ def on_trace_ready(p):
f" throughput/seconds_per_batch={batch_time:.3f}",
)

if i == 2:
print_rank0(torch.cuda.memory_summary())

if p is not None:
p.step()

Expand Down
54 changes: 30 additions & 24 deletions src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,6 @@ class FlatParamHandle:
the same shape as the sharded version of ``params_data``.
"""

params_sharded_grad_tmp: Optional[torch.Tensor] = None
"""
Temporary storage for the local consolidated sharded grads during the reduce-scatter.
"""

process_group: Optional[dist.ProcessGroup] = None

device: Optional[torch.device] = None
Expand Down Expand Up @@ -254,7 +249,8 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F

# Cast sharded ``params_data`` to ``dtype``.
if dtype is not None:
self.params_sharded_data_lp = self.params_data.sharded_data.to(dtype)
self.params_sharded_data_lp = self.params_data.sharded_chunk(all_params_unsharded_data)
self.params_sharded_data_lp.copy_(self.params_data.sharded_data)

# Initialize unsharded, padded gradient.
if set_grads and self.params_unsharded_grad is None:
Expand Down Expand Up @@ -364,10 +360,6 @@ def pre_reduce_scatter_grads_(
Stream.current(self.device).record_for(self.params_unsharded_grad)
self.params_unsharded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype)

self.params_sharded_grad_tmp = torch.empty(
self.params_data.sharded_shape, dtype=self.params_unsharded_grad.dtype, device=self.device
)

def reduce_scatter_grads_(
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
):
Expand All @@ -380,11 +372,8 @@ def reduce_scatter_grads_(

if not self._ran_pre_reduce_scatter_grads:
self.pre_reduce_scatter_grads_(grad_dtype=grad_dtype, grad_reduce_dtype=grad_reduce_dtype)
assert self.params_sharded_grad_tmp is not None
else:
assert self.params_sharded_grad_tmp is not None
Stream.current(self.device).record_for(self.params_unsharded_grad)
Stream.current(self.device).record_for(self.params_sharded_grad_tmp)

self._ran_pre_reduce_scatter_grads = False

Expand All @@ -397,23 +386,40 @@ def reduce_scatter_grads_(
if dist.get_backend() == dist.Backend.NCCL:
# Get chunks corresponding to each rank.
grad_chunks = self.params_data.chunk_unsharded(self.params_unsharded_grad)
dist.reduce_scatter(self.params_sharded_grad_tmp, grad_chunks, group=self.process_group)
dist.reduce_scatter(
grad_chunks[get_rank(group=self.process_group)], grad_chunks, group=self.process_group
)
else:
dist.all_reduce(self.params_unsharded_grad, group=self.process_group)
self.params_sharded_grad_tmp.copy_(self.params_data.sharded_chunk(self.params_unsharded_grad))

# Deallocate the unsharded padded grad.
# NOTE: Since we're potentially using a separate stream for this reduce-scatter, we need to make
# sure `params_unsharded_grad` is not deallocated before the reduce-scatter finishes.
Stream.current(self.device).record_for(self.params_unsharded_grad)
self.params_unsharded_grad = None
def post_reduce_scatter_grads_(
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
):
"""
Finalize sharded gradients after the reduce-scatter.
"""
grad_dtype = grad_dtype or self.params_data.dtype
grad_reduce_dtype = grad_reduce_dtype or grad_dtype

assert self.params_unsharded_grad is not None
new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad)

# Cast the reduce-scatter target to the right dtype, potentially accumulating it into
# the existing gradient.
# Cast the new sharded gradient to the right dtype, potentially accumulating it into
# the existing sharded gradient.
if self.params_sharded_grad is None:
self.params_sharded_grad = self.params_sharded_grad_tmp.to(grad_dtype)
if new_sharded_grad.dtype == grad_dtype:
self.params_sharded_grad = new_sharded_grad.clone()
else:
self.params_sharded_grad = new_sharded_grad.to(grad_dtype)
else:
self.params_sharded_grad.add_(self.params_sharded_grad_tmp)
self.params_sharded_grad.add_(new_sharded_grad)

# Deallocate the unsharded padded grad.
# NOTE: Since we're potentially using a separate stream here, we need to make
# sure `params_unsharded_grad` is not deallocated before this finishes.
Stream.current(self.device).record_for(self.params_unsharded_grad)
self.params_unsharded_grad = None
del new_sharded_grad

# At this point each param will be sharded again, and we set the grad for each param as a view
# into the sharded grad.
Expand Down
4 changes: 4 additions & 0 deletions src/olmo_core/distributed/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,10 @@ def _reduce_scatter_grads(self):
for handle in self.state.flat_param_handles:
handle.reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)

with self.state.post_backward_stream(wait_stream=self.state.reduce_stream):
for handle in self.state.flat_param_handles:
handle.post_reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)

def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None]:
count = 0
while prefetch_queue and count < self.max_prefetch_count:
Expand Down
Loading