Skip to content

Commit

Permalink
Support SDXL #26
Browse files Browse the repository at this point in the history
  • Loading branch information
ljzycmd committed Aug 20, 2023
1 parent 02f5344 commit 2a7861d
Show file tree
Hide file tree
Showing 3 changed files with 158 additions and 11 deletions.
14 changes: 12 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Pytorch implementation of [MasaCtrl: Tuning-free Mutual Self-Attention Control f
[![Project page](https://img.shields.io/badge/Project-Page-brightgreen)](https://ljzycmd.github.io/projects/MasaCtrl/)
[![demo](https://img.shields.io/badge/Demo-Hugging%20Face-brightgreen)](https://huggingface.co/spaces/TencentARC/MasaCtrl)
[![demo](https://img.shields.io/badge/Demo-Colab-brightgreen)](https://colab.research.google.com/drive/1DZeQn2WvRBsNg4feS1bJrwWnIzw1zLJq?usp=sharing)
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/MingDengCao/MasaCtrl)

---

Expand All @@ -24,6 +25,7 @@ Pytorch implementation of [MasaCtrl: Tuning-free Mutual Self-Attention Control f

## Updates

- [2023/8/20] MasaCtrl supports SDXL (and other variants) now. ![sdxl_example](https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/sdxl_example.jpg)
- [2023/5/13] The inference code of MasaCtrl with T2I-Adapter is available.
- [2023/4/28] [Hugging Face demo](https://huggingface.co/spaces/TencentARC/MasaCtrl) released.
- [2023/4/25] Code released.
Expand All @@ -33,7 +35,7 @@ Pytorch implementation of [MasaCtrl: Tuning-free Mutual Self-Attention Control f

## Introduction

We propose MasaCtrl, a tuning-free method for non-rigid consistent image synthesis and editing. The key idea is to combine the `contents` from the *source image* and the `layout` synthesized from *text prompt and additional controls* into the desired synthesized or edited image, with **Mutual Self-Attention Control**.
We propose MasaCtrl, a tuning-free method for non-rigid consistent image synthesis and editing. The key idea is to combine the `contents` from the *source image* and the `layout` synthesized from *text prompt and additional controls* into the desired synthesized or edited image, by querying semantically correlated features with **Mutual Self-Attention Control**.


## Main Features
Expand All @@ -44,13 +46,16 @@ MasaCtrl can perform prompt-based image synthesis and editing that changes the l

>*The target layout is synthesized directly from the target prompt.*
<details><summary>View visual results</summary>
<div align="center">
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_synthetic.png">
<i>Consistent synthesis results</i>

<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_real.png">
<i>Real image editing results</i>
</div>
</details>



### 2 Integration to Controllable Diffusion Models
Expand All @@ -59,30 +64,35 @@ Directly modifying the text prompts often cannot generate target layout of desir

>*The target layout controlled by additional guidance.*
<details><summary>View visual results</summary>
<div align="center">
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_w_adapter.png">
<i>Synthesis (left part) and editing (right part) results with T2I-Adapter</i>
</div>

</details>

### 3 Generalization to Other Models: Anything-V4

Our method also generalize well to other Stable-Diffusion-based models.

<details><summary>View visual results</summary>
<div align="center">
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/anythingv4_synthetic.png">
<i>Results on Anything-V4</i>
</div>
</details>


### 4 Extension to Video Synthesis

With dense consistent guidance, MasaCtrl enables video synthesis

<details><summary>View visual results</summary>
<div align="center">
<img src="https://huggingface.co/TencentARC/MasaCtrl/resolve/main/assets/results_w_adapter_consistent.png">
<i>Video Synthesis Results (with keypose and canny guidance)</i>
</div>
</details>


## Usage
Expand Down
72 changes: 63 additions & 9 deletions masactrl/masactrl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@


class MutualSelfAttentionControl(AttentionBase):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50):
MODEL_TYPE = {
"SD": 16,
"SDXL": 70
}

def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"):
"""
Mutual self-attention control for Stable-Diffusion model
Args:
Expand All @@ -21,17 +26,22 @@ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None,
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
model_type: the model type, SD or SDXL
"""
super().__init__()
self.total_steps = total_steps
self.total_layers = self.MODEL_TYPE.get(model_type, 16)
self.start_step = start_step
self.start_layer = start_layer
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, 16))
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
print("step_idx: ", self.step_idx)
print("layer_idx: ", self.layer_idx)
print("MasaCtrl at denoising steps: ", self.step_idx)
print("MasaCtrl at U-Net layers: ", self.layer_idx)

def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Performing attention for a batch of queries, keys, and values
"""
b = q.shape[0] // num_heads
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
Expand Down Expand Up @@ -62,8 +72,47 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar
return out


class MutualSelfAttentionControlUnion(MutualSelfAttentionControl):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, model_type="SD"):
"""
Mutual self-attention control for Stable-Diffusion model with unition source and target [K, V]
Args:
start_step: the step to start mutual self-attention control
start_layer: the layer to start mutual self-attention control
layer_idx: list of the layers to apply mutual self-attention control
step_idx: list the steps to apply mutual self-attention control
total_steps: the total number of steps
model_type: the model type, SD or SDXL
"""
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)

def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Attention forward function
"""
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)

qu_s, qu_t, qc_s, qc_t = q.chunk(4)
ku_s, ku_t, kc_s, kc_t = k.chunk(4)
vu_s, vu_t, vc_s, vc_t = v.chunk(4)
attnu_s, attnu_t, attnc_s, attnc_t = attn.chunk(4)

# source image branch
out_u_s = super().forward(qu_s, ku_s, vu_s, sim, attnu_s, is_cross, place_in_unet, num_heads, **kwargs)
out_c_s = super().forward(qc_s, kc_s, vc_s, sim, attnc_s, is_cross, place_in_unet, num_heads, **kwargs)

# target image branch, concatenating source and target [K, V]
out_u_t = self.attn_batch(qu_t, torch.cat([ku_s, ku_t]), torch.cat([vu_s, vu_t]), sim[:num_heads], attnu_t, is_cross, place_in_unet, num_heads, **kwargs)
out_c_t = self.attn_batch(qc_t, torch.cat([kc_s, kc_t]), torch.cat([vc_s, vc_t]), sim[:num_heads], attnc_t, is_cross, place_in_unet, num_heads, **kwargs)

out = torch.cat([out_u_s, out_u_t, out_c_s, out_c_t], dim=0)

return out


class MutualSelfAttentionControlMask(MutualSelfAttentionControl):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, mask_s=None, mask_t=None, mask_save_dir=None):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, mask_s=None, mask_t=None, mask_save_dir=None, model_type="SD"):
"""
Maske-guided MasaCtrl to alleviate the problem of fore- and background confusion
Args:
Expand All @@ -74,8 +123,10 @@ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None,
total_steps: the total number of steps
mask_s: source mask with shape (h, w)
mask_t: target mask with same shape as source mask
mask_save_dir: the path to save the mask image
model_type: the model type, SD or SDXL
"""
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps)
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
self.mask_s = mask_s # source mask with shape (h, w)
self.mask_t = mask_t # target mask with same shape as source mask
print("Using mask-guided MasaCtrl")
Expand Down Expand Up @@ -143,7 +194,7 @@ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwar


class MutualSelfAttentionControlMaskAuto(MutualSelfAttentionControl):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, thres=0.1, ref_token_idx=[1], cur_token_idx=[1], mask_save_dir=None):
def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None, total_steps=50, thres=0.1, ref_token_idx=[1], cur_token_idx=[1], mask_save_dir=None, model_type="SD"):
"""
MasaCtrl with mask auto generation from cross-attention map
Args:
Expand All @@ -157,8 +208,8 @@ def __init__(self, start_step=4, start_layer=10, layer_idx=None, step_idx=None,
cur_token_idx: the token index list for cross-attention map aggregation
mask_save_dir: the path to save the mask image
"""
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps)
print("using MutualSelfAttentionControlMaskAuto")
super().__init__(start_step, start_layer, layer_idx, step_idx, total_steps, model_type)
print("Using MutualSelfAttentionControlMaskAuto")
self.thres = thres
self.ref_token_idx = ref_token_idx
self.cur_token_idx = cur_token_idx
Expand All @@ -178,6 +229,9 @@ def after_step(self):
self.cross_attns = []

