diff --git a/src/olmo_core/distributed/fsdp/fsdp.py b/src/olmo_core/distributed/fsdp/fsdp.py index f23dbe1b..bd39047e 100644 --- a/src/olmo_core/distributed/fsdp/fsdp.py +++ b/src/olmo_core/distributed/fsdp/fsdp.py @@ -84,6 +84,12 @@ class FSDPShardingStrategy(StrEnum): only done within a node, which can be more performant for medium to large-sized models. """ + SHARD_GRAD_OP = "SHARD_GRAD_OP" + """ + Like ``FULL_SHARD`` except parameters are not resharded after the forward pass when gradients + are enabled, instead only after the backwards pass. + """ + @dataclass class FSDPDebugConfig: @@ -233,9 +239,11 @@ def forward(self, *args, **kwargs): for module in self.state.forward_execution_order: self.state.forward_prefetch_queue.append(module) - keep_full_params_with_grads = False - if self.is_root and not self.free_root_after_forward and torch.is_grad_enabled(): - keep_full_params_with_grads = True + # Determine whether to reshard after the forward pass. + keep_full_params_with_grads = ( + (self.is_root and not self.free_root_after_forward) + or self.sharding_strategy == FSDPShardingStrategy.SHARD_GRAD_OP + ) and torch.is_grad_enabled() # Unshard parameters in-place. self._unshard(set_grads=keep_full_params_with_grads)