diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index 3da975962..dee22a1f4 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -1,16 +1,17 @@ import math -from vector_quantize_pytorch import GroupedResidualFSQ +from typing import List import torch import torch.nn as nn import torch.nn.functional as F +from vector_quantize_pytorch import GroupedResidualFSQ class ConvNeXtBlock(nn.Module): def __init__( self, dim: int, intermediate_dim: int, - kernel, dilation, + kernel: int, dilation: int, layer_scale_init_value: float = 1e-6, ): # ConvNeXt Block copied from Vocos. @@ -32,25 +33,31 @@ def __init__( def forward(self, x: torch.Tensor, cond = None) -> torch.Tensor: residual = x - x = self.dwconv(x) - x = x.transpose(1, 2) # (B, C, T) -> (B, T, C) - x = self.norm(x) - x = self.pwconv1(x) - x = self.act(x) - x = self.pwconv2(x) + + y = self.dwconv(x) + y.transpose_(1, 2) # (B, C, T) -> (B, T, C) + x = self.norm(y) + del y + y = self.pwconv1(x) + del x + x = self.act(y) + del y + y = self.pwconv2(x) + del x if self.gamma is not None: - x = self.gamma * x - x = x.transpose(1, 2) # (B, T, C) -> (B, C, T) + y *= self.gamma + y.transpose_(1, 2) # (B, T, C) -> (B, C, T) + + x = y + residual + del y - x = residual + x return x - class GFSQ(nn.Module): def __init__(self, - dim, levels, G, R, eps=1e-5, transpose = True + dim: int, levels: List[int], G: int, R: int, eps=1e-5, transpose = True ): super(GFSQ, self).__init__() self.quantizer = GroupedResidualFSQ( @@ -67,19 +74,19 @@ def __init__(self, def _embed(self, x: torch.Tensor): if self.transpose: - x = x.transpose(1,2) + x.transpose_(1, 2) """ x = rearrange( x, "b t (g r) -> g b t r", g = self.G, r = self.R, ) """ - x.view(-1, self.G, self.R).permute(2, 0, 1, 3) + x = x.view(x.size(0), x.size(1), self.G, self.R).permute(2, 0, 1, 3) feat = self.quantizer.get_output_from_indices(x) - return feat.transpose(1,2) if self.transpose else feat + return feat.transpose_(1,2) if self.transpose else feat def forward(self, x,): if self.transpose: - x = x.transpose(1,2) + x.transpose_(1,2) feat, ind = self.quantizer(x) """ ind = rearrange( @@ -92,19 +99,20 @@ def forward(self, x,): embed_onehot = embed_onehot_tmp.to(x.dtype) del embed_onehot_tmp e_mean = torch.mean(embed_onehot, dim=[0,1]) - e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1) + # e_mean = e_mean / (e_mean.sum(dim=1) + self.eps).unsqueeze(1) + torch.div(e_mean, (e_mean.sum(dim=1) + self.eps).unsqueeze(1), out=e_mean) perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + self.eps), dim=1)) return ( torch.zeros(perplexity.shape, dtype=x.dtype, device=x.device), - feat.transpose(1,2) if self.transpose else feat, + feat.transpose_(1,2) if self.transpose else feat, perplexity, None, - ind.transpose(1,2) if self.transpose else ind, + ind.transpose_(1,2) if self.transpose else ind, ) - + class DVAEDecoder(nn.Module): - def __init__(self, idim, odim, + def __init__(self, idim: int, odim: int, n_layer = 12, bn_dim = 64, hidden = 256, kernel = 7, dilation = 2, up = False ): @@ -121,14 +129,16 @@ def __init__(self, idim, odim, def forward(self, input, conditioning=None): # B, T, C - x = input.transpose(1, 2) - x = self.conv_in(x) + x = input.transpose_(1, 2) + y = self.conv_in(x) + del x for f in self.decoder_block: - x = f(x, conditioning) - - x = self.conv_out(x) - return x.transpose(1, 2) - + y = f(y, conditioning) + + x = self.conv_out(y) + del y + return x.transpose_(1, 2) + class DVAE(nn.Module): def __init__( @@ -144,20 +154,21 @@ def __init__( else: self.vq_layer = None - def forward(self, inp): + def forward(self, inp: torch.Tensor) -> torch.Tensor: if self.vq_layer is not None: vq_feats = self.vq_layer._embed(inp) else: vq_feats = inp.detach().clone() - + vq_feats = vq_feats.view( (vq_feats.size(0), 2, vq_feats.size(1)//2, vq_feats.size(2)), ).permute(0, 2, 3, 1).flatten(2) - vq_feats = vq_feats.transpose(1, 2) - dec_out = self.decoder(input=vq_feats) - dec_out = self.out_conv(dec_out.transpose(1, 2)) - mel = dec_out * self.coef + dec_out = self.out_conv( + self.decoder( + input=vq_feats.transpose_(1, 2), + ).transpose_(1, 2), + ) - return mel + return torch.mul(dec_out, self.coef, out=dec_out)