Skip to content

Commit

Permalink
account for variable sequence lengths when generating alignment mask …
Browse files Browse the repository at this point in the history
…from durations
  • Loading branch information
lucidrains committed Sep 2, 2023
1 parent 4be2fd3 commit deb97db
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 10 deletions.
20 changes: 11 additions & 9 deletions naturalspeech2_pytorch/naturalspeech2_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,23 @@ def prob_mask_like(shape, prob, device):
else:
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob

def generate_mask_from_lengths(lengths):
src = lengths.int()
device = src.device
tgt_length = src.sum(dim = -1).amax().item()
def generate_mask_from_repeats(repeats):
repeats = repeats.int()
device = repeats.device

cumsum = src.cumsum(dim = -1)
lengths = repeats.sum(dim = -1)
max_length = lengths.amax().item()
cumsum = repeats.cumsum(dim = -1)
cumsum_exclusive = F.pad(cumsum, (1, -1), value = 0.)

tgt_arange = torch.arange(tgt_length, device = device)
tgt_arange = repeat(tgt_arange, '... j -> ... i j', i = src.shape[-1])
seq = torch.arange(max_length, device = device)
seq = repeat(seq, '... j -> ... i j', i = repeats.shape[-1])

cumsum = rearrange(cumsum, '... i -> ... i 1')
cumsum_exclusive = rearrange(cumsum_exclusive, '... i -> ... i 1')

mask = (tgt_arange < cumsum) & (tgt_arange >= cumsum_exclusive)
lengths = rearrange(lengths, 'b -> b 1 1')
mask = (seq < cumsum) & (seq >= cumsum_exclusive) & (seq < lengths)
return mask

# sinusoidal positional embeds
Expand Down Expand Up @@ -1424,7 +1426,7 @@ def sample(
duration, pitch = self.duration_pitch(phoneme_enc, prompt_enc)
pitch = rearrange(pitch, 'b n -> b 1 n')

aln_mask = generate_mask_from_lengths(duration).float()
aln_mask = generate_mask_from_repeats(duration).float()

cond = self.expand_encodings(rearrange(phoneme_enc, 'b n d -> b d n'), rearrange(aln_mask, 'b n c -> b 1 n c'), pitch)

Expand Down
2 changes: 1 addition & 1 deletion naturalspeech2_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.1.2'
__version__ = '0.1.4'

0 comments on commit deb97db

Please sign in to comment.