From 56227f466db8fd2dec852bc61ddc7368d6af3548 Mon Sep 17 00:00:00 2001 From: ermian Date: Fri, 21 Jun 2024 15:17:59 +0800 Subject: [PATCH] fix: homophones replacer (#386) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit and improve output clarity - 原代码在应用字符映射后错误地使用了变量t而不是text[i],会导致前一步 apply_character_map 的结果被覆盖。现已更正此问题。 - 同音字替换的输出已改进为仅显示替换的字符,使结果更简洁。 --- ChatTTS/core.py | 7 ++++--- ChatTTS/utils/infer_utils.py | 7 +++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/ChatTTS/core.py b/ChatTTS/core.py index e22b2ae80..f4cfada4c 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -179,9 +179,10 @@ def _infer( self.logger.log(logging.WARNING, f'Invalid characters found! : {invalid_characters}') text[i] = apply_character_map(t) if do_homophone_replacement and self.init_homophones_replacer(): - text[i] = self.homophones_replacer.replace(t) - if t != text[i]: - self.logger.log(logging.INFO, f'Homophones replace: {t} -> {text[i]}') + text[i], replaced_words = self.homophones_replacer.replace(text[i]) + if replaced_words: + repl_res = ', '.join([f'{_[0]}->{_[1]}' for _ in replaced_words]) + self.logger.log(logging.INFO, f'Homophones replace: {repl_res}') if not skip_refine_text: text_tokens = refine_text( diff --git a/ChatTTS/utils/infer_utils.py b/ChatTTS/utils/infer_utils.py index b7d70bf52..fa498bf92 100644 --- a/ChatTTS/utils/infer_utils.py +++ b/ChatTTS/utils/infer_utils.py @@ -76,12 +76,15 @@ def load_homophones_map(self, map_file_path): def replace(self, text): result = [] + replaced_words = [] for char in text: if char in self.homophones_map: - result.append(self.homophones_map[char]) + repl_char = self.homophones_map[char] + result.append(repl_char) + replaced_words.append((char, repl_char)) else: result.append(char) - return ''.join(result) + return ''.join(result), replaced_words def count_invalid_characters(s):