Skip to content

Commit

Permalink
Test Strudel for logging
Browse files Browse the repository at this point in the history
  • Loading branch information
benrugg committed Dec 5, 2024
1 parent 211b93e commit 5492888
Show file tree
Hide file tree
Showing 2 changed files with 366 additions and 5 deletions.
7 changes: 2 additions & 5 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
imp.reload(automatic1111_api)
imp.reload(stability_api)
imp.reload(stablehorde_api)
imp.reload(test_api)
else:
from . import (
addon_updater_ops,
Expand All @@ -47,11 +48,7 @@
ui_panels,
ui_preset_styles,
)
from .sd_backends import (
automatic1111_api,
stability_api,
stablehorde_api,
)
from .sd_backends import automatic1111_api, stability_api, stablehorde_api, test_api

import bpy

Expand Down
364 changes: 364 additions & 0 deletions sd_backends/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,364 @@
import bpy
import base64
import requests
from .. import (
config,
operators,
utils,
)


# CORE FUNCTIONS:


def generate(params, img_file, filename_prefix, props):
# validate the params, specifically for the Stability API
if not validate_params(params, props):
return False

# map the generic params to the specific ones for the Stability API
mapped_params = map_params(params)

# create the headers
headers = create_headers()

# prepare the URL (specifically setting the engine id)
api_url = f"{config.STABILITY_API_V1_URL}{props.sd_model}/image-to-image"

# prepare the file input
files = {
"init_image": img_file,
}

# send the API request
try:
response = requests.post(
api_url,
headers=headers,
files=files,
data=mapped_params,
timeout=request_timeout(),
)
img_file.close()
except requests.exceptions.ReadTimeout:
img_file.close()
return operators.handle_error(
f"The server timed out. Try again in a moment, or get help. [Get help with timeouts]({config.HELP_WITH_TIMEOUTS_URL})",
"timeout",
)

# print log info for debugging
# debug_log(response)

# handle the response
if response.status_code == 200:
return handle_success(response, filename_prefix)
else:
return handle_error(response)


def upscale(img_file, filename_prefix, props):
# create the headers
headers = create_headers()

# prepare the URL
api_url = f"{config.STABILITY_API_V2_URL}upscale/fast"

# prepare the file input
files = {
"image": img_file,
}

# prepare the params
data = {"output_format": get_image_format().lower()}

# send the API request
try:
response = requests.post(
api_url, headers=headers, files=files, data=data, timeout=request_timeout()
)
img_file.close()
except requests.exceptions.ReadTimeout:
img_file.close()
return operators.handle_error(
f"The server timed out during upscaling. Try again in a moment, or turn off upscaling.",
"timeout",
)

# print log info for debugging
# debug_log(response)

# handle the response
if response.status_code == 200:
return handle_success(response, filename_prefix)
else:
return handle_error(response)


def handle_success(response, filename_prefix):
try:
data = response.json()
output_file = utils.create_temp_file(filename_prefix + "-")
except:
return operators.handle_error(
f"Couldn't create a temp file to save image", "temp_file"
)

try:
if "image" in data:
with open(output_file, "wb") as file:
file.write(base64.b64decode(data["image"]))
elif "artifacts" in data:
for i, image in enumerate(data["artifacts"]):
with open(output_file, "wb") as file:
file.write(base64.b64decode(image["base64"]))
else:
return operators.handle_error(
f"DreamStudio returned an unexpected response", "unexpected_response"
)

return output_file
except:
return operators.handle_error(
f"DreamStudio returned an unexpected response", "unexpected_response"
)


def handle_error(response):
import json

error_key = ""

try:
# convert the response to JSON (hopefully)
response_obj = response.json()

# get the message key from the response, if it exists
message = response_obj.get("message", str(response.content))

# handle the different types of errors
if response_obj.get("timeout", False):
error_message = f"The server timed out. Try again in a moment, or get help. [Get help with timeouts]({config.HELP_WITH_TIMEOUTS_URL})"
error_key = "timeout"
else:
error_message, error_key = parse_message_for_error(message)
except:
error_message = f"(Server Error) An unknown error occurred in the Stability API. Full server response: {str(response.content)}"
error_key = "unknown_error_response"

return operators.handle_error(error_message, error_key)


# PRIVATE SUPPORT FUNCTIONS:


def create_headers():
return {
"User-Agent": f"Blender/{bpy.app.version_string}",
"Accept": "application/json",
"Authorization": f"Bearer {utils.get_dream_studio_api_key()}",
}


def map_params(params):
# create a new dict so we don't overwrite the original
mapped_params = {}

# copy the params
mapped_params["seed"] = params["seed"]
mapped_params["cfg_scale"] = params["cfg_scale"]
mapped_params["steps"] = params["steps"]

