-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathssd_minimal.py
89 lines (69 loc) · 3.15 KB
/
ssd_minimal.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import torch
import torch.nn.functional as F
from einops import repeat, rearrange
def pad_by_size(x, pad_size):
pad_shape = (0, 0, 0, 0, 0, pad_size, 0, 0) if len(x.shape) == 4 \
else (0, 0, 0, pad_size, 0, 0)
return F.pad(x, pad_shape, mode="constant", value=0)
def segsum(x):
"""More stable segment sum calculation."""
T = x.size(-1)
x = repeat(x, "... d -> ... d e", e=T)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=-1)
x = x.masked_fill(~mask, 0)
x_segsum = torch.cumsum(x, dim=-2)
mask = torch.tril(torch.ones(T, T, device=x.device, dtype=torch.bool), diagonal=0)
x_segsum = x_segsum.masked_fill(~mask, -torch.inf)
return x_segsum
def ssd_minimal_discrete(X, dt, A, B, C, block_len, D=None, initial_states=None):
"""
Arguments:
X: (batch, length, n_heads, d_head)
dt: (batch, length, n_heads)
A: (n_heads)
B: (batch, length, n_heads, d_state)
C: (batch, length, n_heads, d_state)
Return:
Y: (batch, length, n_heads, d_head)
"""
assert X.dtype == A.dtype == B.dtype == C.dtype
seq_len = X.shape[1]
pad_size = (block_len - seq_len % block_len) % block_len
#assert X.shape[1] % block_len == 0
# (Optional) D skip connection preparing
if D is not None:
skip = D.unsqueeze(-1) * pad_by_size(X, pad_size)
# Discretize X and A
X = X * dt.unsqueeze(-1)
A = A * dt
# Rearrange into blocks/chunks
X, A, B, C = [rearrange(pad_by_size(x, pad_size), "b (c l) ... -> b c l ...", l=block_len) for x in (X, A, B, C)]
A = rearrange(A, "b c l h -> b h c l")
A_cumsum = torch.cumsum(A, dim=-1)
# 1. Compute the output for each intra-chunk (diagonal blocks)
L = torch.exp(segsum(A))
Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X)
# 2. Compute the state for each intra-chunk
# (right term of low-rank factorization of off-diagonal blocks; B terms)
decay_states = torch.exp((A_cumsum[:, :, :, -1:] - A_cumsum))
states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X)
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
if initial_states is None:
initial_states = torch.zeros_like(states[:, :1])
states = torch.cat([initial_states, states], dim=1)
decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0))))
new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states)
states, final_state = new_states[:, :-1], new_states[:, -1]
# 4. Compute state -> output conversion per chunk
# (left term of low-rank factorization of off-diagonal blocks; C terms)
state_decay_out = torch.exp(A_cumsum)
Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out)
# Add output of intra-chunk and inter-chunk terms (diagonal and off-diagonal blocks)
Y = rearrange(Y_diag+Y_off, "b c l h p -> b (c l) h p")
# Add optional D residual
if D is not None:
Y = Y + skip
if pad_size > 0:
Y = Y[:, :seq_len, :, :]
return Y, final_state