Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hybrid Offloading for ZeRO3 #5625

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deepspeed/runtime/zero/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .partition_parameters import ZeroParamType
from .partition_parameters import ZeroParamStatus
from .partition_parameters import Init
from .partition_parameters import GatheredParameters
from .partition_parameters import GatheredParameters, ZeRO3HybridOffload
from .partition_parameters import register_external_parameter

from .tiling import TiledLinear
Expand Down
15 changes: 14 additions & 1 deletion deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,8 @@ def _pre_forward_module_hook(module, *args):

@instrument_w_nvtx
def _post_forward_module_hook(module, input, output):
if hasattr(module, "disable_z3_fetch") and module.disable_z3_fetch:
return

global FWD_MODULE_STACK
FWD_MODULE_STACK.pop()
Expand Down Expand Up @@ -440,6 +442,9 @@ def backward(ctx, *args):

@torch.no_grad()
def pre_sub_module_forward_function(self, sub_module):
if hasattr(sub_module, "disable_z3_fetch") and sub_module.disable_z3_fetch:
return

see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False)

global FWD_MODULE_STACK
Expand All @@ -455,9 +460,11 @@ def pre_sub_module_forward_function(self, sub_module):

@torch.no_grad()
def post_sub_module_forward_function(self, sub_module):
if hasattr(sub_module, "disable_z3_fetch") and sub_module.disable_z3_fetch:
return

see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release",
force=False)

param_coordinator = self.get_param_coordinator(training=sub_module.training)
param_coordinator.release_sub_module(sub_module)

Expand All @@ -466,6 +473,9 @@ def post_sub_module_forward_function(self, sub_module):

@torch.no_grad()
def pre_sub_module_backward_function(self, sub_module):
if hasattr(sub_module, "disable_z3_fetch") and sub_module.disable_z3_fetch:
return

assert sub_module.training, "backward pass is invalid for module in evaluation mode"
param_coordinator = self.get_param_coordinator(training=True)
param_coordinator.trace_prologue(sub_module)
Expand All @@ -475,6 +485,9 @@ def pre_sub_module_backward_function(self, sub_module):

@torch.no_grad()
def post_sub_module_backward_function(self, sub_module):
if hasattr(sub_module, "disable_z3_fetch") and sub_module.disable_z3_fetch:
return

assert sub_module.training, "backward pass is invalid for module in evaluation mode"
see_memory_usage(
f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release",
Expand Down
90 changes: 90 additions & 0 deletions deepspeed/runtime/zero/partition_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2239,3 +2239,93 @@ def __exit__(self, *exc):
for h in handles:
h.wait()
self.params[0].partition(param_list=self.params, has_been_updated=True)


class ZeRO3HybridOffload:
"""NOTE: This feature works only for forward pass.
This feature allows users to gather ZeRO3-partitioned params and offload a part of them to host memory. The offloaded parameters are loaded to device memory in pre-forward hook and offloaded back to host memory in post-forward hook.

Args:
model (``torch.nn.Module``): The model whose parameters will be gathered and offloaded. The model must be initialized for ZeRO3.
param_threshold (int): The threshold for the number of parameters to offload to host memory.

Usage:
======
You can reduce all-gather's in loop.

.. code-block:: python

with deepspeed.zero.ZeRO3HybridOffload(model, param_threshold=1e9):
for x in dataset:
output = model(x)

Generation using auto-regressive models is one good example where this feature can be useful.

... code-block:: python

with deepspeed.zero.ZeRO3HybridOffload(model, param_threshold=1e9):
output = model.generate(input_ids)
"""

def __init__(self, model, param_threshold, enabled=True):
self.enabled = enabled
self.model = model
self.param_threshold = param_threshold
self.device = torch.device(get_accelerator().current_device())

def __enter__(self):
if not self.enabled:
return

n_params = 0
self.gathered_params = []
self.handles = []

for m in self.model.modules():
offloaded_params = []
m.disable_z3_fetch = True

for p in m.parameters(recurse=False):
p.all_gather(param_list=[p])
n_params += p.numel()
if n_params > self.param_threshold:
# Offload
p.data = p.data.cpu()
offloaded_params.append(p)
self.gathered_params.append(p)

if len(offloaded_params) > 0:

def wrapper_pre_hook(target_params):

def pre_forward_hook(module, input):
for p in target_params:
p.data = p.data.to(self.device)

return pre_forward_hook

self.handles.append(m.register_forward_pre_hook(wrapper_pre_hook(offloaded_params)))

def wrapper_post_hook(target_params):

def post_forward_hook(module, input, output):
for p in target_params:
p.data = p.data.cpu()

return post_forward_hook

self.handles.append(m.register_forward_hook(wrapper_post_hook(offloaded_params)))

def __exit__(self, *exc):
if not self.enabled:
return

for p in self.gathered_params:
p.data = p.data.to(self.device)
p.partition(param_list=[p], has_been_updated=False)

for h in self.handles:
h.remove()

for m in self.model.modules():
m.disable_z3_fetch = False
90 changes: 90 additions & 0 deletions tests/unit/runtime/zero/test_zero_hybrid_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team

import deepspeed.comm as dist
import torch

from unit.common import DistributedTest, preferred_dtype
from unit.simple_model import random_dataloader

import deepspeed
from deepspeed.accelerator import get_accelerator


class SimpleModel(torch.nn.Module):

def __init__(self, hidden_dim):
super(SimpleModel, self).__init__()
self.linears = torch.nn.ModuleList(
[torch.nn.Linear(hidden_dim, hidden_dim, bias=False),
torch.nn.Linear(hidden_dim, hidden_dim, bias=False)])
self.act = torch.nn.ReLU()
self.cel = torch.nn.CrossEntropyLoss()

def forward(self, x, y):
for m in self.linears:
x = self.act(m(x))
loss = self.cel(x, y)
return x, loss


def run_model(model, config_dict, hidden_dim, dtype):
model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict)
data_loader = random_dataloader(model=model,
total_samples=10,
hidden_dim=hidden_dim,
device=model.device,
dtype=dtype)
dist.barrier()

assert all(p.numel() == 0 for p in model.parameters())

with deepspeed.zero.ZeRO3HybridOffload(model, hidden_dim**2 + 100):
# Has params on device?
assert any(p.numel() > 0 for p in model.parameters()
if p.device == torch.device(get_accelerator().current_device())), "No params on device"
# Has params on cpu?
assert any(p.numel() > 0 for p in model.parameters() if p.device == torch.device('cpu')), "No params on cpu"

for batch in data_loader:
loss = model(batch[0], batch[1])
loss = loss[1]

# Needed in ZeRO 3. Not doing so can give memory leak
model.destroy()


class TestZeRO3HybridOffload(DistributedTest):
# Need multiple gpus to test possible hanging
world_size = 2
reuse_dist_env = True

def test(self):
hidden_dim = 128

config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 1e-6
}
},
"zero_optimization": {
"stage": 3,
"stage3_prefetch_bucket_size": hidden_dim**2,
"stage3_param_persistence_threshold": 0,
"stage3_max_reuse_distance": 0,
}
}
if get_accelerator().is_fp16_supported():
config_dict["fp16"] = {"enabled": True}
elif get_accelerator().is_bf16_supported():
config_dict["bf16"] = {"enabled": True}

model = SimpleModel(hidden_dim)
model.eval()
run_model(model, config_dict, hidden_dim, preferred_dtype())
Loading