Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More FSDP optimizations #10

Merged
merged 9 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 58 additions & 22 deletions src/benchmarks/fsdp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"""

import argparse
import contextlib
import logging
import os
import time
from pathlib import Path
from typing import Literal, Optional
Expand All @@ -32,6 +32,8 @@ def main(
save_path: Optional[str] = None,
load_path: Optional[str] = None,
mixed_precision: bool = True,
profile: bool = False,
trace_output: str = "/tmp/traces/olmo_core.chrome_trace.json.gz",
**kwargs,
):
model, optim, dataloader = build_components(
Expand All @@ -56,33 +58,56 @@ def main(
print_rank0(f"Saving checkpoint to {checkpoint_dir}...")
save_model_and_optim_state(checkpoint_dir, model, optim)

profiler = contextlib.nullcontext()
if profile:
from torch.profiler import ProfilerActivity, schedule

def on_trace_ready(p):
trace_path = Path(trace_output).expanduser()
trace_path.parent.mkdir(exist_ok=True, parents=True)
p.export_chrome_trace(str(trace_path))
print_rank0(f"Tracing complete, saved to '{trace_path}'")

profiler = torch.profiler.profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=False,
profile_memory=False,
with_stack=True,
schedule=schedule(wait=1, warmup=5, active=3, repeat=1),
on_trace_ready=on_trace_ready,
)

print_rank0("Starting training...")
for i, batch in enumerate(iter(dataloader)):
log.debug("Batch: %s", batch)
batch_start = time.monotonic()
with profiler as p:
for i, batch in enumerate(iter(dataloader)):
log.debug("Batch: %s", batch)
batch_start = time.monotonic()

# Zero-gradients.
optim.zero_grad()
# Zero-gradients.
optim.zero_grad()

# Run forward pass.
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision):
loss = compute_loss(model, batch)
# Run forward pass.
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=mixed_precision):
loss = compute_loss(model, batch)

# Trigger backward pass.
loss.backward()
# Trigger backward pass.
loss.backward()

# Clip gradient norms.
model.clip_grad_norm_(1.0)
# Clip gradient norms.
model.clip_grad_norm_(1.0)

# Take optimizer step.
optim.step()
# Take optimizer step.
optim.step()

batch_end = time.monotonic()
print_rank0(
f"Batch [{i+1}/{num_batches}]:\n"
f" loss={loss.item():.3f}\n"
f" throughput/seconds_per_batch={batch_end-batch_start:.3f}",
)
batch_end = time.monotonic()
print_rank0(
f"Batch [{i+1}/{num_batches}]:\n"
f" loss={loss.item():.3f}\n"
f" throughput/seconds_per_batch={batch_end-batch_start:.3f}",
)

if p is not None:
p.step()

if save_path is not None:
checkpoint_dir = Path(save_path) / "final"
Expand Down Expand Up @@ -126,6 +151,15 @@ def main(
"--debug",
action="store_true",
)
parser.add_argument(
"--profile",
action="store_true",
)
parser.add_argument(
"--trace-output",
type=str,
default="/tmp/traces/olmo_core.chrome_trace.json.gz",
)
parser.add_argument(
"--save-path",
type=str,
Expand Down Expand Up @@ -168,7 +202,7 @@ def main(
raise NotImplementedError(args.model_size)

if args.debug:
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
config.debug = True

dist.init_process_group(backend="nccl")
Expand All @@ -185,6 +219,8 @@ def main(
dry_run=args.dry_run,
save_path=args.save_path,
load_path=args.load_path,
profile=args.profile,
trace_output=args.trace_output,
mixed_precision=mixed_precision,
max_prefetch_count=args.max_prefetch_count,
learning_rate=args.lr,
Expand Down
81 changes: 57 additions & 24 deletions src/olmo_core/distributed/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.autograd import Variable

from olmo_core.distributed.tensors import ShardedFlatParameter
from olmo_core.stream import Stream
Expand Down Expand Up @@ -322,7 +323,7 @@ def clip_grad_norm_(self, max_norm: float, norm_type: float = 2.0) -> torch.Tens
nonsharded_params: Set[nn.Parameter] = set()
grads: List[torch.Tensor] = []
for param in self.parameters():
if param.grad is None:
if param.grad is None or param.grad.numel() == 0:
continue

if isinstance(param, ShardedFlatParameter):
Expand Down Expand Up @@ -394,7 +395,11 @@ def _lazy_init(self):
self.state.forward_execution_order.append(self)
return

log.debug("Completing lazy initialization from root FSDP for %s...", self.module.__class__.__name__)
log.debug(
"Completing lazy initialization from root FSDP for %s (%s)...",
self.module.__class__.__name__,
id(self.module),
)

# Initialize streams.
self.state.compute_stream = Stream.default(self.device)
Expand Down Expand Up @@ -494,7 +499,7 @@ def _shard(self):

This should only be called once at initialization.
"""
log.debug("Sharding %s...", self.module.__class__.__name__)
log.debug("Sharding %s (%s)...", self.module.__class__.__name__, id(self.module))

