Skip to content

Commit

Permalink
feat(dvae): expose coef for customize (#405)
Browse files Browse the repository at this point in the history
and unify coef of dvae & decoder
  • Loading branch information
fumiama authored Jun 22, 2024
1 parent b4c3cff commit 2ba7c6c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
25 changes: 18 additions & 7 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,15 @@ def load_models(
self,
source: Literal['huggingface', 'local', 'custom']='local',
force_redownload=False,
custom_path='<LOCAL_PATH>',
**kwargs,
compile: bool = True,
custom_path: Optional[torch.serialization.FILE_LIKE]=None,
device: Optional[torch.device] = None,
coef: Optional[torch.Tensor] = None,
):
if source == 'local':
torch.load
download_path = os.getcwd()
if not check_all_assets(update=True):
if not check_all_assets(update=True) or force_redownload:
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
if not check_all_assets(update=False):
Expand All @@ -77,7 +80,10 @@ def load_models(
self.logger.log(logging.INFO, f'Load from local: {custom_path}')
download_path = custom_path

return self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}, **kwargs)
return self._load(
device=device, compile=compile, coef=coef,
**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()},
)

def _load(
self,
Expand All @@ -92,6 +98,7 @@ def _load(
tokenizer_path: str = None,
device: Optional[torch.device] = None,
compile: bool = True,
coef: Optional[str] = None
):
if device is None:
device = select_device(4096)
Expand All @@ -110,7 +117,8 @@ def _load(

if dvae_config_path:
cfg = OmegaConf.load(dvae_config_path)
dvae = DVAE(**cfg).to(device).eval()
dvae = DVAE(**cfg, coef=coef).to(device).eval()
coef = str(dvae)
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None'
dvae.load_state_dict(torch.load(dvae_ckpt_path))
self.pretrain_models['dvae'] = dvae
Expand All @@ -134,7 +142,8 @@ def _load(

if decoder_config_path:
cfg = OmegaConf.load(decoder_config_path)
decoder = DVAE(**cfg).to(device).eval()
decoder = DVAE(**cfg, coef=coef).to(device).eval()
coef = str(decoder)
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None'
decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu'))
self.pretrain_models['decoder'] = decoder
Expand All @@ -145,7 +154,9 @@ def _load(
tokenizer.padding_side = 'left'
self.pretrain_models['tokenizer'] = tokenizer
self.logger.log(logging.INFO, 'tokenizer loaded.')


self.coef = coef

return self.check_model()

def _infer(
Expand Down
23 changes: 16 additions & 7 deletions ChatTTS/model/dvae.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import math
from typing import List
from typing import List, Optional

import pybase16384 as b14
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -74,7 +76,7 @@ def __init__(self,

def _embed(self, x: torch.Tensor):
if self.transpose:
x.transpose_(1, 2)
x = x.transpose(1, 2)
"""
x = rearrange(
x, "b t (g r) -> g b t r", g = self.G, r = self.R,
Expand All @@ -84,9 +86,9 @@ def _embed(self, x: torch.Tensor):
feat = self.quantizer.get_output_from_indices(x)
return feat.transpose_(1,2) if self.transpose else feat

def forward(self, x,):
def forward(self, x):
if self.transpose:
x.transpose_(1,2)
x = x.transpose(1, 2)
feat, ind = self.quantizer(x)
"""
ind = rearrange(
Expand Down Expand Up @@ -127,7 +129,7 @@ def __init__(self, idim: int, odim: int,
for _ in range(n_layer)])
self.conv_out = nn.Conv1d(hidden, odim, kernel_size=1, bias=False)

def forward(self, input, conditioning=None):
def forward(self, input: torch.Tensor, conditioning=None) -> torch.Tensor:
# B, T, C
x = input.transpose_(1, 2)
y = self.conv_in(x)
Expand All @@ -142,17 +144,24 @@ def forward(self, input, conditioning=None):

class DVAE(nn.Module):
def __init__(
self, decoder_config, vq_config, dim=512
self, decoder_config, vq_config, dim=512, coef: Optional[str] = None,
):
super().__init__()
self.register_buffer('coef', torch.randn(1, 100, 1))
if coef is None:
coef = torch.rand(100)
else:
coef = torch.from_numpy(np.frombuffer(b14.decode_from_string(coef), dtype=np.float32))
self.register_buffer('coef', coef.unsqueeze(0).unsqueeze_(2))

self.decoder = DVAEDecoder(**decoder_config)
self.out_conv = nn.Conv1d(dim, 100, 3, 1, 1, bias=False)
if vq_config is not None:
self.vq_layer = GFSQ(**vq_config)
else:
self.vq_layer = None

def __repr__(self) -> str:
return b14.encode_to_string(self.coef.cpu().numpy().astype(np.float32).tobytes())

def forward(self, inp: torch.Tensor) -> torch.Tensor:

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ vocos
IPython
gradio
python-dotenv
pybase16384
pynini==2.1.5; sys_platform == 'linux'
WeTextProcessing; sys_platform == 'linux'
nemo_text_processing; sys_platform == 'linux'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
install_requires=['omegaconf>=2.3.0',
'numpy<2.0.0',
'numba',
'pybase16384',
'torch>=2.1.0',
'tqdm',
'vector_quantize_pytorch',
'transformers>=4.41.1',
'vocos',
'IPython',
], # 定义依赖哪些模块
packages=find_packages(), # 系统自动从当前目录开始找包
)

0 comments on commit 2ba7c6c

Please sign in to comment.