Skip to content

Commit

Permalink
fix(examples): ipnyb infer (#392)
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Jun 21, 2024
1 parent f17bd31 commit f9e6d2d
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 9 deletions.
8 changes: 5 additions & 3 deletions ChatTTS/model/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def get_emb(self, input_ids, text_mask):
emb_code = [self.emb_code[i](masked_input_ids[:, i]) for i in range(self.num_vq)]
emb_code = torch.stack(emb_code, 2).sum(2)

emb = torch.cat((emb_text, emb_code)).unsqueeze_(0)
emb = torch.zeros((input_ids.shape[:-1])+(emb_text.shape[-1],), device=emb_text.device, dtype=emb_text.dtype)
emb[text_mask] = emb_text
emb[~text_mask] = emb_code.to(emb.dtype)

del emb_text, emb_code

Expand Down Expand Up @@ -263,10 +265,10 @@ def generate(
if not infer_text:
# logits = rearrange(logits, "b c n -> (b n) c")
logits = logits.permute(0, 2, 1)
logits = logits.view(-1, logits.size(2))
logits = logits.reshape(-1, logits.size(2))
# logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1)
logits_token = inputs_ids_sliced.view(
logits_token = inputs_ids_sliced.reshape(
inputs_ids_sliced.size(0)*inputs_ids_sliced.size(1), -1,
)
else:
Expand Down
19 changes: 16 additions & 3 deletions examples/ipynb/example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,19 @@
"metadata": {},
"outputs": [],
"source": [
"import os, sys\n",
"\n",
"if sys.platform == \"darwin\":\n",
" os.environ[\"PYTORCH_ENABLE_MPS_FALLBACK\"] = \"1\"\n",
"\n",
"if not \"root_dir\" in globals():\n",
" now_dir = os.getcwd() # skip examples/ipynb\n",
" root_dir = os.path.join(now_dir, \"../../\")\n",
" sys.path.append(root_dir)\n",
" print(\"init root dir to\", root_dir)\n",
"\n",
"from dotenv import load_dotenv\n",
"load_dotenv(\"sha256.env\")\n",
"load_dotenv(os.path.join(root_dir, \"sha256.env\"))\n",
"\n",
"import torch\n",
"torch._dynamo.config.cache_size_limit = 64\n",
Expand All @@ -38,6 +49,8 @@
"metadata": {},
"outputs": [],
"source": [
"os.chdir(root_dir)\n",
"\n",
"chat = ChatTTS.Chat()\n",
"chat.load_models()\n",
"\n",
Expand Down Expand Up @@ -70,7 +83,7 @@
"source": [
"texts = [\"So we found being competitive and collaborative was a huge way of staying motivated towards our goals, so one person to call when you fall off, one person who gets you back on then one person to actually do the activity with.\",]*3 \\\n",
" + [\"我觉得像我们这些写程序的人,他,我觉得多多少少可能会对开源有一种情怀在吧我觉得开源是一个很好的形式。现在其实最先进的技术掌握在一些公司的手里的话,就他们并不会轻易的开放给所有的人用。\"]*3 \n",
" \n",
"\n",
"wavs = chat.infer(texts)"
]
},
Expand Down Expand Up @@ -239,7 +252,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
"version": "3.9.6"
}
},
"nbformat": 4,
Expand Down
6 changes: 3 additions & 3 deletions examples/web/webui.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,10 @@ def generate_audio_stream(text, temperature, top_P, top_K, audio_seed_input, tex
def main():

with gr.Blocks() as demo:
gr.Markdown("# ChatTTS Webui")
gr.Markdown("ChatTTS Model: [2noise/ChatTTS](https://github.com/2noise/ChatTTS)")
gr.Markdown("# ChatTTS WebUI")
gr.Markdown("> ChatTTS Model: [2noise/ChatTTS](https://github.com/2noise/ChatTTS)")

default_text = "四川美食确实以辣闻名,但也有不辣的选择。[uv_break]比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。[laugh]"
default_text = "四川美食确实以辣闻名,但也有不辣的选择。比如甜水面、赖汤圆、蛋烘糕、叶儿粑等,这些小吃口味温和,甜而不腻,也很受欢迎。"
text_input = gr.Textbox(label="Input Text", lines=4, placeholder="Please Input Text...", value=default_text)

with gr.Row():
Expand Down

0 comments on commit f9e6d2d

Please sign in to comment.