# convert the params to the Stability API format
mapped_params["image_strength"] = round(params["image_similarity"], 2)
mapped_params["sampler"] = params["sampler"].upper()
mapped_params["text_prompts[0][text]"] = params["prompt"]
mapped_params["text_prompts[0][weight]"] = 1.0

if params["negative_prompt"]:
mapped_params["text_prompts[1][text]"] = params["negative_prompt"]
mapped_params["text_prompts[1][weight]"] = -1.0

return mapped_params


def validate_params(params, props):
# validate the dimensions (the sdxl 1024 model only supports a few specific image sizes)
if props.sd_model.startswith(
"stable-diffusion-xl-1024"
) and not utils.are_sdxl_1024_dimensions_valid(params["width"], params["height"]):
return operators.handle_error(
f"The SDXL model only supports these image sizes: {', '.join(utils.sdxl_1024_valid_dimensions)}. Please change your image size and try again.",
"invalid_dimensions",
)
elif params["steps"] < 10:
return operators.handle_error(
"Steps must be set to at least 10.", "steps_too_small"
)
else:
return True


def parse_message_for_error(message):
if '"Authorization" is missing' in message:
return "Your DreamStudio API key is missing. Please enter it above.", "api_key"
elif (
"Incorrect API key" in message
or "Unauthenticated" in message
or "Unable to find corresponding account" in message
):
return (
f"Your DreamStudio API key is incorrect. Please find it on the DreamStudio website, and re-enter it above. [DreamStudio website]({config.DREAM_STUDIO_URL})",
"api_key",
)
elif "not have enough balance" in message:
return (
f"You don't have enough DreamStudio credits. Please purchase credits on the DreamStudio website or switch to a different backend in the AI Render add-on preferences. [DreamStudio website]({config.DREAM_STUDIO_URL})",
"credits",
)
elif "invalid_prompts" in message:
return (
"Invalid prompt. Your prompt includes filtered words. Please change your prompt and try again.",
"prompt",
)
elif "image too large" in message:
return (
"Image size is too large. Please decrease width/height.",
"dimensions_too_large",
)
elif "invalid_height_or_width" in message:
return (
"Invalid width or height. They must be in the range 128-2048 in multiples of 64.",
"invalid_dimensions",
)
elif "body.sampler must be" in message:
return (
"Invalid sampler. Please choose a new Sampler under 'Advanced Options'.",
"sampler",
)
elif "body.cfg_scale must be" in message:
return (
"Invalid prompt strength. 'Prompt Strength' must be in the range 0-35.",
"prompt_strength",
)
elif "body.seed must be" in message:
return "Invalid seed value. Please choose a new 'Seed'.", "seed"
elif "body.steps must be" in message:
return "Invalid number of steps. 'Steps' must be in the range 10-150.", "steps"
return (
f"(Server Error) An error occurred in the Stability API. Full server response: {message}",
"unknown_error",
)


def debug_log(response):
print("request body:")
print(response.request.body)
print("\n")

print("response body:")
print(response.content)

try:
print(response.json())
except:
print("body not json")


# PUBLIC SUPPORT FUNCTIONS:


def get_samplers():
# NOTE: Keep the number values (fourth item in the tuples) in sync with the other
# backends, like Automatic1111. These act like an internal unique ID for Blender
# to use when switching between the lists.
return [
("k_euler", "Euler", "", 10),
("k_euler_ancestral", "Euler a", "", 20),
("k_heun", "Heun", "", 30),
("k_dpm_2", "DPM2", "", 40),
("k_dpm_2_ancestral", "DPM2 a", "", 50),
("k_lms", "LMS", "", 60),
("K_DPMPP_2S_ANCESTRAL", "DPM++ 2S a", "", 110),
("K_DPMPP_2M", "DPM++ 2M", "", 120),
("ddim", "DDIM", "", 210),
("ddpm", "DDPM", "", 220),
]


def default_sampler():
return "K_DPMPP_2M"


def get_upscaler_models(context):
return [
("fast", "fast", ""),
]


def is_upscaler_model_list_loaded(context=None):
return True


def default_upscaler_model():
return "fast"


def request_timeout():
return 55


def get_image_format():
return "PNG"


def supports_negative_prompts():
return True


def supports_choosing_model():
return True


def supports_upscaling():
return True


def supports_choosing_upscaler_model():
return False


def supports_reloading_upscaler_models():
return False


def supports_choosing_upscale_factor():
return False


def fixed_upscale_factor():
return 4.0


def supports_inpainting():
return False


def supports_outpainting():
return False


def min_image_size():
return 640 * 1536


def max_image_size():
return 1024 * 1024


def max_upscaled_image_size():
return 4096 * 4096


def is_using_sdxl_1024_model(props):
return props.sd_model.startswith("stable-diffusion-xl-1024")

0 comments on commit 5492888

Please sign in to comment.