From acb4c658090747c1b3da1281006c8d97eec57373 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Sun, 27 Dec 2020 19:25:13 -0800 Subject: [PATCH] fix bug with residual attention --- setup.py | 2 +- x_transformers/x_transformers.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 6c8e9850..41b0f640 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'x-transformers', packages = find_packages(exclude=['examples']), - version = '0.4.1', + version = '0.4.2', license='MIT', description = 'X-Transformers - Pytorch', author = 'Phil Wang', diff --git a/x_transformers/x_transformers.py b/x_transformers/x_transformers.py index 9eb43adb..7455cf15 100644 --- a/x_transformers/x_transformers.py +++ b/x_transformers/x_transformers.py @@ -233,10 +233,11 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos = dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale mask_value = max_neg_value(dots) - pre_softmax_attn = dots if exists(prev_attn): dots = dots + prev_attn + pre_softmax_attn = dots + if talking_heads: dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous() @@ -269,7 +270,7 @@ def forward(self, x, context = None, mask = None, context_mask = None, rel_pos = out = einsum('b h i j, b h j d -> b h i d', attn, v) out = rearrange(out, 'b h n d -> b n (h d)') - return self.to_out(out), prev_attn + return self.to_out(out), pre_softmax_attn class AttentionLayers(nn.Module): def __init__(