-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtest.py
executable file
·101 lines (82 loc) · 2.94 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import glob
import argparse
import os
import time
import warnings
from defender.models.ml import MLClassifier
from defender.models.malware_torch import TorchModel, AttentionModelv2
from defender.utils import save_dict
import numpy as np
from defender.model import Model, prediction_metrics
# set the maximum amount of memory to 1 GB
# import resource
# resource.setrlimit(resource.RLIMIT_AS, (1 * 1024 * 1024 * 1024, -1))
test_path = "data_test"
def test(model):
model.train = False
# model.model.cuda()
test_scores = {}
labels_full = []
preds_full = []
files_full = []
passed = []
errors = []
tic = time.time()
for t, n in (('gw', np.arange(1,7)),('mw', np.arange(1,9))):
for k in n:
labels = []
preds = []
files = []
err = 0
for f in glob.glob(os.path.join(test_path, f"{t}{k}", "*")):
try:
p = model.predict_files([f]).item()
if np.isnan(p):
warnings.warn("Nan value")
p = 1
preds.append(p)
# labels.append(1 if t=='mw' else 0)
files.append(f)
except Exception as e:
warnings.warn(str(e))
err += 1
preds.append(1)
labels.append(1 if t=='mw' else 0)
# get per dataset metrics
labels = np.asarray(labels)
preds = np.asarray(preds)
metrics = prediction_metrics(labels, preds)
errors.append(err)
print(f"Dataset {t}{k}: {metrics}")
# check if passes minimum
if (t=='gw' and metrics[2] < 0.01) or (t=='mw' and metrics[3] > 0.95):
passed.append(f"{t}{k}")
test_scores[f"{t}{k}"] = metrics
labels_full.append(labels)
preds_full.append(preds)
files_full.extend(files)
total_time = time.time() - tic
# get full metrics
labels = np.concatenate(labels_full, axis=0)
preds = np.concatenate(preds_full, axis=0)
metrics = prediction_metrics(labels, preds)
print(f"Full: {metrics}")
test_scores['total'] = metrics
print(f"Passed datasets: {passed}")
print(f"Took {total_time} s, {total_time/preds.shape[0]} s/it")
print(f"Errors: {errors}")
model.test_scores = {
'labels': "mcc, f1, fpr, tpr",
'scores': {k:np.round(v, 4) for k,v in test_scores.items()},
'passed': passed,
'errors': errors,
'time': total_time/preds.shape[0],
}
model.save(args.model)
save_dict(model.test_scores, args.model + '.txt', as_str=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--model', type=str, required=True, default=None, help='model to evaluate')
args = parser.parse_args()
model = Model.load(args.model)
test(model)