Skip to content

Commit

Permalink
optimize: log & webui (#398)
Browse files Browse the repository at this point in the history
- move log definition out of ChatTTS
- apply colorful log level
- optimize webui logic
- split webui into 2 files for clear reading
  • Loading branch information
fumiama authored Jun 21, 2024
1 parent e58fe48 commit a851b9a
Show file tree
Hide file tree
Showing 12 changed files with 229 additions and 152 deletions.
24 changes: 11 additions & 13 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@

import os, sys
import os
import json
import logging
from functools import partial
from typing import Literal
import tempfile
from typing import Optional
from functools import partial
from typing import Literal, Optional

import torch
from omegaconf import OmegaConf
Expand All @@ -19,16 +17,16 @@
from .utils.io import get_latest_modified_file, del_all
from .infer.api import refine_text, infer_code
from .utils.download import check_all_assets, download_all_assets

logging.basicConfig(level = logging.INFO)
from .utils.log import set_utils_logger


class Chat:
def __init__(self, ):
def __init__(self, logger=logging.getLogger(__name__)):
self.pretrain_models = {}
self.normalizer = {}
self.homophones_replacer = None
self.logger = logging.getLogger(__name__)
self.logger = logger
set_utils_logger(logger)

def check_model(self, level = logging.INFO, use_decoder = False):
not_finish = False
Expand All @@ -46,7 +44,7 @@ def check_model(self, level = logging.INFO, use_decoder = False):

if not not_finish:
self.logger.log(level, f'All initialized.')

return not not_finish

def load_models(
Expand All @@ -62,7 +60,7 @@ def load_models(
with tempfile.TemporaryDirectory() as tmp:
download_all_assets(tmpdir=tmp)
if not check_all_assets(update=False):
logging.error("counld not satisfy all assets needed.")
self.logger.error("counld not satisfy all assets needed.")
return False
elif source == 'huggingface':
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface"))
Expand Down Expand Up @@ -120,14 +118,14 @@ def _load(

if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
gpt = GPT_warpper(**cfg, device=device).eval()
gpt = GPT_warpper(**cfg, device=device, logger=self.logger).eval()
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None'
gpt.load_state_dict(torch.load(gpt_ckpt_path))
if compile and 'cuda' in str(device):
try:
gpt.gpt.forward = torch.compile(gpt.gpt.forward, backend='inductor', dynamic=True)
except RuntimeError as e:
logging.warning(f'Compile failed,{e}. fallback to normal mode.')
self.logger.warning(f'Compile failed,{e}. fallback to normal mode.')
self.pretrain_models['gpt'] = gpt
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), 'spk_stat.pt')
assert os.path.exists(spk_stat_path), f'Missing spk_stat.pt: {spk_stat_path}'
Expand Down
3 changes: 2 additions & 1 deletion ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,11 @@ def __init__(
num_text_tokens,
num_vq=4,
device="cpu",
logger=logging.getLogger(__name__)
):
super().__init__()

self.logger = logging.getLogger(__name__)
self.logger = logger
self.device = device
self.device_gpt = device if "mps" not in str(device) else "cpu"
self.num_vq = num_vq
Expand Down
4 changes: 1 addition & 3 deletions ChatTTS/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,8 @@
import hashlib
import requests
from io import BytesIO
import logging

logger = logging.getLogger(__name__)

from .log import logger

def sha256(f) -> str:
sha256_hash = hashlib.sha256()
Expand Down
4 changes: 2 additions & 2 deletions ChatTTS/utils/gpu_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@

import torch
import logging

from .log import logger

def select_device(min_memory=2048):
logger = logging.getLogger(__name__)
if torch.cuda.is_available():
available_gpus = []
for i in range(torch.cuda.device_count()):
Expand Down
1 change: 0 additions & 1 deletion ChatTTS/utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import re
import torch
import torch.nn.functional as F
import os
import json


Expand Down
5 changes: 3 additions & 2 deletions ChatTTS/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import logging
from typing import Union

from .log import logger

def get_latest_modified_file(directory):
logger = logging.getLogger(__name__)


files = [os.path.join(directory, f) for f in os.listdir(directory)]
if not files:
logger.log(logging.WARNING, f'No files found in the directory: {directory}')
Expand Down
8 changes: 8 additions & 0 deletions ChatTTS/utils/log.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import logging
from pathlib import Path

logger = logging.getLogger(Path(__file__).parent.name)

def set_utils_logger(l: logging.Logger):
global logger
logger = l
24 changes: 14 additions & 10 deletions examples/cmd/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@
import ChatTTS
from IPython.display import Audio

from tools.logger import get_logger

logger = get_logger("Command")

def save_wav_file(wav, index):
wav_filename = f"output_audio_{index}.wav"
# Convert numpy array to bytes and write to WAV file
Expand All @@ -22,33 +26,33 @@ def save_wav_file(wav, index):
wf.setsampwidth(2) # Sample width in bytes
wf.setframerate(24000) # Sample rate in Hz
wf.writeframes(wav_bytes)
print(f"Audio saved to {wav_filename}")
logger.info(f"Audio saved to {wav_filename}")

def main():
# Retrieve text from command line argument
text_input = sys.argv[1] if len(sys.argv) > 1 else "<YOUR TEXT HERE>"
print("Received text input:", text_input)
logger.info("Received text input: %s", text_input)

chat = ChatTTS.Chat()
print("Initializing ChatTTS...")
chat = ChatTTS.Chat(get_logger("ChatTTS"))
logger.info("Initializing ChatTTS...")
if chat.load_models():
print("Models loaded successfully.")
logger.info("Models loaded successfully.")
else:
print("Models load failed.")
logger.error("Models load failed.")
sys.exit(1)

texts = [text_input]
print("Text prepared for inference:", texts)
logger.info("Text prepared for inference: %s", texts)

wavs = chat.infer(texts, use_decoder=True)
print("Inference completed. Audio generation successful.")
logger.info("Inference completed. Audio generation successful.")
# Save each generated wav file to a local file
for index, wav in enumerate(wavs):
save_wav_file(wav, index)

return Audio(wavs[0], rate=24_000, autoplay=True)

if __name__ == "__main__":
print("Starting the TTS application...")
logger.info("Starting the TTS application...")
main()
print("TTS application finished.")
logger.info("TTS application finished.")
100 changes: 100 additions & 0 deletions examples/web/funcs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import random

import torch
import gradio as gr
import numpy as np

from tools.logger import get_logger
logger = get_logger(" WebUI ")

import ChatTTS
chat = ChatTTS.Chat(get_logger("ChatTTS"))

# 音色选项:用于预置合适的音色
voices = {
"默认": {"seed": 2},
"音色1": {"seed": 1111},
"音色2": {"seed": 2222},
"音色3": {"seed": 3333},
"音色4": {"seed": 4444},
"音色5": {"seed": 5555},
"音色6": {"seed": 6666},
"音色7": {"seed": 7777},
"音色8": {"seed": 8888},
"音色9": {"seed": 9999},
"音色10": {"seed": 11111},
}

def generate_seed():
return gr.update(value=random.randint(1, 100000000))

# 返回选择音色对应的seed
def on_voice_change(vocie_selection):
return voices.get(vocie_selection)['seed']

def refine_text(text, audio_seed_input, text_seed_input, refine_text_flag):
if not refine_text_flag:
return text

global chat

torch.manual_seed(audio_seed_input)
params_refine_text = {'prompt': '[oral_2][laugh_0][break_6]'}

torch.manual_seed(text_seed_input)

text = chat.infer(text,
skip_refine_text=False,
refine_text_only=True,
params_refine_text=params_refine_text,
)
return text[0] if isinstance(text, list) else text

def generate_audio(text, temperature, top_P, top_K, audio_seed_input, text_seed_input, stream):
if not text: return None

global chat

torch.manual_seed(audio_seed_input)
rand_spk = chat.sample_random_speaker()
params_infer_code = {
'spk_emb': rand_spk,
'temperature': temperature,
'top_P': top_P,
'top_K': top_K,
}
torch.manual_seed(text_seed_input)

wav = chat.infer(
text,
skip_refine_text=True,
params_infer_code=params_infer_code,
stream=stream,
)

if stream:
for gen in wav:
wavs = [np.array([[]])]
wavs[0] = np.hstack([wavs[0], np.array(gen[0])])
audio = wavs[0][0]

# normalize
am = np.abs(audio).max() * 32768
if am > 32768:
am = 32768 * 32768 / am
np.multiply(audio, am, audio)
audio = audio.astype(np.int16)

yield 24000, audio
return

audio_data = np.array(wav[0]).flatten()
# normalize
am = np.abs(audio_data).max() * 32768
if am > 32768:
am = 32768 * 32768 / am
np.multiply(audio_data, am, audio_data)
audio_data = audio_data.astype(np.int16)
sample_rate = 24000

yield sample_rate, audio_data
Loading

0 comments on commit a851b9a

Please sign in to comment.