From 6c3f28f7d2df25a66be068444a6797768203dbc7 Mon Sep 17 00:00:00 2001 From: oOraph <13552058+oOraph@users.noreply.github.com> Date: Fri, 25 Oct 2024 15:16:47 +0200 Subject: [PATCH] Update docker_images/diffusers/app/pipelines/text_to_image.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: apolinário --- docker_images/diffusers/app/pipelines/text_to_image.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/docker_images/diffusers/app/pipelines/text_to_image.py b/docker_images/diffusers/app/pipelines/text_to_image.py index 42e16452..7be61883 100644 --- a/docker_images/diffusers/app/pipelines/text_to_image.py +++ b/docker_images/diffusers/app/pipelines/text_to_image.py @@ -169,9 +169,11 @@ def _process_req(self, inputs, **kwargs): kwargs["num_inference_steps"] = 20 # Else, don't specify anything, leave the default behaviour - if int(kwargs.get("num_inference_steps", 20)) <= 4 and 'guidance_scale' not in kwargs.keys(): - kwargs["guidance_scale"] = 0 - + if "guidance_scale" not in kwargs: + default_guidance_scale = os.getenv("DEFAULT_GUIDANCE_SCALE") + if default_num_steps: + kwargs["guidance_scale"] = int(default_guidance_scale) + # Else, don't specify anything, leave the default behaviour if "seed" in kwargs: seed = int(kwargs["seed"]) generator = torch.Generator().manual_seed(seed)