diff --git a/deepspeed/runtime/zero/__init__.py b/deepspeed/runtime/zero/__init__.py index 1ccca09a9e69..c5c9118dc346 100644 --- a/deepspeed/runtime/zero/__init__.py +++ b/deepspeed/runtime/zero/__init__.py @@ -6,7 +6,7 @@ from .partition_parameters import ZeroParamType from .partition_parameters import ZeroParamStatus from .partition_parameters import Init -from .partition_parameters import GatheredParameters +from .partition_parameters import GatheredParameters, ZeRO3HybridOffload from .partition_parameters import register_external_parameter from .tiling import TiledLinear diff --git a/deepspeed/runtime/zero/parameter_offload.py b/deepspeed/runtime/zero/parameter_offload.py index e9e79c2647fb..f4fec333cc3d 100644 --- a/deepspeed/runtime/zero/parameter_offload.py +++ b/deepspeed/runtime/zero/parameter_offload.py @@ -279,6 +279,8 @@ def _pre_forward_module_hook(module, *args): @instrument_w_nvtx def _post_forward_module_hook(module, input, output): + if hasattr(module, "disable_z3_fetch") and module.disable_z3_fetch: + return global FWD_MODULE_STACK FWD_MODULE_STACK.pop() @@ -440,6 +442,9 @@ def backward(ctx, *args): @torch.no_grad() def pre_sub_module_forward_function(self, sub_module): + if hasattr(sub_module, "disable_z3_fetch") and sub_module.disable_z3_fetch: + return + see_memory_usage(f"Before sub module function {sub_module.__class__.__name__}", force=False) global FWD_MODULE_STACK @@ -455,9 +460,11 @@ def pre_sub_module_forward_function(self, sub_module): @torch.no_grad() def post_sub_module_forward_function(self, sub_module): + if hasattr(sub_module, "disable_z3_fetch") and sub_module.disable_z3_fetch: + return + see_memory_usage(f"After sub module function {sub_module.__class__.__name__} {sub_module.id} before release", force=False) - param_coordinator = self.get_param_coordinator(training=sub_module.training) param_coordinator.release_sub_module(sub_module) @@ -466,6 +473,9 @@ def post_sub_module_forward_function(self, sub_module): @torch.no_grad() def pre_sub_module_backward_function(self, sub_module): + if hasattr(sub_module, "disable_z3_fetch") and sub_module.disable_z3_fetch: + return + assert sub_module.training, "backward pass is invalid for module in evaluation mode" param_coordinator = self.get_param_coordinator(training=True) param_coordinator.trace_prologue(sub_module) @@ -475,6 +485,9 @@ def pre_sub_module_backward_function(self, sub_module): @torch.no_grad() def post_sub_module_backward_function(self, sub_module): + if hasattr(sub_module, "disable_z3_fetch") and sub_module.disable_z3_fetch: + return + assert sub_module.training, "backward pass is invalid for module in evaluation mode" see_memory_usage( f"After sub module backward function {sub_module.__class__.__name__} {sub_module.id} before release", diff --git a/deepspeed/runtime/zero/partition_parameters.py b/deepspeed/runtime/zero/partition_parameters.py index 09e72a695db3..6362ed8922a5 100755 --- a/deepspeed/runtime/zero/partition_parameters.py +++ b/deepspeed/runtime/zero/partition_parameters.py @@ -2239,3 +2239,93 @@ def __exit__(self, *exc): for h in handles: h.wait() self.params[0].partition(param_list=self.params, has_been_updated=True) + + +class ZeRO3HybridOffload: + """NOTE: This feature works only for forward pass. + This feature allows users to gather ZeRO3-partitioned params and offload a part of them to host memory. The offloaded parameters are loaded to device memory in pre-forward hook and offloaded back to host memory in post-forward hook. + + Args: + model (``torch.nn.Module``): The model whose parameters will be gathered and offloaded. The model must be initialized for ZeRO3. + param_threshold (int): The threshold for the number of parameters to offload to host memory. + + Usage: + ====== + You can reduce all-gather's in loop. + + .. code-block:: python + + with deepspeed.zero.ZeRO3HybridOffload(model, param_threshold=1e9): + for x in dataset: + output = model(x) + + Generation using auto-regressive models is one good example where this feature can be useful. + + ... code-block:: python + + with deepspeed.zero.ZeRO3HybridOffload(model, param_threshold=1e9): + output = model.generate(input_ids) + """ + + def __init__(self, model, param_threshold, enabled=True): + self.enabled = enabled + self.model = model + self.param_threshold = param_threshold + self.device = torch.device(get_accelerator().current_device()) + + def __enter__(self): + if not self.enabled: + return + + n_params = 0 + self.gathered_params = [] + self.handles = [] + + for m in self.model.modules(): + offloaded_params = [] + m.disable_z3_fetch = True + + for p in m.parameters(recurse=False): + p.all_gather(param_list=[p]) + n_params += p.numel() + if n_params > self.param_threshold: + # Offload + p.data = p.data.cpu() + offloaded_params.append(p) + self.gathered_params.append(p) + + if len(offloaded_params) > 0: + + def wrapper_pre_hook(target_params): + + def pre_forward_hook(module, input): + for p in target_params: + p.data = p.data.to(self.device) + + return pre_forward_hook + + self.handles.append(m.register_forward_pre_hook(wrapper_pre_hook(offloaded_params))) + + def wrapper_post_hook(target_params): + + def post_forward_hook(module, input, output): + for p in target_params: + p.data = p.data.cpu() + + return post_forward_hook + + self.handles.append(m.register_forward_hook(wrapper_post_hook(offloaded_params))) + + def __exit__(self, *exc): + if not self.enabled: + return + + for p in self.gathered_params: + p.data = p.data.to(self.device) + p.partition(param_list=[p], has_been_updated=False) + + for h in self.handles: + h.remove() + + for m in self.model.modules(): + m.disable_z3_fetch = False diff --git a/tests/unit/runtime/zero/test_zero_hybrid_offload.py b/tests/unit/runtime/zero/test_zero_hybrid_offload.py new file mode 100644 index 000000000000..d8b7f81ac636 --- /dev/null +++ b/tests/unit/runtime/zero/test_zero_hybrid_offload.py @@ -0,0 +1,90 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import deepspeed.comm as dist +import torch + +from unit.common import DistributedTest, preferred_dtype +from unit.simple_model import random_dataloader + +import deepspeed +from deepspeed.accelerator import get_accelerator + + +class SimpleModel(torch.nn.Module): + + def __init__(self, hidden_dim): + super(SimpleModel, self).__init__() + self.linears = torch.nn.ModuleList( + [torch.nn.Linear(hidden_dim, hidden_dim, bias=False), + torch.nn.Linear(hidden_dim, hidden_dim, bias=False)]) + self.act = torch.nn.ReLU() + self.cel = torch.nn.CrossEntropyLoss() + + def forward(self, x, y): + for m in self.linears: + x = self.act(m(x)) + loss = self.cel(x, y) + return x, loss + + +def run_model(model, config_dict, hidden_dim, dtype): + model, _, _, _ = deepspeed.initialize(model=model, model_parameters=model.parameters(), config=config_dict) + data_loader = random_dataloader(model=model, + total_samples=10, + hidden_dim=hidden_dim, + device=model.device, + dtype=dtype) + dist.barrier() + + assert all(p.numel() == 0 for p in model.parameters()) + + with deepspeed.zero.ZeRO3HybridOffload(model, hidden_dim**2 + 100): + # Has params on device? + assert any(p.numel() > 0 for p in model.parameters() + if p.device == torch.device(get_accelerator().current_device())), "No params on device" + # Has params on cpu? + assert any(p.numel() > 0 for p in model.parameters() if p.device == torch.device('cpu')), "No params on cpu" + + for batch in data_loader: + loss = model(batch[0], batch[1]) + loss = loss[1] + + # Needed in ZeRO 3. Not doing so can give memory leak + model.destroy() + + +class TestZeRO3HybridOffload(DistributedTest): + # Need multiple gpus to test possible hanging + world_size = 2 + reuse_dist_env = True + + def test(self): + hidden_dim = 128 + + config_dict = { + "train_micro_batch_size_per_gpu": 1, + "steps_per_print": 1, + "optimizer": { + "type": "Adam", + "params": { + "lr": 1e-6 + } + }, + "zero_optimization": { + "stage": 3, + "stage3_prefetch_bucket_size": hidden_dim**2, + "stage3_param_persistence_threshold": 0, + "stage3_max_reuse_distance": 0, + } + } + if get_accelerator().is_fp16_supported(): + config_dict["fp16"] = {"enabled": True} + elif get_accelerator().is_bf16_supported(): + config_dict["bf16"] = {"enabled": True} + + model = SimpleModel(hidden_dim) + model.eval() + run_model(model, config_dict, hidden_dim, preferred_dtype())