From 2f207ff37b9d2186b4bb5e6a0b157afa55a05032 Mon Sep 17 00:00:00 2001 From: h3rmit Date: Sat, 26 Aug 2023 10:21:38 +0300 Subject: [PATCH 1/2] Fixed an infinite loop after getting an error in the image generation --- scripts/patching.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/patching.py b/scripts/patching.py index ab3ad94..b4abcec 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -118,7 +118,7 @@ def new_forward(self, x, timesteps=None, context=None, **kwargs): # save original forward pass for module in self.modules(): - if isinstance(module, BasicTransformerBlock): + if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"): module.attn1._fabric_old_forward = module.attn1.forward # fix for medvram option @@ -200,7 +200,7 @@ def patched_attn1_forward(attn1, idx, x, context=None, **kwargs): # restore original pass for module in self.modules(): - if isinstance(module, BasicTransformerBlock): + if isinstance(module, BasicTransformerBlock) and hasattr(module.attn1, "_fabric_old_forward"): module.attn1.forward = module.attn1._fabric_old_forward del module.attn1._fabric_old_forward From eafc798271b9bf32903bc1c5d9f1fb1086b9b836 Mon Sep 17 00:00:00 2001 From: h3rmit Date: Sat, 26 Aug 2023 10:47:48 +0300 Subject: [PATCH 2/2] Limited the U-Net restoration message to only display when it happens --- scripts/fabric.py | 1 - scripts/patching.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/fabric.py b/scripts/fabric.py index 069b69e..7032ba9 100644 --- a/scripts/fabric.py +++ b/scripts/fabric.py @@ -461,7 +461,6 @@ def process(self, p, *args): print("[FABRIC] Skipping U-Net forward pass patching") def postprocess(self, p, processed, *args): - print("[FABRIC] Restoring original U-Net forward pass") unpatch_unet_forward_pass(p.sd_model.model.diffusion_model) images = processed.images[processed.index_of_first_image:] diff --git a/scripts/patching.py b/scripts/patching.py index b4abcec..595594d 100644 --- a/scripts/patching.py +++ b/scripts/patching.py @@ -212,5 +212,6 @@ def patched_attn1_forward(attn1, idx, x, context=None, **kwargs): def unpatch_unet_forward_pass(unet): if hasattr(unet, "_fabric_old_forward"): + print("[FABRIC] Restoring original U-Net forward pass") unet.forward = unet._fabric_old_forward del unet._fabric_old_forward