-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathtdnn.py
68 lines (57 loc) · 3.18 KB
/
tdnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch.nn as nn
from torch.nn.utils import weight_norm
__author__ = 'Jonas Van Der Donckt'
class TDNN(nn.Module):
def __init__(self, context: list, input_channels: int, output_channels: int, full_context: bool = True):
"""
Implementation of a 'Fast' TDNN layer by exploiting the dilation argument of the PyTorch Conv1d class
Due to its fastness the context has gained two constraints:
* The context must be symmetric
* The context must have equal spacing between each consecutive element
For example: the non-full and symmetric context {-3, -2, 0, +2, +3} is not valid since it doesn't have
equal spacing; The non-full context {-6, -3, 0, 3, 6} is both symmetric and has an equal spacing, this is
considered valid.
:param context: The temporal context
:param input_channels: The number of input channels
:param output_channels: The number of channels produced by the temporal convolution
:param full_context: Indicates whether a full context needs to be used
"""
super(TDNN, self).__init__()
self.full_context = full_context
self.input_dim = input_channels
self.output_dim = output_channels
context = sorted(context)
self.check_valid_context(context, full_context)
if full_context:
kernel_size = context[-1] - context[0] + 1 if len(context) > 1 else 1
self.temporal_conv = weight_norm(nn.Conv1d(input_channels, output_channels, kernel_size))
else:
# use dilation
delta = context[1] - context[0]
self.temporal_conv = weight_norm(
nn.Conv1d(input_channels, output_channels, kernel_size=len(context), dilation=delta))
def forward(self, x):
"""
:param x: is one batch of data, x.size(): [batch_size, input_channels, sequence_length]
sequence length is the dimension of the arbitrary length data
:return: [batch_size, output_dim, len(valid_steps)]
"""
return self.temporal_conv(x)
@staticmethod
def check_valid_context(context: list, full_context: bool) -> None:
"""
Check whether the context is symmetrical and whether and whether the passed
context can be used for creating a convolution kernel with dil
:param full_context: indicates whether the full context (dilation=1) will be used
:param context: The context of the model, must be symmetric if no full context and have an equal spacing.
"""
if full_context:
assert len(context) <= 2, "If the full context is given one must only define the smallest and largest"
if len(context) == 2:
assert context[0] + context[-1] == 0, "The context must be symmetric"
else:
assert len(context) % 2 != 0, "The context size must be odd"
assert context[len(context) // 2] == 0, "The context contain 0 in the center"
if len(context) > 1:
delta = [context[i] - context[i - 1] for i in range(1, len(context))]
assert all(delta[0] == delta[i] for i in range(1, len(delta))), "Intra context spacing must be equal!"