Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A prototype for Vits 2 / Yourtts 2 #137

Draft
wants to merge 7 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 179 additions & 0 deletions TTS/tts/configs/vits2_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from dataclasses import dataclass, field
from typing import List

from TTS.tts.configs.shared_configs import BaseTTSConfig
from TTS.tts.models.vits2 import Vits2Args, Vits2AudioConfig


@dataclass
class Vits2Config(BaseTTSConfig):
"""Defines parameters for VITS End2End TTS model.

Args:
model (str):
Model name. Do not change unless you know what you are doing.

model_args (Vits2Args):
Model architecture arguments. Defaults to `Vits2Args()`.

audio (Vits2AudioConfig):
Audio processing configuration. Defaults to `Vits2AudioConfig()`.

grad_clip (List):
Gradient clipping thresholds for each optimizer. Defaults to `[1000.0, 1000.0]`.

lr_gen (float):
Initial learning rate for the generator. Defaults to 0.0002.

lr_disc (float):
Initial learning rate for the discriminator. Defaults to 0.0002.

lr_scheduler_gen (str):
Name of the learning rate scheduler for the generator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.

lr_scheduler_gen_params (dict):
Parameters for the learning rate scheduler of the generator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.

lr_scheduler_disc (str):
Name of the learning rate scheduler for the discriminator. One of the `torch.optim.lr_scheduler.*`. Defaults to
`ExponentialLR`.

lr_scheduler_disc_params (dict):
Parameters for the learning rate scheduler of the discriminator. Defaults to `{'gamma': 0.999875, "last_epoch":-1}`.

scheduler_after_epoch (bool):
If true, step the schedulers after each epoch else after each step. Defaults to `False`.

optimizer (str):
Name of the optimizer to use with both the generator and the discriminator networks. One of the
`torch.optim.*`. Defaults to `AdamW`.

kl_loss_alpha (float):
Loss weight for KL loss. Defaults to 1.0.

disc_loss_alpha (float):
Loss weight for the discriminator loss. Defaults to 1.0.

gen_loss_alpha (float):
Loss weight for the generator loss. Defaults to 1.0.

feat_loss_alpha (float):
Loss weight for the feature matching loss. Defaults to 1.0.

mel_loss_alpha (float):
Loss weight for the mel loss. Defaults to 45.0.

return_wav (bool):
If true, data loader returns the waveform as well as the other outputs. Do not change. Defaults to `True`.

compute_linear_spec (bool):
If true, the linear spectrogram is computed and returned alongside the mel output. Do not change. Defaults to `True`.

use_weighted_sampler (bool):
If true, use weighted sampler with bucketing for balancing samples between datasets used in training. Defaults to `False`.

weighted_sampler_attrs (dict):
Key retuned by the formatter to be used for weighted sampler. For example `{"root_path": 2.0, "speaker_name": 1.0}` sets sample probabilities
by overweighting `root_path` by 2.0. Defaults to `{}`.

weighted_sampler_multipliers (dict):
Weight each unique value of a key returned by the formatter for weighted sampling.
For example `{"root_path":{"/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-100/":1.0, "/raid/datasets/libritts-clean-16khz-bwe-coqui_44khz/LibriTTS/train-clean-360/": 0.5}`.
It will sample instances from `train-clean-100` 2 times more than `train-clean-360`. Defaults to `{}`.

r (int):
Number of spectrogram frames to be generated at a time. Do not change. Defaults to `1`.

add_blank (bool):
If true, a blank token is added in between every character. Defaults to `True`.

test_sentences (List[List]):
List of sentences with speaker and language information to be used for testing.

language_ids_file (str):
Path to the language ids file.

use_language_embedding (bool):
If true, language embedding is used. Defaults to `False`.

Note:
Check :class:`TTS.tts.configs.shared_configs.BaseTTSConfig` for the inherited parameters.

Example:

>>> from TTS.tts.configs.vits2_config import Vits2Config
>>> config = Vits2Config()
"""

model: str = "vits2"
# model specific params
model_args: Vits2Args = field(default_factory=Vits2Args)
audio: Vits2AudioConfig = field(default_factory=Vits2AudioConfig)

# optimizer
grad_clip: List[float] = field(default_factory=lambda: [1000, 1000, 1000])
lr_gen: float = 0.0002
lr_disc: float = 0.0002
lr_dur: float = 0.0002
lr_scheduler_gen: str = "ExponentialLR"
lr_scheduler_gen_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
lr_scheduler_disc: str = "ExponentialLR"
lr_scheduler_disc_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
lr_scheduler_dur: str = "ExponentialLR"
lr_scheduler_dur_params: dict = field(default_factory=lambda: {"gamma": 0.999875, "last_epoch": -1})
scheduler_after_epoch: bool = True
optimizer: str = "AdamW"
optimizer_params: dict = field(default_factory=lambda: {"betas": [0.8, 0.99], "eps": 1e-9, "weight_decay": 0.01})

# loss params
kl_loss_alpha: float = 1.0
disc_loss_alpha: float = 1.0
gen_loss_alpha: float = 1.0
feat_loss_alpha: float = 1.0
mel_loss_alpha: float = 45.0
dur_loss_alpha: float = 1.0
speaker_encoder_loss_alpha: float = 1.0

# data loader params
return_wav: bool = True
compute_linear_spec: bool = True

# sampler params
use_weighted_sampler: bool = False # TODO: move it to the base config
weighted_sampler_attrs: dict = field(default_factory=lambda: {})
weighted_sampler_multipliers: dict = field(default_factory=lambda: {})

