Skip to content

Commit

Permalink
Add a skip layer guidance node that can also skip single layers.
Browse files Browse the repository at this point in the history
This one should work for skipping the single layers of models like Flux
and Auraflow.

If you want to see how these models work and how many double/single layers
they have see the "ModelMerge*" nodes for the specific model.
  • Loading branch information
comfyanonymous committed Nov 18, 2024
1 parent d9f9096 commit 9a0a5d3
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 41 deletions.
51 changes: 10 additions & 41 deletions comfy_extras/nodes_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import comfy.model_management
import nodes
import torch
import re
import comfy_extras.nodes_slg


class TripleCLIPLoader:
@classmethod
def INPUT_TYPES(s):
Expand All @@ -23,6 +25,7 @@ def load_clip(self, clip_name1, clip_name2, clip_name3):
clip = comfy.sd.load_clip(ckpt_paths=[clip_path1, clip_path2, clip_path3], embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (clip,)


class EmptySD3LatentImage:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
Expand All @@ -41,6 +44,7 @@ def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=self.device)
return ({"samples":latent}, )


class CLIPTextEncodeSD3:
@classmethod
def INPUT_TYPES(s):
Expand Down Expand Up @@ -97,7 +101,8 @@ def INPUT_TYPES(s):
CATEGORY = "conditioning/controlnet"
DEPRECATED = True

class SkipLayerGuidanceSD3:

class SkipLayerGuidanceSD3(comfy_extras.nodes_slg.SkipLayerGuidanceDiT):
'''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Expand All @@ -112,48 +117,12 @@ def INPUT_TYPES(s):
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
FUNCTION = "skip_guidance_sd3"

CATEGORY = "advanced/guidance"


def skip_guidance(self, model, layers, scale, start_percent, end_percent):
if layers == "" or layers == None:
return (model, )
# check if layer is comma separated integers
def skip(args, extra_args):
return args

model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)

layers = re.findall(r'\d+', layers)
layers = [int(i) for i in layers]

def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
x = args["input"]
model_options = args["model_options"].copy()

for layer in layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)
model_sampling.percent_to_sigma(start_percent)

sigma_ = sigma[0].item()
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale
return cfg_result

m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function)

return (m, )
def skip_guidance_sd3(self, model, layers, scale, start_percent, end_percent):
return self.skip_guidance(model=model, scale=scale, start_percent=start_percent, end_percent=end_percent, double_layers=layers)


NODE_CLASS_MAPPINGS = {
Expand Down
78 changes: 78 additions & 0 deletions comfy_extras/nodes_slg.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import comfy.model_patcher
import comfy.samplers
import re


class SkipLayerGuidanceDiT:
'''
Enhance guidance towards detailed dtructure by having another set of CFG negative with skipped layers.
Inspired by Perturbed Attention Guidance (https://arxiv.org/abs/2403.17377)
Original experimental implementation for SD3 by Dango233@StabilityAI.
'''
@classmethod
def INPUT_TYPES(s):
return {"required": {"model": ("MODEL", ),
"double_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"single_layers": ("STRING", {"default": "7, 8, 9", "multiline": False}),
"scale": ("FLOAT", {"default": 3.0, "min": 0.0, "max": 10.0, "step": 0.1}),
"start_percent": ("FLOAT", {"default": 0.01, "min": 0.0, "max": 1.0, "step": 0.001}),
"end_percent": ("FLOAT", {"default": 0.15, "min": 0.0, "max": 1.0, "step": 0.001})
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "skip_guidance"
EXPERIMENTAL = True

DESCRIPTION = "Generic version of SkipLayerGuidance node that can be used on every DiT model."

CATEGORY = "advanced/guidance"

def skip_guidance(self, model, scale, start_percent, end_percent, double_layers="", single_layers=""):
# check if layer is comma separated integers
def skip(args, extra_args):
return args

model_sampling = model.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)

double_layers = re.findall(r'\d+', double_layers)
double_layers = [int(i) for i in double_layers]

single_layers = re.findall(r'\d+', single_layers)
single_layers = [int(i) for i in single_layers]

if len(double_layers) == 0 and len(single_layers) == 0:
return (model, )

def post_cfg_function(args):
model = args["model"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
sigma = args["sigma"]
x = args["input"]
model_options = args["model_options"].copy()

for layer in double_layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "double_block", layer)

for layer in single_layers:
model_options = comfy.model_patcher.set_model_options_patch_replace(model_options, skip, "dit", "single_block", layer)

model_sampling.percent_to_sigma(start_percent)

sigma_ = sigma[0].item()
if scale > 0 and sigma_ >= sigma_end and sigma_ <= sigma_start:
(slg,) = comfy.samplers.calc_cond_batch(model, [cond], x, sigma, model_options)
cfg_result = cfg_result + (cond_pred - slg) * scale
return cfg_result

m = model.clone()
m.set_model_sampler_post_cfg_function(post_cfg_function)

return (m, )


NODE_CLASS_MAPPINGS = {
"SkipLayerGuidanceDiT": SkipLayerGuidanceDiT,
}
1 change: 1 addition & 0 deletions nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2133,6 +2133,7 @@ def init_builtin_extra_nodes():
"nodes_lora_extract.py",
"nodes_torch_compile.py",
"nodes_mochi.py",
"nodes_slg.py",
]

import_failed = []
Expand Down

0 comments on commit 9a0a5d3

Please sign in to comment.