diff --git a/docs/source/distributed/fsdp.rst b/docs/source/distributed/fsdp.rst index 6ad9e2d7..5d781a4e 100644 --- a/docs/source/distributed/fsdp.rst +++ b/docs/source/distributed/fsdp.rst @@ -2,5 +2,5 @@ ==================== .. automodule:: olmo_core.distributed.fsdp - :members: FSDP, FSDPPrecision + :members: FSDP, FSDPPrecision, FSDPShardingStrategy :member-order: bysource diff --git a/src/benchmarks/fsdp/common.py b/src/benchmarks/fsdp/common.py index 6eff8641..1d02ee96 100644 --- a/src/benchmarks/fsdp/common.py +++ b/src/benchmarks/fsdp/common.py @@ -151,7 +151,7 @@ def build_components( config: TransformerConfig, batch_size: int, num_batches: int = 100, - fsdp_wrapper: Literal["torch", "olmo_core"] = "olmo_core", + fsdp_wrapper: Literal["torch", "olmo_core", "ddp"] = "olmo_core", wrap_blocks: bool = True, mixed_precision: bool = True, max_prefetch_count: int = 1, @@ -204,6 +204,11 @@ def auto_wrap_policy(module: nn.Module, recurse: bool, *args, **kwargs) -> bool: ) model.apply(init_function) # just in case + elif fsdp_wrapper == "ddp": + from torch.nn.parallel import DistributedDataParallel as DDP + + model = DDP(model.cuda(), device_ids=[dist.get_rank()]) + model.apply(init_function) else: raise NotImplementedError(fsdp_wrapper) diff --git a/src/benchmarks/fsdp/train.py b/src/benchmarks/fsdp/train.py index 245ba9d2..cff5ecc9 100644 --- a/src/benchmarks/fsdp/train.py +++ b/src/benchmarks/fsdp/train.py @@ -13,6 +13,7 @@ import torch import torch.distributed as dist +from torch.nn.utils import clip_grad_norm_ from olmo_core.distributed.checkpoint import ( load_model_and_optim_state, @@ -28,13 +29,14 @@ def main( config: TransformerConfig, batch_size: int, num_batches: int = 100, - fsdp_wrapper: Literal["torch", "olmo_core"] = "olmo_core", + fsdp_wrapper: Literal["torch", "olmo_core", "ddp"] = "olmo_core", dry_run: bool = False, save_path: Optional[str] = None, load_path: Optional[str] = None, mixed_precision: bool = True, profile: bool = False, trace_output: str = "/tmp/traces/olmo_core.chrome_trace.json.gz", + max_grad_norm: Optional[float] = None, **kwargs, ): model, optim, dataloader = build_components( @@ -98,7 +100,12 @@ def on_trace_ready(p): loss.backward() # Clip gradient norms. - model.clip_grad_norm_(1.0) + norm: Optional[torch.Tensor] = None + if max_grad_norm is not None: + if hasattr(model, "clip_grad_norm_"): + norm = model.clip_grad_norm_(max_grad_norm) + else: + norm = clip_grad_norm_(model.parameters(), max_grad_norm) # Take optimizer step. optim.step() @@ -106,13 +113,15 @@ def on_trace_ready(p): batch_time = time.monotonic() - batch_start if i > 0: batch_times.append(batch_time) + norm_str = f"{norm.item():.3f}" if norm is not None else "n/a" print_rank0( f"Batch [{i+1}/{num_batches}]:\n" f" loss={loss.item():.3f}\n" - f" throughput/seconds_per_batch={batch_time:.3f}", + f" throughput/seconds_per_batch={batch_time:.3f}\n" + f" grad/total_norm={norm_str}" ) - if i == 2: + if profile and i == 2: print_rank0(torch.cuda.memory_summary()) if p is not None: @@ -134,7 +143,7 @@ def on_trace_ready(p): parser = argparse.ArgumentParser(prog="train.py", description="Train an FSDP model") parser.add_argument( "--fsdp", - choices=["torch", "olmo_core"], + choices=["torch", "olmo_core", "ddp"], default="olmo_core", help="""The FSDP implementation.""", ) @@ -190,6 +199,10 @@ def on_trace_ready(p): type=int, default=1, ) + parser.add_argument( + "--max-grad-norm", + type=float, + ) parser.add_argument( "--lr", type=float, @@ -237,5 +250,6 @@ def on_trace_ready(p): mixed_precision=mixed_precision, max_prefetch_count=args.max_prefetch_count, learning_rate=args.lr, + max_grad_norm=args.max_grad_norm, seed=args.seed, ) diff --git a/src/olmo_core/distributed/fsdp/__init__.py b/src/olmo_core/distributed/fsdp/__init__.py index e94deda3..b8dd1518 100644 --- a/src/olmo_core/distributed/fsdp/__init__.py +++ b/src/olmo_core/distributed/fsdp/__init__.py @@ -77,6 +77,6 @@ ------------- """ -from .fsdp import FSDP, FSDPDebugConfig, FSDPPrecision +from .fsdp import FSDP, FSDPDebugConfig, FSDPPrecision, FSDPShardingStrategy -__all__ = ["FSDP", "FSDPDebugConfig", "FSDPPrecision"] +__all__ = ["FSDP", "FSDPDebugConfig", "FSDPPrecision", "FSDPShardingStrategy"] diff --git a/src/olmo_core/distributed/fsdp/flat_param_handle.py b/src/olmo_core/distributed/fsdp/flat_param_handle.py index 34f6d1e5..3ce4c8bf 100644 --- a/src/olmo_core/distributed/fsdp/flat_param_handle.py +++ b/src/olmo_core/distributed/fsdp/flat_param_handle.py @@ -15,7 +15,11 @@ ShardedFlatTensor, ShardingSpec, ) -from olmo_core.distributed.utils import get_rank, get_world_size +from olmo_core.distributed.utils import ( + get_gradient_divide_factor, + get_rank, + get_world_size, +) from olmo_core.stream import Stream from olmo_core.utils import get_default_device @@ -62,21 +66,41 @@ class FlatParamHandle: """ process_group: Optional[dist.ProcessGroup] = None + """ + Process group containing all shards. + """ + + inter_group_process_group: Optional[dist.ProcessGroup] = None + """ + Process group for between-group reductions with hybrid sharding. + """ device: Optional[torch.device] = None requires_grad: bool = True + pre_reduce_grad_divide_factor: float = 1.0 + + post_reduce_grad_divide_factor: float = 1.0 + _ran_pre_unshard: bool = False _ran_pre_reduce_scatter_grads: bool = False + def __post_init__(self): + data_parallel_world_size = get_world_size(self.process_group) + if self.inter_group_process_group is not None: + data_parallel_world_size *= self.inter_group_process_group.size() + self.pre_reduce_grad_divide_factor = get_gradient_divide_factor(data_parallel_world_size) + self.post_reduce_grad_divide_factor = data_parallel_world_size / self.pre_reduce_grad_divide_factor + @classmethod def shard_params( cls, params: Iterable[nn.Parameter], param_fqns: Iterable[str], process_group: Optional[dist.ProcessGroup] = None, + inter_group_process_group: Optional[dist.ProcessGroup] = None, device: Optional[torch.device] = None, ) -> FlatParamHandle: """ @@ -183,6 +207,7 @@ def shard_params( ) else: flat_param = ShardedFlatParameter(torch.empty(0, device=device)) + flat_param.requires_grad = param.requires_grad flat_param.mark_as_sharded(sharding_spec, process_group=process_group) flat_params.append(flat_param) @@ -224,6 +249,7 @@ def shard_params( param_fqns=list(param_fqns), params_data=params_data, process_group=process_group, + inter_group_process_group=inter_group_process_group, device=device, requires_grad=requires_grad, ) @@ -253,7 +279,7 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F self.params_sharded_data_lp.copy_(self.params_data.sharded_data) # Initialize unsharded, padded gradient. - if set_grads and self.params_unsharded_grad is None: + if set_grads and self.requires_grad and self.params_unsharded_grad is None: self.params_unsharded_grad = torch.zeros_like(all_params_unsharded_data) def unshard_( @@ -288,7 +314,7 @@ def unshard_( if rank0_only or dist.get_backend() == dist.Backend.GLOO: assert self.params_data.is_sharded self.params_data.unshard_(dtype=dtype, rank0_only=rank0_only) - if set_grads: + if set_grads and self.requires_grad: self.params_unsharded_grad = torch.zeros_like(self.params_data) else: assert not self.params_data.is_sharded @@ -318,7 +344,7 @@ def unshard_( param.unshard_(unsharded_data, dtype=dtype, rank0_only=rank0_only) - if set_grads: + if set_grads and self.requires_grad: if param.grad is None and self.params_sharded_grad is not None: self.params_sharded_grad = None assert self.params_unsharded_grad is not None @@ -360,6 +386,9 @@ def pre_reduce_scatter_grads_( Stream.current(self.device).record_for(self.params_unsharded_grad) self.params_unsharded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype) + if self.pre_reduce_grad_divide_factor > 1.0: + self.params_unsharded_grad.div_(self.pre_reduce_grad_divide_factor) + def reduce_scatter_grads_( self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None ): @@ -368,6 +397,7 @@ def reduce_scatter_grads_( parameter as a view into the new sharded grad. """ if not self.requires_grad or self.params_unsharded_grad is None: + self._ran_pre_reduce_scatter_grads = False return if not self._ran_pre_reduce_scatter_grads: @@ -398,12 +428,20 @@ def post_reduce_scatter_grads_( """ Finalize sharded gradients after the reduce-scatter. """ + if not self.requires_grad or self.params_unsharded_grad is None: + return + grad_dtype = grad_dtype or self.params_data.dtype grad_reduce_dtype = grad_reduce_dtype or grad_dtype - assert self.params_unsharded_grad is not None new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad) + if self.inter_group_process_group is not None: + dist.all_reduce(new_sharded_grad, group=self.inter_group_process_group) + + if self.post_reduce_grad_divide_factor > 1.0: + new_sharded_grad.div_(self.post_reduce_grad_divide_factor) + # Cast the new sharded gradient to the right dtype, potentially accumulating it into # the existing sharded gradient. if self.params_sharded_grad is None: diff --git a/src/olmo_core/distributed/fsdp/fsdp.py b/src/olmo_core/distributed/fsdp/fsdp.py index ae191658..01487856 100644 --- a/src/olmo_core/distributed/fsdp/fsdp.py +++ b/src/olmo_core/distributed/fsdp/fsdp.py @@ -21,16 +21,24 @@ Type, TypeVar, Union, + cast, ) import torch import torch.distributed as dist import torch.nn as nn from torch.autograd import Variable +from torch.distributed.device_mesh import DeviceMesh from olmo_core.distributed.tensors import ShardedFlatParameter from olmo_core.stream import Stream -from olmo_core.utils import apply_to_tensors, gc_cuda, get_default_device, get_grad_norm +from olmo_core.utils import ( + StrEnum, + apply_to_tensors, + gc_cuda, + get_default_device, + get_grad_norm, +) from .flat_param_handle import FlatParamHandle from .state import FSDPState @@ -55,6 +63,28 @@ class FSDPPrecision: """ +class FSDPShardingStrategy(StrEnum): + """ + Defines the sharding strategy used by :class:`FSDP`. + """ + + FULL_SHARD = "FULL_SHARD" + """ + Parameters, gradients, and optimizer states are sharded. For the parameters, this strategy unshards + (via all-gather) before the forward, reshards after the forward (except potentially for the root FSDP instance), + unshards before the backward computation, and reshards after the backward computation. + For gradients, it synchronizes and shards them (via reduce-scatter) after the backward + computation. The sharded optimizer states are updated locally per rank. + """ + + HYBRID_SHARD = "HYBRID_SHARD" + """ + Apply ``FULL_SHARD`` within a process group, and replicate parameters across process groups. + This results in reduced communication volume as expensive all-gathers and reduce-scatters are + only done within a node, which can be more performant for medium to large-sized models. + """ + + @dataclass class FSDPDebugConfig: no_reduce_grads: bool = False @@ -70,27 +100,62 @@ class FSDP(Generic[M], nn.Module): FSDP, a.k.a. Fully Sharded Data Parallel, a ZeRO-3 model wrapper. :param module: The module to wrap. - :param process_group: The distributed process group. + :param process_group: The distributed process group to shard across. + :param device_mesh: Mutually exclusive with ``process_group``. + This is required for :data:`FSDPShardingStrategy.HYBRID_SHARD`, in which case the first + dimension should specify the number of model replicas (hybrid groups), and the second + dimension should specify the number of shards within each replica. + If you're not using :data:`FSDPShardingStrategy.HYBRID_SHARD` and you specify ``device_mesh``, + the process group in the first dimension will be used. :param precision: Mixed precision settings. + :param sharding_strategy: The sharding strategy to use. :param max_prefetch_count: The number of nested FSDP modules that can be prefetched during the forward and backward passes. This is like PyTorch's ``limit_all_gathers`` except it allows more control. + :param free_root_after_forward: By default the root FSDP instance keeps its full params in memory after + the forward pass when grads are enabled to avoid immediately regathering during the backward + pass. Setting this to ``False`` can save some memory at the expense of throughput. """ WRAPPED_MODULE_PREFIX = "_fsdp_wrapped_module" + """ + The prefix the wrapped module is stored under. In general you don't need to know this as the wrapping + FSDP instance behaves like the wrapped module itself for most APIs, and otherwise you should + access the wrapped module through the :data:`module` property. + """ def __init__( self, module: M, process_group: Optional[dist.ProcessGroup] = None, + device_mesh: Optional[DeviceMesh] = None, precision: Optional[FSDPPrecision] = None, + sharding_strategy: FSDPShardingStrategy = FSDPShardingStrategy.FULL_SHARD, max_prefetch_count: int = 1, + free_root_after_forward: bool = False, _debug_config: Optional[FSDPDebugConfig] = None, ): super().__init__() + + # Validate process group and device mesh given the sharding strategy. + inter_group_process_group: Optional[dist.ProcessGroup] = None + if process_group is not None and device_mesh is not None: + raise ValueError("'process_group' and 'device_mesh' are mutually exclusive") + elif device_mesh is not None: + if sharding_strategy == FSDPShardingStrategy.HYBRID_SHARD: + inter_group_process_group = cast(dist.ProcessGroup, device_mesh.get_group(mesh_dim=0)) + process_group = cast(dist.ProcessGroup, device_mesh.get_group(mesh_dim=1)) + else: + process_group = cast(dist.ProcessGroup, device_mesh.get_group(mesh_dim=0)) + elif sharding_strategy == FSDPShardingStrategy.HYBRID_SHARD: + raise ValueError("'device_mesh' is required for `HYBRID_SHARD`") + self._fsdp_wrapped_module = module self.process_group = process_group + self.inter_group_process_group = inter_group_process_group self.precision = precision or FSDPPrecision() + self.sharding_strategy = sharding_strategy self.max_prefetch_count = max_prefetch_count + self.free_root_after_forward = free_root_after_forward self.debug_config = _debug_config or FSDPDebugConfig() self.device = get_default_device() self.state = FSDPState(device=self.device) @@ -164,8 +229,12 @@ def forward(self, *args, **kwargs): for module in self.state.forward_execution_order: self.state.forward_prefetch_queue.append(module) + keep_full_params_with_grads = False + if self.is_root and not self.free_root_after_forward and torch.is_grad_enabled(): + keep_full_params_with_grads = True + # Unshard parameters in-place. - self._unshard() + self._unshard(set_grads=keep_full_params_with_grads) try: # Wait for unsharding stream before running the wrapped module's forward pass. @@ -194,8 +263,10 @@ def forward(self, *args, **kwargs): # Register post-backward hooks to reshard the parameters in place and reduce gradients. self._register_post_backward_hooks() finally: - # Reshard parameters in-place. - self._reshard() + # Reshard parameters in-place, except potentially the root instance to avoid + # immediately regathering in the backward pass. + if not keep_full_params_with_grads: + self._reshard() if self.is_root: # At the end of the first forward pass, execution order is now finalized, meaning @@ -528,7 +599,11 @@ def _shard(self): if params_with_grads: handles.append( FlatParamHandle.shard_params( - params_with_grads, params_with_grads_fqns, process_group=self.process_group, device=self.device + params_with_grads, + params_with_grads_fqns, + process_group=self.process_group, + inter_group_process_group=self.inter_group_process_group, + device=self.device, ) ) if params_without_grads: @@ -537,6 +612,7 @@ def _shard(self): params_without_grads, params_without_grads_fqns, process_group=self.process_group, + inter_group_process_group=self.inter_group_process_group, device=self.device, ) ) @@ -703,8 +779,9 @@ def _pre_backward_hook(self, *unused: Any): self.state.backward_execution_order.append(self) def _register_pre_backward_hook(self, x: torch.Tensor): - handle = x.register_hook(self._pre_backward_hook) - self.state.pre_backward_hook_handles.append(handle) + if x.requires_grad: + hook_handle = x.register_hook(self._pre_backward_hook) + self.state.pre_backward_hook_handles.append(hook_handle) def _register_pre_backward_hooks(self, output: Any): log.debug("Registering pre-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module)) @@ -713,8 +790,8 @@ def _register_pre_backward_hooks(self, output: Any): log.debug( "Removing old pre-backward hooks for %s (%s)...", self.module.__class__.__name__, id(self.module) ) - for handle in self.state.pre_backward_hook_handles: - handle.remove() + for hook_handle in self.state.pre_backward_hook_handles: + hook_handle.remove() self.state.pre_backward_hook_handles.clear() apply_to_tensors(self._register_pre_backward_hook, output) @@ -731,7 +808,7 @@ def _post_backward_hook(self, param_name: str, *unused: Any): if self.state.post_backward_hook_handles: return - log.debug("Running post-backward hook for %s (%s)", self.module.__class__.__name__, id(self.module)) + log.debug("Running post-backward hook for %s (%s)...", self.module.__class__.__name__, id(self.module)) # NOTE: reshard *before* reducing grads to correctly handle precision settings. self._reshard() diff --git a/src/olmo_core/distributed/utils.py b/src/olmo_core/distributed/utils.py index b21179c4..a63a46d7 100644 --- a/src/olmo_core/distributed/utils.py +++ b/src/olmo_core/distributed/utils.py @@ -76,3 +76,10 @@ def all_gather_object(obj: T, group: Optional[dist.ProcessGroup] = None) -> List output_list = [obj] * get_world_size(group) dist.all_gather_object(output_list, obj, group=group) return output_list + + +def get_gradient_divide_factor(world_size: int) -> float: + factor: int = 1 + while world_size % factor == 0 and world_size / factor > factor: + factor *= 2 + return float(factor) diff --git a/src/test/distributed/fsdp/fsdp_test.py b/src/test/distributed/fsdp/fsdp_test.py index 974b67fa..61c749a7 100644 --- a/src/test/distributed/fsdp/fsdp_test.py +++ b/src/test/distributed/fsdp/fsdp_test.py @@ -24,8 +24,10 @@ def run_fsdp_against_non_distributed_model(model_factory, model_data_factory): model_data = model_data_factory().to(get_default_device()) model = model_factory().to(get_default_device()) - fsdp1 = FSDP(model_factory(), _debug_config=FSDPDebugConfig(no_reduce_grads=True)) - fsdp2 = FSDP(model_factory()) + fsdp1 = FSDP( + model_factory(), free_root_after_forward=True, _debug_config=FSDPDebugConfig(no_reduce_grads=True) + ) + fsdp2 = FSDP(model_factory(), free_root_after_forward=True) # Ensure params for all models on all ranks match. for param in model.parameters(): @@ -88,7 +90,9 @@ def run_fsdp_against_non_distributed_model(model_factory, model_data_factory): with torch.no_grad(): dist.all_reduce(param1.grad, group=fsdp1.process_group) torch.testing.assert_close( - param2.sharded_chunk(param1.grad), param2.grad, msg=lambda m: f"On gradient for '{name}'. {m}" + param2.sharded_chunk(param1.grad) / dist.get_world_size(), + param2.grad, + msg=lambda m: f"On gradient for '{name}'. {m}", ) @@ -109,7 +113,7 @@ def run_fsdp_against_ddp(model_factory, model_data_factory): model_data = model_data_factory().to(get_default_device()) ddp_model = DDP(model_factory().to(get_default_device())) - fsdp_model = FSDP(model_factory()) + fsdp_model = FSDP(model_factory(), free_root_after_forward=True) with fsdp_model.summon_full_params(): fsdp_model.module.load_state_dict(ddp_model.module.state_dict()) @@ -150,10 +154,9 @@ def run_fsdp_against_ddp(model_factory, model_data_factory): assert param.is_sharded assert param.grad is not None with torch.no_grad(): - # NOTE: DDP *averages* gradients over ranks, FSDP just takes the sum. torch.testing.assert_close( param.grad, - param.sharded_chunk(expected_grads[name] * dist.get_world_size()), + param.sharded_chunk(expected_grads[name]), msg=lambda m: f"On gradient for '{name}'. {m}", ) @@ -216,7 +219,7 @@ def run_fsdp_with_gradient_accumulation(model_factory, model_data_factory): with torch.no_grad(): torch.testing.assert_close( param.grad, - param.sharded_chunk(expected_grads[name] * dist.get_world_size()), + param.sharded_chunk(expected_grads[name]), msg=lambda m: f"On gradient for '{name}'. {m}", ) @@ -377,14 +380,92 @@ def test_nested_fsdp_api(backend, tiny_model_factory, tiny_model_data_factory): ) -def run_fsdp_with_frozen_params(): +def run_fsdp_with_mix_of_frozen_and_non_frozen_params(case: int): class Model(nn.Module): def __init__(self): super().__init__() self.ff1 = nn.Linear(8, 8) self.ff2 = nn.Linear(8, 8) - self.ff2.weight.requires_grad = False - self.ff2.bias.requires_grad = False + if case == 1: + self.ff1.weight.requires_grad = False + self.ff1.bias.requires_grad = False + elif case == 2: + self.ff2.weight.requires_grad = False + self.ff2.bias.requires_grad = False + else: + raise NotImplementedError + + def forward(self, x): + return self.ff2(self.ff1(x)) + + fsdp = FSDP(Model()) + + # Check handles. + assert len(fsdp.state.flat_param_handles) == 2 + assert fsdp.state.flat_param_handles[0].requires_grad + assert not fsdp.state.flat_param_handles[1].requires_grad + + # Check params. + for name, param in fsdp.named_parameters(): + assert param.grad is None, f"param {param} already has a grad!" + + # Run forward pass + loss = fsdp(torch.rand(2, 8, device=fsdp.device)).sum() + + # Trigger backward pass. + loss.backward() + + # Check grads. + if case == 1: + assert fsdp.module.ff1.weight.grad is None + assert fsdp.module.ff1.bias.grad is None + assert fsdp.module.ff2.weight.grad is not None + assert fsdp.module.ff2.bias.grad is not None + elif case == 2: + assert fsdp.module.ff1.weight.grad is not None + assert fsdp.module.ff1.bias.grad is not None + assert fsdp.module.ff2.weight.grad is None + assert fsdp.module.ff2.bias.grad is None + else: + raise NotImplementedError + + # Make sure every param has been resharded. + for name, param in fsdp.named_parameters(): + assert isinstance(param, ShardedFlatParameter) + assert param.is_sharded, f"param {name} has not been resharded!" + + +@pytest.mark.parametrize("backend", BACKENDS) +@pytest.mark.parametrize("case", (1, 2)) +def test_fsdp_with_mix_of_frozen_and_non_frozen_params(backend, case): + run_distributed_test( + run_fsdp_with_mix_of_frozen_and_non_frozen_params, + backend=backend, + start_method="spawn", + func_args=(case,), + ) + + +def run_fsdp_with_frozen_fsdp_child(case: int): + class Model(nn.Module): + def __init__(self): + super().__init__() + ff1 = nn.Linear(8, 8) + ff2 = nn.Linear(8, 8) + + if case == 1: + ff1.weight.requires_grad = False + ff1.bias.requires_grad = False + ff1 = FSDP(ff1) + elif case == 2: + ff2.weight.requires_grad = False + ff2.bias.requires_grad = False + ff2 = FSDP(ff2) + else: + raise NotImplementedError + + self.ff1 = ff1 + self.ff2 = ff2 def forward(self, x): return self.ff2(self.ff1(x)) @@ -396,16 +477,35 @@ def forward(self, x): # Trigger backward pass. loss.backward() - assert fsdp.module.ff1.weight.grad is not None - assert fsdp.module.ff1.bias.grad is not None + + # Check grads. + if case == 1: + assert fsdp.module.ff1.weight.grad is None + assert fsdp.module.ff1.bias.grad is None + assert fsdp.module.ff2.weight.grad is not None + assert fsdp.module.ff2.bias.grad is not None + elif case == 2: + assert fsdp.module.ff1.weight.grad is not None + assert fsdp.module.ff1.bias.grad is not None + assert fsdp.module.ff2.weight.grad is None + assert fsdp.module.ff2.bias.grad is None + else: + raise NotImplementedError + + # Make sure every param has been resharded. + for name, param in fsdp.named_parameters(): + assert isinstance(param, ShardedFlatParameter) + assert param.is_sharded, f"param {name} has not been resharded!" @pytest.mark.parametrize("backend", BACKENDS) -def test_fsdp_with_frozen_params(backend): +@pytest.mark.parametrize("case", (1, 2)) +def test_fsdp_with_frozen_fsdp_child(backend, case): run_distributed_test( - run_fsdp_with_frozen_params, + run_fsdp_with_frozen_fsdp_child, backend=backend, start_method="spawn", + func_args=(case,), )