# overrides
r: int = 1 # DO NOT CHANGE
add_blank: bool = True

# testing
test_sentences: List[List] = field(
default_factory=lambda: [
["It took me quite a long time to develop a voice, and now that I have it I'm not going to be silent."],
["Be a voice, not an echo."],
["I'm sorry Dave. I'm afraid I can't do that."],
["This cake is great. It's so delicious and moist."],
["Prior to November 22, 1963."],
]
)

# multi-speaker settings
# use speaker embedding layer
num_speakers: int = 0
use_speaker_embedding: bool = False
speakers_file: str = None
speaker_embedding_channels: int = 256
language_ids_file: str = None
use_language_embedding: bool = False

# use d-vectors
use_d_vector_file: bool = False
d_vector_file: List[str] = None
d_vector_dim: int = None

def __post_init__(self):
for key, val in self.model_args.items():
if hasattr(self, key):
self[key] = val
34 changes: 34 additions & 0 deletions TTS/tts/layers/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,40 @@ def forward(self, scores_disc_real, scores_disc_fake):
for i, ldr in enumerate(loss_disc_real):
return_dict[f"loss_disc_real_{i}"] = ldr
return return_dict

class Vits2DurationLoss(nn.Module):
def __init__(self, c: Coqpit):
super().__init__()
self.disc_loss_alpha = c.disc_loss_alpha

@staticmethod
def discriminator_loss(scores_real, scores_fake):
loss = 0
real_losses = []
fake_losses = []
for dr, dg in zip(scores_real, scores_fake):
dr = dr.float()
dg = dg.float()
real_loss = torch.mean((1 - dr) ** 2)
fake_loss = torch.mean(dg**2)
loss += real_loss + fake_loss
real_losses.append(real_loss.item())
fake_losses.append(fake_loss.item())
return loss, real_losses, fake_losses

def forward(self, scores_disc_real, scores_disc_fake):
loss = 0.0
return_dict = {}
loss_disc, loss_disc_real, _ = self.discriminator_loss(
scores_real=scores_disc_real, scores_fake=scores_disc_fake
)
return_dict["loss_dur_disc"] = loss_disc * self.disc_loss_alpha
loss = loss + return_dict["loss_dur_disc"]
return_dict["loss"] = loss

for i, ldr in enumerate(loss_disc_real):
return_dict[f"loss_dur_disc_real_{i}"] = ldr
return return_dict


class ForwardTTSLoss(nn.Module):
Expand Down
89 changes: 89 additions & 0 deletions TTS/tts/layers/vits2/discriminator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import torch
from torch import nn
from torch.nn.modules.conv import Conv1d

from TTS.vocoder.models.hifigan_discriminator import DiscriminatorP


class DiscriminatorS(torch.nn.Module):
"""HiFiGAN Scale Discriminator. Channel sizes are different from the original HiFiGAN.

Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""

def __init__(self, use_spectral_norm=False):
super().__init__()
norm_f = nn.utils.spectral_norm if use_spectral_norm else nn.utils.parametrizations.weight_norm
self.convs = nn.ModuleList(
[
norm_f(Conv1d(1, 16, 15, 1, padding=7)),
norm_f(Conv1d(16, 64, 41, 4, groups=4, padding=20)),
norm_f(Conv1d(64, 256, 41, 4, groups=16, padding=20)),
norm_f(Conv1d(256, 1024, 41, 4, groups=64, padding=20)),
norm_f(Conv1d(1024, 1024, 41, 4, groups=256, padding=20)),
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
]
)
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))

def forward(self, x):
"""
Args:
x (Tensor): input waveform.

Returns:
Tensor: discriminator scores.
List[Tensor]: list of features from the convolutiona layers.
"""
feat = []
for l in self.convs:
x = l(x)
x = torch.nn.functional.leaky_relu(x, 0.1)
feat.append(x)
x = self.conv_post(x)
feat.append(x)
x = torch.flatten(x, 1, -1)
return x, feat


class Vits2Discriminator(nn.Module):
"""VITS discriminator wrapping one Scale Discriminator and a stack of Period Discriminator.

::
waveform -> ScaleDiscriminator() -> scores_sd, feats_sd --> append() -> scores, feats
|--> MultiPeriodDiscriminator() -> scores_mpd, feats_mpd ^

Args:
use_spectral_norm (bool): if `True` swith to spectral norm instead of weight norm.
"""

def __init__(self, periods=(2, 3, 5, 7, 11), use_spectral_norm=False):
super().__init__()
self.nets = nn.ModuleList()
self.nets.append(DiscriminatorS(use_spectral_norm=use_spectral_norm))
self.nets.extend([DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods])

def forward(self, x, x_hat=None):
"""
Args:
x (Tensor): ground truth waveform.
x_hat (Tensor): predicted waveform.

Returns:
List[Tensor]: discriminator scores.
List[List[Tensor]]: list of list of features from each layers of each discriminator.
"""
x_scores = []
x_hat_scores = [] if x_hat is not None else None
x_feats = []
x_hat_feats = [] if x_hat is not None else None
for net in self.nets:
x_score, x_feat = net(x)
x_scores.append(x_score)
x_feats.append(x_feat)
if x_hat is not None:
x_hat_score, x_hat_feat = net(x_hat)
x_hat_scores.append(x_hat_score)
x_hat_feats.append(x_hat_feat)
return x_scores, x_feats, x_hat_scores, x_hat_feats
Loading