-
Notifications
You must be signed in to change notification settings - Fork 0
/
dcam.py
101 lines (82 loc) · 4.61 KB
/
dcam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from load_data import *
from utilities import transform_data4ResNet
import numpy as np
from utilities import plot_dcam
import torch
import timeit
from sklearn.metrics import accuracy_score
import pandas as pd
# importing srcs for dResNet and dCAM
import sys
base_path="./dCAM/src/"
sys.path.insert(0, base_path+'explanation')
sys.path.insert(0, base_path+'models')
#from dcam import *
from CNN_models import *
from DCAM import DCAM
# function took by dCAM notebook to plot the explanation
def main():
concat = False
all_data = load_data("MP",concat=concat)
for dataset_name in all_data.keys():
# transform data into pytorch format
train_dataloader,test_dataloader, n_channels, n_classes, device,enc = \
transform_data4ResNet(all_data[dataset_name],dataset_name, concat=concat)
# load previously trained dResNet and predict each test set instance
print("saved_model/resNet/"+dataset_name+"_concat_"+str(concat) )
modelarch = torch.load("saved_model/resNet/"+dataset_name+"_concat_"+str(concat) )
resnet = ModelCNN(model=modelarch ,n_epochs_stop=30,device=device)#,save_path='saved_model/resNet/'#+dataset_name+"_nFilters_"+str(mid_channels)+"_"+str(i))
cnn_output = resnet.predict( test_dataloader )
# convert back to symbolic representation and get accuracy
symbolic_output = enc.inverse_transform(cnn_output)
print(dataset_name,"concat",concat,"accuracy is",accuracy_score(symbolic_output,all_data[dataset_name]["y_test"]))
# variables used for dCAM
last_conv_layer = resnet.model._modules['layers'][2]
fc_layer_name = resnet.model._modules['final']
testSet_length = all_data[dataset_name]["X_test"].shape[0]
target_idxs = enc.transform(all_data[dataset_name]["y_test"])
explanation= [ {} for i in range (testSet_length)]
X_test = all_data[dataset_name]["X_test"]
column_names = X_test.columns.values if type(all_data[dataset_name]['X_train'])==pd.DataFrame else [i for i in range(n_channels)]
# initialize dCAM object and explain
dcam = DCAM(resnet.model,device,last_conv_layer=last_conv_layer,fc_layer_name=fc_layer_name)
starttime = timeit.default_timer()
for i in range(testSet_length):
print("explaining ",i,"-th sample of",dataset_name,"out of",testSet_length)
instance = X_test[i] if type(X_test)==np.ndarray else X_test.values[i]
if concat:
instance = np.expand_dims(instance,0)
gt_label = target_idxs[i]
output_label = cnn_output[i]
# CMJ has just 3 channels -> #possibple permutations=6
nb_permutation = 6 if dataset_name=="CMJ" else 200
generate_all = True if dataset_name=="CMJ" else False
try:
# try to explain the predictions for the ground truth label
dcam_tl,permutation_success_tl = dcam.run(
instance=instance, nb_permutation=nb_permutation, label_instance=gt_label,generate_all=generate_all)
explanation[i]["dcam_tl"] = dcam_tl
explanation[i]["permutation_success_tl"] = permutation_success_tl
plot_dcam(dcam_tl,instance,dataset_name,i,True,column_names)
except IndexError:
explanation[i]["dcam_tl"] = np.array(-1)
explanation[i]["permutation_success_tl"] = 0
sys.stderr.write("index error in ground truth""""""""\n\n")
try:
# try to explain the predictions for the ground output label
dcam_ol,permutation_success_ol = dcam.run(
instance=instance, nb_permutation=nb_permutation, label_instance=output_label,generate_all=generate_all)
explanation[i]["dcam_ol"] = dcam_ol
explanation[i]["permutation_success_ol"] = permutation_success_ol
plot_dcam(dcam_ol,instance,dataset_name,i,False,column_names)
except IndexError:
explanation[i]["dcam_ol"] = np.array(-1)
explanation[i]["permutation_success_ol"] = 0
sys.stderr.write("index error in predictions""""""""\n\n")
# put in the returned data structure both ground truth and output label previously computed
explanation[i]["ground_truth_label"] = all_data[dataset_name]["y_test"][i]
explanation[i]["output_label"] = symbolic_output[i]
print("average time spent was", (timeit.default_timer() - starttime))
np.save("explanations/dCAM_results/"+dataset_name+"_explenations",explanation)
if __name__ == "__main__" :
main()