Skip to content

Commit

Permalink
fix: add param ensure_non_empty (fix #511)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama committed Jul 3, 2024
1 parent edfdec4 commit 3fbf2aa
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 7 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/unitest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,4 @@ jobs:
run: pip install -r requirements.txt

- name: Run Test
run: |
echo "TODO"
run: tests/testall.sh
1 change: 1 addition & 0 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ class RefineTextParams:
max_new_token: int = 384
min_new_token: int = 0
show_tqdm: bool = True
ensure_non_empty: bool = True

@dataclass(repr=False, eq=False)
class InferCodeParams(RefineTextParams):
Expand Down
44 changes: 39 additions & 5 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def generate(
return_hidden=False,
stream=False,
show_tqdm=True,
ensure_non_empty=True,
context=Context(),
):

Expand All @@ -376,13 +377,14 @@ def generate(
)
finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()

old_temperature = temperature

temperature = (
temperature.unsqueeze_(0)
temperature.unsqueeze(0)
.expand(inputs_ids.shape[0], -1)
.contiguous()
.view(-1, 1)
)
# temperature = rearrange(temperature, "b n -> (b n) 1")

attention_mask_cache = torch.ones(
(
Expand Down Expand Up @@ -464,9 +466,9 @@ def generate(
dtype=torch.float,
device=self.device,
)
for i in range(self.num_vq):
x: torch.Tensor = self.head_code[i](hidden_states)
logits[..., i] = x
for num_vq_iter in range(self.num_vq):
x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
logits[..., num_vq_iter ] = x
del x

# logits = logits[:, -1].float()
Expand Down Expand Up @@ -522,13 +524,45 @@ def generate(
],
1,
)

if i == 0 and finish.any():
self.logger.warn(
"unexpected end at index %s",
str([unexpected_idx.item() for unexpected_idx in finish.nonzero()]),
)
if ensure_non_empty:
if show_tqdm:
pbar.close()
self.logger.warn("regenerate in order to ensure non-empty")
new_gen = self.generate(
emb,
inputs_ids,
old_temperature,
eos_token,
attention_mask,
max_new_token,
min_new_token,
logits_warpers,
logits_processors,
infer_text,
return_attn,
return_hidden,
stream,
show_tqdm,
ensure_non_empty,
context,
)
for result in new_gen:
yield result
return

del inputs_ids
inputs_ids = inputs_ids_tmp
del inputs_ids_tmp, idx_next

if stream:
minus_prev_end_index = end_idx.neg()

end_idx.add_((finish.logical_not().to(end_idx.device)).int())
if stream:
if (
Expand Down
58 changes: 58 additions & 0 deletions tests/#511.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import os, sys

if sys.platform == "darwin":
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"

now_dir = os.getcwd()
sys.path.append(now_dir)

import ChatTTS

from tools.logger import get_logger

logger = get_logger("Test #511")

chat = ChatTTS.Chat(logger)
chat.load(compile=False) # Set to True for better performance

texts = ["语音太短了会造成生成音频错误, 这是占位占位, 老大爷觉得车夫的想法很有道理",
"评分只是衡量音色的稳定性,不代表音色的好坏, 可以根据自己的需求选择合适的音色",
"举个简单的例子,如果一个沙哑且结巴的音色一直很稳定,那么它的评分就会很高。",
"语音太短了会造成生成音频错误, 这是占位占位。我使用 seed id 去生成音频, 但是生成的音频不稳定",
"seed id只是一个参考ID 不同的环境下音色不一定一致。还是推荐使用 .pt 文件载入音色",
"语音太短了会造成生成音频错误, 这是占位占位。音色标的男女准确吗",
"当前第一批测试的音色有两千条, 根据声纹相似性简单打标, 准确度不高, 特别是特征一项",
"语音太短了会造成生成音频错误, 这是占位占位。仅供参考。如果大家有更好的标注方法,欢迎 PR。",
]

rand_spk = chat.sample_random_speaker()

params_infer_code = ChatTTS.Chat.InferCodeParams(
spk_emb = rand_spk, # add sampled speaker
temperature = .3, # using custom temperature
top_P = 0.005, # top P decode
top_K = 1, # top K decode
)

params_refine_text = ChatTTS.Chat.RefineTextParams(
prompt='[oral_0][laugh_0][break_4]',
)

fail = False

for i in range(4):

wavs = chat.infer(
texts,
params_refine_text=params_refine_text,
params_infer_code=params_infer_code,
)

for k, wav in enumerate(wavs):
if wav is None:
logger.warn("iter", i, "index", k, "is None")
fail = True

if fail:
import sys
sys.exit(1)
15 changes: 15 additions & 0 deletions tests/testall.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/bin/sh

exitcode=0

for file in tests/*.py
do
python "$file"
if [ $? -ne 0 ]
then
echo "Error: $file exited with a non-zero status."
exitcode=1
fi
done

exit $exitcode

0 comments on commit 3fbf2aa

Please sign in to comment.