Skip to content

Commit

Permalink
add all-attention memory key/values, set with attn_num_mem_kv = {int}
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 5, 2020
1 parent 2e0f8ad commit 2dce3d2
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.0.16',
version = '0.0.17',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
30 changes: 27 additions & 3 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import torch.nn.functional as F
from functools import partial
from inspect import isfunction
from einops import rearrange, repeat

from einops import rearrange, repeat
from entmax import entmax15

from x_transformers.autoregressive_wrapper import AutoregressiveWrapper
Expand Down Expand Up @@ -137,7 +137,18 @@ def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, dim_head = 64, heads = 8, causal = False, mask = None, talking_heads = False, sparse_topk = None, use_entmax15 = False):
def __init__(
self,
dim,
dim_head = 64,
heads = 8,
causal = False,
mask = None,
talking_heads = False,
sparse_topk = None,
use_entmax15 = False,
num_mem_kv = 0
):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
Expand All @@ -161,6 +172,12 @@ def __init__(self, dim, dim_head = 64, heads = 8, causal = False, mask = None, t
# entmax
self.attn_fn = entmax15 if use_entmax15 else F.softmax

# add memory key / values
self.num_mem_kv = num_mem_kv
if num_mem_kv > 0:
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))

def forward(self, x, context = None, mask = None, context_mask = None, rel_pos = None):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
kv_input = default(context, x)
Expand All @@ -169,6 +186,12 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
kv = self.to_kv(kv_input).chunk(2, dim = -1)

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))

if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b = b), (self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim = -2)
v = torch.cat((mem_v, v), dim = -2)

dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

if talking_heads:
Expand All @@ -187,7 +210,8 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
del mask

if self.causal:
mask = torch.ones((n, n), device = device).triu_(1).bool()
i, j = dots.shape[-2:]
mask = torch.ones((i, j), device = device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, float('-inf'))
del mask

Expand Down

0 comments on commit 2dce3d2

Please sign in to comment.