params_with_grads: List[nn.Parameter] = []
params_with_grads_fqns: List[str] = []
Expand Down Expand Up @@ -568,7 +573,7 @@ def _unshard(

kwargs = dict(cast=cast, set_grads=set_grads, recurse=recurse, rank0_only=rank0_only)

log.debug("Unsharding %s...", self.module.__class__.__name__)
log.debug("Unsharding %s (%s)...", self.module.__class__.__name__, id(self.module))
self.state.params_prefetched = True

# NOTE: `unshard_stream` should wait on current stream (usually `compute_stream` / `default_stream`)
Expand Down Expand Up @@ -600,7 +605,11 @@ def _unshard(
def _prefetch(self, prefetch_from: deque[FSDP], **kwargs):
for module in self._deque_from(prefetch_from):
log.debug(
"Prefetching %s from %s...", module.module.__class__.__name__, self.module.__class__.__name__
"Prefetching %s (%s) from %s (%s)...",
module.module.__class__.__name__,
id(module.module),
self.module.__class__.__name__,
id(self.module),
)
module._unshard(**kwargs)

Expand All @@ -611,7 +620,7 @@ def _reshard(self, writeback: bool = False, recurse: bool = False):
"""
kwargs = dict(writeback=writeback, recurse=recurse)

log.debug("Resharding %s...", self.module.__class__.__name__)
log.debug("Resharding %s (%s)...", self.module.__class__.__name__, id(self.module))
self.state.params_prefetched = False

for handle in self.state.flat_param_handles:
Expand All @@ -637,7 +646,7 @@ def _reduce_scatter_grads(self):

grad_reduce_dtype: Optional[torch.dtype] = self.precision.reduce_dtype or self.precision.param_dtype
with self.state.reduce_stream(wait_stream=self.state.current_stream):
log.debug("Reduce-scattering grads for %s", self.module.__class__.__name__)
log.debug("Reduce-scattering grads for %s (%s)", self.module.__class__.__name__, id(self.module))
for handle in self.state.flat_param_handles:
handle.reduce_scatter_grads_(grad_reduce_dtype=grad_reduce_dtype)

Expand All @@ -659,13 +668,16 @@ def _deque_from(self, prefetch_queue: deque[FSDP]) -> Generator[FSDP, None, None
@torch.no_grad()
def _pre_backward_hook(self, *unused: Any):
del unused
log.debug("Running pre-backward hook for %s...", self.module.__class__.__name__)
log.debug("Running pre-backward hook for %s (%s)...", self.module.__class__.__name__, id(self.module))

# Remove all pre backward hooks for this FSDP instance since they all do the same thing.
for handle in self.state.pre_backward_hook_handles:
handle.remove()
self.state.pre_backward_hook_handles.clear()

if self.is_root:
self._register_post_backward_final_hook()

# Unshard parameters in place.
self._unshard(set_grads=True)

Expand All @@ -684,10 +696,12 @@ def _register_pre_backward_hook(self, x: torch.Tensor):
self.state.pre_backward_hook_handles.append(handle)

def _register_pre_backward_hooks(self, output: Any):
log.debug("Registering pre-backward hooks for %s...", self.module.__class__.__name__)
log.debug("Registering pre-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module))
# Clear existing hooks if there are any.
if self.state.pre_backward_hook_handles:
log.debug("Removing old pre-backward hooks for %s...", self.module.__class__.__name__)
log.debug(
"Removing old pre-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)
)
for handle in self.state.pre_backward_hook_handles:
handle.remove()
self.state.pre_backward_hook_handles.clear()
Expand All @@ -699,29 +713,19 @@ def _register_pre_backward_hooks(self, output: Any):
@torch.no_grad()
def _post_backward_hook(self, param_name: str, *unused: Any):
del unused
log.debug("Running post-backward hook for %s.%s...", self.module.__class__.__name__, param_name)
self.state.post_backward_hook_handles.pop(param_name).remove()

# If there are still more handles then there are still more post-backward hooks to be ran
# in the current FSDP node. Only the last handle should do the work.
if self.state.post_backward_hook_handles:
return

log.debug("Running post-backward hook for %s (%s)", self.module.__class__.__name__, id(self.module))

# NOTE: reshard *before* reducing grads to correctly handle precision settings.
self._reshard()
self._reduce_scatter_grads()

# The root FSDP instance needs to do some final cleanup.
if not self.is_root:
return

# Mark backward execution order as finalized.
self.state.backward_execution_order_finalized = True

# Wait for unsharding and reducing streams to complete so the model is not left in a bad
# state before grad clipping, optimizer step, or whatever else.
self.state.current_stream.wait_stream(self.state.reduce_stream)

def _register_post_backward_hook(self, param_name: str, param: ShardedFlatParameter):
# Force creation of a `grad_fn` in order to register a hook that will run *after* this param's
# backward pass.
Expand All @@ -733,13 +737,42 @@ def _register_post_backward_hook(self, param_name: str, param: ShardedFlatParame
self.state.post_backward_hook_handles[param_name] = handle

def _register_post_backward_hooks(self):
log.debug("Registering post-backward hooks for %s...", self.module.__class__.__name__)
log.debug(
"Registering post-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)
)
# Clear existing hooks if there are any.
if self.state.post_backward_hook_handles:
log.debug("Removing old post-backward hooks for %s...", self.module.__class__.__name__)
log.debug(
"Removing old post-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)
)
for handle in self.state.post_backward_hook_handles.values():
handle.remove()
self.state.post_backward_hook_handles.clear()
for param_name, param in self._managed_named_parameters():
if param.requires_grad:
self._register_post_backward_hook(param_name, param)

@torch.no_grad()
def _post_backward_final_hook(self):
if not self.is_root:
return

log.debug("Running post-backward final hook for %s (%s)", self.module.__class__.__name__, id(self.module))

# Mark backward execution order as finalized.
self.state.backward_execution_order_finalized = True
for child in self._fsdp_children(recurse=True):
child.state.backward_execution_order_finalized = True

# Wait for unsharding and reducing streams to complete so the model is not left in a bad
# state before grad clipping, optimizer step, or whatever else.
self.state.current_stream.wait_stream(self.state.reduce_stream)

def _register_post_backward_final_hook(self):
if not self.is_root:
return

log.debug(
"Registering post-backward final hook for %s (%s)...", self.module.__class__.__name__, id(self.module)
)
Variable._execution_engine.queue_callback(self._post_backward_final_hook)
Loading