diff --git a/css/style.css b/css/style.css index 010c8e7f6..c832a8537 100644 --- a/css/style.css +++ b/css/style.css @@ -197,7 +197,8 @@ display: none; } -#stylePreviewOverlay { +#stylePreviewOverlay, +#modelPreviewOverlay { opacity: 0; pointer-events: none; width: 128px; @@ -215,6 +216,20 @@ transition: transform 0.1s ease, opacity 0.3s ease; } -#stylePreviewOverlay.lower-half { +#stylePreviewOverlay.lower-half, +#modelPreviewOverlay.lower-half { transform: translate(-140px, -140px); } + +#modelPreviewOverlay { + z-index: 10000000000 !important; + justify-content: center; + display: flex; + align-items: center; + background-color: rgba(0, 0, 0, 0.7) !important; + color: white; + padding: 5px; + max-width: 100%; + overflow-wrap: break-word; + word-break: break-all; +} \ No newline at end of file diff --git a/javascript/script.js b/javascript/script.js index 8f4cac58f..864cedda6 100644 --- a/javascript/script.js +++ b/javascript/script.js @@ -120,6 +120,7 @@ document.addEventListener("DOMContentLoaded", function() { }); mutationObserver.observe(gradioApp(), {childList: true, subtree: true}); initStylePreviewOverlay(); + initModelPreviewOverlay(); }); /** @@ -146,38 +147,141 @@ document.addEventListener('keydown', function(e) { } }); -function initStylePreviewOverlay() { - let overlayVisible = false; - const samplesPath = document.querySelector("meta[name='samples-path']").getAttribute("content") +// Utility functions +function formatImagePath(name, templateImagePath, replacedValue = "fooocus_v2") { + return templateImagePath.replace(replacedValue, name.toLowerCase().replaceAll(" ", "_")).replaceAll("\\", "\\\\"); +} + +function createOverlay(id) { const overlay = document.createElement('div'); - overlay.id = 'stylePreviewOverlay'; + overlay.id = id; document.body.appendChild(overlay); - document.addEventListener('mouseover', function(e) { - const label = e.target.closest('.style_selections label'); + return overlay; +} + +function setImageBackground(overlay, url) { + unsetOverlayAsTooltip(overlay) + overlay.style.backgroundImage = `url("${url}")`; +} + +function setOverlayAsTooltip(overlay, altText) { + // Set the text content and any dynamic styles + overlay.textContent = altText; + overlay.style.width = 'fit-content'; + overlay.style.height = 'fit-content'; + // Note: Other styles are already set via CSS +} + +function unsetOverlayAsTooltip(overlay) { + // Clear the text content and reset any dynamic styles + overlay.textContent = ''; + overlay.style.width = '128px'; + overlay.style.height = '128px'; + // Note: Other styles are managed via CSS +} + +function handleMouseMove(overlay) { + return function(e) { + if (overlay.style.opacity !== "1") return; + overlay.style.left = `${e.clientX}px`; + overlay.style.top = `${e.clientY}px`; + overlay.className = e.clientY > window.innerHeight / 2 ? "lower-half" : "upper-half"; + }; +} + +// Image path retrieval for models +const getModelImagePath = selectedItemText => { + selectedItemText = selectedItemText.replace("✓\n", "") + + let imagePath = null; + + if (previewsCheckpoint) + imagePath = previewsCheckpoint[selectedItemText] + + if (previewsLora && !imagePath) + imagePath = previewsLora[selectedItemText] + + return imagePath; +}; + +// Mouse over handlers for different overlays +function handleMouseOverModelPreviewOverlay(overlay, elementSelector, templateImagePath) { + return function(e) { + const targetElement = e.target.closest(elementSelector); + if (!targetElement) return; + + targetElement.removeEventListener("mouseout", onMouseLeave); + targetElement.addEventListener("mouseout", onMouseLeave); + + overlay.style.opacity = "1"; + const selectedItemText = targetElement.innerText; + if (selectedItemText) { + let imagePath = getModelImagePath(selectedItemText); + if (imagePath) { + imagePath = formatImagePath(imagePath, templateImagePath, "sdxl_styles/samples/fooocus_v2.jpg"); + setImageBackground(overlay, imagePath); + } else { + setOverlayAsTooltip(overlay, selectedItemText); + } + } + + function onMouseLeave() { + overlay.style.opacity = "0"; + overlay.style.backgroundImage = ""; + targetElement.removeEventListener("mouseout", onMouseLeave); + } + }; +} + +function handleMouseOverStylePreviewOverlay(overlay, elementSelector, templateImagePath) { + return function(e) { + const label = e.target.closest(elementSelector); if (!label) return; + label.removeEventListener("mouseout", onMouseLeave); label.addEventListener("mouseout", onMouseLeave); - overlayVisible = true; + overlay.style.opacity = "1"; + const originalText = label.querySelector("span").getAttribute("data-original-text"); - const name = originalText || label.querySelector("span").textContent; - overlay.style.backgroundImage = `url("${samplesPath.replace( - "fooocus_v2", - name.toLowerCase().replaceAll(" ", "_") - ).replaceAll("\\", "\\\\")}")`; + let name = originalText || label.querySelector("span").textContent; + let imagePath = formatImagePath(name, templateImagePath); + + overlay.style.backgroundImage = `url("${imagePath}")`; + function onMouseLeave() { - overlayVisible = false; overlay.style.opacity = "0"; overlay.style.backgroundImage = ""; label.removeEventListener("mouseout", onMouseLeave); } - }); - document.addEventListener('mousemove', function(e) { - if(!overlayVisible) return; - overlay.style.left = `${e.clientX}px`; - overlay.style.top = `${e.clientY}px`; - overlay.className = e.clientY > window.innerHeight / 2 ? "lower-half" : "upper-half"; - }); + }; +} + +// Initialization functions for different overlays +function initModelPreviewOverlay() { + const templateImagePath = document.querySelector("meta[name='samples-path']").getAttribute("content"); + const modelOverlay = createOverlay('modelPreviewOverlay'); + + document.addEventListener('mouseover', handleMouseOverModelPreviewOverlay( + modelOverlay, + '.model_selections .item', + templateImagePath + )); + + document.addEventListener('mousemove', handleMouseMove(modelOverlay)); +} + +function initStylePreviewOverlay() { + const templateImagePath = document.querySelector("meta[name='samples-path']").getAttribute("content"); + const styleOverlay = createOverlay('stylePreviewOverlay'); + + document.addEventListener('mouseover', handleMouseOverStylePreviewOverlay( + styleOverlay, + '.style_selections label', + templateImagePath + )); + + document.addEventListener('mousemove', handleMouseMove(styleOverlay)); } /** diff --git a/modules/async_worker.py b/modules/async_worker.py index b2af67126..15dbf0776 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -42,6 +42,7 @@ def worker(): from modules.util import remove_empty_str, HWC3, resize_image, \ get_image_shape_ceil, set_image_shape_ceil, get_shape_ceil, resample_image, erode_or_dilate from modules.upscaler import perform_upscale + from modules.model_previewer import add_preview_by_attempt try: async_gradio_app = shared.gradio_root @@ -799,7 +800,10 @@ def callback(step, x0, x, total_steps, y): if n != 'None': d.append((f'LoRA {li + 1}', f'{n} : {w}')) d.append(('Version', 'v' + fooocus_version.version)) - log(x, d) + image_location = log(x, d) + + if modules.config.use_add_model_previews: + add_preview_by_attempt(base_model_name, refiner_model_name, loras, image_location) yield_result(async_task, imgs, do_not_show_finished_images=len(tasks) == 1) except ldm_patched.modules.model_management.InterruptProcessingException as e: diff --git a/modules/config.py b/modules/config.py index 58107806c..45c591f92 100644 --- a/modules/config.py +++ b/modules/config.py @@ -8,7 +8,7 @@ from modules.model_loader import load_file_from_url from modules.util import get_files_from_folder - +from modules.model_previewer import cleanup as cleanup_model_previews config_path = os.path.abspath("./config.txt") config_example_path = os.path.abspath("config_modification_tutorial.txt") @@ -316,6 +316,16 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ default_value=-1, validator=lambda x: isinstance(x, int) ) +use_cleanup_model_previews = get_config_item_or_set_default( + key='use_cleanup_model_previews', + default_value=False, + validator=lambda x: x == False or x == True +) +use_add_model_previews = get_config_item_or_set_default( + key='use_add_model_previews', + default_value=True, + validator=lambda x: x == False or x == True +) example_inpaint_prompts = get_config_item_or_set_default( key='example_inpaint_prompts', default_value=[ @@ -342,6 +352,9 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ "default_prompt_negative", "default_styles", "default_aspect_ratio", + "default_aspect_ratio", + "use_cleanup_model_previews" + "use_add_model_previews", "checkpoint_downloads", "embeddings_downloads", "lora_downloads", @@ -514,5 +527,7 @@ def downloading_upscale_model(): ) return os.path.join(path_upscale_models, 'fooocus_upscaler_s409985e5.bin') - update_all_model_names() + +if use_cleanup_model_previews: + cleanup_model_previews() diff --git a/modules/model_previewer.py b/modules/model_previewer.py new file mode 100644 index 000000000..a11b29e88 --- /dev/null +++ b/modules/model_previewer.py @@ -0,0 +1,148 @@ +import os +import json + +# Constants +CHECKPOINTS_DIR = 'models/checkpoints' +LORAS_DIR = 'models/loras' +OUTPUT_FOLDER = 'outputs' +PREVIEW_LOG_FILE = 'preview_log.json' + +def read_json_file(file_path): + """ Reads a JSON file and returns its contents, or creates a new file if it doesn't exist. """ + try: + if not file_exists(file_path): + with open(file_path, 'w') as file: + json.dump({}, file) + return {} + + with open(file_path, 'r') as file: + return json.load(file) + except IOError as e: + print(f"Error reading file {file_path}: {e}") + return {} + +def update_json_file(file_path, data): + """ Writes updated data to a JSON file. """ + try: + with open(file_path, 'w') as file: + json.dump(data, file, indent=4) + except IOError as e: + print(f"Error writing to file {file_path}: {e}") + +def file_exists(file_path): + """ Checks if a file exists at the given path. """ + return os.path.exists(file_path) + +def verify_and_cleanup_data(json_data, base_folder): + """ Verifies the existence of files and cleans up JSON data. """ + cleaned_data = {} + for safetensor, images in json_data.items(): + safetensor_path = os.path.join(base_folder, safetensor) + if file_exists(safetensor_path): + existing_images = [img for img in images if file_exists(os.path.join(OUTPUT_FOLDER, img))] + if existing_images: + cleaned_data[safetensor] = existing_images + return cleaned_data + +def get_cleaned_data(json_path, base_folder): + data = read_json_file(json_path) + cleaned_data = verify_and_cleanup_data(data, base_folder) + return cleaned_data + +def process_directory(directory): + """ Process a single directory (checkpoints or loras). """ + json_path = os.path.join(directory, PREVIEW_LOG_FILE) + cleaned_data = get_cleaned_data(json_path, directory) + try: + with open(json_path, 'w') as f: + json.dump(cleaned_data, f, indent=4) + except IOError as e: + print(f"Error writing to file {json_path}: {e}") + +def cleanup(): + """ Cleans up the JSON files in both checkpoints and loras directories. """ + process_directory(CHECKPOINTS_DIR) + process_directory(LORAS_DIR) + +def add_preview(model_name, image_location, directory): + """ Adds a new image location to the preview list of a given model file. """ + print(f"Adding new preview '{image_location}' for '{directory}/{model_name}'") + json_path = os.path.join(directory, PREVIEW_LOG_FILE) + data = read_json_file(json_path) + + if model_name not in data: + data[model_name] = [] + if image_location not in data[model_name]: + data[model_name].append(image_location) + update_json_file(json_path, data) + +def add_preview_for_checkpoint(model_name, image_location): + """ Adds a new image location for the given model file in checkpoints. """ + add_preview(model_name, image_location, CHECKPOINTS_DIR) + +def add_preview_image_for_lora(model_name, image_location): + """ Adds a new image location for the given model file in loras. """ + add_preview(model_name, image_location, LORAS_DIR) + +def add_preview_by_attempt(base_model_name, refiner_model_name, loras, image_location): + print(f"Attempting to add new preview for base model '{base_model_name}', refiner model '{refiner_model_name}' or for lora model '{loras}' to image location '{image_location}'") + + # Add preview based on the only one lora name + active_loras = [lora for lora in loras if lora[0] != 'None'] + if len(active_loras) == 1: + active_lora_name = active_loras[0][0] + add_preview_image_for_lora(active_lora_name, image_location) + + # Add preview based on only one model name if possible + if len(active_loras) == 0: + if refiner_model_name == "None": + add_preview_for_checkpoint(base_model_name, image_location) + elif "_SD_" in refiner_model_name: + add_preview_for_checkpoint(refiner_model_name, image_location) + +def get_preview(model_name, directory): + json_path = os.path.join(directory, PREVIEW_LOG_FILE) + cleaned_data = get_cleaned_data(json_path, directory) + return get_preview_from_data(model_name, cleaned_data) + +def get_preview_from_data(model_name, data): + """ Retrieves the latest available image for the given model file. """ + images = data.get(model_name, []) + if images: + latest_image = sorted(images, reverse=True)[0] + latest_image_path = OUTPUT_FOLDER + "/" + latest_image + if file_exists(latest_image_path): + return latest_image_path + print(f"Verbose Debug: File exists for model '{model_name}' at path '{latest_image_path}'.") + else: + print(f"Verbose Debug: File does not exist for model '{model_name}' at path '{latest_image_path}'.") + else: + print(f"Verbose Debug: No images found for model '{model_name}' in data.") + return None + +def get_all_previews(directory): + """ Retrieves the latest available image for all. """ + json_path = os.path.join(directory, PREVIEW_LOG_FILE) + print(f"Verbose Debug: Get previews from '{json_path}'.") + data = read_json_file(json_path) + + valid_previews = {} + + # Find all files in the specified directory (only first level) + for filename in os.listdir(directory): + image_path = get_preview_from_data(filename, data) + if image_path is not None: + print(f"Verbose Debug: Valid preview found for '{filename}'.") + valid_previews[filename] = image_path + else: + print(f"Verbose Debug: No valid preview found for '{filename}'.") + + return valid_previews + +def get_all_previews_for_checkpoints(): + """ Retrieves the available images for a list of all model names in checkpoints. """ + return get_all_previews(CHECKPOINTS_DIR) + +def get_all_previews_for_loras(): + """ Retrieves the available images for a list of all model names in loras. """ + return get_all_previews(LORAS_DIR) \ No newline at end of file diff --git a/modules/private_logger.py b/modules/private_logger.py index 968bd4f5d..8182d4ed6 100644 --- a/modules/private_logger.py +++ b/modules/private_logger.py @@ -105,4 +105,5 @@ def log(img, dic): log_cache[html_name] = middle_part - return + image_location = date_string + "/" + only_name + return image_location diff --git a/modules/ui_gradio_extensions.py b/modules/ui_gradio_extensions.py index bebf9f8ca..615ced201 100644 --- a/modules/ui_gradio_extensions.py +++ b/modules/ui_gradio_extensions.py @@ -3,6 +3,8 @@ import os import gradio as gr import args_manager +import json +from modules.model_previewer import get_all_previews_for_checkpoints, get_all_previews_for_loras from modules.localization import localization_js @@ -40,12 +42,29 @@ def javascript_html(): head += f'\n' head += f'\n' head += f'\n' + + js_code = get_js_code_from_updated_previews() + head += f"\n" if args_manager.args.theme: head += f'\n' return head +def get_js_code_from_updated_previews(): + # Fetch the updated previews data + updated_previews_checkpoint = get_all_previews_for_checkpoints() + updated_previews_lora = get_all_previews_for_loras() + + # Convert to JSON strings + updated_previews_checkpoint_json = json.dumps(updated_previews_checkpoint) + updated_previews_lora_json = json.dumps(updated_previews_lora) + + # Inject updated data into JavaScript + return f""" + previewsCheckpoint = {updated_previews_checkpoint_json}; + previewsLora = {updated_previews_lora_json}; + """ def css_html(): style_css_path = webpath('css/style.css') diff --git a/webui.py b/webui.py index fadd852af..835273694 100644 --- a/webui.py +++ b/webui.py @@ -294,8 +294,8 @@ def refresh_seed(r, seed_string): with gr.Tab(label='Model'): with gr.Group(): with gr.Row(): - base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True) - refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True) + base_model = gr.Dropdown(label='Base Model (SDXL only)', choices=modules.config.model_filenames, value=modules.config.default_base_model_name, show_label=True, elem_classes=['model_selections']) + refiner_model = gr.Dropdown(label='Refiner (SDXL or SD 1.5)', choices=['None'] + modules.config.model_filenames, value=modules.config.default_refiner_model_name, show_label=True, elem_classes=['model_selections']) refiner_switch = gr.Slider(label='Refiner Switch At', minimum=0.1, maximum=1.0, step=0.0001, info='Use 0.4 for SD1.5 realistic models; ' @@ -314,7 +314,7 @@ def refresh_seed(r, seed_string): for i, (n, v) in enumerate(modules.config.default_loras): with gr.Row(): lora_model = gr.Dropdown(label=f'LoRA {i + 1}', - choices=['None'] + modules.config.lora_filenames, value=n) + choices=['None'] + modules.config.lora_filenames, value=n, elem_classes=['model_selections']) lora_weight = gr.Slider(label='Weight', minimum=-2, maximum=2, step=0.01, value=v, elem_classes='lora_weight') lora_ctrls += [lora_model, lora_weight]