Skip to content

Commit

Permalink
v0.5.3 (various bugfixes)
Browse files Browse the repository at this point in the history
  • Loading branch information
dvruette committed Aug 26, 2023
1 parent 676d3d1 commit 2b55811
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 36 deletions.
22 changes: 8 additions & 14 deletions scripts/fabric.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
import dataclasses
import functools
import hashlib
import json
import traceback
from pathlib import Path
from dataclasses import dataclass, asdict
from typing import Optional

import gradio as gr
from PIL import Image
Expand All @@ -15,7 +15,7 @@
from modules.ui_components import FormGroup, FormRow
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img

from scripts.helpers import WebUiComponents
from scripts.helpers import WebUiComponents, image_hash
from scripts.patching import patch_unet_forward_pass, unpatch_unet_forward_pass

# Compatibility with WebUI v1.3.0 and earlier versions
Expand All @@ -27,7 +27,7 @@
from modules.ui import create_refresh_button


__version__ = "0.5.2"
__version__ = "0.5.3"

DEBUG = os.getenv("DEBUG", "false").lower() in ("true", "1")

Expand Down Expand Up @@ -62,15 +62,6 @@ def use_feedback(params):
return True


def image_hash(img, length=16):
hash_sha256 = hashlib.sha256()
hash_sha256.update(img.tobytes())
img_hash = hash_sha256.hexdigest()
if length and length > 0:
img_hash = img_hash[:length]
return img_hash


def save_feedback_image(img, filename=None, base_path=OUTPUT_PATH):
if filename is None:
filename = image_hash(img) + ".png"
Expand Down Expand Up @@ -138,8 +129,11 @@ class FabricParams:
neg_scale: float = 0.5
pos_images: list = dataclasses.field(default_factory=list)
neg_images: list = dataclasses.field(default_factory=list)
pos_latents: list = None
neg_latents: list = None
pos_latents: Optional[list] = None
neg_latents: Optional[list] = None
pos_latent_cache: Optional[dict] = None
neg_latent_cache: Optional[dict] = None

feedback_during_high_res_fix: bool = False


Expand Down
13 changes: 13 additions & 0 deletions scripts/helpers.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
import hashlib

from PIL import Image


def image_hash(img: Image.Image, length: int = 16):
hash_sha256 = hashlib.sha256()
hash_sha256.update(img.tobytes())
img_hash = hash_sha256.hexdigest()
if length and length > 0:
img_hash = img_hash[:length]
return img_hash


class WebUiComponents:
txt2img_gallery = None
Expand Down
7 changes: 6 additions & 1 deletion scripts/marking.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch

from modules.prompt_parser import MulticondLearnedConditioning, ComposableScheduledPromptConditioning, ScheduledPromptConditioning
from modules.processing import StableDiffusionProcessing


"""
Expand Down Expand Up @@ -76,10 +77,14 @@ def unmark_prompt_context(x):
mark_batch = mark[:, None, None, None].to(x.dtype).to(x.device)
uc_indices = mark.detach().cpu().numpy().tolist()
uc_indices = [i for i, item in enumerate(uc_indices) if item < 0.5]

StableDiffusionProcessing.cached_c = [None, None]
StableDiffusionProcessing.cached_uc = [None, None]

return mark_batch, uc_indices, context


def patch_process_sample(process):
def apply_marking_patch(process):
if getattr(process, 'sample_before_CN_hack', None) is None:
process.sample_before_CN_hack = process.sample
process.sample = process_sample.__get__(process)
47 changes: 26 additions & 21 deletions scripts/patching.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from ldm.modules.attention import BasicTransformerBlock

from scripts.marking import patch_process_sample, unmark_prompt_context
from scripts.marking import apply_marking_patch, unmark_prompt_context
from scripts.helpers import image_hash
from scripts.weighted_attention import weighted_attention


Expand All @@ -33,18 +34,21 @@ def get_latents_from_params(p, params, width, height):
def get_latents(images, cached_latents=None):
# check if latents need to be computed or recomputed (if image size changed e.g. due to high-res fix)
if cached_latents is None:
return [encode_to_latent(p, img, w, h) for img in images]
else:
ls = []
for latent, img in zip(cached_latents, images):
if latent.shape[-2:] != (w_latent, h_latent):
print(f"[FABRIC] Recomputing latent for image of size {img.size}")
latent = encode_to_latent(p, img, w, h)
ls.append(latent)
return ls
cached_latents = {}

latents = []
for img in images:
img_hash = image_hash(img)
if img_hash not in cached_latents:
cached_latents[img_hash] = encode_to_latent(p, img, w, h)
elif cached_latents[img_hash].shape[-2:] != (w_latent, h_latent):
print(f"[FABRIC] Recomputing latent for image of size {img.size}")
cached_latents[img_hash] = encode_to_latent(p, img, w, h)
latents.append(cached_latents[img_hash])
return latents, cached_latents

params.pos_latents = get_latents(params.pos_images, params.pos_latents)
params.neg_latents = get_latents(params.neg_images, params.neg_latents)
params.pos_latents, params.pos_latent_cache = get_latents(params.pos_images, params.pos_latent_cache)
params.neg_latents, params.neg_latent_cache = get_latents(params.neg_images, params.neg_latent_cache)
return params.pos_latents, params.neg_latents


Expand Down Expand Up @@ -106,6 +110,15 @@ def new_forward(self, x, timesteps=None, context=None, **kwargs):
pos_latents = pos_latents if has_cond else []
neg_latents = neg_latents if has_uncond else []
all_latents = pos_latents + neg_latents

# Note: calls to the VAE with `--medvram` will move the U-Net to CPU, so we need to move it back to GPU
if shared.cmd_opts.medvram:
try:
# Trigger register_forward_pre_hook to move the model to correct device
p.sd_model.model()
except:
pass

if len(all_latents) == 0:
return self._fabric_old_forward(x, timesteps, context, **kwargs)

Expand All @@ -121,14 +134,6 @@ def new_forward(self, x, timesteps=None, context=None, **kwargs):
if isinstance(module, BasicTransformerBlock) and not hasattr(module.attn1, "_fabric_old_forward"):
module.attn1._fabric_old_forward = module.attn1.forward

# fix for medvram option
if shared.cmd_opts.medvram:
try:
# Trigger register_forward_pre_hook to move the model to correct device
p.sd_model.model()
except:
pass

## cache hidden states

cached_hiddens = {}
Expand Down Expand Up @@ -208,7 +213,7 @@ def patched_attn1_forward(attn1, idx, x, context=None, **kwargs):

unet.forward = new_forward.__get__(unet)

patch_process_sample(p)
apply_marking_patch(p)

def unpatch_unet_forward_pass(unet):
if hasattr(unet, "_fabric_old_forward"):
Expand Down

0 comments on commit 2b55811

Please sign in to comment.