From 26efdf6ee7feaed7a6b926d3237a393e97814754 Mon Sep 17 00:00:00 2001 From: Eren G??lge Date: Fri, 17 Nov 2023 13:42:33 +0100 Subject: [PATCH] Make k_diffusion optional --- TTS/tts/layers/tortoise/diffusion.py | 13 +++++++++++-- requirements.txt | 1 - 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/TTS/tts/layers/tortoise/diffusion.py b/TTS/tts/layers/tortoise/diffusion.py index cb350af779..fcdaa9d76e 100644 --- a/TTS/tts/layers/tortoise/diffusion.py +++ b/TTS/tts/layers/tortoise/diffusion.py @@ -13,12 +13,19 @@ import numpy as np import torch import torch as th -from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral from tqdm import tqdm from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper -K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} + +try: + from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral + + K_DIFFUSION_SAMPLERS = {"k_euler_a": sample_euler_ancestral, "dpm++2m": sample_dpmpp_2m} +except ImportError: + K_DIFFUSION_SAMPLERS = None + + SAMPLERS = ["dpm++2m", "p", "ddim"] @@ -531,6 +538,8 @@ def sample_loop(self, *args, **kwargs): if self.conditioning_free is not True: raise RuntimeError("cond_free must be true") with tqdm(total=self.num_timesteps) as pbar: + if K_DIFFUSION_SAMPLERS is None: + raise ModuleNotFoundError("Install k_diffusion for using k_diffusion samplers") return self.k_diffusion_sample_loop(K_DIFFUSION_SAMPLERS[s], pbar, *args, **kwargs) else: raise RuntimeError("sampler not impl") diff --git a/requirements.txt b/requirements.txt index 864215117e..ce0e5d9207 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,7 +46,6 @@ bangla bnnumerizer bnunicodenormalizer #deps for tortoise -k_diffusion einops>=0.6.0 transformers>=4.33.0 #deps for bark