From a851b9a9fe66392064f921a11cbcfcaaca1803f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Sat, 22 Jun 2024 01:47:56 +0900 Subject: [PATCH] optimize: log & webui (#398) - move log definition out of ChatTTS - apply colorful log level - optimize webui logic - split webui into 2 files for clear reading --- ChatTTS/core.py | 24 +++--- ChatTTS/model/gpt.py | 3 +- ChatTTS/utils/download.py | 4 +- ChatTTS/utils/gpu_utils.py | 4 +- ChatTTS/utils/infer_utils.py | 1 - ChatTTS/utils/io.py | 5 +- ChatTTS/utils/log.py | 8 ++ examples/cmd/run.py | 24 +++--- examples/web/funcs.py | 100 +++++++++++++++++++++++ examples/web/webui.py | 154 ++++++++--------------------------- tools/logger/__init__.py | 1 + tools/logger/log.py | 53 ++++++++++++ 12 files changed, 229 insertions(+), 152 deletions(-) create mode 100644 ChatTTS/utils/log.py create mode 100644 examples/web/funcs.py create mode 100644 tools/logger/__init__.py create mode 100644 tools/logger/log.py diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 6553ef839..a2f78d520 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -1,11 +1,9 @@ - -import os, sys +import os import json import logging -from functools import partial -from typing import Literal import tempfile -from typing import Optional +from functools import partial +from typing import Literal, Optional import torch from omegaconf import OmegaConf @@ -19,16 +17,16 @@ from .utils.io import get_latest_modified_file, del_all from .infer.api import refine_text, infer_code from .utils.download import check_all_assets, download_all_assets - -logging.basicConfig(level = logging.INFO) +from .utils.log import set_utils_logger class Chat: - def __init__(self, ): + def __init__(self, logger=logging.getLogger(__name__)): self.pretrain_models = {} self.normalizer = {} self.homophones_replacer = None - self.logger = logging.getLogger(__name__) + self.logger = logger + set_utils_logger(logger) def check_model(self, level = logging.INFO, use_decoder = False): not_finish = False @@ -46,7 +44,7 @@ def check_model(self, level = logging.INFO, use_decoder = False): if not not_finish: self.logger.log(level, f'All initialized.') - + return not not_finish def load_models( @@ -62,7 +60,7 @@ def load_models( with tempfile.TemporaryDirectory() as tmp: download_all_assets(tmpdir=tmp) if not check_all_assets(update=False): - logging.error("counld not satisfy all assets needed.") + self.logger.error("counld not satisfy all assets needed.") return False elif source == 'huggingface': hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface")) @@ -120,14 +118,14 @@ def _load( if gpt_config_path: cfg = OmegaConf.load(gpt_config_path) - gpt = GPT_warpper(**cfg, device=device).eval() + gpt = GPT_warpper(**cfg, device=device, logger=self.logger).eval() assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' gpt.load_state_dict(torch.load(gpt_ckpt_path)) if compile and 'cuda' in str(device): try: gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True) except RuntimeError as e: - logging.warning(f'Compile failed,{e}. fallback to normal mode.') + self.logger.warning(f'Compile failed,{e}. fallback to normal mode.') self.pretrain_models['gpt'] = gpt spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt') assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}' diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index d98ffde0c..f1617cd0b 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -41,10 +41,11 @@ def __init__( num_text_tokens, num_vq=4, device="cpu", + logger=logging.getLogger(__name__) ): super().__init__() - self.logger = logging.getLogger(__name__) + self.logger = logger self.device = device self.device_gpt = device if "mps" not in str(device) else "cpu" self.num_vq = num_vq diff --git a/ChatTTS/utils/download.py b/ChatTTS/utils/download.py index 7880504ce..75b14613f 100644 --- a/ChatTTS/utils/download.py +++ b/ChatTTS/utils/download.py @@ -3,10 +3,8 @@ import hashlib import requests from io import BytesIO -import logging - -logger = logging.getLogger(__name__) +from .log import logger def sha256(f) -> str: sha256_hash = hashlib.sha256() diff --git a/ChatTTS/utils/gpu_utils.py b/ChatTTS/utils/gpu_utils.py index ed98cd5c8..8c3969422 100644 --- a/ChatTTS/utils/gpu_utils.py +++ b/ChatTTS/utils/gpu_utils.py @@ -1,9 +1,9 @@ import torch -import logging + +from .log import logger def select_device(min_memory=2048): - logger = logging.getLogger(__name__) if torch.cuda.is_available(): available_gpus = [] for i in range(torch.cuda.device_count()): diff --git a/ChatTTS/utils/infer_utils.py b/ChatTTS/utils/infer_utils.py index fa498bf92..ddebd60a8 100644 --- a/ChatTTS/utils/infer_utils.py +++ b/ChatTTS/utils/infer_utils.py @@ -2,7 +2,6 @@ import re import torch import torch.nn.functional as F -import os import json diff --git a/ChatTTS/utils/io.py b/ChatTTS/utils/io.py index ec2cb584d..8b13410cf 100644 --- a/ChatTTS/utils/io.py +++ b/ChatTTS/utils/io.py @@ -3,9 +3,10 @@ import logging from typing import Union +from .log import logger + def get_latest_modified_file(directory): - logger = logging.getLogger(__name__) - + files = [os.path.join(directory, f) for f in os.listdir(directory)] if not files: logger.log(logging.WARNING, f'No files found in the directory: {directory}') diff --git a/ChatTTS/utils/log.py b/ChatTTS/utils/log.py new file mode 100644 index 000000000..6055f8479 --- /dev/null +++ b/ChatTTS/utils/log.py @@ -0,0 +1,8 @@ +import logging +from pathlib import Path + +logger = logging.getLogger(Path(__file__).parent.name) + +def set_utils_logger(l: logging.Logger): + global logger + logger = l diff --git a/examples/cmd/run.py b/examples/cmd/run.py index 7af340f94..9f483db3b 100644 --- a/examples/cmd/run.py +++ b/examples/cmd/run.py @@ -13,6 +13,10 @@ import ChatTTS from IPython.display import Audio +from tools.logger import get_logger + +logger = get_logger("Command") + def save_wav_file(wav, index): wav_filename = f"output_audio_{index}.wav" # Convert numpy array to bytes and write to WAV file @@ -22,26 +26,26 @@ def save_wav_file(wav, index): wf.setsampwidth(2) # Sample width in bytes wf.setframerate(24000) # Sample rate in Hz wf.writeframes(wav_bytes) - print(f"Audio saved to {wav_filename}") + logger.info(f"Audio saved to {wav_filename}") def main(): # Retrieve text from command line argument text_input = sys.argv[1] if len(sys.argv) > 1 else "" - print("Received text input:", text_input) + logger.info("Received text input: %s", text_input) - chat = ChatTTS.Chat() - print("Initializing ChatTTS...") + chat = ChatTTS.Chat(get_logger("ChatTTS")) + logger.info("Initializing ChatTTS...") if chat.load_models(): - print("Models loaded successfully.") + logger.info("Models loaded successfully.") else: - print("Models load failed.") + logger.error("Models load failed.") sys.exit(1) texts = [text_input] - print("Text prepared for inference:", texts) + logger.info("Text prepared for inference: %s", texts) wavs = chat.infer(texts, use_decoder=True) - print("Inference completed. Audio generation successful.") + logger.info("Inference completed. Audio generation successful.") # Save each generated wav file to a local file for index, wav in enumerate(wavs): save_wav_file(wav, index) @@ -49,6 +53,6 @@ def main(): return Audio(wavs[0], rate=24_000, autoplay=True) if __name__ == "__main__": - print("Starting the TTS application...") + logger.info("Starting the TTS application...") main() - print("TTS application finished.") + logger.info("TTS application finished.") diff --git a/examples/web/funcs.py b/examples/web/funcs.py new file mode 100644 index 000000000..4b25df66c --- /dev/null +++ b/examples/web/funcs.py @@ -0,0 +1,100 @@ +import random + +import torch +import gradio as gr +import numpy as np + +from tools.logger import get_logger +logger = get_logger(" WebUI ") + +import ChatTTS +chat = ChatTTS.Chat(get_logger("ChatTTS")) + +# 音色选项:用于预置合适的音色 +voices = { + "默认": {"seed": 2}, + "音色1": {"seed": 1111}, + "音色2": {"seed": 2222}, + "音色3": {"seed": 3333}, + "音色4": {"seed": 4444}, + "音色5": {"seed": 5555}, + "音色6": {"seed": 6666}, + "音色7": {"seed": 7777}, + "音色8": {"seed": 8888}, + "音色9": {"seed": 9999}, + "音色10": {"seed": 11111}, +} + +def generate_seed(): + return gr.update(value=random.randint(1, 100000000)) + +# 返回选择音色对应的seed +def on_voice_change(vocie_selection): + return voices.get(vocie_selection)['seed'] + +def refine_text(text, audio_seed_input, text_seed_input, refine_text_flag): + if not refine_text_flag: + return text + + global chat + + torch.manual_seed(audio_seed_input) + params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'} + + torch.manual_seed(text_seed_input) + + text = chat.infer(text, + skip_refine_text=False, + refine_text_only=True, + params_refine_text=params_refine_text, + ) + return text[0] if isinstance(text, list) else text + +def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, stream): + if not text: return None + + global chat + + torch.manual_seed(audio_seed_input) + rand_spk = chat.sample_random_speaker() + params_infer_code = { + 'spk_emb': rand_spk, + 'temperature': temperature, + 'top_P': top_P, + 'top_K': top_K, + } + torch.manual_seed(text_seed_input) + + wav = chat.infer( + text, + skip_refine_text=True, + params_infer_code=params_infer_code, + stream=stream, + ) + + if stream: + for gen in wav: + wavs = [np.array([[]])] + wavs[0] = np.hstack([wavs[0], np.array(gen[0])]) + audio = wavs[0][0] + + # normalize + am = np.abs(audio).max() * 32768 + if am > 32768: + am = 32768 * 32768 / am + np.multiply(audio, am, audio) + audio = audio.astype(np.int16) + + yield 24000, audio + return + + audio_data = np.array(wav[0]).flatten() + # normalize + am = np.abs(audio_data).max() * 32768 + if am > 32768: + am = 32768 * 32768 / am + np.multiply(audio_data, am, audio_data) + audio_data = audio_data.astype(np.int16) + sample_rate = 24000 + + yield sample_rate, audio_data diff --git a/examples/web/webui.py b/examples/web/webui.py index 17cdcc0ae..135a6d673 100644 --- a/examples/web/webui.py +++ b/examples/web/webui.py @@ -6,141 +6,44 @@ now_dir = os.getcwd() sys.path.append(now_dir) -import random import argparse -import torch import gradio as gr -import numpy as np from dotenv import load_dotenv load_dotenv("sha256.env") -import ChatTTS - -# 音色选项:用于预置合适的音色 -voices = { - "默认": {"seed": 2}, - "音色1": {"seed": 1111}, - "音色2": {"seed": 2222}, - "音色3": {"seed": 3333}, - "音色4": {"seed": 4444}, - "音色5": {"seed": 5555}, - "音色6": {"seed": 6666}, - "音色7": {"seed": 7777}, - "音色8": {"seed": 8888}, - "音色9": {"seed": 9999}, - "音色10": {"seed": 11111}, -} - -def generate_seed(): - new_seed = random.randint(1, 100000000) - return { - "__type__": "update", - "value": new_seed - } - -# 返回选择音色对应的seed -def on_voice_change(vocie_selection): - return voices.get(vocie_selection)['seed'] - -def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag): - - torch.manual_seed(audio_seed_input) - rand_spk = chat.sample_random_speaker() - params_infer_code = { - 'spk_emb': rand_spk, - 'temperature': temperature, - 'top_P': top_P, - 'top_K': top_K, - } - params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'} - - torch.manual_seed(text_seed_input) - - if refine_text_flag: - text = chat.infer(text, - skip_refine_text=False, - refine_text_only=True, - params_refine_text=params_refine_text, - params_infer_code=params_infer_code - ) - - wav = chat.infer(text, - skip_refine_text=True, - params_refine_text=params_refine_text, - params_infer_code=params_infer_code - ) - - audio_data = np.array(wav[0]).flatten() - sample_rate = 24000 - text_data = text[0] if isinstance(text, list) else text - - return [(sample_rate, audio_data), text_data] - -def generate_audio_stream(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, refine_text_flag): - - torch.manual_seed(audio_seed_input) - rand_spk = chat.sample_random_speaker() - params_infer_code = { - 'spk_emb': rand_spk, - 'temperature': temperature, - 'top_P': top_P, - 'top_K': top_K, - } - params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'} - - torch.manual_seed(text_seed_input) - - - wavs_gen = chat.infer(text, - skip_refine_text=True, - params_refine_text=params_refine_text, - params_infer_code=params_infer_code, - stream=True) - - for gen in wavs_gen: - wavs = [np.array([[]])] - wavs[0] = np.hstack([wavs[0], np.array(gen[0])]) - audio = wavs[0][0] - - max_audio = np.abs(audio).max() # 简单防止16bit爆音 - if max_audio > 1: - audio /= max_audio - - yield 24000,(audio * 32768).astype(np.int16) - - - +from examples.web.funcs import * def main(): with gr.Blocks() as demo: gr.Markdown("# ChatTTS WebUI") - gr.Markdown("> ChatTTS Model: [2noise/ChatTTS](https://github.com/2noise/ChatTTS)") + gr.Markdown("- **GitHub Repo**: https://github.com/2noise/ChatTTS") + gr.Markdown("- **HuggingFace Repo**: https://huggingface.co/2Noise/ChatTTS") default_text = "四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。" text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text) with gr.Row(): refine_text_checkbox = gr.Checkbox(label="Refine text", value=True) - temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature") - top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P") - top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K") + temperature_slider = gr.Slider(minimum=0.00001, maximum=1.0, step=0.00001, value=0.3, label="Audio temperature", interactive=True) + top_p_slider = gr.Slider(minimum=0.1, maximum=0.9, step=0.05, value=0.7, label="top_P", interactive=True) + top_k_slider = gr.Slider(minimum=1, maximum=20, step=1, value=20, label="top_K", interactive=True) with gr.Row(): - voice_options = {} voice_selection = gr.Dropdown(label="音色", choices=voices.keys(), value='默认') audio_seed_input = gr.Number(value=2, label="Audio Seed") generate_audio_seed = gr.Button("\U0001F3B2") text_seed_input = gr.Number(value=42, label="Text Seed") generate_text_seed = gr.Button("\U0001F3B2") - generate_button = gr.Button("Generate") - stream_generate_button = gr.Button("Streaming Generate") + with gr.Row(): + auto_play_checkbox = gr.Checkbox(label="Auto Play", value=False, scale=1) + stream_mode_checkbox = gr.Checkbox(label="Stream Mode", value=False, scale=1) + generate_button = gr.Button("Generate", scale=2) text_output = gr.Textbox(label="Output Text", interactive=False) - audio_output = gr.Audio(label="Output Audio",value=None,streaming=True,autoplay=True,interactive=False,show_label=True) # 使用Gradio的回调功能来更新数值输入框 voice_selection.change(fn=on_voice_change, inputs=voice_selection, outputs=audio_seed_input) @@ -152,14 +55,25 @@ def main(): generate_text_seed.click(generate_seed, inputs=[], outputs=text_seed_input) - - generate_button.click(generate_audio, - inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox], - outputs=[audio_output, text_output]) - stream_generate_button.click(generate_audio_stream, - inputs=[text_input, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, refine_text_checkbox], - outputs=[audio_output]) + generate_button.click(fn=lambda: "", outputs=text_output) + generate_button.click(refine_text, + inputs=[text_input, audio_seed_input, text_seed_input, refine_text_checkbox], + outputs=text_output) + + @gr.render(inputs=[auto_play_checkbox, stream_mode_checkbox]) + def make_audio(autoplay, stream): + audio_output = gr.Audio( + label="Output Audio", + value=None, + autoplay=autoplay, + streaming=stream, + interactive=False, + show_label=True, + ) + text_output.change(generate_audio, + inputs=[text_output, temperature_slider, top_p_slider, top_k_slider, audio_seed_input, text_seed_input, stream_mode_checkbox], + outputs=audio_output) gr.Examples( examples=[ @@ -177,20 +91,20 @@ def main(): parser.add_argument('--custom_path', type=str, default=None, help='the custom model path') args = parser.parse_args() - print("loading ChatTTS model...") + logger.info("loading ChatTTS model...") + global chat - chat = ChatTTS.Chat() if args.custom_path == None: ret = chat.load_models() else: - print('local model path:', args.custom_path) + logger.info('local model path: %s', args.custom_path) ret = chat.load_models('custom', custom_path=args.custom_path) - + if ret: - print("Models loaded successfully.") + logger.info("Models loaded successfully.") else: - print("Models load failed.") + logger.error("Models load failed.") sys.exit(1) diff --git a/tools/logger/__init__.py b/tools/logger/__init__.py new file mode 100644 index 000000000..7aa18cf19 --- /dev/null +++ b/tools/logger/__init__.py @@ -0,0 +1 @@ +from .log import get_logger diff --git a/tools/logger/log.py b/tools/logger/log.py new file mode 100644 index 000000000..5e5066d99 --- /dev/null +++ b/tools/logger/log.py @@ -0,0 +1,53 @@ +import platform +import logging +from datetime import datetime, timezone + +# from https://github.com/FloatTech/ZeroBot-Plugin/blob/c70766a989698452e60e5e48fb2f802a2444330d/console/console_windows.go#L89-L96 +colorCodePanic = "\x1b[1;31m" +colorCodeFatal = "\x1b[1;31m" +colorCodeError = "\x1b[31m" +colorCodeWarn = "\x1b[33m" +colorCodeInfo = "\x1b[37m" +colorCodeDebug = "\x1b[32m" +colorCodeTrace = "\x1b[36m" +colorReset = "\x1b[0m" + +log_level_color_code = { + logging.DEBUG: colorCodeDebug, + logging.INFO: colorCodeInfo, + logging.WARN: colorCodeWarn, + logging.ERROR: colorCodeError, + logging.FATAL: colorCodeFatal, +} + +log_level_msg_str = { + logging.DEBUG: "DEBU", + logging.INFO: "INFO", + logging.WARN: "WARN", + logging.ERROR: "ERRO", + logging.FATAL: "FATL", +} + +class Formatter(logging.Formatter): + def __init__(self, color=platform.system().lower() != "windows"): + # https://stackoverflow.com/questions/2720319/python-figure-out-local-timezone + self.tz = datetime.now(timezone.utc).astimezone().tzinfo + self.color = color + + def format(self, record: logging.LogRecord): + logstr = "[" + datetime.now(self.tz).strftime('%z %Y%m%d %H:%M:%S') + "] [" + if self.color: + logstr += log_level_color_code.get(record.levelno, colorCodeInfo) + logstr += log_level_msg_str.get(record.levelno, record.levelname) + if self.color: + logstr += colorReset + logstr += f"] {str(record.name)} | {str(record.msg)}" + return logstr + +def get_logger(name: str, lv = logging.INFO): + logger = logging.getLogger(name) + syslog = logging.StreamHandler() + syslog.setFormatter(Formatter()) + logger.setLevel(lv) + logger.addHandler(syslog) + return logger