diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index 92f6f3ab3c..62420e9958 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -2,6 +2,8 @@ repos:
- repo: "https://github.com/pre-commit/pre-commit-hooks"
rev: v5.0.0
hooks:
+ - id: check-json
+ files: "TTS/.models.json"
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
diff --git a/README.md b/README.md
index 5ca825b6ba..7dddf3a37b 100644
--- a/README.md
+++ b/README.md
@@ -1,13 +1,12 @@
## 🐸Coqui TTS News
- 📣 Fork of the [original, unmaintained repository](https://github.com/coqui-ai/TTS). New PyPI package: [coqui-tts](https://pypi.org/project/coqui-tts)
+- 📣 [OpenVoice](https://github.com/myshell-ai/OpenVoice) models now available for voice conversion.
- 📣 Prebuilt wheels are now also published for Mac and Windows (in addition to Linux as before) for easier installation across platforms.
-- 📣 ⓍTTSv2 is here with 16 languages and better performance across the board.
+- 📣 ⓍTTSv2 is here with 17 languages and better performance across the board. ⓍTTS can stream with <200ms latency.
- 📣 ⓍTTS fine-tuning code is out. Check the [example recipes](https://github.com/idiap/coqui-ai-TTS/tree/dev/recipes/ljspeech).
-- 📣 ⓍTTS can now stream with <200ms latency.
-- 📣 ⓍTTS, our production TTS model that can speak 13 languages, is released [Blog Post](https://coqui.ai/blog/tts/open_xtts), [Demo](https://huggingface.co/spaces/coqui/xtts), [Docs](https://coqui-tts.readthedocs.io/en/latest/models/xtts.html)
- 📣 [🐶Bark](https://github.com/suno-ai/bark) is now available for inference with unconstrained voice cloning. [Docs](https://coqui-tts.readthedocs.io/en/latest/models/bark.html)
-- 📣 You can use [~1100 Fairseq models](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
+- 📣 You can use [Fairseq models in ~1100 languages](https://github.com/facebookresearch/fairseq/tree/main/examples/mms) with 🐸TTS.
##
@@ -121,6 +120,7 @@ repository are also still a useful source of information.
### Voice Conversion
- FreeVC: [paper](https://arxiv.org/abs/2210.15418)
+- OpenVoice: [technical report](https://arxiv.org/abs/2312.01479)
You can also help us implement more models.
@@ -244,8 +244,14 @@ tts = TTS(model_name="voice_conversion_models/multilingual/vctk/freevc24", progr
tts.voice_conversion_to_file(source_wav="my/source.wav", target_wav="my/target.wav", file_path="output.wav")
```
-#### Example voice cloning together with the voice conversion model.
-This way, you can clone voices by using any model in 🐸TTS.
+Other available voice conversion models:
+- `voice_conversion_models/multilingual/multi-dataset/openvoice_v1`
+- `voice_conversion_models/multilingual/multi-dataset/openvoice_v2`
+
+#### Example voice cloning together with the default voice conversion model.
+
+This way, you can clone voices by using any model in 🐸TTS. The FreeVC model is
+used for voice conversion after synthesizing speech.
```python
@@ -412,4 +418,6 @@ $ tts --out_path output/path/speech.wav --model_name "// str:
"""Voice conversion with FreeVC. Convert source wav to target speaker.
Args:
diff --git a/TTS/bin/synthesize.py b/TTS/bin/synthesize.py
index 20e429df04..454f528ab4 100755
--- a/TTS/bin/synthesize.py
+++ b/TTS/bin/synthesize.py
@@ -407,18 +407,18 @@ def main():
# load models
synthesizer = Synthesizer(
- tts_path,
- tts_config_path,
- speakers_file_path,
- language_ids_file_path,
- vocoder_path,
- vocoder_config_path,
- encoder_path,
- encoder_config_path,
- vc_path,
- vc_config_path,
- model_dir,
- args.voice_dir,
+ tts_checkpoint=tts_path,
+ tts_config_path=tts_config_path,
+ tts_speakers_file=speakers_file_path,
+ tts_languages_file=language_ids_file_path,
+ vocoder_checkpoint=vocoder_path,
+ vocoder_config=vocoder_config_path,
+ encoder_checkpoint=encoder_path,
+ encoder_config=encoder_config_path,
+ vc_checkpoint=vc_path,
+ vc_config=vc_config_path,
+ model_dir=model_dir,
+ voice_dir=args.voice_dir,
).to(device)
# query speaker ids of a multi-speaker model.
diff --git a/TTS/tts/layers/vits/networks.py b/TTS/tts/layers/vits/networks.py
index 50ed1024de..ab2ca5667a 100644
--- a/TTS/tts/layers/vits/networks.py
+++ b/TTS/tts/layers/vits/networks.py
@@ -256,7 +256,7 @@ def __init__(
)
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
- def forward(self, x, x_lengths, g=None):
+ def forward(self, x, x_lengths, g=None, tau=1.0):
"""
Shapes:
- x: :math:`[B, C, T]`
@@ -268,5 +268,5 @@ def forward(self, x, x_lengths, g=None):
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
mean, log_scale = torch.split(stats, self.out_channels, dim=1)
- z = (mean + torch.randn_like(mean) * torch.exp(log_scale)) * x_mask
+ z = (mean + torch.randn_like(mean) * tau * torch.exp(log_scale)) * x_mask
return z, mean, log_scale, x_mask
diff --git a/TTS/utils/manage.py b/TTS/utils/manage.py
index bd445b3a2f..38fcfd60e9 100644
--- a/TTS/utils/manage.py
+++ b/TTS/utils/manage.py
@@ -424,7 +424,7 @@ def _find_files(output_path: str) -> Tuple[str, str]:
model_file = None
config_file = None
for file_name in os.listdir(output_path):
- if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth"]:
+ if file_name in ["model_file.pth", "model_file.pth.tar", "model.pth", "checkpoint.pth"]:
model_file = os.path.join(output_path, file_name)
elif file_name == "config.json":
config_file = os.path.join(output_path, file_name)
diff --git a/TTS/utils/synthesizer.py b/TTS/utils/synthesizer.py
index 90af4f48f9..a9b9feffc1 100644
--- a/TTS/utils/synthesizer.py
+++ b/TTS/utils/synthesizer.py
@@ -1,6 +1,7 @@
import logging
import os
import time
+from pathlib import Path
from typing import List
import numpy as np
@@ -15,7 +16,9 @@
from TTS.tts.utils.synthesis import synthesis, transfer_voice, trim_silence
from TTS.utils.audio import AudioProcessor
from TTS.utils.audio.numpy_transforms import save_wav
+from TTS.vc.configs.openvoice_config import OpenVoiceConfig
from TTS.vc.models import setup_model as setup_vc_model
+from TTS.vc.models.openvoice import OpenVoice
from TTS.vocoder.models import setup_model as setup_vocoder_model
from TTS.vocoder.utils.generic_utils import interpolate_vocoder_input
@@ -25,6 +28,7 @@
class Synthesizer(nn.Module):
def __init__(
self,
+ *,
tts_checkpoint: str = "",
tts_config_path: str = "",
tts_speakers_file: str = "",
@@ -91,23 +95,20 @@ def __init__(
if tts_checkpoint:
self._load_tts(tts_checkpoint, tts_config_path, use_cuda)
- self.output_sample_rate = self.tts_config.audio["sample_rate"]
if vocoder_checkpoint:
self._load_vocoder(vocoder_checkpoint, vocoder_config, use_cuda)
- self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
- if vc_checkpoint:
+ if vc_checkpoint and model_dir is None:
self._load_vc(vc_checkpoint, vc_config, use_cuda)
- self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
if model_dir:
if "fairseq" in model_dir:
self._load_fairseq_from_dir(model_dir, use_cuda)
- self.output_sample_rate = self.tts_config.audio["sample_rate"]
+ elif "openvoice" in model_dir:
+ self._load_openvoice_from_dir(Path(model_dir), use_cuda)
else:
self._load_tts_from_dir(model_dir, use_cuda)
- self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
@staticmethod
def _get_segmenter(lang: str):
@@ -136,6 +137,7 @@ def _load_vc(self, vc_checkpoint: str, vc_config_path: str, use_cuda: bool) -> N
"""
# pylint: disable=global-statement
self.vc_config = load_config(vc_config_path)
+ self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
self.vc_model = setup_vc_model(config=self.vc_config)
self.vc_model.load_checkpoint(self.vc_config, vc_checkpoint)
if use_cuda:
@@ -150,9 +152,24 @@ def _load_fairseq_from_dir(self, model_dir: str, use_cuda: bool) -> None:
self.tts_model = Vits.init_from_config(self.tts_config)
self.tts_model.load_fairseq_checkpoint(self.tts_config, checkpoint_dir=model_dir, eval=True)
self.tts_config = self.tts_model.config
+ self.output_sample_rate = self.tts_config.audio["sample_rate"]
if use_cuda:
self.tts_model.cuda()
+ def _load_openvoice_from_dir(self, checkpoint: Path, use_cuda: bool) -> None:
+ """Load the OpenVoice model from a directory.
+
+ We assume the model knows how to load itself from the directory and
+ there is a config.json file in the directory.
+ """
+ self.vc_config = OpenVoiceConfig()
+ self.vc_model = OpenVoice.init_from_config(self.vc_config)
+ self.vc_model.load_checkpoint(self.vc_config, checkpoint, eval=True)
+ self.vc_config = self.vc_model.config
+ self.output_sample_rate = self.vc_config.audio["output_sample_rate"]
+ if use_cuda:
+ self.vc_model.cuda()
+
def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""Load the TTS model from a directory.
@@ -160,6 +177,7 @@ def _load_tts_from_dir(self, model_dir: str, use_cuda: bool) -> None:
"""
config = load_config(os.path.join(model_dir, "config.json"))
self.tts_config = config
+ self.output_sample_rate = self.tts_config.audio["output_sample_rate"]
self.tts_model = setup_tts_model(config)
self.tts_model.load_checkpoint(config, checkpoint_dir=model_dir, eval=True)
if use_cuda:
@@ -181,6 +199,7 @@ def _load_tts(self, tts_checkpoint: str, tts_config_path: str, use_cuda: bool) -
"""
# pylint: disable=global-statement
self.tts_config = load_config(tts_config_path)
+ self.output_sample_rate = self.tts_config.audio["sample_rate"]
if self.tts_config["use_phonemes"] and self.tts_config["phonemizer"] is None:
raise ValueError("Phonemizer is not defined in the TTS config.")
@@ -218,6 +237,7 @@ def _load_vocoder(self, model_file: str, model_config: str, use_cuda: bool) -> N
use_cuda (bool): enable/disable CUDA use.
"""
self.vocoder_config = load_config(model_config)
+ self.output_sample_rate = self.vocoder_config.audio["sample_rate"]
self.vocoder_ap = AudioProcessor(**self.vocoder_config.audio)
self.vocoder_model = setup_vocoder_model(self.vocoder_config)
self.vocoder_model.load_checkpoint(self.vocoder_config, model_file, eval=True)
diff --git a/TTS/vc/configs/freevc_config.py b/TTS/vc/configs/freevc_config.py
index 207181b303..d600bfb1f4 100644
--- a/TTS/vc/configs/freevc_config.py
+++ b/TTS/vc/configs/freevc_config.py
@@ -229,7 +229,7 @@ class FreeVCConfig(BaseVCConfig):
If true, language embedding is used. Defaults to `False`.
Note:
- Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.
+ Check :class:`TTS.tts.configs.shared_configs.BaseVCConfig` for the inherited parameters.
Example:
diff --git a/TTS/vc/configs/openvoice_config.py b/TTS/vc/configs/openvoice_config.py
new file mode 100644
index 0000000000..261cdd6f47
--- /dev/null
+++ b/TTS/vc/configs/openvoice_config.py
@@ -0,0 +1,201 @@
+from dataclasses import dataclass, field
+from typing import Optional
+
+from coqpit import Coqpit
+
+from TTS.vc.configs.shared_configs import BaseVCConfig
+
+
+@dataclass
+class OpenVoiceAudioConfig(Coqpit):
+ """Audio configuration
+
+ Args:
+ input_sample_rate (int):
+ The sampling rate of the input waveform.
+
+ output_sample_rate (int):
+ The sampling rate of the output waveform.
+
+ fft_size (int):
+ The length of the filter.
+
+ hop_length (int):
+ The hop length.
+
+ win_length (int):
+ The window length.
+ """
+
+ input_sample_rate: int = field(default=22050)
+ output_sample_rate: int = field(default=22050)
+ fft_size: int = field(default=1024)
+ hop_length: int = field(default=256)
+ win_length: int = field(default=1024)
+
+
+@dataclass
+class OpenVoiceArgs(Coqpit):
+ """OpenVoice model arguments.
+
+ zero_g (bool):
+ Whether to zero the gradients.
+
+ inter_channels (int):
+ The number of channels in the intermediate layers.
+
+ hidden_channels (int):
+ The number of channels in the hidden layers.
+
+ filter_channels (int):
+ The number of channels in the filter layers.
+
+ n_heads (int):
+ The number of attention heads.
+
+ n_layers (int):
+ The number of layers.
+
+ kernel_size (int):
+ The size of the kernel.
+
+ p_dropout (float):
+ The dropout probability.
+
+ resblock (str):
+ The type of residual block.
+
+ resblock_kernel_sizes (List[int]):
+ The kernel sizes for the residual blocks.
+
+ resblock_dilation_sizes (List[List[int]]):
+ The dilation sizes for the residual blocks.
+
+ upsample_rates (List[int]):
+ The upsample rates.
+
+ upsample_initial_channel (int):
+ The number of channels in the initial upsample layer.
+
+ upsample_kernel_sizes (List[int]):
+ The kernel sizes for the upsample layers.
+
+ n_layers_q (int):
+ The number of layers in the quantization network.
+
+ use_spectral_norm (bool):
+ Whether to use spectral normalization.
+
+ gin_channels (int):
+ The number of channels in the global conditioning vector.
+
+ tau (float):
+ Tau parameter for the posterior encoder
+ """
+
+ zero_g: bool = field(default=True)
+ inter_channels: int = field(default=192)
+ hidden_channels: int = field(default=192)
+ filter_channels: int = field(default=768)
+ n_heads: int = field(default=2)
+ n_layers: int = field(default=6)
+ kernel_size: int = field(default=3)
+ p_dropout: float = field(default=0.1)
+ resblock: str = field(default="1")
+ resblock_kernel_sizes: list[int] = field(default_factory=lambda: [3, 7, 11])
+ resblock_dilation_sizes: list[list[int]] = field(default_factory=lambda: [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
+ upsample_rates: list[int] = field(default_factory=lambda: [8, 8, 2, 2])
+ upsample_initial_channel: int = field(default=512)
+ upsample_kernel_sizes: list[int] = field(default_factory=lambda: [16, 16, 4, 4])
+ n_layers_q: int = field(default=3)
+ use_spectral_norm: bool = field(default=False)
+ gin_channels: int = field(default=256)
+ tau: float = field(default=0.3)
+
+
+@dataclass
+class OpenVoiceConfig(BaseVCConfig):
+ """Defines parameters for OpenVoice VC model.
+
+ Args:
+ model (str):
+ Model name. Do not change unless you know what you are doing.
+
+ model_args (OpenVoiceArgs):
+ Model architecture arguments. Defaults to `OpenVoiceArgs()`.
+
+ audio (OpenVoiceAudioConfig):
+ Audio processing configuration. Defaults to `OpenVoiceAudioConfig()`.
+
+ return_wav (bool):
+ If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.
+
+ compute_linear_spec (bool):
+ If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.
+
+ use_weighted_sampler (bool):
+ If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.
+
+ weighted_sampler_attrs (dict):
+ Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
+ by overweighting `root_path` by 2.0. Defaults to `{}`.
+
+ weighted_sampler_multipliers (dict):
+ Weight each unique value of a key returned by the formatter for weighted sampling.
+ For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
+ It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.
+
+ r (int):
+ Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.
+
+ add_blank (bool):
+ If true, a blank token is added in between every character. Defaults to `True`.
+
+ Note:
+ Check :class:`TTS.tts.configs.shared_configs.BaseVCConfig` for the inherited parameters.
+
+ Example:
+
+ >>> from TTS.vc.configs.openvoice_config import OpenVoiceConfig
+ >>> config = OpenVoiceConfig()
+ """
+
+ model: str = "openvoice"
+ # model specific params
+ model_args: OpenVoiceArgs = field(default_factory=OpenVoiceArgs)
+ audio: OpenVoiceAudioConfig = field(default_factory=OpenVoiceAudioConfig)
+
+ # optimizer
+ # TODO with training support
+
+ # loss params
+ # TODO with training support
+
+ # data loader params
+ return_wav: bool = True
+ compute_linear_spec: bool = True
+
+ # sampler params
+ use_weighted_sampler: bool = False # TODO: move it to the base config
+ weighted_sampler_attrs: dict = field(default_factory=lambda: {})
+ weighted_sampler_multipliers: dict = field(default_factory=lambda: {})
+
+ # overrides
+ r: int = 1 # DO NOT CHANGE
+ add_blank: bool = True
+
+ # multi-speaker settings
+ # use speaker embedding layer
+ num_speakers: int = 0
+ speakers_file: Optional[str] = None
+ speaker_embedding_channels: int = 256
+
+ # use d-vectors
+ use_d_vector_file: bool = False
+ d_vector_file: Optional[list[str]] = None
+ d_vector_dim: Optional[int] = None
+
+ def __post_init__(self) -> None:
+ for key, val in self.model_args.items():
+ if hasattr(self, key):
+ self[key] = val
diff --git a/TTS/vc/modules/__init__.py b/TTS/vc/layers/__init__.py
similarity index 100%
rename from TTS/vc/modules/__init__.py
rename to TTS/vc/layers/__init__.py
diff --git a/TTS/vc/modules/freevc/__init__.py b/TTS/vc/layers/freevc/__init__.py
similarity index 100%
rename from TTS/vc/modules/freevc/__init__.py
rename to TTS/vc/layers/freevc/__init__.py
diff --git a/TTS/vc/modules/freevc/commons.py b/TTS/vc/layers/freevc/commons.py
similarity index 100%
rename from TTS/vc/modules/freevc/commons.py
rename to TTS/vc/layers/freevc/commons.py
diff --git a/TTS/vc/modules/freevc/mel_processing.py b/TTS/vc/layers/freevc/mel_processing.py
similarity index 100%
rename from TTS/vc/modules/freevc/mel_processing.py
rename to TTS/vc/layers/freevc/mel_processing.py
diff --git a/TTS/vc/modules/freevc/modules.py b/TTS/vc/layers/freevc/modules.py
similarity index 99%
rename from TTS/vc/modules/freevc/modules.py
rename to TTS/vc/layers/freevc/modules.py
index ea17be24d6..c34f22d701 100644
--- a/TTS/vc/modules/freevc/modules.py
+++ b/TTS/vc/layers/freevc/modules.py
@@ -7,7 +7,7 @@
from TTS.tts.layers.generic.normalization import LayerNorm2
from TTS.tts.layers.generic.wavenet import fused_add_tanh_sigmoid_multiply
-from TTS.vc.modules.freevc.commons import init_weights
+from TTS.vc.layers.freevc.commons import init_weights
from TTS.vocoder.models.hifigan_generator import get_padding
LRELU_SLOPE = 0.1
diff --git a/TTS/vc/modules/freevc/speaker_encoder/__init__.py b/TTS/vc/layers/freevc/speaker_encoder/__init__.py
similarity index 100%
rename from TTS/vc/modules/freevc/speaker_encoder/__init__.py
rename to TTS/vc/layers/freevc/speaker_encoder/__init__.py
diff --git a/TTS/vc/modules/freevc/speaker_encoder/audio.py b/TTS/vc/layers/freevc/speaker_encoder/audio.py
similarity index 97%
rename from TTS/vc/modules/freevc/speaker_encoder/audio.py
rename to TTS/vc/layers/freevc/speaker_encoder/audio.py
index 5b23a4dbb6..5fa317ce45 100644
--- a/TTS/vc/modules/freevc/speaker_encoder/audio.py
+++ b/TTS/vc/layers/freevc/speaker_encoder/audio.py
@@ -5,7 +5,7 @@
import librosa
import numpy as np
-from TTS.vc.modules.freevc.speaker_encoder.hparams import (
+from TTS.vc.layers.freevc.speaker_encoder.hparams import (
audio_norm_target_dBFS,
mel_n_channels,
mel_window_length,
diff --git a/TTS/vc/modules/freevc/speaker_encoder/hparams.py b/TTS/vc/layers/freevc/speaker_encoder/hparams.py
similarity index 100%
rename from TTS/vc/modules/freevc/speaker_encoder/hparams.py
rename to TTS/vc/layers/freevc/speaker_encoder/hparams.py
diff --git a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py b/TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py
similarity index 98%
rename from TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py
rename to TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py
index 294bf322cb..a6d5bcf942 100644
--- a/TTS/vc/modules/freevc/speaker_encoder/speaker_encoder.py
+++ b/TTS/vc/layers/freevc/speaker_encoder/speaker_encoder.py
@@ -7,8 +7,8 @@
from torch import nn
from trainer.io import load_fsspec
-from TTS.vc.modules.freevc.speaker_encoder import audio
-from TTS.vc.modules.freevc.speaker_encoder.hparams import (
+from TTS.vc.layers.freevc.speaker_encoder import audio
+from TTS.vc.layers.freevc.speaker_encoder.hparams import (
mel_n_channels,
mel_window_step,
model_embedding_size,
diff --git a/TTS/vc/modules/freevc/wavlm/__init__.py b/TTS/vc/layers/freevc/wavlm/__init__.py
similarity index 94%
rename from TTS/vc/modules/freevc/wavlm/__init__.py
rename to TTS/vc/layers/freevc/wavlm/__init__.py
index 4046e137f5..62f7e74aaf 100644
--- a/TTS/vc/modules/freevc/wavlm/__init__.py
+++ b/TTS/vc/layers/freevc/wavlm/__init__.py
@@ -6,7 +6,7 @@
from trainer.io import get_user_data_dir
from TTS.utils.generic_utils import is_pytorch_at_least_2_4
-from TTS.vc.modules.freevc.wavlm.wavlm import WavLM, WavLMConfig
+from TTS.vc.layers.freevc.wavlm.wavlm import WavLM, WavLMConfig
logger = logging.getLogger(__name__)
diff --git a/TTS/vc/modules/freevc/wavlm/config.json b/TTS/vc/layers/freevc/wavlm/config.json
similarity index 100%
rename from TTS/vc/modules/freevc/wavlm/config.json
rename to TTS/vc/layers/freevc/wavlm/config.json
diff --git a/TTS/vc/modules/freevc/wavlm/modules.py b/TTS/vc/layers/freevc/wavlm/modules.py
similarity index 100%
rename from TTS/vc/modules/freevc/wavlm/modules.py
rename to TTS/vc/layers/freevc/wavlm/modules.py
diff --git a/TTS/vc/modules/freevc/wavlm/wavlm.py b/TTS/vc/layers/freevc/wavlm/wavlm.py
similarity index 99%
rename from TTS/vc/modules/freevc/wavlm/wavlm.py
rename to TTS/vc/layers/freevc/wavlm/wavlm.py
index 10dd09ed0c..775f3e5979 100644
--- a/TTS/vc/modules/freevc/wavlm/wavlm.py
+++ b/TTS/vc/layers/freevc/wavlm/wavlm.py
@@ -17,7 +17,7 @@
import torch.nn.functional as F
from torch.nn import LayerNorm
-from TTS.vc.modules.freevc.wavlm.modules import (
+from TTS.vc.layers.freevc.wavlm.modules import (
Fp32GroupNorm,
Fp32LayerNorm,
GLU_Linear,
diff --git a/TTS/vc/models/freevc.py b/TTS/vc/models/freevc.py
index 62559de534..c654219c39 100644
--- a/TTS/vc/models/freevc.py
+++ b/TTS/vc/models/freevc.py
@@ -12,17 +12,16 @@
from torch.nn.utils.parametrize import remove_parametrizations
from trainer.io import load_fsspec
-import TTS.vc.modules.freevc.commons as commons
-import TTS.vc.modules.freevc.modules as modules
+import TTS.vc.layers.freevc.modules as modules
from TTS.tts.layers.vits.discriminator import DiscriminatorS
from TTS.tts.utils.helpers import sequence_mask
from TTS.tts.utils.speakers import SpeakerManager
from TTS.vc.configs.freevc_config import FreeVCConfig
+from TTS.vc.layers.freevc.commons import init_weights, rand_slice_segments
+from TTS.vc.layers.freevc.mel_processing import mel_spectrogram_torch
+from TTS.vc.layers.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
+from TTS.vc.layers.freevc.wavlm import get_wavlm
from TTS.vc.models.base_vc import BaseVC
-from TTS.vc.modules.freevc.commons import init_weights
-from TTS.vc.modules.freevc.mel_processing import mel_spectrogram_torch
-from TTS.vc.modules.freevc.speaker_encoder.speaker_encoder import SpeakerEncoder as SpeakerEncoderEx
-from TTS.vc.modules.freevc.wavlm import get_wavlm
from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP
logger = logging.getLogger(__name__)
@@ -385,7 +384,7 @@ def forward(
z_p = self.flow(z, spec_mask, g=g)
# Randomly slice z and compute o using dec
- z_slice, ids_slice = commons.rand_slice_segments(z, spec_lengths, self.segment_size)
+ z_slice, ids_slice = rand_slice_segments(z, spec_lengths, self.segment_size)
o = self.dec(z_slice, g=g)
return o, ids_slice, spec_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
diff --git a/TTS/vc/models/openvoice.py b/TTS/vc/models/openvoice.py
new file mode 100644
index 0000000000..135b0861b9
--- /dev/null
+++ b/TTS/vc/models/openvoice.py
@@ -0,0 +1,320 @@
+import json
+import logging
+import os
+from pathlib import Path
+from typing import Any, Mapping, Optional, Union
+
+import librosa
+import numpy as np
+import numpy.typing as npt
+import torch
+from coqpit import Coqpit
+from torch import nn
+from torch.nn import functional as F
+from trainer.io import load_fsspec
+
+from TTS.tts.layers.vits.networks import PosteriorEncoder
+from TTS.tts.utils.speakers import SpeakerManager
+from TTS.utils.audio.torch_transforms import wav_to_spec
+from TTS.vc.configs.openvoice_config import OpenVoiceConfig
+from TTS.vc.models.base_vc import BaseVC
+from TTS.vc.models.freevc import Generator, ResidualCouplingBlock
+
+logger = logging.getLogger(__name__)
+
+
+class ReferenceEncoder(nn.Module):
+ """NN module creating a fixed size prosody embedding from a spectrogram.
+
+ inputs: mel spectrograms [batch_size, num_spec_frames, num_mel]
+ outputs: [batch_size, embedding_dim]
+ """
+
+ def __init__(self, spec_channels: int, embedding_dim: int = 0, layernorm: bool = True) -> None:
+ super().__init__()
+ self.spec_channels = spec_channels
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
+ K = len(ref_enc_filters)
+ filters = [1] + ref_enc_filters
+ convs = [
+ torch.nn.utils.parametrizations.weight_norm(
+ nn.Conv2d(
+ in_channels=filters[i],
+ out_channels=filters[i + 1],
+ kernel_size=(3, 3),
+ stride=(2, 2),
+ padding=(1, 1),
+ )
+ )
+ for i in range(K)
+ ]
+ self.convs = nn.ModuleList(convs)
+
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
+ self.gru = nn.GRU(
+ input_size=ref_enc_filters[-1] * out_channels,
+ hidden_size=256 // 2,
+ batch_first=True,
+ )
+ self.proj = nn.Linear(128, embedding_dim)
+ self.layernorm = nn.LayerNorm(self.spec_channels) if layernorm else None
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ N = inputs.size(0)
+
+ out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
+ if self.layernorm is not None:
+ out = self.layernorm(out)
+
+ for conv in self.convs:
+ out = conv(out)
+ out = F.relu(out) # [N, 128, Ty//2^K, n_mels//2^K]
+
+ out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
+ T = out.size(1)
+ N = out.size(0)
+ out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
+
+ self.gru.flatten_parameters()
+ _memory, out = self.gru(out) # out --- [1, N, 128]
+
+ return self.proj(out.squeeze(0))
+
+ def calculate_channels(self, L: int, kernel_size: int, stride: int, pad: int, n_convs: int) -> int:
+ for _ in range(n_convs):
+ L = (L - kernel_size + 2 * pad) // stride + 1
+ return L
+
+
+class OpenVoice(BaseVC):
+ """
+ OpenVoice voice conversion model (inference only).
+
+ Source: https://github.com/myshell-ai/OpenVoice
+ Paper: https://arxiv.org/abs/2312.01479
+
+ Paper abstract:
+ We introduce OpenVoice, a versatile voice cloning approach that requires
+ only a short audio clip from the reference speaker to replicate their voice and
+ generate speech in multiple languages. OpenVoice represents a significant
+ advancement in addressing the following open challenges in the field: 1)
+ Flexible Voice Style Control. OpenVoice enables granular control over voice
+ styles, including emotion, accent, rhythm, pauses, and intonation, in addition
+ to replicating the tone color of the reference speaker. The voice styles are not
+ directly copied from and constrained by the style of the reference speaker.
+ Previous approaches lacked the ability to flexibly manipulate voice styles after
+ cloning. 2) Zero-Shot Cross-Lingual Voice Cloning. OpenVoice achieves zero-shot
+ cross-lingual voice cloning for languages not included in the massive-speaker
+ training set. Unlike previous approaches, which typically require extensive
+ massive-speaker multi-lingual (MSML) dataset for all languages, OpenVoice can
+ clone voices into a new language without any massive-speaker training data for
+ that language. OpenVoice is also computationally efficient, costing tens of
+ times less than commercially available APIs that offer even inferior
+ performance. To foster further research in the field, we have made the source
+ code and trained model publicly accessible. We also provide qualitative results
+ in our demo website. Prior to its public release, our internal version of
+ OpenVoice was used tens of millions of times by users worldwide between May and
+ October 2023, serving as the backend of MyShell.
+ """
+
+ def __init__(self, config: Coqpit, speaker_manager: Optional[SpeakerManager] = None) -> None:
+ super().__init__(config, None, speaker_manager, None)
+
+ self.init_multispeaker(config)
+
+ self.zero_g = self.args.zero_g
+ self.inter_channels = self.args.inter_channels
+ self.hidden_channels = self.args.hidden_channels
+ self.filter_channels = self.args.filter_channels
+ self.n_heads = self.args.n_heads
+ self.n_layers = self.args.n_layers
+ self.kernel_size = self.args.kernel_size
+ self.p_dropout = self.args.p_dropout
+ self.resblock = self.args.resblock
+ self.resblock_kernel_sizes = self.args.resblock_kernel_sizes
+ self.resblock_dilation_sizes = self.args.resblock_dilation_sizes
+ self.upsample_rates = self.args.upsample_rates
+ self.upsample_initial_channel = self.args.upsample_initial_channel
+ self.upsample_kernel_sizes = self.args.upsample_kernel_sizes
+ self.n_layers_q = self.args.n_layers_q
+ self.use_spectral_norm = self.args.use_spectral_norm
+ self.gin_channels = self.args.gin_channels
+ self.tau = self.args.tau
+
+ self.spec_channels = config.audio.fft_size // 2 + 1
+
+ self.dec = Generator(
+ self.inter_channels,
+ self.resblock,
+ self.resblock_kernel_sizes,
+ self.resblock_dilation_sizes,
+ self.upsample_rates,
+ self.upsample_initial_channel,
+ self.upsample_kernel_sizes,
+ gin_channels=self.gin_channels,
+ )
+ self.enc_q = PosteriorEncoder(
+ self.spec_channels,
+ self.inter_channels,
+ self.hidden_channels,
+ kernel_size=5,
+ dilation_rate=1,
+ num_layers=16,
+ cond_channels=self.gin_channels,
+ )
+
+ self.flow = ResidualCouplingBlock(
+ self.inter_channels,
+ self.hidden_channels,
+ kernel_size=5,
+ dilation_rate=1,
+ n_layers=4,
+ gin_channels=self.gin_channels,
+ )
+
+ self.ref_enc = ReferenceEncoder(self.spec_channels, self.gin_channels)
+
+ @property
+ def device(self) -> torch.device:
+ return next(self.parameters()).device
+
+ @staticmethod
+ def init_from_config(config: OpenVoiceConfig) -> "OpenVoice":
+ return OpenVoice(config)
+
+ def init_multispeaker(self, config: Coqpit, data: Optional[list[Any]] = None) -> None:
+ """Initialize multi-speaker modules of a model. A model can be trained either with a speaker embedding layer
+ or with external `d_vectors` computed from a speaker encoder model.
+
+ You must provide a `speaker_manager` at initialization to set up the multi-speaker modules.
+
+ Args:
+ config (Coqpit): Model configuration.
+ data (list, optional): Dataset items to infer number of speakers. Defaults to None.
+ """
+ self.num_spks = config.num_speakers
+ if self.speaker_manager:
+ self.num_spks = self.speaker_manager.num_speakers
+
+ def load_checkpoint(
+ self,
+ config: OpenVoiceConfig,
+ checkpoint_path: Union[str, os.PathLike[Any]],
+ eval: bool = False,
+ strict: bool = True,
+ cache: bool = False,
+ ) -> None:
+ """Map from OpenVoice's config structure."""
+ config_path = Path(checkpoint_path).parent / "config.json"
+ with open(config_path, encoding="utf-8") as f:
+ config_org = json.load(f)
+ self.config.audio.input_sample_rate = config_org["data"]["sampling_rate"]
+ self.config.audio.output_sample_rate = config_org["data"]["sampling_rate"]
+ self.config.audio.fft_size = config_org["data"]["filter_length"]
+ self.config.audio.hop_length = config_org["data"]["hop_length"]
+ self.config.audio.win_length = config_org["data"]["win_length"]
+ state = load_fsspec(str(checkpoint_path), map_location=torch.device("cpu"), cache=cache)
+ self.load_state_dict(state["model"], strict=strict)
+ if eval:
+ self.eval()
+
+ def forward(self) -> None: ...
+ def train_step(self) -> None: ...
+ def eval_step(self) -> None: ...
+
+ @staticmethod
+ def _set_x_lengths(x: torch.Tensor, aux_input: Mapping[str, Optional[torch.Tensor]]) -> torch.Tensor:
+ if "x_lengths" in aux_input and aux_input["x_lengths"] is not None:
+ return aux_input["x_lengths"]
+ return torch.tensor(x.shape[1:2]).to(x.device)
+
+ @torch.no_grad()
+ def inference(
+ self,
+ x: torch.Tensor,
+ aux_input: Mapping[str, Optional[torch.Tensor]] = {"x_lengths": None, "g_src": None, "g_tgt": None},
+ ) -> dict[str, torch.Tensor]:
+ """
+ Inference pass of the model
+
+ Args:
+ x (torch.Tensor): Input tensor. Shape: (batch_size, c_seq_len).
+ x_lengths (torch.Tensor): Lengths of the input tensor. Shape: (batch_size,).
+ g_src (torch.Tensor): Source speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
+ g_tgt (torch.Tensor): Target speaker embedding tensor. Shape: (batch_size, spk_emb_dim).
+
+ Returns:
+ o_hat: Output spectrogram tensor. Shape: (batch_size, spec_seq_len, spec_dim).
+ x_mask: Spectrogram mask. Shape: (batch_size, spec_seq_len).
+ (z, z_p, z_hat): A tuple of latent variables.
+ """
+ x_lengths = self._set_x_lengths(x, aux_input)
+ if "g_src" in aux_input and aux_input["g_src"] is not None:
+ g_src = aux_input["g_src"]
+ else:
+ raise ValueError("aux_input must define g_src")
+ if "g_tgt" in aux_input and aux_input["g_tgt"] is not None:
+ g_tgt = aux_input["g_tgt"]
+ else:
+ raise ValueError("aux_input must define g_tgt")
+ z, _m_q, _logs_q, y_mask = self.enc_q(
+ x, x_lengths, g=g_src if not self.zero_g else torch.zeros_like(g_src), tau=self.tau
+ )
+ z_p = self.flow(z, y_mask, g=g_src)
+ z_hat = self.flow(z_p, y_mask, g=g_tgt, reverse=True)
+ o_hat = self.dec(z_hat * y_mask, g=g_tgt if not self.zero_g else torch.zeros_like(g_tgt))
+ return {
+ "model_outputs": o_hat,
+ "y_mask": y_mask,
+ "z": z,
+ "z_p": z_p,
+ "z_hat": z_hat,
+ }
+
+ def load_audio(self, wav: Union[str, npt.NDArray[np.float32], torch.Tensor, list[float]]) -> torch.Tensor:
+ """Read and format the input audio."""
+ if isinstance(wav, str):
+ out = torch.from_numpy(librosa.load(wav, sr=self.config.audio.input_sample_rate)[0])
+ elif isinstance(wav, np.ndarray):
+ out = torch.from_numpy(wav)
+ elif isinstance(wav, list):
+ out = torch.from_numpy(np.array(wav))
+ else:
+ out = wav
+ return out.to(self.device).float()
+
+ def extract_se(self, audio: Union[str, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]:
+ audio_ref = self.load_audio(audio)
+ y = torch.FloatTensor(audio_ref)
+ y = y.to(self.device)
+ y = y.unsqueeze(0)
+ spec = wav_to_spec(
+ y,
+ n_fft=self.config.audio.fft_size,
+ hop_length=self.config.audio.hop_length,
+ win_length=self.config.audio.win_length,
+ center=False,
+ ).to(self.device)
+ with torch.no_grad():
+ g = self.ref_enc(spec.transpose(1, 2)).unsqueeze(-1)
+
+ return g, spec
+
+ @torch.inference_mode()
+ def voice_conversion(self, src: Union[str, torch.Tensor], tgt: Union[str, torch.Tensor]) -> npt.NDArray[np.float32]:
+ """
+ Voice conversion pass of the model.
+
+ Args:
+ src (str or torch.Tensor): Source utterance.
+ tgt (str or torch.Tensor): Target utterance.
+
+ Returns:
+ Output numpy array.
+ """
+ src_se, src_spec = self.extract_se(src)
+ tgt_se, _ = self.extract_se(tgt)
+
+ aux_input = {"g_src": src_se, "g_tgt": tgt_se}
+ audio = self.inference(src_spec, aux_input)
+ return audio["model_outputs"][0, 0].data.cpu().float().numpy()
diff --git a/tests/inference_tests/test_synthesizer.py b/tests/inference_tests/test_synthesizer.py
index ce4fc751c2..21cc194131 100644
--- a/tests/inference_tests/test_synthesizer.py
+++ b/tests/inference_tests/test_synthesizer.py
@@ -23,7 +23,7 @@ def test_in_out(self):
tts_root_path = get_tests_input_path()
tts_checkpoint = os.path.join(tts_root_path, "checkpoint_10.pth")
tts_config = os.path.join(tts_root_path, "dummy_model_config.json")
- synthesizer = Synthesizer(tts_checkpoint, tts_config, None, None)
+ synthesizer = Synthesizer(tts_checkpoint=tts_checkpoint, tts_config_path=tts_config)
synthesizer.tts("Better this test works!!")
def test_split_into_sentences(self):
diff --git a/tests/vc_tests/test_freevc.py b/tests/vc_tests/test_freevc.py
index c90551b494..fe07b2723c 100644
--- a/tests/vc_tests/test_freevc.py
+++ b/tests/vc_tests/test_freevc.py
@@ -22,31 +22,19 @@
class TestFreeVC(unittest.TestCase):
def _create_inputs(self, config, batch_size=2):
- input_dummy = torch.rand(batch_size, 30 * config.audio["hop_length"]).to(device)
- input_lengths = torch.randint(100, 30 * config.audio["hop_length"], (batch_size,)).long().to(device)
- input_lengths[-1] = 30 * config.audio["hop_length"]
spec = torch.rand(batch_size, 30, config.audio["filter_length"] // 2 + 1).to(device)
mel = torch.rand(batch_size, 30, config.audio["n_mel_channels"]).to(device)
spec_lengths = torch.randint(20, 30, (batch_size,)).long().to(device)
spec_lengths[-1] = spec.size(2)
waveform = torch.rand(batch_size, spec.size(2) * config.audio["hop_length"]).to(device)
- return input_dummy, input_lengths, mel, spec, spec_lengths, waveform
+ return mel, spec, spec_lengths, waveform
@staticmethod
def _create_inputs_inference():
- source_wav = torch.rand(16000)
+ source_wav = torch.rand(15999)
target_wav = torch.rand(16000)
return source_wav, target_wav
- @staticmethod
- def _check_parameter_changes(model, model_ref):
- count = 0
- for param, param_ref in zip(model.parameters(), model_ref.parameters()):
- assert (param != param_ref).any(), "param {} with shape {} not updated!! \n{}\n{}".format(
- count, param.shape, param, param_ref
- )
- count += 1
-
def test_methods(self):
config = FreeVCConfig()
model = FreeVC(config).to(device)
@@ -69,7 +57,7 @@ def _test_forward(self, batch_size):
model.train()
print(" > Num parameters for FreeVC model:%s" % (count_parameters(model)))
- _, _, mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size)
+ mel, spec, spec_lengths, waveform = self._create_inputs(config, batch_size)
wavlm_vec = model.extract_wavlm_features(waveform)
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
@@ -86,7 +74,7 @@ def _test_inference(self, batch_size):
model = FreeVC(config).to(device)
model.eval()
- _, _, mel, _, _, waveform = self._create_inputs(config, batch_size)
+ mel, _, _, waveform = self._create_inputs(config, batch_size)
wavlm_vec = model.extract_wavlm_features(waveform)
wavlm_vec_lengths = torch.ones(batch_size, dtype=torch.long)
@@ -108,8 +96,8 @@ def test_voice_conversion(self):
source_wav, target_wav = self._create_inputs_inference()
output_wav = model.voice_conversion(source_wav, target_wav)
assert (
- output_wav.shape[0] + config.audio.hop_length == source_wav.shape[0]
- ), f"{output_wav.shape} != {source_wav.shape}"
+ output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length
+ ), f"{output_wav.shape} != {source_wav.shape}, {config.audio.hop_length}"
def test_train_step(self): ...
diff --git a/tests/vc_tests/test_openvoice.py b/tests/vc_tests/test_openvoice.py
new file mode 100644
index 0000000000..c9f7ae3931
--- /dev/null
+++ b/tests/vc_tests/test_openvoice.py
@@ -0,0 +1,42 @@
+import os
+import unittest
+
+import torch
+
+from tests import get_tests_input_path
+from TTS.vc.models.openvoice import OpenVoice, OpenVoiceConfig
+
+torch.manual_seed(1)
+use_cuda = torch.cuda.is_available()
+device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+c = OpenVoiceConfig()
+
+WAV_FILE = os.path.join(get_tests_input_path(), "example_1.wav")
+
+
+class TestOpenVoice(unittest.TestCase):
+
+ @staticmethod
+ def _create_inputs_inference():
+ source_wav = torch.rand(16100)
+ target_wav = torch.rand(16000)
+ return source_wav, target_wav
+
+ def test_load_audio(self):
+ config = OpenVoiceConfig()
+ model = OpenVoice(config).to(device)
+ wav = model.load_audio(WAV_FILE)
+ wav2 = model.load_audio(wav)
+ assert all(torch.isclose(wav, wav2))
+
+ def test_voice_conversion(self):
+ config = OpenVoiceConfig()
+ model = OpenVoice(config).to(device)
+ model.eval()
+
+ source_wav, target_wav = self._create_inputs_inference()
+ output_wav = model.voice_conversion(source_wav, target_wav)
+ assert (
+ output_wav.shape[0] == source_wav.shape[0] - source_wav.shape[0] % config.audio.hop_length
+ ), f"{output_wav.shape} != {source_wav.shape}"