Skip to content

Commit

Permalink
Merge pull request #248 from mrhan1993/dev
Browse files Browse the repository at this point in the history
Support image custom format
  • Loading branch information
mrhan1993 authored Mar 18, 2024
2 parents 0fbb004 + f2dd06a commit b2c2377
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 15 deletions.
1 change: 1 addition & 0 deletions fooocusapi/api_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def req_to_params(req: Text2ImgRequest) -> ImageGenerationParams:
inpaint_additional_prompt=inpaint_additional_prompt,
image_prompts=image_prompts,
advanced_params=advanced_params,
save_extension=req.save_extension,
require_base64=req.require_base64,
)

Expand Down
38 changes: 34 additions & 4 deletions fooocusapi/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import numpy as np
from PIL import Image
import uuid
import json
from pathlib import Path
from PIL.PngImagePlugin import PngInfo



output_dir = os.path.abspath(os.path.join(
os.path.dirname(__file__), '..', 'outputs', 'files'))
Expand All @@ -13,16 +18,41 @@
static_serve_base_url = 'http://127.0.0.1:8888/files/'


def save_output_file(img: np.ndarray) -> str:
def save_output_file(img: np.ndarray, image_meta: dict = None,
image_name: str = '', extension: str = 'png') -> str:
"""
Save np image to file
Args:
img: np.ndarray image to save
image_meta: dict of image metadata
image_name: str of image name
extension: str of image extension
Returns:
str of file name
"""
current_time = datetime.datetime.now()
date_string = current_time.strftime("%Y-%m-%d")

filename = os.path.join(date_string, str(uuid.uuid4()) + '.png')
image_name = str(uuid.uuid4()) if image_name == '' else image_name

filename = os.path.join(date_string, image_name + '.' + extension)
file_path = os.path.join(output_dir, filename)

if extension not in ['png', 'jpg', 'webp']:
extension = 'png'

if image_meta is None:
image_meta = {}

meta = None
if extension == 'png':
meta = PngInfo()
meta.add_text("params", json.dumps(image_meta))

os.makedirs(os.path.dirname(file_path), exist_ok=True)
Image.fromarray(img).save(file_path)
return filename
Image.fromarray(img).save(file_path, format=extension,
pnginfo=meta, optimize=True)
return Path(filename).as_posix()


