Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Cascade? #95

Open
Teriks opened this issue Jun 27, 2024 · 5 comments
Open

Stable Cascade? #95

Teriks opened this issue Jun 27, 2024 · 5 comments

Comments

@Teriks
Copy link

Teriks commented Jun 27, 2024

It seems like it might be possible for this to work with stable cascade?

I am wondering if there is a working snippet for prior + decoder or if it is incompatible at the moment.

@Teriks
Copy link
Author

Teriks commented Jun 28, 2024

This generates a recognizable image, though given the quality of the image, there is definitely something missing from the equation somewhere. Though it seems somewhat possible.

import torch
from compel import Compel, ReturnedEmbeddingsType
from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline

device = 'cuda'
prompt = "an image of a (shiba inu)1.5 donning a spacesuit++"

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant='bf16',
                                                   torch_dtype=torch.bfloat16).to(device)

prior_compel = Compel(tokenizer=prior.tokenizer,
                      text_encoder=prior.text_encoder,
                      returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
                      requires_pooled=True, device=device)


conditioning, pooled = prior_compel(prompt)

prior_output = prior(
    num_inference_steps=20,
    guidance_scale=4,
    prompt_embeds=conditioning,
    prompt_embeds_pooled=pooled.unsqueeze(1))

prior.to('cpu')

decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant='bf16',
                                                       torch_dtype=torch.float16).to(device)

decoder_compel = Compel(tokenizer=decoder.tokenizer,
                        text_encoder=decoder.text_encoder,
                        returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
                        requires_pooled=True,
                        device=device)

conditioning, pooled = decoder_compel(prompt)

decoder(num_inference_steps=10,
        guidance_scale=0.0,
        prompt_embeds=conditioning,
        prompt_embeds_pooled=pooled.unsqueeze(1),
        image_embeddings=prior_output.image_embeddings.half()).images[0].save('test.png')

@duonglegiang
Copy link

Hi @Teriks, have you resolved the issue using the prior + decoder setup in the snippet?

@damian0815
Copy link
Owner

you might want to confirm if this returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, is correct for StableCascade

@Teriks
Copy link
Author

Teriks commented Sep 6, 2024

The embeddings provider probably needs some alternate logic to handle Stable Cascade.

I decided to sit down and mess with it a little, I think it needs something like this.

You would probably need to implement a new ReturnedEmbeddingsType flag.

Here is a monkey patch demo that produces a decent quality image.

I might have time for a PR next week, though it would be very simple to add.

@damian0815 @duonglegiang

from typing import *

import torch
import compel

from diffusers import StableCascadePriorPipeline, StableCascadeDecoderPipeline


class SCascadeEmbeddingsProvider(compel.EmbeddingsProvider):

    def _encode_token_ids_to_embeddings(self, token_ids: torch.Tensor,
                                        attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        text_encoder_output = self.text_encoder(token_ids,
                                                attention_mask,
                                                output_hidden_states=True,
                                                return_dict=True)

        return text_encoder_output.hidden_states[-1]

    def get_pooled_embeddings(self, texts: List[str], attention_mask: Optional[torch.Tensor] = None,
                              device: Optional[str] = None) -> Optional[torch.Tensor]:
        device = device or self.device

        token_ids = self.get_token_ids(texts, padding="max_length", truncation_override=True)

        token_ids = torch.tensor(token_ids, dtype=torch.long).to(device)

        text_encoder_output = self.text_encoder(token_ids, attention_mask, return_dict=True)

        pooled = text_encoder_output.text_embeds

        return pooled.unsqueeze(1)


# monkey patch in the correct behavior for this example

def patch_compel(compel_obj: compel.Compel):
    compel_obj.conditioning_provider.__class__ = SCascadeEmbeddingsProvider


# Do generation


device = 'cuda'
prompt = "an image of a shiba inu with (blue eyes)1.4, donning a green+ spacesuit"

prior = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", variant='bf16',
                                                   torch_dtype=torch.bfloat16).to(device)

prior_compel = compel.Compel(tokenizer=prior.tokenizer,
                             text_encoder=prior.text_encoder,
                             requires_pooled=True, device=device)

# patch prior
patch_compel(prior_compel)

conditioning, pooled = prior_compel(prompt)

prior_output = prior(
    num_inference_steps=20,
    guidance_scale=4,
    prompt_embeds=conditioning,
    prompt_embeds_pooled=pooled)

prior.to('cpu')

decoder = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", variant='bf16',
                                                       torch_dtype=torch.float16).to(device)

decoder_compel = compel.Compel(tokenizer=decoder.tokenizer,
                               text_encoder=decoder.text_encoder,
                               requires_pooled=True, device=device)

# patch decoder
patch_compel(decoder_compel)

conditioning, pooled = decoder_compel(prompt)

image = decoder(
    num_inference_steps=10,
    guidance_scale=0.0,
    prompt_embeds=conditioning,
    prompt_embeds_pooled=pooled,
    image_embeddings=prior_output.image_embeddings.half()).images[0]

image.save('test.png')

decoder.to('cpu')

Example Result:

test

Prompt: an image of a shiba inu with (blue eyes)1.4, donning a green+ spacesuit, (cartoon style)1.6

test

@Teriks
Copy link
Author

Teriks commented Sep 11, 2024

Stable Cascade support, new ReturnedEmbeddingsType #104

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants