-
Notifications
You must be signed in to change notification settings - Fork 941
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add web ui for core ml stable diffusion (#56)
- Loading branch information
Showing
3 changed files
with
153 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
try: | ||
import gradio as gr | ||
import python_coreml_stable_diffusion.pipeline as pipeline | ||
from diffusers import StableDiffusionPipeline | ||
|
||
def init(args): | ||
pipeline.logger.info("Initializing PyTorch pipe for reference configuration") | ||
pytorch_pipe = StableDiffusionPipeline.from_pretrained(args.model_version, | ||
use_auth_token=True) | ||
|
||
user_specified_scheduler = None | ||
if args.scheduler is not None: | ||
user_specified_scheduler = pipeline.SCHEDULER_MAP[ | ||
args.scheduler].from_config(pytorch_pipe.scheduler.config) | ||
|
||
coreml_pipe = pipeline.get_coreml_pipe(pytorch_pipe=pytorch_pipe, | ||
mlpackages_dir=args.i, | ||
model_version=args.model_version, | ||
compute_unit=args.compute_unit, | ||
scheduler_override=user_specified_scheduler) | ||
|
||
|
||
def infer(prompt, steps): | ||
pipeline.logger.info("Beginning image generation.") | ||
image = coreml_pipe( | ||
prompt=prompt, | ||
height=coreml_pipe.height, | ||
width=coreml_pipe.width, | ||
num_inference_steps=steps, | ||
) | ||
images = [] | ||
images.append(image["images"][0]) | ||
return images | ||
|
||
|
||
demo = gr.Blocks() | ||
|
||
with demo: | ||
gr.Markdown( | ||
"<center><h1>Core ML Stable Diffusion</h1>Run Stable Diffusion on Apple Silicon with Core ML</center>") | ||
with gr.Group(): | ||
with gr.Box(): | ||
with gr.Row(): | ||
with gr.Column(): | ||
with gr.Row(): | ||
text = gr.Textbox( | ||
label="Prompt", | ||
lines=11, | ||
placeholder="Enter your prompt", | ||
) | ||
with gr.Row(): | ||
btn = gr.Button("Generate image") | ||
with gr.Row(): | ||
steps = gr.Slider(label="Steps", minimum=1, | ||
maximum=50, value=10, step=1) | ||
with gr.Column(): | ||
gallery = gr.Gallery( | ||
label="Generated image", elem_id="gallery" | ||
) | ||
|
||
text.submit(infer, inputs=[text, steps], outputs=gallery) | ||
btn.click(infer, inputs=[text, steps], outputs=gallery) | ||
|
||
demo.launch(debug=True, server_name="0.0.0.0") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = pipeline.argparse.ArgumentParser() | ||
|
||
parser.add_argument( | ||
"-i", | ||
required=True, | ||
help=("Path to input directory with the .mlpackage files generated by " | ||
"python_coreml_stable_diffusion.torch2coreml")) | ||
parser.add_argument( | ||
"--model-version", | ||
default="CompVis/stable-diffusion-v1-4", | ||
help= | ||
("The pre-trained model checkpoint and configuration to restore. " | ||
"For available versions: https://huggingface.co/models?search=stable-diffusion" | ||
)) | ||
parser.add_argument( | ||
"--compute-unit", | ||
choices=pipeline.get_available_compute_units(), | ||
default="ALL", | ||
help=("The compute units to be used when executing Core ML models. " | ||
f"Options: {pipeline.get_available_compute_units()}")) | ||
parser.add_argument( | ||
"--scheduler", | ||
choices=tuple(pipeline.SCHEDULER_MAP.keys()), | ||
default=None, | ||
help=("The scheduler to use for running the reverse diffusion process. " | ||
"If not specified, the default scheduler from the diffusers pipeline is utilized")) | ||
|
||
args = parser.parse_args() | ||
init(args) | ||
|
||
except ModuleNotFoundError as moduleNotFound: | ||
print(f'Found that `gradio` is not installed, try to install it automatically') | ||
try: | ||
import subprocess | ||
import sys | ||
|
||
subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'gradio']) | ||
print(f'Successfully installed missing package `gradio`.') | ||
print(f'Now re-execute the command :D') | ||
except subprocess.CalledProcessError: | ||
print(f'Automatic package installation failed, try manually executing `pip install gradio`, then retry the command again.') |