Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More improvements to FSDP, benchmark against DDP #13

Merged
merged 30 commits into from
May 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a0e8e84
leave root params in mem
epwalsh Apr 19, 2024
431a140
fix logic
epwalsh Apr 19, 2024
98025a3
don't show mem usage all the time
epwalsh Apr 19, 2024
60aec7f
fix
epwalsh Apr 19, 2024
69d74c4
make configurable
epwalsh Apr 19, 2024
78146a3
add alloc/free for unsharded data
epwalsh Apr 19, 2024
e6f5fc2
add alloc/free for unsharded grad
epwalsh Apr 19, 2024
c343413
fix
epwalsh Apr 19, 2024
d1cf74d
record for
epwalsh Apr 19, 2024
34765e7
revert 69d74c4 - alloc/free for unsharded grad
epwalsh Apr 19, 2024
7ea0651
revert alloc/free trick for unsharded params data
epwalsh Apr 19, 2024
ce69f24
add support for DDP in benchmark
epwalsh Apr 19, 2024
dbbf0ae
set device ids explicitly
epwalsh Apr 19, 2024
28401f1
fix
epwalsh Apr 19, 2024
811b6c3
change up how weights are initialized
epwalsh Apr 19, 2024
82378a9
fix test
epwalsh Apr 19, 2024
0386841
Add back alloc/free hack for unsharded data
epwalsh Apr 24, 2024
5f140f2
Revert "Add back alloc/free hack for unsharded data"
epwalsh Apr 24, 2024
0f408d2
Handle frozen layers with reshard-only post-backward hook
epwalsh Apr 24, 2024
3d63078
Revert "Handle frozen layers with reshard-only post-backward hook"
epwalsh Apr 24, 2024
450d8e8
add to test
epwalsh Apr 24, 2024
9c1e52f
add to test
epwalsh Apr 24, 2024
48abc4d
Fixes for frozen modules
epwalsh Apr 24, 2024
c41d2d4
Divide grad before and after reducing for stability
epwalsh Apr 24, 2024
494be77
Add support for hybrid sharding
epwalsh Apr 24, 2024
a9f8194
make grad clipping optional
epwalsh Apr 30, 2024
6f136b6
clean up
epwalsh Apr 30, 2024
d66a683
calculate grad norm more efficiently
epwalsh Apr 30, 2024
d9579b1
Revert "calculate grad norm more efficiently"
epwalsh Apr 30, 2024
a48bfb7
fix
epwalsh Apr 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/distributed/fsdp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
====================

.. automodule:: olmo_core.distributed.fsdp
:members: FSDP, FSDPPrecision
:members: FSDP, FSDPPrecision, FSDPShardingStrategy
:member-order: bysource
7 changes: 6 additions & 1 deletion src/benchmarks/fsdp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def build_components(
config: TransformerConfig,
batch_size: int,
num_batches: int = 100,
fsdp_wrapper: Literal["torch", "olmo_core"] = "olmo_core",
fsdp_wrapper: Literal["torch", "olmo_core", "ddp"] = "olmo_core",
wrap_blocks: bool = True,
mixed_precision: bool = True,
max_prefetch_count: int = 1,
Expand Down Expand Up @@ -204,6 +204,11 @@ def auto_wrap_policy(module: nn.Module, recurse: bool, *args, **kwargs) -> bool:
)

model.apply(init_function) # just in case
elif fsdp_wrapper == "ddp":
from torch.nn.parallel import DistributedDataParallel as DDP

model = DDP(model.cuda(), device_ids=[dist.get_rank()])
model.apply(init_function)
else:
raise NotImplementedError(fsdp_wrapper)

Expand Down
24 changes: 19 additions & 5 deletions src/benchmarks/fsdp/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import torch
import torch.distributed as dist
from torch.nn.utils import clip_grad_norm_