def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
"""
Performing attention for a batch of queries, keys, and values
"""
B = q.shape[0] // num_heads
H = W = int(np.sqrt(q.shape[1]))
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
Expand Down
83 changes: 83 additions & 0 deletions run_synthesis_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import os
import torch
import torch.nn as nn
import torch.nn.functional as F

import numpy as np

from tqdm import tqdm
from einops import rearrange, repeat
from omegaconf import OmegaConf

from diffusers import DDIMScheduler, DiffusionPipeline

from masactrl.diffuser_utils import MasaCtrlPipeline
from masactrl.masactrl_utils import AttentionBase
from masactrl.masactrl_utils import regiter_attention_editor_diffusers
from masactrl.masactrl import MutualSelfAttentionControl

from torchvision.utils import save_image
from torchvision.io import read_image
from pytorch_lightning import seed_everything

torch.cuda.set_device(0) # set the GPU device

# Note that you may add your Hugging Face token to get access to the models
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

model_path = "stabilityai/stable-diffusion-xl-base-1.0"
# model_path = "Linaqruf/animagine-xl"
scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False)
model = DiffusionPipeline.from_pretrained(model_path, scheduler=scheduler).to(device)


def consistent_synthesis():
seed = 42
seed_everything(seed)

out_dir_ori = "./workdir/masactrl_exp/oldman_smiling"
os.makedirs(out_dir_ori, exist_ok=True)

