Skip to content

Commit

Permalink
fix bug with residual attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 28, 2020
1 parent ee19224 commit acb4c65
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '0.4.1',
version = '0.4.2',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
5 changes: 3 additions & 2 deletions x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,11 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots)

pre_softmax_attn = dots
if exists(prev_attn):
dots = dots + prev_attn

pre_softmax_attn = dots

if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()

Expand Down Expand Up @@ -269,7 +270,7 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos =
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')

return self.to_out(out), prev_attn
return self.to_out(out), pre_softmax_attn

class AttentionLayers(nn.Module):
def __init__(
Expand Down

0 comments on commit acb4c65

Please sign in to comment.