Skip to content

Commit

Permalink
Make k_diffusion optional
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Nov 17, 2023
1 parent 08d11e9 commit 26efdf6
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
13 changes: 11 additions & 2 deletions TTS/tts/layers/tortoise/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
import numpy as np
import torch
import torch as th
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral
from tqdm import tqdm

from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper

K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}

try:
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral

K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m}
except ImportError:
K_DIFFUSION_SAMPLERS = None


SAMPLERS = ["dpm++2m", "p", "ddim"]


Expand Down Expand Up @@ -531,6 +538,8 @@ def sample_loop(self, *args, **kwargs):
if self.conditioning_free is not True:
raise RuntimeError("cond_free must be true")
with tqdm(total=self.num_timesteps) as pbar:
if K_DIFFUSION_SAMPLERS is None:
raise ModuleNotFoundError("Install k_diffusion for using k_diffusion samplers")
return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs)
else:
raise RuntimeError("sampler not impl")
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ bangla
bnnumerizer
bnunicodenormalizer
#deps for tortoise
k_diffusion
einops>=0.6.0
transformers>=4.33.0
#deps for bark
Expand Down

0 comments on commit 26efdf6

Please sign in to comment.