prompts = [
"A portrait of an old man, facing camera, best quality",
"A portrait of an old man, facing camera, smiling, best quality",
]

# inference the synthesized image with MasaCtrl
# TODO: note that the hyper paramerter of MasaCtrl for SDXL may be not optimal
STEP = 4
LAYER_LIST = [44, 54, 64] # run the synthesis with MasaCtrl at three different layer configs

# initialize the noise map
start_code = torch.randn([1, 4, 128, 128], device=device)
# start_code = None
start_code = start_code.expand(len(prompts), -1, -1, -1)

# inference the synthesized image without MasaCtrl
editor = AttentionBase()
regiter_attention_editor_diffusers(model, editor)
image_ori = model(prompts, latents=start_code, guidance_scale=7.5).images

for LAYER in LAYER_LIST:
# hijack the attention module
editor = MutualSelfAttentionControl(STEP, LAYER, model_type="SDXL")
regiter_attention_editor_diffusers(model, editor)

# inference the synthesized image
image_masactrl = model(prompts, latents=start_code, guidance_scale=7.5).images

sample_count = len(os.listdir(out_dir_ori))
out_dir = os.path.join(out_dir_ori, f"sample_{sample_count}")
os.makedirs(out_dir, exist_ok=True)
image_ori[0].save(os.path.join(out_dir, f"source_step{STEP}_layer{LAYER}.png"))
image_ori[1].save(os.path.join(out_dir, f"without_step{STEP}_layer{LAYER}.png"))
image_masactrl[-1].save(os.path.join(out_dir, f"masactrl_step{STEP}_layer{LAYER}.png"))
with open(os.path.join(out_dir, f"prompts.txt"), "w") as f:
for p in prompts:
f.write(p + "\n")
f.write(f"seed: {seed}\n")
print("Syntheiszed images are saved in", out_dir)


if __name__ == "__main__":
consistent_synthesis()

0 comments on commit 2a7861d

Please sign in to comment.