from olmo_core.distributed.checkpoint import (
load_model_and_optim_state,
Expand All @@ -28,13 +29,14 @@ def main(
config: TransformerConfig,
batch_size: int,
num_batches: int = 100,
fsdp_wrapper: Literal["torch", "olmo_core"] = "olmo_core",
fsdp_wrapper: Literal["torch", "olmo_core", "ddp"] = "olmo_core",
dry_run: bool = False,
save_path: Optional[str] = None,
load_path: Optional[str] = None,
mixed_precision: bool = True,
profile: bool = False,
trace_output: str = "/tmp/traces/olmo_core.chrome_trace.json.gz",
max_grad_norm: Optional[float] = None,
**kwargs,
):
model, optim, dataloader = build_components(
Expand Down Expand Up @@ -98,21 +100,28 @@ def on_trace_ready(p):
loss.backward()

# Clip gradient norms.
model.clip_grad_norm_(1.0)
norm: Optional[torch.Tensor] = None
if max_grad_norm is not None:
if hasattr(model, "clip_grad_norm_"):
norm = model.clip_grad_norm_(max_grad_norm)
else:
norm = clip_grad_norm_(model.parameters(), max_grad_norm)

# Take optimizer step.
optim.step()

batch_time = time.monotonic() - batch_start
if i > 0:
batch_times.append(batch_time)
norm_str = f"{norm.item():.3f}" if norm is not None else "n/a"
print_rank0(
f"Batch [{i+1}/{num_batches}]:\n"
f" loss={loss.item():.3f}\n"
f" throughput/seconds_per_batch={batch_time:.3f}",
f" throughput/seconds_per_batch={batch_time:.3f}\n"
f" grad/total_norm={norm_str}"
)

if i == 2:
if profile and i == 2:
print_rank0(torch.cuda.memory_summary())

if p is not None:
Expand All @@ -134,7 +143,7 @@ def on_trace_ready(p):
parser = argparse.ArgumentParser(prog="train.py", description="Train an FSDP model")
parser.add_argument(
"--fsdp",
choices=["torch", "olmo_core"],
choices=["torch", "olmo_core", "ddp"],
default="olmo_core",
help="""The FSDP implementation.""",
)
Expand Down Expand Up @@ -190,6 +199,10 @@ def on_trace_ready(p):
type=int,
default=1,
)
parser.add_argument(
"--max-grad-norm",
type=float,
)
parser.add_argument(
"--lr",
type=float,
Expand Down Expand Up @@ -237,5 +250,6 @@ def on_trace_ready(p):
mixed_precision=mixed_precision,
max_prefetch_count=args.max_prefetch_count,
learning_rate=args.lr,
max_grad_norm=args.max_grad_norm,
seed=args.seed,
)
4 changes: 2 additions & 2 deletions src/olmo_core/distributed/fsdp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,6 @@
-------------
"""

from .fsdp import FSDP, FSDPDebugConfig, FSDPPrecision
from .fsdp import FSDP, FSDPDebugConfig, FSDPPrecision, FSDPShardingStrategy

__all__ = ["FSDP", "FSDPDebugConfig", "FSDPPrecision"]
__all__ = ["FSDP", "FSDPDebugConfig", "FSDPPrecision", "FSDPShardingStrategy"]
48 changes: 43 additions & 5 deletions src/olmo_core/distributed/fsdp/flat_param_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@
ShardedFlatTensor,
ShardingSpec,
)
from olmo_core.distributed.utils import get_rank, get_world_size
from olmo_core.distributed.utils import (
get_gradient_divide_factor,
get_rank,
get_world_size,
)
from olmo_core.stream import Stream
from olmo_core.utils import get_default_device

Expand Down Expand Up @@ -62,21 +66,41 @@ class FlatParamHandle:
"""

process_group: Optional[dist.ProcessGroup] = None
"""
Process group containing all shards.
"""

inter_group_process_group: Optional[dist.ProcessGroup] = None
"""
Process group for between-group reductions with hybrid sharding.
"""

device: Optional[torch.device] = None

requires_grad: bool = True

pre_reduce_grad_divide_factor: float = 1.0

post_reduce_grad_divide_factor: float = 1.0

_ran_pre_unshard: bool = False

_ran_pre_reduce_scatter_grads: bool = False

def __post_init__(self):
data_parallel_world_size = get_world_size(self.process_group)
if self.inter_group_process_group is not None:
data_parallel_world_size *= self.inter_group_process_group.size()
self.pre_reduce_grad_divide_factor = get_gradient_divide_factor(data_parallel_world_size)
self.post_reduce_grad_divide_factor = data_parallel_world_size / self.pre_reduce_grad_divide_factor

@classmethod
def shard_params(
cls,
params: Iterable[nn.Parameter],
param_fqns: Iterable[str],
process_group: Optional[dist.ProcessGroup] = None,
inter_group_process_group: Optional[dist.ProcessGroup] = None,
device: Optional[torch.device] = None,
) -> FlatParamHandle:
"""
Expand Down Expand Up @@ -183,6 +207,7 @@ def shard_params(
)
else:
flat_param = ShardedFlatParameter(torch.empty(0, device=device))
flat_param.requires_grad = param.requires_grad
flat_param.mark_as_sharded(sharding_spec, process_group=process_group)

flat_params.append(flat_param)
Expand Down Expand Up @@ -224,6 +249,7 @@ def shard_params(
param_fqns=list(param_fqns),
params_data=params_data,
process_group=process_group,
inter_group_process_group=inter_group_process_group,
device=device,
requires_grad=requires_grad,
)
Expand Down Expand Up @@ -253,7 +279,7 @@ def pre_unshard_(self, dtype: Optional[torch.dtype] = None, rank0_only: bool = F
self.params_sharded_data_lp.copy_(self.params_data.sharded_data)

# Initialize unsharded, padded gradient.
if set_grads and self.params_unsharded_grad is None:
if set_grads and self.requires_grad and self.params_unsharded_grad is None:
self.params_unsharded_grad = torch.zeros_like(all_params_unsharded_data)

def unshard_(
Expand Down Expand Up @@ -288,7 +314,7 @@ def unshard_(
if rank0_only or dist.get_backend() == dist.Backend.GLOO:
assert self.params_data.is_sharded
self.params_data.unshard_(dtype=dtype, rank0_only=rank0_only)
if set_grads:
if set_grads and self.requires_grad:
self.params_unsharded_grad = torch.zeros_like(self.params_data)
else:
assert not self.params_data.is_sharded
Expand Down Expand Up @@ -318,7 +344,7 @@ def unshard_(

param.unshard_(unsharded_data, dtype=dtype, rank0_only=rank0_only)

if set_grads:
if set_grads and self.requires_grad:
if param.grad is None and self.params_sharded_grad is not None:
self.params_sharded_grad = None
assert self.params_unsharded_grad is not None
Expand Down Expand Up @@ -360,6 +386,9 @@ def pre_reduce_scatter_grads_(
Stream.current(self.device).record_for(self.params_unsharded_grad)
self.params_unsharded_grad = self.params_unsharded_grad.to(dtype=grad_reduce_dtype)

if self.pre_reduce_grad_divide_factor > 1.0:
self.params_unsharded_grad.div_(self.pre_reduce_grad_divide_factor)

def reduce_scatter_grads_(
self, grad_dtype: Optional[torch.dtype] = None, grad_reduce_dtype: Optional[torch.dtype] = None
):
Expand All @@ -368,6 +397,7 @@ def reduce_scatter_grads_(
parameter as a view into the new sharded grad.
"""
if not self.requires_grad or self.params_unsharded_grad is None:
self._ran_pre_reduce_scatter_grads = False
return

if not self._ran_pre_reduce_scatter_grads:
Expand Down Expand Up @@ -398,12 +428,20 @@ def post_reduce_scatter_grads_(
"""
Finalize sharded gradients after the reduce-scatter.
"""
if not self.requires_grad or self.params_unsharded_grad is None:
return

grad_dtype = grad_dtype or self.params_data.dtype
grad_reduce_dtype = grad_reduce_dtype or grad_dtype

assert self.params_unsharded_grad is not None
new_sharded_grad = self.params_data.sharded_chunk(self.params_unsharded_grad)

if self.inter_group_process_group is not None:
dist.all_reduce(new_sharded_grad, group=self.inter_group_process_group)

if self.post_reduce_grad_divide_factor > 1.0:
new_sharded_grad.div_(self.post_reduce_grad_divide_factor)

# Cast the new sharded gradient to the right dtype, potentially accumulating it into
# the existing sharded gradient.
if self.params_sharded_grad is None:
Expand Down
Loading
Loading