diff --git a/modules/sd_disable_initialization.py b/modules/sd_disable_initialization.py index 273a7edd8b4..e521780e2d0 100644 --- a/modules/sd_disable_initialization.py +++ b/modules/sd_disable_initialization.py @@ -188,7 +188,7 @@ def load_from_state_dict(original, module, state_dict, prefix, *args, **kwargs): if param.is_meta: dtype = sd_param.dtype if sd_param is not None else param.dtype - module._parameters[name] = torch.nn.parameter.Parameter(torch.zeros_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) + module._parameters[name] = torch.nn.parameter.Parameter(torch.empty_like(param, device=device, dtype=dtype), requires_grad=param.requires_grad) for name in module._buffers: key = prefix + name