diff --git a/README.md b/README.md index 10fbb41..9d0a64b 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ Next-generation TTS model using flow-matching and DiT, inspired by [Stable Diffu As the first open-source TTS model that tried to combine flow-matching and DiT, **StableTTS** is a fast and lightweight TTS model for chinese, english and japanese speech generation. It has 31M parameters. -✨ **Huggingface demo:** [🤗](https://huggingface.co/spaces/KdaiP/StableTTS1.1) +✨ **Hugging Face demo:** [🤗](https://huggingface.co/spaces/KdaiP/StableTTS1.1) ## News @@ -36,7 +36,7 @@ As the first open-source TTS model that tried to combine flow-matching and DiT, ### Text-To-Mel model -Download and place the model in the `./checkpoints` directory, it is ready for inference, finetuning and webui. +By default, the inference scripts will automatically download the pretrained models from Hugging Face. For training, download the models and place them in the `checkpoints` directory. | Model Name | Task Details | Dataset | Download Link | |:----------:|:------------:|:-------------:|:-------------:| diff --git a/api.py b/api.py index 96809f9..84c8154 100644 --- a/api.py +++ b/api.py @@ -16,6 +16,8 @@ from datas.dataset import intersperse from utils.audio import load_and_resample_audio +from cached_path import cached_path + def get_vocoder(model_path, model_name='ffgan') -> nn.Module: if model_name == 'ffgan': # training or changing ffgan config is not supported in this repo @@ -83,8 +85,8 @@ def get_params(self): if __name__ == '__main__': device = 'cuda' if torch.cuda.is_available() else 'cpu' - tts_model_path = './checkpoints/checkpoint_0.pt' - vocoder_model_path = './vocoders/pretrained/vocos.pt' + tts_model_path = str(cached_path('hf://KdaiP/StableTTS1.1/StableTTS/checkpoint_0.pt')) + vocoder_model_path = str(cached_path('hf://KdaiP/StableTTS1.1/vocoders/vocos.pt')) model = StableTTSAPI(tts_model_path, vocoder_model_path, 'vocos') model.to(device) diff --git a/inference.ipynb b/inference.ipynb index b079a3c..d116c1b 100644 --- a/inference.ipynb +++ b/inference.ipynb @@ -10,11 +10,12 @@ "import torch\n", "\n", "from api import StableTTSAPI\n", + "from cached_path import cached_path\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", - "tts_model_path = './checkpoints/checkpoint_0.pt' # path to StableTTS checkpoint\n", - "vocoder_model_path = './vocoders/pretrained/firefly-gan-base-generator.ckpt' # path to vocoder checkpoint\n", + "tts_model_path = str(cached_path('hf://KdaiP/StableTTS1.1/StableTTS/checkpoint_0.pt')) # or path to StableTTS checkpoint\n", + "vocoder_model_path = str(cached_path('hf://KdaiP/StableTTS1.1/vocoders/firefly-gan-base-generator.ckpt')) # or path to vocoder checkpoint\n", "vocoder_type = 'ffgan' # ffgan or vocos\n", "\n", "# vocoder_model_path = './vocoders/pretrained/vocos.pt'\n", diff --git a/requirements.txt b/requirements.txt index 39dbf7d..0327dd9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,8 @@ soundfile # to make sure that torchaudio has at least one valid backend tensorboard +cached_path + # for monotonic_align numba diff --git a/webui.py b/webui.py index 8532256..f857ee3 100644 --- a/webui.py +++ b/webui.py @@ -11,10 +11,12 @@ from api import StableTTSAPI +from cached_path import cached_path + device = 'cuda' if torch.cuda.is_available() else 'cpu' -tts_model_path = './checkpoints/checkpoint_0.pt' -vocoder_model_path = './vocoders/pretrained/firefly-gan-base-generator.ckpt' +tts_model_path = str(cached_path('hf://KdaiP/StableTTS1.1/StableTTS/checkpoint_0.pt')) +vocoder_model_path = str(cached_path('hf://KdaiP/StableTTS1.1/vocoders/firefly-gan-base-generator.ckpt')) vocoder_type = 'ffgan' model = StableTTSAPI(tts_model_path, vocoder_model_path, vocoder_type).to(device)