Skip to content

Commit

Permalink
just make a guess about pyworld and make sure it runs
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Sep 1, 2023
1 parent fb9e1d5 commit 4be2fd3
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 19 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,9 @@ trainer.train()
- [x] complete perceiver then cross attention conditioning on ddpm side
- [x] add classifier free guidance, even if not in paper
- [x] complete duration / pitch prediction during training - thanks to Manmay
- [x] make sure pyworld way of computing pitch can also work

- [ ] make sure pyworld way of computing pitch can also work
- [ ] consult phd student in TTS field about pyworld usage
- [ ] also offer direct summation conditioning using spear-tts text-to-semantic module, if available
- [ ] add self-conditioning on ddpm side
- [ ] take care of automatic slicing of audio for prompt, being aware of minimal audio segment as allowed by the codec model
Expand Down
72 changes: 55 additions & 17 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def default(val, d):
return val
return d() if callable(d) else d

def divisible_by(num, den):
return (num % den) == 0

def identity(t, *args, **kwargs):
return t

Expand Down Expand Up @@ -94,7 +97,7 @@ def generate_mask_from_lengths(lengths):
class LearnedSinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
assert (dim % 2) == 0
assert divisible_by(dim, 2)

This comment has been minimized.

Copy link
@p0p4k

p0p4k Sep 1, 2023

Contributor

Is this for readability ; your "style" of coding, or anything significant? ty

This comment has been minimized.

Copy link
@lucidrains

lucidrains Sep 1, 2023

Author Owner

nothing significant, just my style

half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))

Expand All @@ -115,19 +118,37 @@ def compute_pitch_pytorch(wav, sample_rate):

#as mentioned in paper using pyworld

def compute_pitch(spec, sample_rate, hop_length, pitch_fmax=640.0):
# align F0 length to the spectrogram length
if len(spec) % hop_length == 0:
spec = np.pad(spec, (0, hop_length // 2), mode="reflect")
def compute_pitch_pyworld(wav, sample_rate, hop_length, pitch_fmax=640.0):
is_tensor_input = torch.is_tensor(wav)

f0, t = pw.dio(
spec.astype(np.double),
fs=sample_rate,
f0_ceil=pitch_fmax,
frame_period=1000 * hop_length / sample_rate,
)
f0 = pw.stonemask(spec.astype(np.double), f0, t, sample_rate)
return f0
if is_tensor_input:
device = wav.device
wav = wav.contiguous().cpu().numpy()

if divisible_by(len(wav), hop_length):
wav = np.pad(wav, (0, hop_length // 2), mode="reflect")

wav = wav.astype(np.double)

outs = []

for sample in wav:
f0, t = pw.dio(

This comment has been minimized.

Copy link
@lucidrains
sample,
fs = sample_rate,
f0_ceil = pitch_fmax,
frame_period = 1000 * hop_length / sample_rate,
)

f0 = pw.stonemask(sample, f0, t, sample_rate)
outs.append(f0)

outs = np.stack(outs)

if is_tensor_input:
outs = torch.from_numpy(outs).to(device)

return outs

def f0_to_coarse(f0, f0_bin = 256, f0_max = 1100.0, f0_min = 50.0):
f0_mel_max = 1127 * torch.log(1 + torch.tensor(f0_max) / 700)
Expand Down Expand Up @@ -1115,6 +1136,8 @@ def __init__(
num_phoneme_tokens: int = 150,
pitch_emb_dim: int = 256,
pitch_emb_pp_hidden_dim: int= 512,
calc_pitch_with_pyworld = True, # pyworld or kaldi from torchaudio
mel_hop_length = 160,
audio_to_mel_kwargs: dict = dict(),
scale = 1., # this will be set to < 1. for better convergence when training on higher resolution images
duration_loss_weight = 1.,
Expand Down Expand Up @@ -1145,11 +1168,16 @@ def __init__(
if exists(self.target_sample_hz):
audio_to_mel_kwargs.update(sampling_rate = self.target_sample_hz)

self.mel_hop_length = mel_hop_length

self.audio_to_mel = AudioToMel(
n_mels = aligner_dim_in,
hop_length = mel_hop_length,
**audio_to_mel_kwargs
)

self.calc_pitch_with_pyworld = calc_pitch_with_pyworld

self.phoneme_enc = PhonemeEncoder(tokenizer=tokenizer, num_tokens=num_phoneme_tokens)
self.prompt_enc = SpeechPromptEncoder(dim_codebook=dim_codebook)
self.duration_pitch = DurationPitchPredictor(dim=duration_pitch_dim)
Expand Down Expand Up @@ -1456,21 +1484,31 @@ def forward(
prompt_enc = self.prompt_enc(prompt)
phoneme_enc = self.phoneme_enc(text)

# process pitch
# process pitch with kaldi

if not exists(pitch):
assert exists(audio) and audio.ndim == 2
assert exists(self.target_sample_hz)

pitch = compute_pitch_pytorch(audio, self.target_sample_hz)
if self.calc_pitch_with_pyworld:
pitch = compute_pitch_pyworld(
audio,
sample_rate = self.target_sample_hz,
hop_length = self.mel_hop_length
)
else:
pitch = compute_pitch_pytorch(audio, self.target_sample_hz)

pitch = rearrange(pitch, 'b n -> b 1 n')

# process mel

if not exists(mel):
assert exists(audio) and audio.ndim == 2
mel = self.audio_to_mel(audio)
mel = mel[..., :pitch.shape[-1]]

if exists(pitch):
mel = mel[..., :pitch.shape[-1]]

mel_max_length = mel.shape[-1]

Expand Down Expand Up @@ -1803,7 +1841,7 @@ def train(self):
if accelerator.is_main_process:
self.ema.update()

if self.step % self.save_and_sample_every == 0:
if divisible_by(self.step, self.save_and_sample_every):
milestone = self.step // self.save_and_sample_every

models = [(self.unwrapped_model, str(self.step))]
Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.1'
__version__ = '0.1.2'

0 comments on commit 4be2fd3

Please sign in to comment.