diff --git a/docs/source/distributed/fsdp.rst b/docs/source/distributed/fsdp.rst new file mode 100644 index 00000000..6ad9e2d7 --- /dev/null +++ b/docs/source/distributed/fsdp.rst @@ -0,0 +1,6 @@ +``distributed.fsdp`` +==================== + +.. automodule:: olmo_core.distributed.fsdp + :members: FSDP, FSDPPrecision + :member-order: bysource diff --git a/docs/source/index.rst b/docs/source/index.rst index 7f0a2a00..fecda678 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -17,6 +17,7 @@ utils.rst io.rst distributed/checkpoint.rst + distributed/fsdp.rst .. toctree:: :hidden: diff --git a/src/olmo_core/distributed/checkpoint.py b/src/olmo_core/distributed/checkpoint.py index 3aca67d4..7d5380d4 100644 --- a/src/olmo_core/distributed/checkpoint.py +++ b/src/olmo_core/distributed/checkpoint.py @@ -7,8 +7,8 @@ Features -------- -- Sharded distributed models, such as PyTorch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel` - are supported out-of-the-box. +- Sharded distributed models, such OLMo-core's :class:`~olmo_core.distributed.fsdp.FSDP` or PyTorch's + :class:`~torch.distributed.fsdp.FullyShardedDataParallel` are supported out-of-the-box. - Utilizes `safetensors `_ under the hood for fast, efficient, and safe serialization/deserialization. - Save with one distributed topology, seamlessly load with a different one. For example, diff --git a/src/olmo_core/distributed/fsdp/__init__.py b/src/olmo_core/distributed/fsdp/__init__.py index ec88d260..0bfb0720 100644 --- a/src/olmo_core/distributed/fsdp/__init__.py +++ b/src/olmo_core/distributed/fsdp/__init__.py @@ -1,3 +1,15 @@ +""" +This is a light-weight rewrite of PyTorch's :class:`~torch.distributed.fsdp.FullyShardedDataParallel` +with a few of improvements, including: + +- Well-defined "hands off" handling of buffers. FSDP never shards buffers, they are left as-is. +- Well-defined handling of frozen params. You can mix and match within an FSDP instance as long + as the frozen/non-frozen params are consistent across the process group. + +API Reference +------------- +""" + from .fsdp import FSDP, FSDPDebugConfig, FSDPPrecision __all__ = ["FSDP", "FSDPDebugConfig", "FSDPPrecision"] diff --git a/src/olmo_core/distributed/fsdp/fsdp.py b/src/olmo_core/distributed/fsdp/fsdp.py index 1fafc127..6da8b935 100644 --- a/src/olmo_core/distributed/fsdp/fsdp.py +++ b/src/olmo_core/distributed/fsdp/fsdp.py @@ -38,6 +38,10 @@ @dataclass class FSDPPrecision: + """ + Mixed precision settings for :class:`FSDP`. + """ + param_dtype: Optional[torch.dtype] = None """ The data type to cast full model parameters to during the forward and backward pass. @@ -67,15 +71,7 @@ class FSDPDebugConfig: class FSDP(Generic[M], nn.Module): """ - This is a complete rewrite of PyTorch's ``FullyShardedDataParallel``, a ZeRO-3 model wrapper, - with a number of improvements such as: - - * Well-defined "hands off" handling of buffers. FSDP never shards buffers, they are left as-is. - * Well-defined handling of frozen params. You can mix and match within an FSDP instance as long - as the frozen/non-frozen params are consistent across the process group. - * A much better checkpointing mechanism (:mod:`olmo_core.distributed.checkpoint`) which has virtually - no memory overhead, works seamlessly when restarting with a different distributed topology, - and works for both local and remote files. + FSDP, a.k.a. Fully Sharded Data Parallel, a ZeRO-3 model wrapper. :param module: The module to wrap. :param process_group: The distributed process group. @@ -162,6 +158,9 @@ def module(self) -> M: return self._fsdp_wrapped_module def forward(self, *args, **kwargs): + """ + Run the forward pass on the wrapped module, gathering full parameters when necessary. + """ self._lazy_init() log.debug("Running forward pass for %s...", self.module.__class__.__name__) @@ -188,11 +187,13 @@ def forward(self, *args, **kwargs): if self.is_root: # At the end of the first forward pass, execution order is now finalized, meaning - # we can use 'self.state.forward_execution_order' to start prefetching unshards. + # we can use 'self.state.forward_execution_order' to start prefetching unshards during + # the next forward pass. self.state.forward_execution_order_finalized = True assert not self.state.forward_prefetch_queue if torch.is_grad_enabled(): + # Prepare for backward pass. if self.is_root and self.state.backward_execution_order_finalized: # Fill backward-pass prefetch queue for unsharding. for module in self.state.backward_execution_order: @@ -209,11 +210,18 @@ def forward(self, *args, **kwargs): def state_dict(self, *args, **kwargs): """ - Return the state dict. The keys in the state dict will always correspond to the original keys - in the wrapped model. + Return the state dict. + + .. seealso:: + For saving and loading :class:`FSDP` checkpoints, see :mod:`olmo_core.distributed.checkpoint`. - The data in the state dict will be sharded flat data unless you're within the :meth:`summon_full_params()` - context or have gathered the full parameters another way. + .. tip:: + The data in the state dict will be sharded flat data unless you're within the :meth:`summon_full_params()` + context or have gathered the full parameters another way. + + .. tip:: + The parameter names will be the original parameter names of the wrapped module, i.e. + without the :data:`WRAPPED_MODULE_PREFIX`. """ return self.module.state_dict(*args, **kwargs) @@ -221,6 +229,9 @@ def load_state_dict(self, state_dict, *args, **kwargs): """ Load a state dict. The data in the state dict should correspond to the current state of the FSDP wrapper, either sharded or unsharded. + + .. seealso:: + For saving and loading :class:`FSDP` checkpoints, see :mod:`olmo_core.distributed.checkpoint`. """ # Fix keys to include the right prefix. key_mapping = self._get_key_mapping() # maps original key to wrapped key @@ -230,7 +241,9 @@ def named_buffers(self, *args, **kwargs): """ Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself. - The buffer names will be the original buffer names. + .. tip:: + The parameter names will be the original parameter names of the wrapped module, i.e. + without the :data:`WRAPPED_MODULE_PREFIX`. """ key_mapping = self._get_key_mapping(reverse=True) for name, buffer in super().named_buffers(*args, **kwargs): @@ -241,7 +254,9 @@ def named_parameters(self, *args, **kwargs): Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself. - The parameter names will be the original parameter names. + .. tip:: + The parameter names will be the original parameter names of the wrapped module, i.e. + without the :data:`WRAPPED_MODULE_PREFIX`. """ key_mapping = self._get_key_mapping(reverse=True) for name, param in super().named_parameters(*args, **kwargs): @@ -271,9 +286,9 @@ def apply(self, fn): Typical use includes initializing the parameters of a model. - Compared to ``torch.nn.Module.apply``, this version additionally gathers the full parameters + Compared to :meth:`torch.nn.Module.apply`, this version additionally gathers the full parameters for all sharded parameters that are *directly managed* but the given FSDP instance before applying ``fn``. - This should not be called from within another ``summon_full_params`` context. + This should not be called from within another :meth:`summon_full_params()` context. """ with self.summon_full_params(recurse=False, writeback=True, rank0_only=False): ret = super().apply(fn) @@ -283,7 +298,7 @@ def apply(self, fn): @torch.no_grad() def clip_grad_norm_(self, max_norm: float, norm_type: float = 2.0) -> torch.Tensor: """ - Clip the gradient norm of all parameters. + Clip the gradient norm of all parameters, returning the norm prior to clipping. The norm is computed over all parameters’ gradients as viewed as a single vector, and the gradients are modified in-place. @@ -399,16 +414,16 @@ def _get_key_mapping(self, reverse: bool = False, modules: bool = False) -> Dict key_mapping: Dict[str, str] = {} # maps original key to wrapped key def collect_key_mappings(module: nn.Module, og_prefix: str, wrapped_prefix: str): + if isinstance(module, FSDP): + wrapped_prefix = f"{wrapped_prefix}{self.WRAPPED_MODULE_PREFIX}." + module = module.module + if not modules: for param_name, _ in chain( module.named_parameters(recurse=False), module.named_buffers(recurse=False) ): key_mapping[f"{og_prefix}{param_name}"] = f"{wrapped_prefix}{param_name}" - if isinstance(module, FSDP): - wrapped_prefix = f"{wrapped_prefix}{self.WRAPPED_MODULE_PREFIX}." - module = module.module - for child_name, child in module.named_children(): if modules: key_mapping[og_prefix.strip(".")] = wrapped_prefix.strip(".") diff --git a/src/test/distributed/fsdp/fsdp_test.py b/src/test/distributed/fsdp/fsdp_test.py index ec0b7736..07a3f70b 100644 --- a/src/test/distributed/fsdp/fsdp_test.py +++ b/src/test/distributed/fsdp/fsdp_test.py @@ -1,3 +1,5 @@ +from functools import partial + import pytest import torch import torch.distributed as dist @@ -357,6 +359,153 @@ def test_nested_fsdp_api(backend, tiny_model_factory, tiny_model_data_factory): ) +def run_fsdp_with_frozen_params(): + class Model(nn.Module): + def __init__(self): + super().__init__() + self.ff1 = nn.Linear(8, 8) + self.ff2 = nn.Linear(8, 8) + self.ff2.weight.requires_grad = False + self.ff2.bias.requires_grad = False + + def forward(self, x): + return self.ff2(self.ff1(x)) + + fsdp = FSDP(Model()) + + # Run forward pass + loss = fsdp(torch.rand(2, 8, device=fsdp.device)).sum() + + # Trigger backward pass. + loss.backward() + assert fsdp.module.ff1.weight.grad is not None + assert fsdp.module.ff1.bias.grad is not None + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_fsdp_with_frozen_params(backend): + run_distributed_test( + run_fsdp_with_frozen_params, + backend=backend, + start_method="spawn", + ) + + +def run_fsdp_with_node_activation_checkpointing(): + from torch.utils.checkpoint import checkpoint + + checkpoint_fn = partial(checkpoint, use_reentrant=False) + + class Model(nn.Module): + def __init__(self, do_activation_checkpointing: bool = True): + super().__init__() + self.ff1 = FSDP(nn.Linear(4, 8)) + self.ff2 = FSDP(nn.Linear(8, 8)) + self.ff3 = FSDP(nn.Linear(8, 4)) + self.do_activation_checkpointing = do_activation_checkpointing + + def forward(self, x): + x = self.ff1(x) + if self.do_activation_checkpointing: + x = checkpoint_fn(self.ff2, x) + else: + x = self.ff2(x) + x = self.ff3(x) + return x + + fsdp_ckpt = FSDP(Model(do_activation_checkpointing=True)) + fsdp = FSDP(Model(do_activation_checkpointing=True)) + + # Synchronize weights. + fsdp.load_state_dict(fsdp_ckpt.state_dict()) + + # Run forward pass + inputs = torch.rand(2, 4, device=fsdp.device) + loss_ckpt = fsdp_ckpt(inputs).sum() + loss = fsdp(inputs).sum() + torch.testing.assert_close(loss_ckpt, loss) + + # Run backward pass. + loss_ckpt.backward() + loss.backward() + for p1, p2 in zip(fsdp_ckpt.parameters(), fsdp.parameters()): + assert p1.grad is not None + assert p2.grad is not None + assert p1.grad.shape == p2.grad.shape + torch.testing.assert_close(p1.grad, p2.grad) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_fsdp_with_node_activation_checkpointing(backend): + run_distributed_test( + run_fsdp_with_node_activation_checkpointing, + backend=backend, + start_method="spawn", + ) + + +def run_fsdp_with_intra_node_activation_checkpointing(): + from torch.utils.checkpoint import checkpoint + + checkpoint_fn = partial(checkpoint, use_reentrant=False) + + class SubModel(nn.Module): + def __init__(self, do_activation_checkpointing: bool = True): + super().__init__() + self.ff1 = nn.Linear(8, 4) + self.ff2 = nn.Linear(4, 8) + self.do_activation_checkpointing = do_activation_checkpointing + + def forward(self, x): + x = self.ff1(x) + if self.do_activation_checkpointing: + x = checkpoint_fn(self.ff2, x) + else: + x = self.ff2(x) + return x + + class Model(nn.Module): + def __init__(self, do_activation_checkpointing: bool = True): + super().__init__() + self.ff1 = nn.Linear(4, 8) + self.ff2 = FSDP(SubModel(do_activation_checkpointing=do_activation_checkpointing)) + + def forward(self, x): + x = self.ff1(x) + x = self.ff2(x) + return x + + fsdp_ckpt = FSDP(Model(do_activation_checkpointing=True)) + fsdp = FSDP(Model(do_activation_checkpointing=True)) + + # Synchronize weights. + fsdp.load_state_dict(fsdp_ckpt.state_dict()) + + # Run forward pass + inputs = torch.rand(2, 4, device=fsdp.device) + loss_ckpt = fsdp_ckpt(inputs).sum() + loss = fsdp(inputs).sum() + torch.testing.assert_close(loss_ckpt, loss) + + # Run backward pass. + loss_ckpt.backward() + loss.backward() + for p1, p2 in zip(fsdp_ckpt.parameters(), fsdp.parameters()): + assert p1.grad is not None + assert p2.grad is not None + assert p1.grad.shape == p2.grad.shape + torch.testing.assert_close(p1.grad, p2.grad) + + +@pytest.mark.parametrize("backend", BACKENDS) +def test_fsdp_with_intra_node_activation_checkpointing(backend): + run_distributed_test( + run_fsdp_with_intra_node_activation_checkpointing, + backend=backend, + start_method="spawn", + ) + + def run_fsdp_with_mixed_precision(model_factory, model_data_factory, precision): fsdp = FSDP(model_factory(), precision=precision)