Skip to content

Commit

Permalink
Document FSDP, add tests with act checkpointing
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed Apr 3, 2024
1 parent 1ea4d0b commit 767a803
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 25 deletions.
6 changes: 6 additions & 0 deletions docs/source/distributed/fsdp.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
``distributed.fsdp``
====================

.. automodule:: olmo_core.distributed.fsdp
:members: FSDP, FSDPPrecision
:member-order: bysource
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
utils.rst
io.rst
distributed/checkpoint.rst
distributed/fsdp.rst

.. toctree::
:hidden:
Expand Down
4 changes: 2 additions & 2 deletions src/olmo_core/distributed/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
Features
--------
- Sharded distributed models, such as PyTorch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel`
are supported out-of-the-box.
- Sharded distributed models, such OLMo-core's :class:`~olmo_core.distributed.fsdp.FSDP` or PyTorch's
:class:`~torch.distributed.fsdp.FullyShardedDataParallel` are supported out-of-the-box.
- Utilizes `safetensors <https://huggingface.co/docs/safetensors/>`_ under the hood for fast, efficient, and
safe serialization/deserialization.
- Save with one distributed topology, seamlessly load with a different one. For example,
Expand Down
12 changes: 12 additions & 0 deletions src/olmo_core/distributed/fsdp/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,15 @@
"""
This is a light-weight rewrite of PyTorch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel`
with a few of improvements, including:
- Well-defined "hands off" handling of buffers. FSDP never shards buffers, they are left as-is.
- Well-defined handling of frozen params. You can mix and match within an FSDP instance as long
as the frozen/non-frozen params are consistent across the process group.
API Reference
-------------
"""

from .fsdp import FSDP, FSDPDebugConfig, FSDPPrecision

__all__ = ["FSDP", "FSDPDebugConfig", "FSDPPrecision"]
61 changes: 38 additions & 23 deletions src/olmo_core/distributed/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@

@dataclass
class FSDPPrecision:
"""
Mixed precision settings for :class:`FSDP`.
"""

param_dtype: Optional[torch.dtype] = None
"""
The data type to cast full model parameters to during the forward and backward pass.
Expand Down Expand Up @@ -67,15 +71,7 @@ class FSDPDebugConfig:

class FSDP(Generic[M], nn.Module):
"""
This is a complete rewrite of PyTorch's ``FullyShardedDataParallel``, a ZeRO-3 model wrapper,
with a number of improvements such as:
* Well-defined "hands off" handling of buffers. FSDP never shards buffers, they are left as-is.
* Well-defined handling of frozen params. You can mix and match within an FSDP instance as long
as the frozen/non-frozen params are consistent across the process group.
* A much better checkpointing mechanism (:mod:`olmo_core.distributed.checkpoint`) which has virtually
no memory overhead, works seamlessly when restarting with a different distributed topology,
and works for both local and remote files.
FSDP, a.k.a. Fully Sharded Data Parallel, a ZeRO-3 model wrapper.
:param module: The module to wrap.
:param process_group: The distributed process group.
Expand Down Expand Up @@ -162,6 +158,9 @@ def module(self) -> M:
return self._fsdp_wrapped_module

def forward(self, *args, **kwargs):
"""
Run the forward pass on the wrapped module, gathering full parameters when necessary.
"""
self._lazy_init()

log.debug("Running forward pass for %s...", self.module.__class__.__name__)
Expand All @@ -188,11 +187,13 @@ def forward(self, *args, **kwargs):

if self.is_root:
# At the end of the first forward pass, execution order is now finalized, meaning
# we can use 'self.state.forward_execution_order' to start prefetching unshards.
# we can use 'self.state.forward_execution_order' to start prefetching unshards during
# the next forward pass.
self.state.forward_execution_order_finalized = True
assert not self.state.forward_prefetch_queue

if torch.is_grad_enabled():
# Prepare for backward pass.
if self.is_root and self.state.backward_execution_order_finalized:
# Fill backward-pass prefetch queue for unsharding.
for module in self.state.backward_execution_order:
Expand All @@ -209,18 +210,28 @@ def forward(self, *args, **kwargs):

