Skip to content

Commit

Permalink
Make style
Browse files Browse the repository at this point in the history
  • Loading branch information
erogol committed Nov 17, 2023
1 parent 26efdf6 commit 44880f0
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 24 deletions.
1 change: 0 additions & 1 deletion TTS/tts/layers/tortoise/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

from TTS.tts.layers.tortoise.dpm_solver import DPM_Solver, NoiseScheduleVP, model_wrapper


try:
from k_diffusion.sampling import sample_dpmpp_2m, sample_euler_ancestral

Expand Down
4 changes: 3 additions & 1 deletion TTS/tts/layers/xtts/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,9 @@ def forward(
audio_codes = F.pad(audio_codes[:, :max_mel_len], (0, 1), value=self.stop_audio_token)

# Pad mel codes with stop_audio_token
audio_codes = self.set_mel_padding(audio_codes, code_lengths - 3) # -3 to get the real code lengths without consider start and stop tokens that was not added yet
audio_codes = self.set_mel_padding(
audio_codes, code_lengths - 3
) # -3 to get the real code lengths without consider start and stop tokens that was not added yet

# Build input and target tensors
# Prepend start token to inputs and append stop token to targets
Expand Down
23 changes: 12 additions & 11 deletions TTS/tts/layers/xtts/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
import os
import re
import torch
import pypinyin
import textwrap

from functools import cached_property

import pypinyin
import torch
from hangul_romanize import Transliter
from hangul_romanize.rule import academic
from num2words import num2words
from spacy.lang.ar import Arabic
from spacy.lang.en import English
from spacy.lang.es import Spanish
from spacy.lang.ja import Japanese
from spacy.lang.zh import Chinese
from tokenizers import Tokenizer

from TTS.tts.layers.xtts.zh_num2words import TextNorm as zh_num2words

from spacy.lang.en import English
from spacy.lang.zh import Chinese
from spacy.lang.ja import Japanese
from spacy.lang.ar import Arabic
from spacy.lang.es import Spanish


def get_spacy_lang(lang):
if lang == "zh":
Expand All @@ -32,6 +31,7 @@ def get_spacy_lang(lang):
# For most languages, Enlish does the job
return English()


def split_sentence(text, lang, text_split_length=250):
"""Preprocess the input text"""
text_splits = []
Expand Down Expand Up @@ -67,6 +67,7 @@ def split_sentence(text, lang, text_split_length=250):

return text_splits


_whitespace_re = re.compile(r"\s+")

# List of (regular expression, replacement) pairs for abbreviations:
Expand Down Expand Up @@ -619,7 +620,7 @@ def katsu(self):
return cutlet.Cutlet()

def check_input_length(self, txt, lang):
lang = lang.split("-")[0] # remove the region
lang = lang.split("-")[0] # remove the region
limit = self.char_limits.get(lang, 250)
if len(txt) > limit:
print(
Expand All @@ -640,7 +641,7 @@ def preprocess_text(self, txt, lang):
return txt

def encode(self, txt, lang):
lang = lang.split("-")[0] # remove the region
lang = lang.split("-")[0] # remove the region
self.check_input_length(txt, lang)
txt = self.preprocess_text(txt, lang)
lang = "zh-cn" if lang == "zh" else lang
Expand Down
14 changes: 5 additions & 9 deletions TTS/tts/models/xtts.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,13 +513,13 @@ def inference(
enable_text_splitting=False,
**hf_generate_kwargs,
):
language = language.split("-")[0] # remove the country code
language = language.split("-")[0] # remove the country code
length_scale = 1.0 / max(speed, 0.05)
if enable_text_splitting:
text = split_sentence(text, language, self.tokenizer.char_limits[language])
else:
text = [text]

wavs = []
gpt_latents_list = []
for sent in text:
Expand Down Expand Up @@ -563,9 +563,7 @@ def inference(

if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2),
scale_factor=length_scale,
mode="linear"
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)

gpt_latents_list.append(gpt_latents.cpu())
Expand Down Expand Up @@ -623,7 +621,7 @@ def inference_stream(
enable_text_splitting=False,
**hf_generate_kwargs,
):
language = language.split("-")[0] # remove the country code
language = language.split("-")[0] # remove the country code
length_scale = 1.0 / max(speed, 0.05)
if enable_text_splitting:
text = split_sentence(text, language, self.tokenizer.char_limits[language])
Expand Down Expand Up @@ -675,9 +673,7 @@ def inference_stream(
gpt_latents = torch.cat(all_latents, dim=0)[None, :]
if length_scale != 1.0:
gpt_latents = F.interpolate(
gpt_latents.transpose(1, 2),
scale_factor=length_scale,
mode="linear"
gpt_latents.transpose(1, 2), scale_factor=length_scale, mode="linear"
).transpose(1, 2)
wav_gen = self.hifigan_decoder(gpt_latents, g=speaker_embedding.to(self.device))
wav_chunk, wav_gen_prev, wav_overlap = self.handle_chunks(
Expand Down
4 changes: 2 additions & 2 deletions tests/zoo_tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def test_xtts_v2_streaming():
"en",
gpt_cond_latent,
speaker_embedding,
speed=1.5
speed=1.5,
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
Expand All @@ -198,7 +198,7 @@ def test_xtts_v2_streaming():
"en",
gpt_cond_latent,
speaker_embedding,
speed=0.66
speed=0.66,
)
wav_chuncks = []
for i, chunk in enumerate(chunks):
Expand Down

0 comments on commit 44880f0

Please sign in to comment.