From 2ba7c6c1388a7a99f10d366a6e036cd8cdb40f2a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 22 Jun 2024 19:59:53 +0900 Subject: [PATCH] feat(dvae): expose coef for customize (#405) and unify coef of dvae & decoder --- ChatTTS/core.py | 25 ++++++++++++++++++------- ChatTTS/model/dvae.py | 23 ++++++++++++++++------- requirements.txt | 1 + setup.py | 2 +- 4 files changed, 36 insertions(+), 15 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 7204a55ba..e7f83f84b 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -51,12 +51,15 @@ def load_models( self, source: Literal['huggingface', 'local', 'custom']='local', force_redownload=False, - custom_path='', - **kwargs, + compile: bool = True, + custom_path: Optional[torch.serialization.FILE_LIKE]=None, + device: Optional[torch.device] = None, + coef: Optional[torch.Tensor] = None, ): if source == 'local': + torch.load download_path = os.getcwd() - if not check_all_assets(update=True): + if not check_all_assets(update=True) or force_redownload: with tempfile.TemporaryDirectory() as tmp: download_all_assets(tmpdir=tmp) if not check_all_assets(update=False): @@ -77,7 +80,10 @@ def load_models( self.logger.log(logging.INFO, f'Load from local: {custom_path}') download_path = custom_path - return self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs) + return self._load( + device=device, compile=compile, coef=coef, + **{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, + ) def _load( self, @@ -92,6 +98,7 @@ def _load( tokenizer_path: str = None, device: Optional[torch.device] = None, compile: bool = True, + coef: Optional[str] = None ): if device is None: device = select_device(4096) @@ -110,7 +117,8 @@ def _load( if dvae_config_path: cfg = OmegaConf.load(dvae_config_path) - dvae = DVAE(**cfg).to(device).eval() + dvae = DVAE(**cfg, coef=coef).to(device).eval() + coef = str(dvae) assert dvae_ckpt_path, 'dvae_ckpt_path should not be None' dvae.load_state_dict(torch.load(dvae_ckpt_path)) self.pretrain_models['dvae'] = dvae @@ -134,7 +142,8 @@ def _load( if decoder_config_path: cfg = OmegaConf.load(decoder_config_path) - decoder = DVAE(**cfg).to(device).eval() + decoder = DVAE(**cfg, coef=coef).to(device).eval() + coef = str(decoder) assert decoder_ckpt_path, 'decoder_ckpt_path should not be None' decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu')) self.pretrain_models['decoder'] = decoder @@ -145,7 +154,9 @@ def _load( tokenizer.padding_side = 'left' self.pretrain_models['tokenizer'] = tokenizer self.logger.log(logging.INFO, 'tokenizer loaded.') - + + self.coef = coef + return self.check_model() def _infer( diff --git a/ChatTTS/model/dvae.py b/ChatTTS/model/dvae.py index dee22a1f4..0fd4c2873 100644 --- a/ChatTTS/model/dvae.py +++ b/ChatTTS/model/dvae.py @@ -1,6 +1,8 @@ import math -from typing import List +from typing import List, Optional +import pybase16384 as b14 +import numpy as np import torch import torch.nn as nn import torch.nn.functional as F @@ -74,7 +76,7 @@ def __init__(self, def _embed(self, x: torch.Tensor): if self.transpose: - x.transpose_(1, 2) + x = x.transpose(1, 2) """ x = rearrange( x, "b t (g r) -> g b t r", g = self.G, r = self.R, @@ -84,9 +86,9 @@ def _embed(self, x: torch.Tensor): feat = self.quantizer.get_output_from_indices(x) return feat.transpose_(1,2) if self.transpose else feat - def forward(self, x,): + def forward(self, x): if self.transpose: - x.transpose_(1,2) + x = x.transpose(1, 2) feat, ind = self.quantizer(x) """ ind = rearrange( @@ -127,7 +129,7 @@ def __init__(self, idim: int, odim: int, for _ in range(n_layer)]) self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False) - def forward(self, input, conditioning=None): + def forward(self, input: torch.Tensor, conditioning=None) -> torch.Tensor: # B, T, C x = input.transpose_(1, 2) y = self.conv_in(x) @@ -142,10 +144,14 @@ def forward(self, input, conditioning=None): class DVAE(nn.Module): def __init__( - self, decoder_config, vq_config, dim=512 + self, decoder_config, vq_config, dim=512, coef: Optional[str] = None, ): super().__init__() - self.register_buffer('coef', torch.randn(1, 100, 1)) + if coef is None: + coef = torch.rand(100) + else: + coef = torch.from_numpy(np.frombuffer(b14.decode_from_string(coef), dtype=np.float32)) + self.register_buffer('coef', coef.unsqueeze(0).unsqueeze_(2)) self.decoder = DVAEDecoder(**decoder_config) self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False) @@ -153,6 +159,9 @@ def __init__( self.vq_layer = GFSQ(**vq_config) else: self.vq_layer = None + + def __repr__(self) -> str: + return b14.encode_to_string(self.coef.cpu().numpy().astype(np.float32).tobytes()) def forward(self, inp: torch.Tensor) -> torch.Tensor: diff --git a/requirements.txt b/requirements.txt index 1b9bba011..18745dd64 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ vocos IPython gradio python-dotenv +pybase16384 pynini==2.1.5; sys_platform == 'linux' WeTextProcessing; sys_platform == 'linux' nemo_text_processing; sys_platform == 'linux' diff --git a/setup.py b/setup.py index ecd0b3afd..7d21b7d2c 100644 --- a/setup.py +++ b/setup.py @@ -6,12 +6,12 @@ install_requires=['omegaconf>=2.3.0', 'numpy<2.0.0', 'numba', + 'pybase16384', 'torch>=2.1.0', 'tqdm', 'vector_quantize_pytorch', 'transformers>=4.41.1', 'vocos', - 'IPython', ], # 定义依赖哪些模块 packages=find_packages(), # 系统自动从当前目录开始找包 )