diff --git a/naturalspeech2_pytorch/naturalspeech2_pytorch.py b/naturalspeech2_pytorch/naturalspeech2_pytorch.py index aaa351b..9db47d6 100644 --- a/naturalspeech2_pytorch/naturalspeech2_pytorch.py +++ b/naturalspeech2_pytorch/naturalspeech2_pytorch.py @@ -75,21 +75,23 @@ def prob_mask_like(shape, prob, device): else: return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob -def generate_mask_from_lengths(lengths): - src = lengths.int() - device = src.device - tgt_length = src.sum(dim = -1).amax().item() +def generate_mask_from_repeats(repeats): + repeats = repeats.int() + device = repeats.device - cumsum = src.cumsum(dim = -1) + lengths = repeats.sum(dim = -1) + max_length = lengths.amax().item() + cumsum = repeats.cumsum(dim = -1) cumsum_exclusive = F.pad(cumsum, (1, -1), value = 0.) - tgt_arange = torch.arange(tgt_length, device = device) - tgt_arange = repeat(tgt_arange, '... j -> ... i j', i = src.shape[-1]) + seq = torch.arange(max_length, device = device) + seq = repeat(seq, '... j -> ... i j', i = repeats.shape[-1]) cumsum = rearrange(cumsum, '... i -> ... i 1') cumsum_exclusive = rearrange(cumsum_exclusive, '... i -> ... i 1') - mask = (tgt_arange < cumsum) & (tgt_arange >= cumsum_exclusive) + lengths = rearrange(lengths, 'b -> b 1 1') + mask = (seq < cumsum) & (seq >= cumsum_exclusive) & (seq < lengths) return mask # sinusoidal positional embeds @@ -1424,7 +1426,7 @@ def sample( duration, pitch = self.duration_pitch(phoneme_enc, prompt_enc) pitch = rearrange(pitch, 'b n -> b 1 n') - aln_mask = generate_mask_from_lengths(duration).float() + aln_mask = generate_mask_from_repeats(duration).float() cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mask, 'b n c -> b 1 n c'), pitch) diff --git a/naturalspeech2_pytorch/version.py b/naturalspeech2_pytorch/version.py index 10939f0..7525d19 100644 --- a/naturalspeech2_pytorch/version.py +++ b/naturalspeech2_pytorch/version.py @@ -1 +1 @@ -__version__ = '0.1.2' +__version__ = '0.1.4'