def delete_output_file(filename: str):
Expand Down
11 changes: 8 additions & 3 deletions fooocusapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ class Text2ImgRequest(BaseModel):
refiner_switch: float = Field(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0)
loras: List[Lora] = Field(default=default_loras_model)
advanced_params: AdvancedParams | None = AdvancedParams()
save_extension: str = Field(default='png', description="Save extension, one of [png, jpg, webp]")
require_base64: bool = Field(default=False, description="Return base64 data of generated image")
async_process: bool = Field(default=False, description="Set to true will run async and return job info for retrieve generataion result later")
webhook_url: str | None = Field(default='', description="Optional URL for a webhook callback. If provided, the system will send a POST request to this URL upon task completion or failure."
Expand Down Expand Up @@ -214,6 +215,7 @@ def as_form(cls, input_image: UploadFile = Form(description="Init image for upsa
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
advanced_params: str | None = Form(default=None, description="Advanced parameters in JSON"),
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generataion result later"),
):
Expand All @@ -226,7 +228,8 @@ def as_form(cls, input_image: UploadFile = Form(description="Init image for upsa
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
loras=loras_model, advanced_params=advanced_params_obj, require_base64=require_base64, async_process=async_process)
loras=loras_model, advanced_params=advanced_params_obj, save_extension=save_extension,
require_base64=require_base64, async_process=async_process)


class ImgInpaintOrOutpaintRequest(Text2ImgRequest):
Expand Down Expand Up @@ -262,6 +265,7 @@ def as_form(cls, input_image: UploadFile = Form(description="Init image for inpa
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
advanced_params: str| None = Form(default=None, description="Advanced parameters in JSON"),
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generataion result later"),
):
Expand All @@ -281,7 +285,7 @@ def as_form(cls, input_image: UploadFile = Form(description="Init image for inpa
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
loras=loras_model, advanced_params=advanced_params_obj, require_base64=require_base64, async_process=async_process)
loras=loras_model, advanced_params=advanced_params_obj, save_extension=save_extension, require_base64=require_base64, async_process=async_process)


class ImgPromptRequest(ImgInpaintOrOutpaintRequest):
Expand Down Expand Up @@ -343,6 +347,7 @@ def as_form(cls, input_image: UploadFile = Form(File(None), description="Init im
refiner_switch: float = Form(default=default_refiner_switch, description="Refiner Switch At", ge=0.1, le=1.0),
loras: str | None = Form(default=default_loras_json, description='Lora config in JSON. Format as [{"model_name": "sd_xl_offset_example-lora_1.0.safetensors", "weight": 0.5}]'),
advanced_params: str| None = Form(default=None, description="Advanced parameters in JSON"),
save_extension: str = Form(default="png", description="Save extension, png, jpg or webp"),
require_base64: bool = Form(default=False, description="Return base64 data of generated image"),
async_process: bool = Form(default=False, description="Set to true will run async and return job info for retrieve generataion result later"),
):
Expand Down Expand Up @@ -376,7 +381,7 @@ def as_form(cls, input_image: UploadFile = Form(File(None), description="Init im
performance_selection=performance_selection, aspect_ratios_selection=aspect_ratios_selection,
image_number=image_number, image_seed=image_seed, sharpness=sharpness, guidance_scale=guidance_scale,
base_model_name=base_model_name, refiner_model_name=refiner_model_name, refiner_switch=refiner_switch,
loras=loras_model, advanced_params=advanced_params_obj, require_base64=require_base64, async_process=async_process)
loras=loras_model, advanced_params=advanced_params_obj, save_extension=save_extension, require_base64=require_base64, async_process=async_process)


class GeneratedImageResult(BaseModel):
Expand Down
2 changes: 2 additions & 0 deletions fooocusapi/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self, prompt: str,
inpaint_additional_prompt: str | None,
image_prompts: List[Tuple[np.ndarray, float, float, str]],
advanced_params: List[any] | None,
save_extension: str,
require_base64: bool):
self.prompt = prompt
self.negative_prompt = negative_prompt
Expand All @@ -126,6 +127,7 @@ def __init__(self, prompt: str,
self.inpaint_input_image = inpaint_input_image
self.inpaint_additional_prompt = inpaint_additional_prompt
self.image_prompts = image_prompts
self.save_extension = save_extension
self.require_base64 = require_base64

if advanced_params is None:
Expand Down
17 changes: 9 additions & 8 deletions fooocusapi/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,14 +100,14 @@ def progressbar(_, number, text):
print(f'[Fooocus] {text}')
outputs.append(['preview', (number, text, None)])

def yield_result(_, imgs, tasks):
def yield_result(_, imgs, tasks, extension='png'):
if not isinstance(imgs, list):
imgs = [imgs]

results = []
for i, im in enumerate(imgs):
seed = -1 if len(tasks) == 0 else tasks[i]['task_seed']
img_filename = save_output_file(im)
img_filename = save_output_file(img=im, extension=extension)
results.append(ImageGenerationResult(im=img_filename, seed=str(seed), finish_reason=GenerationFinishReason.success))
async_task.set_result(results, False)
worker_queue.finish_task(async_task.job_id)
Expand Down Expand Up @@ -150,6 +150,7 @@ def yield_result(_, imgs, tasks):
inpaint_input_image = params.inpaint_input_image
inpaint_additional_prompt = params.inpaint_additional_prompt
inpaint_mask_image_upload = None
save_extension = params.save_extension

if inpaint_additional_prompt is None:
inpaint_additional_prompt = ''
Expand Down Expand Up @@ -547,7 +548,7 @@ def yield_result(_, imgs, tasks):
if direct_return:
d = [('Upscale (Fast)', '2x')]
log(uov_input_image, d)
yield_result(async_task, uov_input_image, tasks)
yield_result(async_task, uov_input_image, tasks, save_extension)
return

tiled = True
Expand Down Expand Up @@ -693,7 +694,7 @@ def yield_result(_, imgs, tasks):
cn_img = HWC3(cn_img)
task[0] = core.numpy_to_pytorch(cn_img)
if advanced_parameters.debugging_cn_preprocessor:
yield_result(async_task, cn_img, tasks)
yield_result(async_task, cn_img, tasks, save_extension)
return
for task in cn_tasks[flags.cn_cpds]:
cn_img, cn_stop, cn_weight = task
Expand All @@ -705,7 +706,7 @@ def yield_result(_, imgs, tasks):
cn_img = HWC3(cn_img)
task[0] = core.numpy_to_pytorch(cn_img)
if advanced_parameters.debugging_cn_preprocessor:
yield_result(async_task, cn_img, tasks)
yield_result(async_task, cn_img, tasks, save_extension)
return
for task in cn_tasks[flags.cn_ip]:
cn_img, cn_stop, cn_weight = task
Expand All @@ -716,7 +717,7 @@ def yield_result(_, imgs, tasks):

task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_path)
if advanced_parameters.debugging_cn_preprocessor:
yield_result(async_task, cn_img, tasks)
yield_result(async_task, cn_img, tasks, save_extension)
return
for task in cn_tasks[flags.cn_ip_face]:
cn_img, cn_stop, cn_weight = task
Expand All @@ -730,7 +731,7 @@ def yield_result(_, imgs, tasks):

task[0] = ip_adapter.preprocess(cn_img, ip_adapter_path=ip_adapter_face_path)
if advanced_parameters.debugging_cn_preprocessor:
yield_result(async_task, cn_img, tasks)
yield_result(async_task, cn_img, tasks, save_extension)
return

all_ip_tasks = cn_tasks[flags.cn_ip] + cn_tasks[flags.cn_ip_face]
Expand Down Expand Up @@ -877,7 +878,7 @@ def callback(step, x0, x, total_steps, y):
if async_task.finish_with_error:
worker_queue.finish_task(async_task.job_id)
return async_task.task_result
yield_result(None, results, tasks)
yield_result(None, results, tasks, save_extension)
return
except Exception as e:
print('Worker error:', e)
Expand Down

0 comments on commit b2c2377

Please sign in to comment.