def state_dict(self, *args, **kwargs):
"""
Return the state dict. The keys in the state dict will always correspond to the original keys
in the wrapped model.
Return the state dict.
.. seealso::
For saving and loading :class:`FSDP` checkpoints, see :mod:`olmo_core.distributed.checkpoint`.
The data in the state dict will be sharded flat data unless you're within the :meth:`summon_full_params()`
context or have gathered the full parameters another way.
.. tip::
The data in the state dict will be sharded flat data unless you're within the :meth:`summon_full_params()`
context or have gathered the full parameters another way.
.. tip::
The parameter names will be the original parameter names of the wrapped module, i.e.
without the :data:`WRAPPED_MODULE_PREFIX`.
"""
return self.module.state_dict(*args, **kwargs)

def load_state_dict(self, state_dict, *args, **kwargs):
"""
Load a state dict. The data in the state dict should correspond to the current state of the
FSDP wrapper, either sharded or unsharded.
.. seealso::
For saving and loading :class:`FSDP` checkpoints, see :mod:`olmo_core.distributed.checkpoint`.
"""
# Fix keys to include the right prefix.
key_mapping = self._get_key_mapping() # maps original key to wrapped key
Expand All @@ -230,7 +241,9 @@ def named_buffers(self, *args, **kwargs):
"""
Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself.
The buffer names will be the original buffer names.
.. tip::
The parameter names will be the original parameter names of the wrapped module, i.e.
without the :data:`WRAPPED_MODULE_PREFIX`.
"""
key_mapping = self._get_key_mapping(reverse=True)
for name, buffer in super().named_buffers(*args, **kwargs):
Expand All @@ -241,7 +254,9 @@ def named_parameters(self, *args, **kwargs):
Return an iterator over module parameters, yielding both the name of the parameter as well
as the parameter itself.
The parameter names will be the original parameter names.
.. tip::
The parameter names will be the original parameter names of the wrapped module, i.e.
without the :data:`WRAPPED_MODULE_PREFIX`.
"""
key_mapping = self._get_key_mapping(reverse=True)
for name, param in super().named_parameters(*args, **kwargs):
Expand Down Expand Up @@ -271,9 +286,9 @@ def apply(self, fn):
Typical use includes initializing the parameters of a model.
Compared to ``torch.nn.Module.apply``, this version additionally gathers the full parameters
Compared to :meth:`torch.nn.Module.apply`, this version additionally gathers the full parameters
for all sharded parameters that are *directly managed* but the given FSDP instance before applying ``fn``.
This should not be called from within another ``summon_full_params`` context.
This should not be called from within another :meth:`summon_full_params()` context.
"""
with self.summon_full_params(recurse=False, writeback=True, rank0_only=False):
ret = super().apply(fn)
Expand All @@ -283,7 +298,7 @@ def apply(self, fn):
@torch.no_grad()
def clip_grad_norm_(self, max_norm: float, norm_type: float = 2.0) -> torch.Tensor:
"""
Clip the gradient norm of all parameters.
Clip the gradient norm of all parameters, returning the norm prior to clipping.
The norm is computed over all parameters’ gradients as viewed as a single vector, and the
gradients are modified in-place.
Expand Down Expand Up @@ -399,16 +414,16 @@ def _get_key_mapping(self, reverse: bool = False, modules: bool = False) -> Dict
key_mapping: Dict[str, str] = {} # maps original key to wrapped key

def collect_key_mappings(module: nn.Module, og_prefix: str, wrapped_prefix: str):
if isinstance(module, FSDP):
wrapped_prefix = f"{wrapped_prefix}{self.WRAPPED_MODULE_PREFIX}."
module = module.module

if not modules:
for param_name, _ in chain(
module.named_parameters(recurse=False), module.named_buffers(recurse=False)
):
key_mapping[f"{og_prefix}{param_name}"] = f"{wrapped_prefix}{param_name}"

if isinstance(module, FSDP):
wrapped_prefix = f"{wrapped_prefix}{self.WRAPPED_MODULE_PREFIX}."
module = module.module

for child_name, child in module.named_children():
if modules:
key_mapping[og_prefix.strip(".")] = wrapped_prefix.strip(".")
Expand Down
149 changes: 149 additions & 0 deletions src/test/distributed/fsdp/fsdp_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import pytest
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -357,6 +359,153 @@ def test_nested_fsdp_api(backend, tiny_model_factory, tiny_model_data_factory):
)


def run_fsdp_with_frozen_params():
class Model(nn.Module):
def __init__(self):
super().__init__()
self.ff1 = nn.Linear(8, 8)
self.ff2 = nn.Linear(8, 8)
self.ff2.weight.requires_grad = False
self.ff2.bias.requires_grad = False

def forward(self, x):
return self.ff2(self.ff1(x))

fsdp = FSDP(Model())

# Run forward pass
loss = fsdp(torch.rand(2, 8, device=fsdp.device)).sum()

# Trigger backward pass.
loss.backward()
assert fsdp.module.ff1.weight.grad is not None
assert fsdp.module.ff1.bias.grad is not None


@pytest.mark.parametrize("backend", BACKENDS)
def test_fsdp_with_frozen_params(backend):
run_distributed_test(
run_fsdp_with_frozen_params,
backend=backend,
start_method="spawn",
)


def run_fsdp_with_node_activation_checkpointing():
from torch.utils.checkpoint import checkpoint

checkpoint_fn = partial(checkpoint, use_reentrant=False)

class Model(nn.Module):
def __init__(self, do_activation_checkpointing: bool = True):
super().__init__()
self.ff1 = FSDP(nn.Linear(4, 8))
self.ff2 = FSDP(nn.Linear(8, 8))
self.ff3 = FSDP(nn.Linear(8, 4))
self.do_activation_checkpointing = do_activation_checkpointing

def forward(self, x):
x = self.ff1(x)
if self.do_activation_checkpointing:
x = checkpoint_fn(self.ff2, x)
else:
x = self.ff2(x)
x = self.ff3(x)
return x

fsdp_ckpt = FSDP(Model(do_activation_checkpointing=True))
fsdp = FSDP(Model(do_activation_checkpointing=True))

# Synchronize weights.
fsdp.load_state_dict(fsdp_ckpt.state_dict())

# Run forward pass
inputs = torch.rand(2, 4, device=fsdp.device)
loss_ckpt = fsdp_ckpt(inputs).sum()
loss = fsdp(inputs).sum()
torch.testing.assert_close(loss_ckpt, loss)

# Run backward pass.
loss_ckpt.backward()
loss.backward()
for p1, p2 in zip(fsdp_ckpt.parameters(), fsdp.parameters()):
assert p1.grad is not None
assert p2.grad is not None
assert p1.grad.shape == p2.grad.shape
torch.testing.assert_close(p1.grad, p2.grad)


@pytest.mark.parametrize("backend", BACKENDS)
def test_fsdp_with_node_activation_checkpointing(backend):
run_distributed_test(
run_fsdp_with_node_activation_checkpointing,
backend=backend,
start_method="spawn",
)


def run_fsdp_with_intra_node_activation_checkpointing():
from torch.utils.checkpoint import checkpoint

checkpoint_fn = partial(checkpoint, use_reentrant=False)

class SubModel(nn.Module):
def __init__(self, do_activation_checkpointing: bool = True):
super().__init__()
self.ff1 = nn.Linear(8, 4)
self.ff2 = nn.Linear(4, 8)
self.do_activation_checkpointing = do_activation_checkpointing

def forward(self, x):
x = self.ff1(x)
if self.do_activation_checkpointing:
x = checkpoint_fn(self.ff2, x)
else:
x = self.ff2(x)
return x

class Model(nn.Module):
def __init__(self, do_activation_checkpointing: bool = True):
super().__init__()
self.ff1 = nn.Linear(4, 8)
self.ff2 = FSDP(SubModel(do_activation_checkpointing=do_activation_checkpointing))

def forward(self, x):
x = self.ff1(x)
x = self.ff2(x)
return x

fsdp_ckpt = FSDP(Model(do_activation_checkpointing=True))
fsdp = FSDP(Model(do_activation_checkpointing=True))

# Synchronize weights.
fsdp.load_state_dict(fsdp_ckpt.state_dict())

# Run forward pass
inputs = torch.rand(2, 4, device=fsdp.device)
loss_ckpt = fsdp_ckpt(inputs).sum()
loss = fsdp(inputs).sum()
torch.testing.assert_close(loss_ckpt, loss)

# Run backward pass.
loss_ckpt.backward()
loss.backward()
for p1, p2 in zip(fsdp_ckpt.parameters(), fsdp.parameters()):
assert p1.grad is not None
assert p2.grad is not None
assert p1.grad.shape == p2.grad.shape
torch.testing.assert_close(p1.grad, p2.grad)


@pytest.mark.parametrize("backend", BACKENDS)
def test_fsdp_with_intra_node_activation_checkpointing(backend):
run_distributed_test(
run_fsdp_with_intra_node_activation_checkpointing,
backend=backend,
start_method="spawn",
)


def run_fsdp_with_mixed_precision(model_factory, model_data_factory, precision):
fsdp = FSDP(model_factory(), precision=precision)

Expand Down

0 comments on commit 767a803

Please sign in to comment.