-
Notifications
You must be signed in to change notification settings - Fork 103
/
Copy pathvit_grad_rollout.py
66 lines (56 loc) · 2.39 KB
/
vit_grad_rollout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import torch
from PIL import Image
import numpy
import sys
from torchvision import transforms
import numpy as np
import cv2
def grad_rollout(attentions, gradients, discard_ratio):
result = torch.eye(attentions[0].size(-1))
with torch.no_grad():
for attention, grad in zip(attentions, gradients):
weights = grad
attention_heads_fused = (attention*weights).mean(axis=1)
attention_heads_fused[attention_heads_fused < 0] = 0
# Drop the lowest attentions, but
# don't drop the class token
flat = attention_heads_fused.view(attention_heads_fused.size(0), -1)
_, indices = flat.topk(int(flat.size(-1)*discard_ratio), -1, False)
#indices = indices[indices != 0]
flat[0, indices] = 0
I = torch.eye(attention_heads_fused.size(-1))
a = (attention_heads_fused + 1.0*I)/2
a = a / a.sum(dim=-1)
result = torch.matmul(a, result)
# Look at the total attention between the class token,
# and the image patches
mask = result[0, 0 , 1 :]
# In case of 224x224 image, this brings us from 196 to 14
width = int(mask.size(-1)**0.5)
mask = mask.reshape(width, width).numpy()
mask = mask / np.max(mask)
return mask
class VITAttentionGradRollout:
def __init__(self, model, attention_layer_name='attn_drop',
discard_ratio=0.9):
self.model = model
self.discard_ratio = discard_ratio
for name, module in self.model.named_modules():
if attention_layer_name in name:
module.register_forward_hook(self.get_attention)
module.register_backward_hook(self.get_attention_gradient)
self.attentions = []
self.attention_gradients = []
def get_attention(self, module, input, output):
self.attentions.append(output.cpu())
def get_attention_gradient(self, module, grad_input, grad_output):
self.attention_gradients.append(grad_input[0].cpu())
def __call__(self, input_tensor, category_index):
self.model.zero_grad()
output = self.model(input_tensor)
category_mask = torch.zeros(output.size())
category_mask[:, category_index] = 1
loss = (output*category_mask).sum()
loss.backward()
return grad_rollout(self.attentions, self.attention_gradients,
self.discard_ratio)