Skip to content

Commit

Permalink
add talking heads feature, turned on with attn_talking_heads = True
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 5, 2020
1 parent 408ed81 commit 6749bc2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.0.12',
version = '0.0.14',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
32 changes: 25 additions & 7 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, dim_head = 64, heads = 8, causal = False, mask = None):
def __init__(self, dim, dim_head = 64, heads = 8, causal = False, mask = None, talking_heads = False):
super().__init__()
self.scale = dim_head ** -0.5
self.heads = heads
Expand All @@ -147,8 +147,13 @@ def __init__(self, dim, dim_head = 64, heads = 8, causal = False, mask = None):
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
self.to_out = nn.Linear(inner_dim, dim)

self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))

def forward(self, x, context = None, mask = None, context_mask = None, rel_pos = None):
b, n, _, h, device = *x.shape, self.heads, x.device
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
kv_input = default(context, x)

q = self.to_q(x)
Expand All @@ -157,6 +162,9 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, *kv))
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj)

if exists(rel_pos):
dots = rel_pos(dots)

Expand All @@ -175,12 +183,16 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
del mask

attn = dots.softmax(dim = -1)

if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.post_softmax_proj)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)

class Encoder(nn.Module):
def __init__(self, dim, depth, dim_head = 64, heads = 8, use_scalenorm = False, ff_glu = False, rel_pos_bias = False):
def __init__(self, dim, depth, dim_head = 64, heads = 8, use_scalenorm = False, ff_glu = False, rel_pos_bias = False, **kwargs):
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([])
Expand All @@ -189,10 +201,13 @@ def __init__(self, dim, depth, dim_head = 64, heads = 8, use_scalenorm = False,
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
prenorm_fn = partial(PreNorm, dim, norm_class = norm_class)

ff_kwargs, kwargs = group_by_key_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = group_by_key_prefix_and_trim('attn_', kwargs)

for _ in range(depth):
self.layers.append(nn.ModuleList([
prenorm_fn(Attention(dim, dim_head = dim_head, heads = heads)),
prenorm_fn(FeedForward(dim, glu = ff_glu))
prenorm_fn(Attention(dim, dim_head = dim_head, heads = heads, **attn_kwargs)),
prenorm_fn(FeedForward(dim, **ff_kwargs))
]))
def forward(self, x, context = None, mask = None):
for (self_attn, ff) in self.layers:
Expand All @@ -201,7 +216,7 @@ def forward(self, x, context = None, mask = None):
return x

class Decoder(nn.Module):
def __init__(self, dim, depth, dim_head = 64, heads = 8, cross_attend = False, use_scalenorm = False, ff_glu = False, rel_pos_bias = False):
def __init__(self, dim, depth, dim_head = 64, heads = 8, cross_attend = False, use_scalenorm = False, rel_pos_bias = False, **kwargs):
super().__init__()
self.dim = dim
self.layers = nn.ModuleList([])
Expand All @@ -210,11 +225,14 @@ def __init__(self, dim, depth, dim_head = 64, heads = 8, cross_attend = False, u
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
prenorm_fn = partial(PreNorm, dim, norm_class = norm_class)

ff_kwargs, kwargs = group_by_key_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = group_by_key_prefix_and_trim('attn_', kwargs)

for _ in range(depth):
self.layers.append(nn.ModuleList([
prenorm_fn(Attention(dim, dim_head = dim_head, heads = heads, causal = True)),
prenorm_fn(Attention(dim, dim_head = dim_head, heads = heads)) if cross_attend else None,
prenorm_fn(FeedForward(dim, glu = ff_glu)),
prenorm_fn(FeedForward(dim, **ff_kwargs)),
]))
def forward(self, x, context = None, mask = None, context_mask = None):
for (self_attn, cross_attn, ff) in self.layers:
Expand Down

0 comments on commit 6749bc2

Please sign in to comment.