diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 92f6f3ab3c..62420e9958 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,6 +2,8 @@ repos: - repo: "https://github.com/pre-commit/pre-commit-hooks" rev: v5.0.0 hooks: + - id: check-json + files: "TTS/.models.json" - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace diff --git a/README.md b/README.md index 5ca825b6ba..7dddf3a37b 100644 --- a/README.md +++ b/README.md @@ -1,13 +1,12 @@ ## 🐸Coqui TTS News - 📣 Fork of the [original, unmaintained repository](https://github.com/coqui-ai/TTS). New PyPI package: [coqui-tts](https://pypi.org/project/coqui-tts) +- 📣 [OpenVoice](https://github.com/myshell-ai/OpenVoice) models now available for voice conversion. - 📣 Prebuilt wheels are now also published for Mac and Windows (in addition to Linux as before) for easier installation across platforms. -- 📣 ⓍTTSv2 is here with 16 languages and better performance across the board. +- 📣 ⓍTTSv2 is here with 17 languages and better performance across the board. ⓍTTS can stream with <200ms latency. - 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/idiap/coqui-ai-TTS/tree/dev/recipes/ljspeech). -- 📣 ⓍTTS can now stream with <200ms latency. -- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/latest/models/xtts.html) - 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/bark.html) -- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS. +- 📣 You can use [Fairseq models in ~1100 languages](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS. ## @@ -121,6 +120,7 @@ repository are also still a useful source of information. ### Voice Conversion - FreeVC: [paper](https://arxiv.org/abs/2210.15418) +- OpenVoice: [technical report](https://arxiv.org/abs/2312.01479) You can also help us implement more models. @@ -244,8 +244,14 @@ tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progr tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav") ``` -#### Example voice cloning together with the voice conversion model. -This way, you can clone voices by using any model in 🐸TTS. +Other available voice conversion models: +- `voice_conversion_models/multilingual/multi-dataset/openvoice_v1` +- `voice_conversion_models/multilingual/multi-dataset/openvoice_v2` + +#### Example voice cloning together with the default voice conversion model. + +This way, you can clone voices by using any model in 🐸TTS. The FreeVC model is +used for voice conversion after synthesizing speech. ```python @@ -412,4 +418,6 @@ $ tts --out_path output/path/speech.wav --model_name "// str: """Voice conversion with FreeVC. Convert source wav to target speaker. Args: diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py index 20e429df04..454f528ab4 100755 --- a/TTS/bin/synthesize.py +++ b/TTS/bin/synthesize.py @@ -407,18 +407,18 @@ def main(): # load models synthesizer = Synthesizer( - tts_path, - tts_config_path, - speakers_file_path, - language_ids_file_path, - vocoder_path, - vocoder_config_path, - encoder_path, - encoder_config_path, - vc_path, - vc_config_path, - model_dir, - args.voice_dir, + tts_checkpoint=tts_path, + tts_config_path=tts_config_path, + tts_speakers_file=speakers_file_path, + tts_languages_file=language_ids_file_path, + vocoder_checkpoint=vocoder_path, + vocoder_config=vocoder_config_path, + encoder_checkpoint=encoder_path, + encoder_config=encoder_config_path, + vc_checkpoint=vc_path, + vc_config=vc_config_path, + model_dir=model_dir, + voice_dir=args.voice_dir, ).to(device) # query speaker ids of a multi-speaker model. diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py index 50ed1024de..ab2ca5667a 100644 --- a/TTS/tts/layers/vits/networks.py +++ b/TTS/tts/layers/vits/networks.py @@ -256,7 +256,7 @@ def __init__( ) self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1) - def forward(self, x, x_lengths, g=None): + def forward(self, x, x_lengths, g=None, tau=1.0): """ Shapes: - x: :math:`[B, C, T]` @@ -268,5 +268,5 @@ def forward(self, x, x_lengths, g=None): x = self.enc(x, x_mask, g=g) stats = self.proj(x) * x_mask mean, log_scale = torch.split(stats, self.out_channels, dim=1) - z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask + z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask return z, mean, log_scale, x_mask diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py index bd445b3a2f..38fcfd60e9 100644 --- a/TTS/utils/manage.py +++ b/TTS/utils/manage.py @@ -424,7 +424,7 @@ def _find_files(output_path: str) -> Tuple[str, str]: model_file = None config_file = None for file_name in os.listdir(output_path): - if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]: + if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]: model_file = os.path.join(output_path, file_name) elif file_name == "config.json": config_file = os.path.join(output_path, file_name) diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py index 90af4f48f9..a9b9feffc1 100644 --- a/TTS/utils/synthesizer.py +++ b/TTS/utils/synthesizer.py @@ -1,6 +1,7 @@ import logging import os import time +from pathlib import Path from typing import List import numpy as np @@ -15,7 +16,9 @@ from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence from TTS.utils.audio import AudioProcessor from TTS.utils.audio.numpy_transforms import save_wav +from TTS.vc.configs.openvoice_config import OpenVoiceConfig from TTS.vc.models import setup_model as setup_vc_model +from TTS.vc.models.openvoice import OpenVoice from TTS.vocoder.models import setup_model as setup_vocoder_model from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input @@ -25,6 +28,7 @@ class Synthesizer(nn.Module): def __init__( self, + *, tts_checkpoint: str = "", tts_config_path: str = "", tts_speakers_file: str = "", @@ -91,23 +95,20 @@ def __init__( if tts_checkpoint: self._load_tts(tts_checkpoint, tts_config_path, use_cuda) - self.output_sample_rate = self.tts_config.audio["sample_rate"] if vocoder_checkpoint: self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda) - self.output_sample_rate = self.vocoder_config.audio["sample_rate"] - if vc_checkpoint: + if vc_checkpoint and model_dir is None: self._load_vc(vc_checkpoint, vc_config, use_cuda) - self.output_sample_rate = self.vc_config.audio["output_sample_rate"] if model_dir: if "fairseq" in model_dir: self._load_fairseq_from_dir(model_dir, use_cuda) - self.output_sample_rate = self.tts_config.audio["sample_rate"] + elif "openvoice" in model_dir: + self._load_openvoice_from_dir(Path(model_dir), use_cuda) else: self._load_tts_from_dir(model_dir, use_cuda) - self.output_sample_rate = self.tts_config.audio["output_sample_rate"] @staticmethod def _get_segmenter(lang: str): @@ -136,6 +137,7 @@ def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> N """ # pylint: disable=global-statement self.vc_config = load_config(vc_config_path) + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] self.vc_model = setup_vc_model(config=self.vc_config) self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint) if use_cuda: @@ -150,9 +152,24 @@ def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None: self.tts_model = Vits.init_from_config(self.tts_config) self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True) self.tts_config = self.tts_model.config + self.output_sample_rate = self.tts_config.audio["sample_rate"] if use_cuda: self.tts_model.cuda() + def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None: + """Load the OpenVoice model from a directory. + + We assume the model knows how to load itself from the directory and + there is a config.json file in the directory. + """ + self.vc_config = OpenVoiceConfig() + self.vc_model = OpenVoice.init_from_config(self.vc_config) + self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True) + self.vc_config = self.vc_model.config + self.output_sample_rate = self.vc_config.audio["output_sample_rate"] + if use_cuda: + self.vc_model.cuda() + def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: """Load the TTS model from a directory. @@ -160,6 +177,7 @@ def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None: """ config = load_config(os.path.join(model_dir, "config.json")) self.tts_config = config + self.output_sample_rate = self.tts_config.audio["output_sample_rate"] self.tts_model = setup_tts_model(config) self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True) if use_cuda: @@ -181,6 +199,7 @@ def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) - """ # pylint: disable=global-statement self.tts_config = load_config(tts_config_path) + self.output_sample_rate = self.tts_config.audio["sample_rate"] if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None: raise ValueError("Phonemizer is not defined in the TTS config.") @@ -218,6 +237,7 @@ def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> N use_cuda (bool): enable/disable CUDA use. """ self.vocoder_config = load_config(model_config) + self.output_sample_rate = self.vocoder_config.audio["sample_rate"] self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio) self.vocoder_model = setup_vocoder_model(self.vocoder_config) self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True) diff --git a/TTS/vc/configs/freevc_config.py b/TTS/vc/configs/freevc_config.py index 207181b303..d600bfb1f4 100644 --- a/TTS/vc/configs/freevc_config.py +++ b/TTS/vc/configs/freevc_config.py @@ -229,7 +229,7 @@ class FreeVCConfig(BaseVCConfig): If true, language embedding is used. Defaults to `False`. Note: - Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters. + Check :class:`TTS.tts.configs.shared_configs.BaseVCConfig` for the inherited parameters. Example: diff --git a/TTS/vc/configs/openvoice_config.py b/TTS/vc/configs/openvoice_config.py new file mode 100644 index 0000000000..261cdd6f47 --- /dev/null +++ b/TTS/vc/configs/openvoice_config.py @@ -0,0 +1,201 @@ +from dataclasses import dataclass, field +from typing import Optional + +from coqpit import Coqpit + +from TTS.vc.configs.shared_configs import BaseVCConfig + + +@dataclass +class OpenVoiceAudioConfig(Coqpit): + """Audio configuration + + Args: + input_sample_rate (int): + The sampling rate of the input waveform. + + output_sample_rate (int): + The sampling rate of the output waveform. + + fft_size (int): + The length of the filter. + + hop_length (int): + The hop length. + + win_length (int): + The window length. + """ + + input_sample_rate: int = field(default=22050) + output_sample_rate: int = field(default=22050) + fft_size: int = field(default=1024) + hop_length: int = field(default=256) + win_length: int = field(default=1024) + + +@dataclass +class OpenVoiceArgs(Coqpit): + """OpenVoice model arguments. + + zero_g (bool): + Whether to zero the gradients. + + inter_channels (int): + The number of channels in the intermediate layers. + + hidden_channels (int): + The number of channels in the hidden layers. + + filter_channels (int): + The number of channels in the filter layers. + + n_heads (int): + The number of attention heads. + + n_layers (int): + The number of layers. + + kernel_size (int): + The size of the kernel. + + p_dropout (float): + The dropout probability. + + resblock (str): + The type of residual block. + + resblock_kernel_sizes (List[int]): + The kernel sizes for the residual blocks. + + resblock_dilation_sizes (List[List[int]]): + The dilation sizes for the residual blocks. + + upsample_rates (List[int]): + The upsample rates. + + upsample_initial_channel (int): + The number of channels in the initial upsample layer. + + upsample_kernel_sizes (List[int]): + The kernel sizes for the upsample layers. + + n_layers_q (int): + The number of layers in the quantization network. + + use_spectral_norm (bool): + Whether to use spectral normalization. + + gin_channels (int): + The number of channels in the global conditioning vector. + + tau (float): + Tau parameter for the posterior encoder + """ + + zero_g: bool = field(default=True) + inter_channels: int = field(default=192) + hidden_channels: int = field(default=192) + filter_channels: int = field(default=768) + n_heads: int = field(default=2) + n_layers: int = field(default=6) + kernel_size: int = field(default=3) + p_dropout: float = field(default=0.1) + resblock: str = field(default="1") + resblock_kernel_sizes: list[int] = field(default_factory=lambda: [3, 7, 11]) + resblock_dilation_sizes: list[list[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]) + upsample_rates: list[int] = field(default_factory=lambda: [8, 8, 2, 2]) + upsample_initial_channel: int = field(default=512) + upsample_kernel_sizes: list[int] = field(default_factory=lambda: [16, 16, 4, 4]) + n_layers_q: int = field(default=3) + use_spectral_norm: bool = field(default=False) + gin_channels: int = field(default=256) + tau: float = field(default=0.3) + + +@dataclass +class OpenVoiceConfig(BaseVCConfig): + """Defines parameters for OpenVoice VC model. + + Args: + model (str): + Model name. Do not change unless you know what you are doing. + + model_args (OpenVoiceArgs): + Model architecture arguments. Defaults to `OpenVoiceArgs()`. + + audio (OpenVoiceAudioConfig): + Audio processing configuration. Defaults to `OpenVoiceAudioConfig()`. + + return_wav (bool): + If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`. + + compute_linear_spec (bool): + If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`. + + use_weighted_sampler (bool): + If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`. + + weighted_sampler_attrs (dict): + Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities + by overweighting `root_path` by 2.0. Defaults to `{}`. + + weighted_sampler_multipliers (dict): + Weight each unique value of a key returned by the formatter for weighted sampling. + For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`. + It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`. + + r (int): + Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`. + + add_blank (bool): + If true, a blank token is added in between every character. Defaults to `True`. + + Note: + Check :class:`TTS.tts.configs.shared_configs.BaseVCConfig` for the inherited parameters. + + Example: + + >>> from TTS.vc.configs.openvoice_config import OpenVoiceConfig + >>> config = OpenVoiceConfig() + """ + + model: str = "openvoice" + # model specific params + model_args: OpenVoiceArgs = field(default_factory=OpenVoiceArgs) + audio: OpenVoiceAudioConfig = field(default_factory=OpenVoiceAudioConfig) + + # optimizer + # TODO with training support + + # loss params + # TODO with training support + + # data loader params + return_wav: bool = True + compute_linear_spec: bool = True + + # sampler params + use_weighted_sampler: bool = False # TODO: move it to the base config + weighted_sampler_attrs: dict = field(default_factory=lambda: {}) + weighted_sampler_multipliers: dict = field(default_factory=lambda: {}) + + # overrides + r: int = 1 # DO NOT CHANGE + add_blank: bool = True + + # multi-speaker settings + # use speaker embedding layer + num_speakers: int = 0 + speakers_file: Optional[str] = None + speaker_embedding_channels: int = 256 + + # use d-vectors + use_d_vector_file: bool = False + d_vector_file: Optional[list[str]] = None + d_vector_dim: Optional[int] = None + + def __post_init__(self) -> None: + for key, val in self.model_args.items(): + if hasattr(self, key): + self[key] = val diff --git a/TTS/vc/modules/__init__.py b/TTS/vc/layers/__init__.py similarity index 100% rename from TTS/vc/modules/__init__.py rename to TTS/vc/layers/__init__.py diff --git a/TTS/vc/modules/freevc/__init__.py b/TTS/vc/layers/freevc/__init__.py similarity index 100% rename from TTS/vc/modules/freevc/__init__.py rename to TTS/vc/layers/freevc/__init__.py diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/layers/freevc/commons.py similarity index 100% rename from TTS/vc/modules/freevc/commons.py rename to TTS/vc/layers/freevc/commons.py diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/layers/freevc/mel_processing.py similarity index 100% rename from TTS/vc/modules/freevc/mel_processing.py rename to TTS/vc/layers/freevc/mel_processing.py diff --git a/TTS/vc/modules/freevc/modules.py b/TTS/vc/layers/freevc/modules.py similarity index 99% rename from TTS/vc/modules/freevc/modules.py rename to TTS/vc/layers/freevc/modules.py index ea17be24d6..c34f22d701 100644 --- a/TTS/vc/modules/freevc/modules.py +++ b/TTS/vc/layers/freevc/modules.py @@ -7,7 +7,7 @@ from TTS.tts.layers.generic.normalization import LayerNorm2 from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply -from TTS.vc.modules.freevc.commons import init_weights +from TTS.vc.layers.freevc.commons import init_weights from TTS.vocoder.models.hifigan_generator import get_padding LRELU_SLOPE = 0.1 diff --git a/TTS/vc/modules/freevc/speaker_encoder/__init__.py b/TTS/vc/layers/freevc/speaker_encoder/__init__.py similarity index 100% rename from TTS/vc/modules/freevc/speaker_encoder/__init__.py rename to TTS/vc/layers/freevc/speaker_encoder/__init__.py diff --git a/TTS/vc/modules/freevc/speaker_encoder/audio.py b/TTS/vc/layers/freevc/speaker_encoder/audio.py similarity index 97% rename from TTS/vc/modules/freevc/speaker_encoder/audio.py rename to TTS/vc/layers/freevc/speaker_encoder/audio.py index 5b23a4dbb6..5fa317ce45 100644 --- a/TTS/vc/modules/freevc/speaker_encoder/audio.py +++ b/TTS/vc/layers/freevc/speaker_encoder/audio.py @@ -5,7 +5,7 @@ import librosa import numpy as np -from TTS.vc.modules.freevc.speaker_encoder.hparams import ( +from TTS.vc.layers.freevc.speaker_encoder.hparams import ( audio_norm_target_dBFS, mel_n_channels, mel_window_length, diff --git a/TTS/vc/modules/freevc/speaker_encoder/hparams.py b/TTS/vc/layers/freevc/speaker_encoder/hparams.py similarity index 100% rename from TTS/vc/modules/freevc/speaker_encoder/hparams.py rename to TTS/vc/layers/freevc/speaker_encoder/hparams.py diff --git a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py b/TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py similarity index 98% rename from TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py rename to TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py index 294bf322cb..a6d5bcf942 100644 --- a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py +++ b/TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py @@ -7,8 +7,8 @@ from torch import nn from trainer.io import load_fsspec -from TTS.vc.modules.freevc.speaker_encoder import audio -from TTS.vc.modules.freevc.speaker_encoder.hparams import ( +from TTS.vc.layers.freevc.speaker_encoder import audio +from TTS.vc.layers.freevc.speaker_encoder.hparams import ( mel_n_channels, mel_window_step, model_embedding_size, diff --git a/TTS/vc/modules/freevc/wavlm/__init__.py b/TTS/vc/layers/freevc/wavlm/__init__.py similarity index 94% rename from TTS/vc/modules/freevc/wavlm/__init__.py rename to TTS/vc/layers/freevc/wavlm/__init__.py index 4046e137f5..62f7e74aaf 100644 --- a/TTS/vc/modules/freevc/wavlm/__init__.py +++ b/TTS/vc/layers/freevc/wavlm/__init__.py @@ -6,7 +6,7 @@ from trainer.io import get_user_data_dir from TTS.utils.generic_utils import is_pytorch_at_least_2_4 -from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig +from TTS.vc.layers.freevc.wavlm.wavlm import WavLM, WavLMConfig logger = logging.getLogger(__name__) diff --git a/TTS/vc/modules/freevc/wavlm/config.json b/TTS/vc/layers/freevc/wavlm/config.json similarity index 100% rename from TTS/vc/modules/freevc/wavlm/config.json rename to TTS/vc/layers/freevc/wavlm/config.json diff --git a/TTS/vc/modules/freevc/wavlm/modules.py b/TTS/vc/layers/freevc/wavlm/modules.py similarity index 100% rename from TTS/vc/modules/freevc/wavlm/modules.py rename to TTS/vc/layers/freevc/wavlm/modules.py diff --git a/TTS/vc/modules/freevc/wavlm/wavlm.py b/TTS/vc/layers/freevc/wavlm/wavlm.py similarity index 99% rename from TTS/vc/modules/freevc/wavlm/wavlm.py rename to TTS/vc/layers/freevc/wavlm/wavlm.py index 10dd09ed0c..775f3e5979 100644 --- a/TTS/vc/modules/freevc/wavlm/wavlm.py +++ b/TTS/vc/layers/freevc/wavlm/wavlm.py @@ -17,7 +17,7 @@ import torch.nn.functional as F from torch.nn import LayerNorm -from TTS.vc.modules.freevc.wavlm.modules import ( +from TTS.vc.layers.freevc.wavlm.modules import ( Fp32GroupNorm, Fp32LayerNorm, GLU_Linear, diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py index 62559de534..c654219c39 100644 --- a/TTS/vc/models/freevc.py +++ b/TTS/vc/models/freevc.py @@ -12,17 +12,16 @@ from torch.nn.utils.parametrize import remove_parametrizations from trainer.io import load_fsspec -import TTS.vc.modules.freevc.commons as commons -import TTS.vc.modules.freevc.modules as modules +import TTS.vc.layers.freevc.modules as modules from TTS.tts.layers.vits.discriminator import DiscriminatorS from TTS.tts.utils.helpers import sequence_mask from TTS.tts.utils.speakers import SpeakerManager from TTS.vc.configs.freevc_config import FreeVCConfig +from TTS.vc.layers.freevc.commons import init_weights, rand_slice_segments +from TTS.vc.layers.freevc.mel_processing import mel_spectrogram_torch +from TTS.vc.layers.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx +from TTS.vc.layers.freevc.wavlm import get_wavlm from TTS.vc.models.base_vc import BaseVC -from TTS.vc.modules.freevc.commons import init_weights -from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch -from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx -from TTS.vc.modules.freevc.wavlm import get_wavlm from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP logger = logging.getLogger(__name__) @@ -385,7 +384,7 @@ def forward( z_p = self.flow(z, spec_mask, g=g) # Randomly slice z and compute o using dec - z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size) + z_slice, ids_slice = rand_slice_segments(z, spec_lengths, self.segment_size) o = self.dec(z_slice, g=g) return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q) diff --git a/TTS/vc/models/openvoice.py b/TTS/vc/models/openvoice.py new file mode 100644 index 0000000000..135b0861b9 --- /dev/null +++ b/TTS/vc/models/openvoice.py @@ -0,0 +1,320 @@ +import json +import logging +import os +from pathlib import Path +from typing import Any, Mapping, Optional, Union + +import librosa +import numpy as np +import numpy.typing as npt +import torch +from coqpit import Coqpit +from torch import nn +from torch.nn import functional as F +from trainer.io import load_fsspec + +from TTS.tts.layers.vits.networks import PosteriorEncoder +from TTS.tts.utils.speakers import SpeakerManager +from TTS.utils.audio.torch_transforms import wav_to_spec +from TTS.vc.configs.openvoice_config import OpenVoiceConfig +from TTS.vc.models.base_vc import BaseVC +from TTS.vc.models.freevc import Generator, ResidualCouplingBlock + +logger = logging.getLogger(__name__) + + +class ReferenceEncoder(nn.Module): + """NN module creating a fixed size prosody embedding from a spectrogram. + + inputs: mel spectrograms [batch_size, num_spec_frames, num_mel] + outputs: [batch_size, embedding_dim] + """ + + def __init__(self, spec_channels: int, embedding_dim: int = 0, layernorm: bool = True) -> None: + super().__init__() + self.spec_channels = spec_channels + ref_enc_filters = [32, 32, 64, 64, 128, 128] + K = len(ref_enc_filters) + filters = [1] + ref_enc_filters + convs = [ + torch.nn.utils.parametrizations.weight_norm( + nn.Conv2d( + in_channels=filters[i], + out_channels=filters[i + 1], + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + ) + for i in range(K) + ] + self.convs = nn.ModuleList(convs) + + out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K) + self.gru = nn.GRU( + input_size=ref_enc_filters[-1] * out_channels, + hidden_size=256 // 2, + batch_first=True, + ) + self.proj = nn.Linear(128, embedding_dim) + self.layernorm = nn.LayerNorm(self.spec_channels) if layernorm else None + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + N = inputs.size(0) + + out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs] + if self.layernorm is not None: + out = self.layernorm(out) + + for conv in self.convs: + out = conv(out) + out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K] + + out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K] + T = out.size(1) + N = out.size(0) + out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K] + + self.gru.flatten_parameters() + _memory, out = self.gru(out) # out --- [1, N, 128] + + return self.proj(out.squeeze(0)) + + def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int: + for _ in range(n_convs): + L = (L - kernel_size + 2 * pad) // stride + 1 + return L + + +class OpenVoice(BaseVC): + """ + OpenVoice voice conversion model (inference only). + + Source: https://github.com/myshell-ai/OpenVoice + Paper: https://arxiv.org/abs/2312.01479 + + Paper abstract: + We introduce OpenVoice, a versatile voice cloning approach that requires + only a short audio clip from the reference speaker to replicate their voice and + generate speech in multiple languages. OpenVoice represents a significant + advancement in addressing the following open challenges in the field: 1) + Flexible Voice Style Control. OpenVoice enables granular control over voice + styles, including emotion, accent, rhythm, pauses, and intonation, in addition + to replicating the tone color of the reference speaker. The voice styles are not + directly copied from and constrained by the style of the reference speaker. + Previous approaches lacked the ability to flexibly manipulate voice styles after + cloning. 2) Zero-Shot Cross-Lingual Voice Cloning. OpenVoice achieves zero-shot + cross-lingual voice cloning for languages not included in the massive-speaker + training set. Unlike previous approaches, which typically require extensive + massive-speaker multi-lingual (MSML) dataset for all languages, OpenVoice can + clone voices into a new language without any massive-speaker training data for + that language. OpenVoice is also computationally efficient, costing tens of + times less than commercially available APIs that offer even inferior + performance. To foster further research in the field, we have made the source + code and trained model publicly accessible. We also provide qualitative results + in our demo website. Prior to its public release, our internal version of + OpenVoice was used tens of millions of times by users worldwide between May and + October 2023, serving as the backend of MyShell. + """ + + def __init__(self, config: Coqpit, speaker_manager: Optional[SpeakerManager] = None) -> None: + super().__init__(config, None, speaker_manager, None) + + self.init_multispeaker(config) + + self.zero_g = self.args.zero_g + self.inter_channels = self.args.inter_channels + self.hidden_channels = self.args.hidden_channels + self.filter_channels = self.args.filter_channels + self.n_heads = self.args.n_heads + self.n_layers = self.args.n_layers + self.kernel_size = self.args.kernel_size + self.p_dropout = self.args.p_dropout + self.resblock = self.args.resblock + self.resblock_kernel_sizes = self.args.resblock_kernel_sizes + self.resblock_dilation_sizes = self.args.resblock_dilation_sizes + self.upsample_rates = self.args.upsample_rates + self.upsample_initial_channel = self.args.upsample_initial_channel + self.upsample_kernel_sizes = self.args.upsample_kernel_sizes + self.n_layers_q = self.args.n_layers_q + self.use_spectral_norm = self.args.use_spectral_norm + self.gin_channels = self.args.gin_channels + self.tau = self.args.tau + + self.spec_channels = config.audio.fft_size // 2 + 1 + + self.dec = Generator( + self.inter_channels, + self.resblock, + self.resblock_kernel_sizes, + self.resblock_dilation_sizes, + self.upsample_rates, + self.upsample_initial_channel, + self.upsample_kernel_sizes, + gin_channels=self.gin_channels, + ) + self.enc_q = PosteriorEncoder( + self.spec_channels, + self.inter_channels, + self.hidden_channels, + kernel_size=5, + dilation_rate=1, + num_layers=16, + cond_channels=self.gin_channels, + ) + + self.flow = ResidualCouplingBlock( + self.inter_channels, + self.hidden_channels, + kernel_size=5, + dilation_rate=1, + n_layers=4, + gin_channels=self.gin_channels, + ) + + self.ref_enc = ReferenceEncoder(self.spec_channels, self.gin_channels) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @staticmethod + def init_from_config(config: OpenVoiceConfig) -> "OpenVoice": + return OpenVoice(config) + + def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None: + """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer + or with external `d_vectors` computed from a speaker encoder model. + + You must provide a `speaker_manager` at initialization to set up the multi-speaker modules. + + Args: + config (Coqpit): Model configuration. + data (list, optional): Dataset items to infer number of speakers. Defaults to None. + """ + self.num_spks = config.num_speakers + if self.speaker_manager: + self.num_spks = self.speaker_manager.num_speakers + + def load_checkpoint( + self, + config: OpenVoiceConfig, + checkpoint_path: Union[str, os.PathLike[Any]], + eval: bool = False, + strict: bool = True, + cache: bool = False, + ) -> None: + """Map from OpenVoice's config structure.""" + config_path = Path(checkpoint_path).parent / "config.json" + with open(config_path, encoding="utf-8") as f: + config_org = json.load(f) + self.config.audio.input_sample_rate = config_org["data"]["sampling_rate"] + self.config.audio.output_sample_rate = config_org["data"]["sampling_rate"] + self.config.audio.fft_size = config_org["data"]["filter_length"] + self.config.audio.hop_length = config_org["data"]["hop_length"] + self.config.audio.win_length = config_org["data"]["win_length"] + state = load_fsspec(str(checkpoint_path), map_location=torch.device("cpu"), cache=cache) + self.load_state_dict(state["model"], strict=strict) + if eval: + self.eval() + + def forward(self) -> None: ... + def train_step(self) -> None: ... + def eval_step(self) -> None: ... + + @staticmethod + def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, Optional[torch.Tensor]]) -> torch.Tensor: + if "x_lengths" in aux_input and aux_input["x_lengths"] is not None: + return aux_input["x_lengths"] + return torch.tensor(x.shape[1:2]).to(x.device) + + @torch.no_grad() + def inference( + self, + x: torch.Tensor, + aux_input: Mapping[str, Optional[torch.Tensor]] = {"x_lengths": None, "g_src": None, "g_tgt": None}, + ) -> dict[str, torch.Tensor]: + """ + Inference pass of the model + + Args: + x (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len). + x_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,). + g_src (torch.Tensor): Source speaker embedding tensor. Shape: (batch_size, spk_emb_dim). + g_tgt (torch.Tensor): Target speaker embedding tensor. Shape: (batch_size, spk_emb_dim). + + Returns: + o_hat: Output spectrogram tensor. Shape: (batch_size, spec_seq_len, spec_dim). + x_mask: Spectrogram mask. Shape: (batch_size, spec_seq_len). + (z, z_p, z_hat): A tuple of latent variables. + """ + x_lengths = self._set_x_lengths(x, aux_input) + if "g_src" in aux_input and aux_input["g_src"] is not None: + g_src = aux_input["g_src"] + else: + raise ValueError("aux_input must define g_src") + if "g_tgt" in aux_input and aux_input["g_tgt"] is not None: + g_tgt = aux_input["g_tgt"] + else: + raise ValueError("aux_input must define g_tgt") + z, _m_q, _logs_q, y_mask = self.enc_q( + x, x_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=self.tau + ) + z_p = self.flow(z, y_mask, g=g_src) + z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True) + o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt)) + return { + "model_outputs": o_hat, + "y_mask": y_mask, + "z": z, + "z_p": z_p, + "z_hat": z_hat, + } + + def load_audio(self, wav: Union[str, npt.NDArray[np.float32], torch.Tensor, list[float]]) -> torch.Tensor: + """Read and format the input audio.""" + if isinstance(wav, str): + out = torch.from_numpy(librosa.load(wav, sr=self.config.audio.input_sample_rate)[0]) + elif isinstance(wav, np.ndarray): + out = torch.from_numpy(wav) + elif isinstance(wav, list): + out = torch.from_numpy(np.array(wav)) + else: + out = wav + return out.to(self.device).float() + + def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + audio_ref = self.load_audio(audio) + y = torch.FloatTensor(audio_ref) + y = y.to(self.device) + y = y.unsqueeze(0) + spec = wav_to_spec( + y, + n_fft=self.config.audio.fft_size, + hop_length=self.config.audio.hop_length, + win_length=self.config.audio.win_length, + center=False, + ).to(self.device) + with torch.no_grad(): + g = self.ref_enc(spec.transpose(1, 2)).unsqueeze(-1) + + return g, spec + + @torch.inference_mode() + def voice_conversion(self, src: Union[str, torch.Tensor], tgt: Union[str, torch.Tensor]) -> npt.NDArray[np.float32]: + """ + Voice conversion pass of the model. + + Args: + src (str or torch.Tensor): Source utterance. + tgt (str or torch.Tensor): Target utterance. + + Returns: + Output numpy array. + """ + src_se, src_spec = self.extract_se(src) + tgt_se, _ = self.extract_se(tgt) + + aux_input = {"g_src": src_se, "g_tgt": tgt_se} + audio = self.inference(src_spec, aux_input) + return audio["model_outputs"][0, 0].data.cpu().float().numpy() diff --git a/tests/inference_tests/test_synthesizer.py b/tests/inference_tests/test_synthesizer.py index ce4fc751c2..21cc194131 100644 --- a/tests/inference_tests/test_synthesizer.py +++ b/tests/inference_tests/test_synthesizer.py @@ -23,7 +23,7 @@ def test_in_out(self): tts_root_path = get_tests_input_path() tts_checkpoint = os.path.join(tts_root_path, "checkpoint_10.pth") tts_config = os.path.join(tts_root_path, "dummy_model_config.json") - synthesizer = Synthesizer(tts_checkpoint, tts_config, None, None) + synthesizer = Synthesizer(tts_checkpoint=tts_checkpoint, tts_config_path=tts_config) synthesizer.tts("Better this test works!!") def test_split_into_sentences(self): diff --git a/tests/vc_tests/test_freevc.py b/tests/vc_tests/test_freevc.py index c90551b494..fe07b2723c 100644 --- a/tests/vc_tests/test_freevc.py +++ b/tests/vc_tests/test_freevc.py @@ -22,31 +22,19 @@ class TestFreeVC(unittest.TestCase): def _create_inputs(self, config, batch_size=2): - input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device) - input_lengths = torch.randint(100, 30 * config.audio["hop_length"], (batch_size,)).long().to(device) - input_lengths[-1] = 30 * config.audio["hop_length"] spec = torch.rand(batch_size, 30, config.audio["filter_length"] // 2 + 1).to(device) mel = torch.rand(batch_size, 30, config.audio["n_mel_channels"]).to(device) spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device) spec_lengths[-1] = spec.size(2) waveform = torch.rand(batch_size, spec.size(2) * config.audio["hop_length"]).to(device) - return input_dummy, input_lengths, mel, spec, spec_lengths, waveform + return mel, spec, spec_lengths, waveform @staticmethod def _create_inputs_inference(): - source_wav = torch.rand(16000) + source_wav = torch.rand(15999) target_wav = torch.rand(16000) return source_wav, target_wav - @staticmethod - def _check_parameter_changes(model, model_ref): - count = 0 - for param, param_ref in zip(model.parameters(), model_ref.parameters()): - assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format( - count, param.shape, param, param_ref - ) - count += 1 - def test_methods(self): config = FreeVCConfig() model = FreeVC(config).to(device) @@ -69,7 +57,7 @@ def _test_forward(self, batch_size): model.train() print(" > Num parameters for FreeVC model:%s" % (count_parameters(model))) - _, _, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size) + mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size) wavlm_vec = model.extract_wavlm_features(waveform) wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long) @@ -86,7 +74,7 @@ def _test_inference(self, batch_size): model = FreeVC(config).to(device) model.eval() - _, _, mel, _, _, waveform = self._create_inputs(config, batch_size) + mel, _, _, waveform = self._create_inputs(config, batch_size) wavlm_vec = model.extract_wavlm_features(waveform) wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long) @@ -108,8 +96,8 @@ def test_voice_conversion(self): source_wav, target_wav = self._create_inputs_inference() output_wav = model.voice_conversion(source_wav, target_wav) assert ( - output_wav.shape[0] + config.audio.hop_length == source_wav.shape[0] - ), f"{output_wav.shape} != {source_wav.shape}" + output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length + ), f"{output_wav.shape} != {source_wav.shape}, {config.audio.hop_length}" def test_train_step(self): ... diff --git a/tests/vc_tests/test_openvoice.py b/tests/vc_tests/test_openvoice.py new file mode 100644 index 0000000000..c9f7ae3931 --- /dev/null +++ b/tests/vc_tests/test_openvoice.py @@ -0,0 +1,42 @@ +import os +import unittest + +import torch + +from tests import get_tests_input_path +from TTS.vc.models.openvoice import OpenVoice, OpenVoiceConfig + +torch.manual_seed(1) +use_cuda = torch.cuda.is_available() +device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + +c = OpenVoiceConfig() + +WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav") + + +class TestOpenVoice(unittest.TestCase): + + @staticmethod + def _create_inputs_inference(): + source_wav = torch.rand(16100) + target_wav = torch.rand(16000) + return source_wav, target_wav + + def test_load_audio(self): + config = OpenVoiceConfig() + model = OpenVoice(config).to(device) + wav = model.load_audio(WAV_FILE) + wav2 = model.load_audio(wav) + assert all(torch.isclose(wav, wav2)) + + def test_voice_conversion(self): + config = OpenVoiceConfig() + model = OpenVoice(config).to(device) + model.eval() + + source_wav, target_wav = self._create_inputs_inference() + output_wav = model.voice_conversion(source_wav, target_wav) + assert ( + output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length + ), f"{output_wav.shape} != {source_wav.shape}"