Skip to content

Commit

Permalink
Add basic support for Text-to-Image diffusion #5
Browse files Browse the repository at this point in the history
  • Loading branch information
pramitchoudhary committed Nov 1, 2022
1 parent 0d2acdd commit 1d43df0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 20 deletions.
42 changes: 26 additions & 16 deletions img_styler/image_prompt/stable_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,40 @@
import gc
from io import BytesIO
from typing import Optional

import torch
from torch import autocast
from diffusers import StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
from PIL import Image
from io import BytesIO
from diffusers import StableDiffusionImg2ImgPipeline
from torch import autocast


def generate_image_with_prompt(input_img_path: str, prompt_txt: str = "Face portrait",
output_path: str=None, save: bool=True):
torch.cuda.empty_cache()
device = "cuda"
def generate_image_with_prompt(input_img_path: Optional[str]=None, prompt_txt: str = "Face portrait",
output_path: str=None):
# License: https://huggingface.co/spaces/CompVis/stable-diffusion-license
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model_path = "./models/stable_diffusion_v1_4"
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_path, revision="fp16",
torch_dtype=torch.float16,)
pipe = pipe.to(device)

# Open image
image_input = Image.open(input_img_path).convert("RGB")
init_image = image_input.resize((512, 512))
if input_img_path:
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_path, revision="fp16",
torch_dtype=torch.float16).to(device)
# Open image
image_input = Image.open(input_img_path).convert("RGB")
init_image = image_input.resize((512, 512))

with autocast(device):
images = pipe(prompt=prompt_txt, init_image=init_image, strength=0.5, guidance_scale=7.5)["sample"]
else: # Default prompt
pipe = StableDiffusionPipeline.from_pretrained(model_path, revision="fp16",
torch_dtype=torch.float16).to(device)

with autocast(device):
images = pipe(prompt=prompt_txt).images

with autocast(device):
images = pipe(prompt=prompt_txt, init_image=init_image, strength=0.5, guidance_scale=7.5)["sample"]

file_name = output_path + '/result.jpg'
images[0].save(file_name)
if output_path:
images[0].save(file_name)
gc.collect()
torch.cuda.empty_cache()
return file_name
3 changes: 3 additions & 0 deletions img_styler/ui/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ async def update_processed_face(q: Q, save=False):
)
if q.client.task_choice == 'D':
q.page['prompt_form'] = ui.form_card(ui.box('main', order=1, height='200px', width='900px'), items=[
ui.checkbox(name='prompt_use_source_img', label='Use source image',
value=True, tooltip='Image-to-Image text-guided diffusion is applied by default.\
If un-checked, default Text-to-Image diffusion is used.'),
ui.textbox(name='prompt_textbox', label='Prompt', multiline=True, value=q.client.prompt_textbox),
ui.button(name='prompt_apply', label='Apply')])
if save:
Expand Down
11 changes: 7 additions & 4 deletions img_styler/ui/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ async def process(q: Q):
text=f'Image by the name "{q.args.img_name}" already exists!',
type='error',
position='bottom-left',
))
))
else:
os.rename(out_path, new_img_path)
temp_img = Image.open(new_img_path)
temp_img.save(os.path.join(INPUT_PATH, 'portrait.jpg'))

q.app.source_faces = get_files_in_dir(dir_path=INPUT_PATH)
q.client.processedimg = new_img_path
q.page['meta'] = ui.meta_card(box='', notification_bar=ui.notification_bar(
Expand All @@ -114,7 +114,7 @@ async def process(q: Q):
))
await q.page.save()
del q.page['meta']

await update_controls(q)
await update_faces(q)
await update_processed_face(q)
Expand Down Expand Up @@ -265,8 +265,11 @@ async def image_upload(q: Q):
async def prompt_apply(q: Q):
logger.info(f"Enable prompt.")
logger.info(f"Prompt value: {q.args.prompt_textbox}")
res_path = generate_image_with_prompt(input_img_path=q.client.source_face, prompt_txt=q.args.prompt_textbox,
if q.args.prompt_use_source_img:
res_path = generate_image_with_prompt(input_img_path=q.client.source_face, prompt_txt=q.args.prompt_textbox,
output_path=OUTPUT_PATH)
else: # Don't initialize with source image
res_path = generate_image_with_prompt(prompt_txt=q.args.prompt_textbox, output_path=OUTPUT_PATH)

q.client.prompt_textbox = q.args.prompt_textbox
q.client.processedimg = res_path
Expand Down

0 comments on commit 1d43df0

Please sign in to comment.