-
Notifications
You must be signed in to change notification settings - Fork 0
/
context_aware_attention.py
41 lines (32 loc) · 1.75 KB
/
context_aware_attention.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
## Imports
import torch
import torch.nn as nn
class ContextAwareAttention(nn.Module):
def __init__(self, dim_model, dim_context, dropout_rate = 0.0):
super(ContextAwareAttention, self).__init__()
self.dim_model = dim_model
self.dim_context = dim_context
self.dropout_rate = dropout_rate
self.attention_layer = nn.MultiheadAttention(embed_dim=self.dim_model,
num_heads=1,
dropout=self.dropout_rate,
bias=True,
add_zero_attn=False,
batch_first=True)
self.u_k = nn.Linear(self.dim_context, self.dim_model, bias=False)
self.w1_k = nn.Linear(self.dim_model, 1, bias=False)
self.w2_k = nn.Linear(self.dim_model, 1, bias=False)
self.u_v = nn.Linear(self.dim_context, self.dim_model, bias=False)
self.w1_v = nn.Linear(self.dim_model, 1, bias=False)
self.w2_v = nn.Linear(self.dim_model, 1, bias=False)
def forward(self, q, k, v, context = None):
key_context = self.u_k(context)
value_context = self.u_v(context)
lambda_k = torch.sigmoid(self.w1_k(k) + self.w2_k(key_context))
lambda_v = torch.sigmoid(self.w1_v(v) + self.w2_v(value_context))
k_cap = (1 - lambda_k) * k + lambda_k * key_context
v_cap = (1 - lambda_v) * v + lambda_v * value_context
attention_output, _ = self.attention_layer(query=q,
key=k_cap,
value=v_cap)
return attention_output