-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathexplain_model.py
245 lines (213 loc) · 10.1 KB
/
explain_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
#%%
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import os
from skimage.segmentation import mark_boundaries
import lime
from lime import lime_image
from skimage.color import label2rgb
from lime.wrappers.scikit_image import SegmentationAlgorithm
import pickle
#%%
# os.chdir("/home/hcrlab/cozmo/explainability/")
os.chdir(os.path.expanduser("~/repos/explainable_behavior/"))
#%% load model
model = keras.models.load_model("model/cozmo_drive_model.h5")
#%% load data
with open("model/pickles/train_val_test.pkl", 'rb') as file:
train_images, val_images, test_images, train_labels, val_labels, test_labels = pickle.load(file)
#%%
# load label names
with open("model/pickles/label_names.pkl", 'rb') as file:
labels_dict, labels_list = pickle.load(file)
#%%
# load top predictions
with open("model/pickles/predictions.pkl", 'rb') as file:
predictions, top_predictions = pickle.load(file)
#%%
explainer = lime_image.LimeImageExplainer(verbose=False)
segmenter = SegmentationAlgorithm('slic', n_segments=100, compactness=1, sigma=1)
#%%
num_top_labels=4
explanation = explainer.explain_instance(test_images[0], classifier_fn=model.predict,
top_labels=num_top_labels, hide_color=0, num_samples=10000, segmentation_fn=segmenter)
#%%
temp, mask = explanation.get_image_and_mask(test_labels[0], positive_only=True,
num_features=5, hide_rest=False)
fig, (ax1, ax2) = plt.subplots(1,2, figsize=(8, 4))
ax1.imshow(label2rgb(mask, temp, bg_label=0), interpolation='nearest')
ax1.set_title('Positive Regions for {}'.format(labels_list[test_labels[0]]))
temp, mask = explanation.get_image_and_mask(test_labels[0], positive_only=False,
num_features=10, hide_rest=False)
ax2.imshow(label2rgb(3-mask, temp, bg_label=0), interpolation = 'nearest')
ax2.set_title('Positive/Negative Regions for {}'.format(labels_list[test_labels[0]]))
#%%
fig, m_axs = plt.subplots(2,num_top_labels, figsize=(12,4))
for i, (c_ax, gt_ax) in zip(explanation.top_labels, m_axs.T):
temp, mask = explanation.get_image_and_mask(i, positive_only=True, num_features=5,
hide_rest=False, min_weight=0.01)
c_ax.imshow(label2rgb(mask,temp, bg_label=0), interpolation='nearest')
c_ax.set_title('Positive for {}\nScore:{:2.2f}%'.format(labels_list[i], 100*predictions[0, i]))
c_ax.axis('off')
action_id = np.random.choice(np.where(train_labels==i)[0])
gt_ax.imshow(train_images[action_id])
gt_ax.set_title('Example of {}'.format(labels_list[i]))
gt_ax.axis('off')
#%%
# generate several explanation summary images for correct predictions
for i in np.unique(test_labels):
print("Generating explanations for correctly classified {} actions...".format(labels_list[i]))
# create necessary folders
while True:
try:
os.chdir('model/explanations/{}/correct/'.format(labels_list[i]))
break
except FileNotFoundError:
os.makedirs('model/explanations/{}/correct/'.format(labels_list[i]))
i_locations = np.where(test_labels == i)[0]
equality = np.where(test_labels == top_predictions)
# want indices where model predicts correctly AND matches label (i) we're
# currently working on
i_locations = np.intersect1d(equality, i_locations)
# i_locations = np.where(top_predictions[i_locations] == test_labels[i_locations])
# randomly pick 10 images; if there are fewer than ten to choose from just
# pick them all
selection = np.random.choice(i_locations, 10) if len(i_locations) > 10 else i_locations
i_labels = test_labels[selection]
i_predictions = top_predictions[selection]
# create mask array
mask_array = np.empty([selection.shape[0], 60, 80])
# generate explanation summary image for each selected image
for j in range(selection.shape[0]):
# create explanation
num_top_labels = 4
explanation = explainer.explain_instance(test_images[selection[j]], classifier_fn=model.predict,
top_labels=num_top_labels, hide_color=0, num_samples=1000, segmentation_fn=segmenter)
# create figure
fig, m_axs = plt.subplots(2,num_top_labels, figsize=(12,4))
first_loop = True # for saving mask only for the correct prediction
for k, (c_ax, gt_ax) in zip(explanation.top_labels, m_axs.T):
temp, mask = explanation.get_image_and_mask(k, positive_only=True, num_features=5,
hide_rest=False, min_weight=0.01)
if first_loop:
mask_array[j] = mask
c_ax.imshow(label2rgb(mask, temp, bg_label=0), interpolation='nearest')
c_ax.set_title('Positive for {}\nScore:{:2.2f}%'.format(labels_list[k], 100*predictions[selection[j], k]))
c_ax.axis('off')
action_id = np.random.choice(np.where(train_labels==k)[0])
gt_ax.imshow(train_images[action_id])
gt_ax.set_title('Example of {}'.format(labels_list[k]))
gt_ax.axis('off')
first_loop = False
plt.savefig("{}.jpg".format(selection[j]))
plt.close(fig)
# save average explanation
if selection.shape[0] > 0:
fig, ax = plt.subplots()
average_explanation = np.reshape(stats.mode(mask_array, axis=0)[0], (60, 80))
ax.imshow(label2rgb(average_explanation, bg_label=0), interpolation='nearest')
ax.set_title("Average explanation")
plt.savefig("../average_explanation.jpg")
plt.close(fig)
os.chdir('../../../..')
print("done")
#%%
# generate several explanation summary images for incorrect predictions
for i in np.unique(test_labels):
print("Generating explanations for incorrectly classified {} actions...".format(labels_list[i]))
# create necessary folders
while True:
try:
os.chdir('model/explanations/{}/incorrect/'.format(labels_list[i]))
break
except FileNotFoundError:
os.makedirs('model/explanations/{}/incorrect/'.format(labels_list[i]))
i_locations = np.where(test_labels == i)[0]
equality = np.where(test_labels != top_predictions)
# want indices where model predicts incorrectly AND matches label (i) we're
# currently working on
i_locations = np.intersect1d(equality, i_locations)
# randomly pick 10 images; if there are fewer than ten to choose from just
# pick them all
selection = np.random.choice(i_locations, 10) if len(i_locations) > 10 else i_locations
i_labels = test_labels[selection]
i_predictions = top_predictions[selection]
# generate explanation summary image for each selected image
for j in range(selection.shape[0]):
# create explanation
num_top_labels = 4
explanation = explainer.explain_instance(test_images[selection[j]], classifier_fn=model.predict,
top_labels=num_top_labels, hide_color=0, num_samples=1000, segmentation_fn=segmenter)
# create figure
fig, m_axs = plt.subplots(2,num_top_labels, figsize=(12,4))
for k, (c_ax, gt_ax) in zip(explanation.top_labels, m_axs.T):
temp, mask = explanation.get_image_and_mask(k, positive_only=True, num_features=5,
hide_rest=False, min_weight=0.01)
c_ax.imshow(label2rgb(mask,temp, bg_label=0), interpolation='nearest')
c_ax.set_title('Positive for {}\nScore:{:2.2f}%'.format(labels_list[k], 100*predictions[selection[j], k]))
c_ax.axis('off')
action_id = np.random.choice(np.where(train_labels==k)[0])
gt_ax.imshow(train_images[action_id])
gt_ax.set_title('Example of {}'.format(labels_list[k]))
gt_ax.axis('off')
plt.savefig("{}.jpg".format(selection[j]))
plt.close(fig)
os.chdir('../../../..')
print("done")
#%%
# generate average explanations
for i in np.unique(test_labels):
print("Generating average explanation for {} actions...".format(labels_list[i]))
# ensure we're in correct folder
assert os.getcwd().split('/')[-1] == 'explainability'
# create necessary folders
while True:
try:
os.chdir('model/explanations/{}/'.format(labels_list[i]))
break
except FileNotFoundError:
os.makedirs('model/explanations/{}/'.format(labels_list[i]))
i_locations = np.where(top_predictions == i)[0]
# randomly pick 10 images; if there are fewer than ten to choose from just
# pick them all
selection = np.random.choice(i_locations, 10) if len(i_locations) > 10 else i_locations
# create mask array
mask_array = np.empty([selection.shape[0], 60, 80])
# generate explanation summary image for each selected image
for j in range(selection.shape[0]):
# create explanation
num_top_labels = 4
explanation = explainer.explain_instance(test_images[selection[j]], classifier_fn=model.predict,
top_labels=num_top_labels, hide_color=0, num_samples=1000, segmentation_fn=segmenter)
temp, mask = explanation.get_image_and_mask(i, positive_only=True, num_features=5,
hide_rest=False, min_weight=0.01)
mask_array[j] = mask
# save average explanation
if selection.shape[0] > 0:
fig, ax = plt.subplots()
average_explanation = np.reshape(stats.mode(mask_array, axis=0)[0], (60, 80))
ax.imshow(label2rgb(average_explanation, bg_label=0), interpolation='nearest')
ax.set_title("Average explanation")
plt.savefig("average_explanation.jpg")
plt.close(fig)
os.chdir('../../../')
print("done")
#%%
explainer = lime_image.LimeImageExplainer()
explanation = explainer.explain_instance(test_images[1], model.predict,
top_labels=3, hide_color=0, num_samples=1000, batch_size=1)
#%%
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0],
positive_only=False, num_features=5, hide_rest=False, min_weight=0.1)
plt.imshow(mark_boundaries(temp, mask))
#%%
# save lime predictions for first 15 images
explainer = lime_image.LimeImageExplainer()
for i in range(15):
explanation = explainer.explain_instance(test_images[i], model.predict, top_labels=3,
hide_color=0, num_samples=1000, batch_size=1)
#%%