diff --git a/scripts/fabric.py b/scripts/fabric.py index 069b69e..7032ba9 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -461,7 +461,6 @@ def process(self, p, *args): print("[FABRIC] Skipping U-Net forward pass patching") def postprocess(self, p, processed, *args): - print("[FABRIC] Restoring original U-Net forward pass") unpatch_unet_forward_pass(p.sd_model.model.diffusion_model) images = processed.images[processed.index_of_first_image:] diff --git a/scripts/patching.py b/scripts/patching.py index ab3ad94..595594d 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -118,7 +118,7 @@ def new_forward(self, x, timesteps=None, context=None, **kwargs): # save original forward pass for module in self.modules(): - if isinstance(module, BasicTransformerBlock): + if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"): module.attn1._fabric_old_forward = module.attn1.forward # fix for medvram option @@ -200,7 +200,7 @@ def patched_attn1_forward(attn1, idx, x, context=None, **kwargs): # restore original pass for module in self.modules(): - if isinstance(module, BasicTransformerBlock): + if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"): module.attn1.forward = module.attn1._fabric_old_forward del module.attn1._fabric_old_forward @@ -212,5 +212,6 @@ def patched_attn1_forward(attn1, idx, x, context=None, **kwargs): def unpatch_unet_forward_pass(unet): if hasattr(unet, "_fabric_old_forward"): + print("[FABRIC] Restoring original U-Net forward pass") unet.forward = unet._fabric_old_forward del unet._fabric_old_forward