Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add support for ip-adapter #56

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 73 additions & 3 deletions lib_omost/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import torch
import numpy as np
import copy

from tqdm.auto import trange
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import *
from diffusers.models.transformers import Transformer2DModel
from diffusers.models.embeddings import ImageProjection
from diffusers.utils import load_image


original_Transformer2DModel_forward = Transformer2DModel.forward
Expand Down Expand Up @@ -173,7 +176,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states, hidden_states_ori
return h


class StableDiffusionXLOmostPipeline(StableDiffusionXLImg2ImgPipeline):
class StableDiffusionXLOmostPipeline(StableDiffusionXLImg2ImgPipeline, IPAdapterMixin):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.k_model = KModel(unet=self.unet)
Expand All @@ -188,6 +191,58 @@ def __init__(self, *args, **kwargs):
self.unet.set_attn_processor(attn_procs)
return


def prepare_ip_adapter_image_embeds(
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
):
if ip_adapter_image_embeds is None:
if not isinstance(ip_adapter_image, list):
ip_adapter_image = [ip_adapter_image]

if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
raise ValueError(
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
)

image_embeds = []
for single_ip_adapter_image, image_proj_layer in zip(
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
):
output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
single_image_embeds, single_negative_image_embeds = self.encode_image(
single_ip_adapter_image, device, 1, output_hidden_state
)
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
single_negative_image_embeds = torch.stack(
[single_negative_image_embeds] * num_images_per_prompt, dim=0
)

if do_classifier_free_guidance:
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
single_image_embeds = single_image_embeds.to(device)

image_embeds.append(single_image_embeds)
else:
repeat_dims = [1]
image_embeds = []
for single_image_embeds in ip_adapter_image_embeds:
if do_classifier_free_guidance:
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
single_negative_image_embeds = single_negative_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
)
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
else:
single_image_embeds = single_image_embeds.repeat(
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
)
image_embeds.append(single_image_embeds)

return image_embeds

@torch.inference_mode()
def encode_bag_of_subprompts_greedy(self, prefixes: list[str], suffixes: list[str]):
device = self.text_encoder.device
Expand Down Expand Up @@ -370,12 +425,15 @@ def __call__(
num_inference_steps: int = 25,
guidance_scale: float = 5.0,
batch_size: Optional[int] = 1,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
cross_attention_kwargs: Optional[dict] = None,
ip_adapter_image: Optional[PipelineImageInput] = None,
ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
):

device = self.unet.device
Expand Down Expand Up @@ -412,18 +470,30 @@ def __call__(
pooled_prompt_embeds = pooled_prompt_embeds.repeat(batch_size, 1).to(noise)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1).to(noise)

# TODO: Replace with input argument
ip_adapter_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_adapter_diner.png")

if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
image_embeds = self.prepare_ip_adapter_image_embeds(
ip_adapter_image,
ip_adapter_image_embeds,
device,
batch_size * num_images_per_prompt,
False # TODO: Add self.do_classifier_free_guidance,
)

# Feeds

sampler_kwargs = dict(
cfg_scale=guidance_scale,
positive=dict(
encoder_hidden_states=prompt_embeds,
added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids, "image_embeds": image_embeds if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None},
cross_attention_kwargs=cross_attention_kwargs
),
negative=dict(
encoder_hidden_states=negative_prompt_embeds,
added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids},
added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids, "image_embeds": image_embeds if ip_adapter_image is not None or ip_adapter_image_embeds is not None else None},
cross_attention_kwargs=cross_attention_kwargs
)
)
Expand Down