-
Notifications
You must be signed in to change notification settings - Fork 1
/
utils_ground.py
100 lines (85 loc) · 4.43 KB
/
utils_ground.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import numpy as np
import glob
from keras.models import load_model
from keras.datasets import fashion_mnist, mnist, cifar10
def load_data_split(parameters):
try:
if parameters.dataName == "fashion":
(_, _), (x_test, y_test) = fashion_mnist.load_data()
if parameters.dataType != "original-split":
x_test = np.load(parameters.save_data_root_adv + "{0}-{1}.npy".format(parameters.dataType.replace('-split', ''), parameters.severity))
img_rows, img_cols = 28, 28
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
x_test = x_test.astype('float32') / 255
elif parameters.dataName == "mnist":
(_, _), (x_test, y_test) = mnist.load_data()
if parameters.dataType != "original-split":
x_test = np.load(parameters.save_data_root_adv + "{0}-{1}.npy".format(parameters.dataType.replace('-split', ''), parameters.severity))
img_rows, img_cols = 28, 28
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
x_test = x_test.astype('float32') / 255
elif parameters.dataName == "cifar10":
(_, _), (x_test, y_test) = cifar10.load_data()
if parameters.dataType != "original-split":
x_test = np.load(parameters.save_data_root_adv + "{0}-{1}.npy".format(parameters.dataType.replace('-split', ''), parameters.severity))
x_test = x_test.astype('float32') / 255
y_test = np.squeeze(y_test)
x_test_first = x_test[:5000, :]
y_test_first = y_test[:5000]
x_test_second = x_test[5000:, :]
y_test_second = y_test[5000:]
return x_test_first, y_test_first, x_test_second, y_test_second
except:
print("invalid data name")
def load_data(parameters):
try:
if parameters.dataName == "fashion":
(_, _), (x_test, y_test) = fashion_mnist.load_data()
if parameters.dataType != "original":
x_test = np.load(parameters.save_data_root_adv + "{0}-{1}.npy".format(parameters.dataType, parameters.severity))
img_rows, img_cols = 28, 28
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
x_test = x_test.astype('float32') / 255
elif parameters.dataName == "mnist":
(_, _), (x_test, y_test) = mnist.load_data()
if parameters.dataType != "original":
x_test = np.load(parameters.save_data_root_adv + "{0}-{1}.npy".format(parameters.dataType, parameters.severity))
img_rows, img_cols = 28, 28
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
x_test = x_test.astype('float32') / 255
elif parameters.dataName == "cifar10":
(_, _), (x_test, y_test) = cifar10.load_data()
if parameters.dataType != "original":
x_test = np.load(parameters.save_data_root_adv + "{0}-{1}.npy".format(parameters.dataType, parameters.severity))
x_test = x_test.astype('float32') / 255
y_test = np.squeeze(y_test)
return x_test, y_test
except:
print("invalid data name")
def load_models(parameters):
model = []
model_name = '{0}/model-'.format(parameters.save_model_root)
modelNum = len(glob.glob1(parameters.save_model_root, "*.h5"))
filename = []
for i in range(modelNum):
model_name_ = model_name + str(i) + '.h5'
filename.append(model_name_)
for i in range(len(filename)):
model.append(load_model(filename[i]))
return model
def computeAcc(parameters, label_list, select_index, num):
acc_list = np.zeros(parameters.model_num)
for j in range(parameters.model_num):
acc_list[j] = np.sum(label_list[select_index, j+1] == label_list[select_index, 0]) / num
return acc_list
def label_read(model, x_test, y_test, parameters):
label_list = np.zeros((len(x_test), len(model) + 1))
label_list[:, 0] = y_test
for i in range(len(model)):
label_list[:, i+1] = np.argmax(model[i].predict(x_test), axis=1)
np.save(parameters.save_ground_root + "labels-{0}-{1}.npy".format(parameters.dataType, parameters.severity), label_list)
return label_list
def pred_read(model, x_test, parameters):
for i in range(len(model)):
model_pre = model[i].predict(x_test)
np.save(parameters.save_model_pre_root + "pre-{0}-{1}-{2}.npy".format(i, parameters.dataType, parameters.severity), model_pre)