Skip to content

Commit

Permalink
Add SHARD_GRAD_OP FSDP sharding strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
epwalsh committed May 8, 2024
1 parent b919c9b commit 95bcc38
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/olmo_core/distributed/fsdp/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@ class FSDPShardingStrategy(StrEnum):
only done within a node, which can be more performant for medium to large-sized models.
"""

SHARD_GRAD_OP = "SHARD_GRAD_OP"
"""
Like ``FULL_SHARD`` except parameters are not resharded after the forward pass when gradients
are enabled, instead only after the backwards pass.
"""


@dataclass
class FSDPDebugConfig:
Expand Down Expand Up @@ -233,9 +239,11 @@ def forward(self, *args, **kwargs):
for module in self.state.forward_execution_order:
self.state.forward_prefetch_queue.append(module)

keep_full_params_with_grads = False
if self.is_root and not self.free_root_after_forward and torch.is_grad_enabled():
keep_full_params_with_grads = True
# Determine whether to reshard after the forward pass.
keep_full_params_with_grads = (
(self.is_root and not self.free_root_after_forward)
or self.sharding_strategy == FSDPShardingStrategy.SHARD_GRAD_OP
) and torch.is_grad_enabled()

# Unshard parameters in-place.
self._unshard(set_grads=keep_full_params_with_grads)
Expand Down

0 comments on commit 95bcc38

Please sign in to comment.