Skip to content

Commit

Permalink
More improvements to FSDP, benchmark against DDP (#13)
Browse files Browse the repository at this point in the history
* leave root params in mem

* fix logic

* don't show mem usage all the time

* fix

* make configurable

* add alloc/free for unsharded data

* add alloc/free for unsharded grad

* fix

* record for

* revert 69d74c4 - alloc/free for unsharded grad

* revert alloc/free trick for unsharded params data

* add support for DDP in benchmark

* set device ids explicitly

* fix

* change up how weights are initialized

* fix test

* Add back alloc/free hack for unsharded data

* Revert "Add back alloc/free hack for unsharded data"

This reverts commit 0386841.

* Handle frozen layers with reshard-only post-backward hook

* Revert "Handle frozen layers with reshard-only post-backward hook"

This reverts commit 0f408d2.

* add to test

* add to test

* Fixes for frozen modules

* Divide grad before and after reducing for stability

* Add support for hybrid sharding

* make grad clipping optional

* clean up

* calculate grad norm more efficiently

* Revert "calculate grad norm more efficiently"

This reverts commit d66a683.

* fix
  • Loading branch information
epwalsh authored May 2, 2024
1 parent a4e0ccf commit 1847b8e
Show file tree
Hide file tree
Showing 8 changed files with 280 additions and 39 deletions.
2 changes: 1 addition & 1 deletion docs/source/distributed/fsdp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
====================

.. automodule:: olmo_core.distributed.fsdp
:members: FSDP, FSDPPrecision
:members: FSDP, FSDPPrecision, FSDPShardingStrategy
:member-order: bysource
7 changes: 6 additions & 1 deletion src/benchmarks/fsdp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def build_components(
config: TransformerConfig,
batch_size: int,
num_batches: int = 100,
fsdp_wrapper: Literal["torch", "olmo_core"] = "olmo_core",
fsdp_wrapper: Literal["torch", "olmo_core", "ddp"] = "olmo_core",
wrap_blocks: bool = True,
mixed_precision: bool = True,
max_prefetch_count: int = 1,
Expand Down Expand Up @@ -204,6 +204,11 @@ def auto_wrap_policy(module: nn.Module, recurse: bool, *args, **kwargs) -> bool:
)

model.apply(init_function) # just in case
elif fsdp_wrapper == "ddp":
from torch.nn.parallel import DistributedDataParallel as DDP

model = DDP(model.cuda(), device_ids=[dist.get_rank()])
model.apply(init_function)
else:
raise NotImplementedError(fsdp_wrapper)

Expand Down
24 changes: 19 additions & 5 deletions src/benchmarks/fsdp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
import torch.distributed as dist
from torch.nn.utils import clip_grad_norm_

from olmo_core.distributed.checkpoint import (
load_model_and_optim_state,
Expand All @@ -28,13 +29,14 @@ def main(
config: TransformerConfig,
batch_size: int,
num_batches: int = 100,
fsdp_wrapper: Literal["torch", "olmo_core"] = "olmo_core",
fsdp_wrapper: Literal["torch", "olmo_core", "ddp"] = "olmo_core",
dry_run: bool = False,
save_path: Optional[str] = None,
load_path: Optional[str] = None,
mixed_precision: bool = True,
profile: bool = False,
trace_output: str = "/tmp/traces/olmo_core.chrome_trace.json.gz",
max_grad_norm: Optional[float] = None,
**kwargs,
):
model, optim, dataloader = build_components(
Expand Down Expand Up @@ -98,21 +100,28 @@ def on_trace_ready(p):
loss.backward()

# Clip gradient norms.
model.clip_grad_norm_(1.0)
norm: Optional[torch.Tensor] = None
if max_grad_norm is not None:
if hasattr(model, "clip_grad_norm_"):
norm = model.clip_grad_norm_(max_grad_norm)
else:
norm = clip_grad_norm_(model.parameters(), max_grad_norm)

# Take optimizer step.
optim.step()

batch_time = time.monotonic() - batch_start
if i > 0:
batch_times.append(batch_time)
norm_str = f"{norm.item():.3f}" if norm is not None else "n/a"
print_rank0(
f"Batch [{i+1}/{num_batches}]:\n"
f" loss={loss.item():.3f}\n"
f" throughput/seconds_per_batch={batch_time:.3f}",
f" throughput/seconds_per_batch={batch_time:.3f}\n"
f" grad/total_norm={norm_str}"
)

if i == 2:
if profile and i == 2:
print_rank0(torch.cuda.memory_summary())

if p is not None:
Expand All @@ -134,7 +143,7 @@ def on_trace_ready(p):
parser = argparse.ArgumentParser(prog="train.py", description="Train an FSDP model")
parser.add_argument(
"--fsdp",
choices=["torch", "olmo_core"],
choices=["torch", "olmo_core", "ddp"],
default="olmo_core",
help="""The FSDP implementation.""",
)
Expand Down Expand Up @@ -190,6 +199,10 @@ def on_trace_ready(p):
type=int,
default=1,
)
parser.add_argument(
"--max-grad-norm",
type=float,
)
parser.add_argument(
"--lr",
type=float,
Expand Down Expand Up @@ -237,5 +250,6 @@ def on_trace_ready(p):
mixed_precision=mixed_precision,
max_prefetch_count=args.max_prefetch_count,
learning_rate=args.lr,
max_grad_norm=args.max_grad_norm,
seed=args.seed,
)
4 changes: 2 additions & 2 deletions src/olmo_core/distributed/fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,6 @@
-------------
"""

from .fsdp import FSDP, FSDPDebugConfig, FSDPPrecision
from .fsdp import FSDP, FSDPDebugConfig, FSDPPrecision, FSDPShardingStrategy

__all__ = ["FSDP", "FSDPDebugConfig", "FSDPPrecision"]
__all__ = ["FSDP", "FSDPDebugConfig", "FSDPPrecision", "FSDPShardingStrategy"]
48 changes: 43 additions & 5 deletions src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
ShardedFlatTensor,
ShardingSpec,
)
from olmo_core.distributed.utils import get_rank, get_world_size
from olmo_core.distributed.utils import (
get_gradient_divide_factor,
get_rank,
get_world_size,
)
from olmo_core.stream import Stream
from olmo_core.utils import get_default_device

Expand Down Expand Up @@ -62,21 +66,41 @@ class FlatParamHandle:
"""

process_group: Optional[dist.ProcessGroup] = None
"""
Process group containing all shards.
"""

inter_group_process_group: Optional[dist.ProcessGroup] = None
"""
Process group for between-group reductions with hybrid sharding.
"""

device: Optional[torch.device] = None

requires_grad: bool = True

pre_reduce_grad_divide_factor: float = 1.0

post_reduce_grad_divide_factor: float = 1.0

_ran_pre_unshard: bool = False

_ran_pre_reduce_scatter_grads: bool = False

def __post_init__(self):
data_parallel_world_size = get_world_size(self.process_group)
if self.inter_group_process_group is not None:
data_parallel_world_size *= self.inter_group_process_group.size()
self.pre_reduce_grad_divide_factor = get_gradient_divide_factor(data_parallel_world_size)
self.post_reduce_grad_divide_factor = data_parallel_world_size / self.pre_reduce_grad_divide_factor

@classmethod
def shard_params(
cls,
params: Iterable[nn.Parameter],
param_fqns: Iterable[str],
process_group: Optional[dist.ProcessGroup] = None,
inter_group_process_group: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> FlatParamHandle:
"""
Expand Down Expand Up @@ -183,6 +207,7 @@ def shard_params(
)
else:
flat_param = ShardedFlatParameter(torch.empty(0, device=device))
flat_param.requires_grad = param.requires_grad
flat_param.mark_as_sharded(sharding_spec, process_group=process_group)

flat_params.append(flat_param)
Expand Down Expand Up @@ -224,6 +249,7 @@ def shard_params(
param_fqns=list(param_fqns),
params_data=params_data,
process_group=process_group,
inter_group_process_group=inter_group_process_group,
device=device,
requires_grad=requires_grad,
)
Expand Down Expand Up @@ -253,7 +279,7 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F
self.params_sharded_data_lp.copy_(self.params_data.sharded_data)

# Initialize unsharded, padded gradient.
if set_grads and self.params_unsharded_grad is None:
if set_grads and self.requires_grad and self.params_unsharded_grad is None:
self.params_unsharded_grad = torch.zeros_like(all_params_unsharded_data)

def unshard_(
Expand Down Expand Up @@ -288,7 +314,7 @@ def unshard_(
if rank0_only or dist.get_backend() == dist.Backend.GLOO:
assert self.params_data.is_sharded
self.params_data.unshard_(dtype=dtype, rank0_only=rank0_only)
if set_grads:
if set_grads and self.requires_grad:
self.params_unsharded_grad = torch.zeros_like(self.params_data)
else:
assert not self.params_data.is_sharded
Expand Down Expand Up @@ -318,7 +344,7 @@ def unshard_(

param.unshard_(unsharded_data, dtype=dtype, rank0_only=rank0_only)

if set_grads:
if set_grads and self.requires_grad:
if param.grad is None and self.params_sharded_grad is not None:
self.params_sharded_grad = None
assert self.params_unsharded_grad is not None
Expand Down Expand Up @@ -360,6 +386,9 @@ def pre_reduce_scatter_grads_(
Stream.current(self.device).record_for(self.params_unsharded_grad)
self.params_unsharded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype)

if self.pre_reduce_grad_divide_factor > 1.0:
self.params_unsharded_grad.div_(self.pre_reduce_grad_divide_factor)

def reduce_scatter_grads_(
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
):
Expand All @@ -368,6 +397,7 @@ def reduce_scatter_grads_(
parameter as a view into the new sharded grad.
"""
if not self.requires_grad or self.params_unsharded_grad is None:
self._ran_pre_reduce_scatter_grads = False
return

if not self._ran_pre_reduce_scatter_grads:
Expand Down Expand Up @@ -398,12 +428,20 @@ def post_reduce_scatter_grads_(
"""
Finalize sharded gradients after the reduce-scatter.
"""
if not self.requires_grad or self.params_unsharded_grad is None:
return

grad_dtype = grad_dtype or self.params_data.dtype
grad_reduce_dtype = grad_reduce_dtype or grad_dtype

assert self.params_unsharded_grad is not None
new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad)

if self.inter_group_process_group is not None:
dist.all_reduce(new_sharded_grad, group=self.inter_group_process_group)

if self.post_reduce_grad_divide_factor > 1.0:
new_sharded_grad.div_(self.post_reduce_grad_divide_factor)

# Cast the new sharded gradient to the right dtype, potentially accumulating it into
# the existing sharded gradient.
if self.params_sharded_grad is None:
Expand Down
Loading

0 comments on commit 1847b8e

Please sign in to comment.