Skip to content

Commit

Permalink
Merge branch 'master' into torch_load_weight_only_warning
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Oct 25, 2024
2 parents 69fc219 + 3d5cf73 commit 675dc44
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def _apply_forward_and_backward_to_tensors_only(module, forward_function, backwa

class ZeROOrderedDict(OrderedDict):

def __init__(self, parent_module=None, *args, **kwargs):
def __init__(self, parent_module, *args, **kwargs):
"""A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
Args:
Expand All @@ -49,13 +49,18 @@ def __init__(self, parent_module=None, *args, **kwargs):
self._parent_module = parent_module
self._in_forward = False

def __reduce__(self):
r0, _, *r2 = super().__reduce__()
return (r0, (self._parent_module, )) + r2

def __getitem__(self, key):
param = super().__getitem__(key)

# Params can be registered as None (e.g., bias)
if param is None:
return param

# TODO: only weaken this check during compilation
if hasattr(param, "ds_status") and param.ds_status == ZeroParamStatus.NOT_AVAILABLE:
if self._parent_module._parameters._in_forward:
register_external_parameter(FWD_MODULE_STACK[-1], param)
Expand Down

0 comments on commit 675dc44

Please sign in to comment.