Skip to content

Commit

Permalink
Boost SDXL speed with initialized schedule step reset (huggingface#1284)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Socek <[email protected]>
  • Loading branch information
dsocek authored Sep 11, 2024
1 parent b2c29b1 commit 0027e32
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -651,6 +651,8 @@ def __call__(
t1 = t0

self._num_timesteps = len(timesteps)
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index()

hb_profiler = HabanaProfile(
warmup=profiling_warmup_steps,
Expand Down Expand Up @@ -688,8 +690,6 @@ def __call__(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)

self._num_timesteps = len(timesteps)

# 8.3 Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
use_warmup_inference_steps = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,8 @@ def denoising_value_valid(dnv):
).to(device=device, dtype=latents.dtype)

self._num_timesteps = len(timesteps)
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index()

# 8.3 Denoising loop
throughput_warmup_steps = kwargs.get("throughput_warmup_steps", 3)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,8 @@ def denoising_value_valid(dnv):
).to(device=device, dtype=latents.dtype)

self._num_timesteps = len(timesteps)
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index()

outputs = {
"images": [],
Expand Down

0 comments on commit 0027e32

Please sign in to comment.