-
Notifications
You must be signed in to change notification settings - Fork 5
/
class.py
348 lines (278 loc) · 30.6 KB
/
class.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from pytorch_pretrained_bert import BertModel, BertTokenizer
import torch
from torch import nn
import torch.utils.data as Data
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np
from sklearn import metrics
import sklearn
import time
from pytorch_pretrained_bert.optimization import BertAdam
class Highway(nn.Module):
def __init__(self, input_dim, num_layers=1):
super(Highway, self).__init__()
self._layers = nn.ModuleList([nn.Linear(input_dim, input_dim * 2) for _ in range(num_layers)])
for layer in self._layers:
layer.bias[input_dim:].data.fill_(1)
def forward(self, inputs):
current_inputs = inputs
for layer in self._layers:
linear_part = current_inputs
projected_inputs = layer(current_inputs)
nonlinear_part, gate = projected_inputs.chunk(2, dim=-1)
nonlinear_part = torch.relu(nonlinear_part)
gate = torch.sigmoid(gate)
current_inputs = gate * linear_part + (1 - gate) * nonlinear_part
return current_inputs
class Bert_HBiLSTM(nn.Module):
"""
Bert_HBiLSTM
"""
def __init__(self, config):
super(Bert_HBiLSTM, self).__init__()
self.bert = config.bert
self.config = config
for name, param in self.bert.named_parameters():
param.requires_grad_(False)
self.lstm = nn.LSTM(config.embedding_dim, config.hidden_dim, num_layers=config.num_layers, batch_first=True,
bidirectional=True)
self.drop = nn.Dropout(config.drop_rate)
self.highway = Highway(config.hidden_dim * 2, 1)
self.hidden2one = nn.Linear(config.hidden_dim*2, 1)
self.relu = nn.ReLU()
self.sequence2numclass = nn.Linear(config.max_sequnce, config.num_class)
# pack_padded pad_packed_sequence
def forward(self, word_input, input_mask):
# word_input input_mask FloatTensor
word_input_last = word_input[:, 512:]
word_input = word_input[:, :512]
input_mask_last = input_mask[:, 512:]
input_mask = input_mask[:, :512]
word_input, _ = self.bert(word_input, attention_mask=input_mask, output_all_encoded_layers=False)
word_input_last, _ = self.bert(word_input_last, attention_mask=input_mask_last, output_all_encoded_layers=False)
input_mask.requires_grad = False
input_mask_last.requires_grad = False
# word_input = word_input.float(word_input)
# input_mask = input_mask.float(input_mask)
word_input = word_input * (input_mask.unsqueeze(-1).float())
word_input_last = word_input_last * (input_mask_last.unsqueeze(-1).float())
cat_input = torch.cat([word_input, word_input_last], dim=1)
# bert->bilstm->highway
lstm_out, _ = self.lstm(cat_input)
output = self.highway(lstm_out)
output = self.drop(output)
# hidden_dim*2 -> 1 -> sequense
output = self.hidden2one(output)
output = output.squeeze(-1)
# output = self.relu(output)
output = self.sequence2numclass(output)
output = F.log_softmax(output, dim=1)
return output
# api dict->api:num
# txt&label to num\tlabel
# num -> bert token -> bert ids
# train_iter,test_iter
def load_data(max_sequnce, data_file, label_file):
CLS, SEP, PAD = 101, 102, 0 # tokenizer.convert_tokens_to_ids(['[CLS]','[SEP]', '[PAD]']) 分别是对应的id
api_list = open(data_file, 'r', encoding='utf-8').readlines()
lab_list = open(label_file, 'r', encoding='utf-8').readlines()
# 用这个dict存储每一类数据和其mask,然后8:2分割 Trojan:[(ids, mask), (ids, mask)]
collected_by_label = {
"Trojan": [],
"Backdoor": [],
"Downloader":[],
"Worms": [],
"Spyware": [],
"Adware": [],
"Dropper": [],
"Virus": []
}
train_input_ids = []
train_input_mak = []
train_input_lab = []
test_input_ids = []
test_input_mak = []
test_input_lab = []
for index in tqdm(range(len(lab_list))):
last_api = ''
simple_api = []
label = lab_list[index].strip() # 去掉末尾的\n
api = api_list[index].strip().replace('\t', ' ').replace('\s', ' ').replace('\xa0', ' ')
while ' ' in api:
api = api.replace(' ', ' ')
for i in api.split(' '):
if i != last_api:
simple_api.append(i)
last_api = i
# api -> ids
ids = []
for j in simple_api:
ids += api_index[j]
if len(ids) > max_sequnce-4: # 由于是1024,所以要加两次cls、sep
ids = ids[:(max_sequnce-4)]
ids = [CLS] + ids[:510] + [SEP] + [CLS] + ids[510:] + [SEP]
mask = [1]*len(ids)
elif len(ids)> 510:
ids = [CLS] + ids[:510] + [SEP] + [CLS] + ids[510:] + [SEP]
mask = [1]*len(ids)
else:
ids = [CLS] + ids + [SEP]
mask = [1]*len(ids)
if len(ids) <= max_sequnce:
ids = ids + [PAD]*(max_sequnce-len(ids))
mask = mask + [0]*(max_sequnce-len(mask))
collected_by_label[label].append((ids, mask))
# 8:2切分数据集以及合并train、test
for label, data in tqdm(collected_by_label.items()):
label = label_index[label] # "Trojan" -> [0,0,0,0,0,0,0,1]
train = data[:len(data)//10*8]
test = data[len(data)//10*8:]
for ids, mask in train:
train_input_ids.append(ids)
train_input_mak.append(mask)
train_input_lab.append(label)
for ids, mask in test:
test_input_ids.append(ids)
test_input_mak.append(mask)
test_input_lab.append(label)
train_input_ids = torch.tensor(train_input_ids, dtype=torch.int64)
train_input_mak = torch.tensor(train_input_mak, dtype=torch.int64)
train_input_lab = torch.tensor(train_input_lab, dtype=torch.int64)
test_input_ids = torch.tensor(test_input_ids, dtype=torch.int64)
test_input_mak = torch.tensor(test_input_mak, dtype=torch.int64)
test_input_lab = torch.tensor(test_input_lab, dtype=torch.int64)
return train_input_ids,train_input_mak,train_input_lab,test_input_ids,test_input_mak,test_input_lab
def test(config, model, test_iter):
# test
# model.load_state_dict(torch.load(config.save_path))
model.eval()
start_time = time.time()
test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)
msg = 'Test Loss: {0:>5.2}, Test Acc: {1:>6.2%}'
print(msg.format(test_loss, test_acc))
print("Precision, Recall and F1-Score...")
print(test_report)
print("Confusion Matrix...")
print(test_confusion)
def evaluate(config, model, data_iter, test=False):
model.eval()
loss_total = 0
predict_all = np.array([], dtype=int)
labels_all = np.array([], dtype=int)
with torch.no_grad():
for ids, mask, labels in data_iter:
ids = ids.to(config.device)
mask = mask.to(config.device)
labels = labels.to(config.device)
outputs = model(ids, mask)
outputs = outputs.to(config.device)
loss = F.cross_entropy(outputs, labels)
loss_total += loss
labels = labels.data.cpu().numpy()
predic = torch.max(outputs.data, 1)[1].cpu().numpy()
labels_all = np.append(labels_all, labels)
predict_all = np.append(predict_all, predic)
acc = metrics.accuracy_score(labels_all, predict_all)
if test:
report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)
confusion = metrics.confusion_matrix(labels_all, predict_all)
return acc, loss_total / len(data_iter), report, confusion
return acc, loss_total / len(data_iter)
def train(config, model, train_iter, test_iter):
start_time = time.time()
model = model.to(config.device)
model.train()
param_optimizer = list(model.named_parameters())
# no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
# optimizer_grouped_parameters = [
# {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
# {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]
optimizer = torch.optim.RMSprop(model.parameters(), lr=config.learning_rate)
# optimizer = BertAdam(optimizer_grouped_parameters,
# lr=config.learning_rate,
# warmup=0.05,
# t_total=len(train_iter) * config.num_epochs)
total_batch = 0 # 记录进行到多少batch
dev_best_loss = float('inf')
last_improve = 0 # 记录上次验证集loss下降的batch数
flag = False # 记录是否很久没有效果提升
for epoch in range(config.num_epochs):
print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
epoch_iterator = tqdm(train_iter, desc="BatchIteration")
model.train()
for step, batch in enumerate(epoch_iterator):
ids, mask, labels = batch
ids = ids.to(config.device)
mask = mask.to(config.device)
labels = labels.to(config.device)
outputs = model(ids, mask)
outputs = outputs.to(config.device)
model.zero_grad()
train_loss = F.cross_entropy(outputs, labels)
train_loss.backward()
print(train_loss)
optimizer.step()
test(config, model, test_iter)
torch.save(model, config.model_path)
class Config:
def __init__(self):
self.max_sequnce = 1024
self.hidden_dim = 100
self.embedding_dim = 1024
self.num_class = 8
self.batch_size = 64
self.learning_rate = 0.01
self.num_epochs = 100
self.drop_rate = 0.5
self.num_layers = 1
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.bert = BertModel.from_pretrained('./bert-large/')
self.model_path = './save_model/model.pkl'
if torch.cuda.is_available():
self.use_cuda = True
else:
self.use_cuda = False
if torch.cuda.is_available():
torch.cuda.set_device(0)
self.class_list = ["Trojan","Backdoor","Downloader","Worms","Spyware","Adware","Dropper","Virus"]
label_index = {
# "Trojan": [0,0,0,0,0,0,0,1],
# "Backdoor": [0,0,0,0,0,0,1,0],
# "Downloader":[0,0,0,0,0,1,0,0],
# "Worms": [0,0,0,0,1,0,0,0],
# "Spyware": [0,0,0,1,0,0,0,0],
# "Adware": [0,0,1,0,0,0,0,0],
# "Dropper": [0,1,0,0,0,0,0,0],
# "Virus": [1,0,0,0,0,0,0,0]
"Trojan": 0,
"Backdoor": 1,
"Downloader":2,
"Worms": 3,
"Spyware": 4,
"Adware": 5,
"Dropper": 6,
"Virus": 7
}
# 为了加快预处理速度,每个api对应的ids已经直接转换好,并用字典保存了下来
api_index = {"__process__": [1035, 1035, 2832, 1035, 1035], "__anomaly__": [1035, 1035, 28685, 1035, 1035], "__exception__": [1035, 1035, 6453, 1035, 1035], "__missing__": [1035, 1035, 4394, 1035, 1035], "certcontrolstore": [8292, 5339, 8663, 13181, 4877, 19277], "certcreatecertificatecontext": [8292, 5339, 16748, 3686, 17119, 3775, 8873, 16280, 8663, 18209], "certopenstore": [8292, 5339, 26915, 23809, 2063], "certopensystemstorea": [8292, 5339, 26915, 29390, 19277, 2050], "certopensystemstorew": [8292, 5339, 26915, 29390, 19277, 2860], "cryptacquirecontexta": [19888, 6305, 15549, 2890, 8663, 18209, 2050], "cryptacquirecontextw": [19888, 6305, 15549, 2890, 8663, 18209, 2860], "cryptcreatehash": [19888, 16748, 3686, 14949, 2232], "cryptdecrypt": [19888, 3207, 26775, 22571, 2102], "cryptencrypt": [19888, 2368, 26775, 22571, 2102], "cryptexportkey": [19888, 10288, 6442, 14839], "cryptgenkey": [19888, 6914, 14839], "crypthashdata": [19888, 14949, 14945, 6790], "cryptdecodemessage": [19888, 3207, 16044, 7834, 3736, 3351], "cryptdecodeobjectex": [19888, 3207, 16044, 16429, 20614, 10288], "cryptdecryptmessage": [19888, 3207, 26775, 22571, 21246, 7971, 4270], "cryptencryptmessage": [19888, 2368, 26775, 22571, 21246, 7971, 4270], "crypthashmessage": [19888, 14949, 14227, 7971, 4270], "cryptprotectdata": [19888, 21572, 26557, 2102, 2850, 2696], "cryptprotectmemory": [19888, 21572, 26557, 21246, 6633, 10253], "cryptunprotectdata": [19888, 4609, 21572, 26557, 2102, 2850, 2696], "cryptunprotectmemory": [19888, 4609, 21572, 26557, 21246, 6633, 10253], "prf": [10975, 2546], "ssl3generatekeymaterial": [7020, 2140, 2509, 6914, 22139, 14839, 8585, 14482], "setunhandledexceptionfilter": [2275, 4609, 11774, 3709, 10288, 24422, 8873, 21928], "rtladdvectoredcontinuehandler": [19387, 27266, 2094, 3726, 16761, 2098, 8663, 7629, 5657, 11774, 3917], "rtladdvectoredexceptionhandler": [19387, 27266, 2094, 3726, 16761, 14728, 2595, 24422, 11774, 3917], "rtldispatchexception": [19387, 6392, 2483, 4502, 10649, 10288, 24422], "rtlremovevectoredcontinuehandler": [19387, 20974, 6633, 21818, 3726, 16761, 2098, 8663, 7629, 5657, 11774, 3917], "rtlremovevectoredexceptionhandler": [19387, 20974, 6633, 21818, 3726, 16761, 14728, 2595, 24422, 11774, 3917], "copyfilea": [6100, 8873, 19738], "copyfileexw": [6100, 8873, 10559, 2595, 2860], "copyfilew": [6100, 8873, 2571, 2860], "createdirectoryexw": [2580, 7442, 16761, 6672, 2595, 2860], "createdirectoryw": [2580, 7442, 16761, 2100, 2860], "deletefilew": [3972, 12870, 8873, 2571, 2860], "deviceiocontrol": [5080, 3695, 8663, 13181, 2140], "findfirstfileexa": [2424, 8873, 12096, 8873, 10559, 18684], "findfirstfileexw": [2424, 8873, 12096, 8873, 10559, 2595, 2860], "getfileattributesexw": [2131, 8873, 19738, 4779, 3089, 8569, 4570, 10288, 2860], "getfileattributesw": [2131, 8873, 19738, 4779, 3089, 8569, 4570, 2860], "getfileinformationbyhandle": [2131, 8873, 19856, 14192, 3370, 3762, 11774, 2571], "getfileinformationbyhandleex": [2131, 8873, 19856, 14192, 3370, 3762, 11774, 10559, 2595], "getfilesize": [2131, 8873, 4244, 4697], "getfilesizeex": [2131, 8873, 4244, 4697, 10288], "getfiletype": [2131, 8873, 7485, 18863], "getshortpathnamew": [4152, 27794, 15069, 18442, 2860], "getsystemdirectorya": [4152, 27268, 6633, 4305, 2890, 16761, 3148], "getsystemdirectoryw": [4152, 27268, 6633, 4305, 2890, 16761, 2100, 2860], "getsystemwindowsdirectorya": [4152, 27268, 6633, 11101, 15568, 4305, 2890, 16761, 3148], "getsystemwindowsdirectoryw": [4152, 27268, 6633, 11101, 15568, 4305, 2890, 16761, 2100, 2860], "gettemppathw": [2131, 18532, 13944, 2705, 2860], "getvolumenameforvolumemountpointw": [2131, 6767, 12942, 8189, 4168, 29278, 6767, 12942, 6633, 21723, 8400, 2860], "getvolumepathnamew": [2131, 6767, 12942, 13699, 8988, 18442, 2860], "getvolumepathnamesforvolumenamew": [2131, 6767, 12942, 13699, 8988, 18442, 22747, 2953, 6767, 12942, 8189, 4168, 2860], "movefilewithprogressw": [2693, 8873, 2571, 24415, 21572, 17603, 4757, 2860], "removedirectorya": [3718, 7442, 16761, 3148], "removedirectoryw": [3718, 7442, 16761, 2100, 2860], "searchpathw": [3945, 15069, 2860], "setendoffile": [2275, 10497, 7245, 9463], "setfileattributesw": [2275, 8873, 19738, 4779, 3089, 8569, 4570, 2860], "setfileinformationbyhandle": [2275, 8873, 19856, 14192, 3370, 3762, 11774, 2571], "setfilepointer": [2275, 8873, 2571, 8400, 2121], "setfilepointerex": [2275, 8873, 2571, 8400, 7869, 2595], "ntcreatedirectoryobject": [23961, 16748, 4383, 7442, 16761, 7677, 2497, 20614], "ntcreatefile": [23961, 16748, 3686, 8873, 2571], "ntdeletefile": [23961, 9247, 12870, 8873, 2571], "ntdeviceiocontrolfile": [23961, 24844, 6610, 3695, 8663, 13181, 10270, 9463], "ntopendirectoryobject": [23961, 26915, 4305, 2890, 16761, 7677, 2497, 20614], "ntopenfile": [23961, 26915, 8873, 2571], "ntqueryattributesfile": [23961, 4226, 20444, 4779, 3089, 8569, 4570, 8873, 2571], "ntquerydirectoryfile": [23961, 4226, 2854, 4305, 2890, 16761, 2100, 8873, 2571], "ntqueryfullattributesfile": [23961, 4226, 2854, 3993, 20051, 18886, 8569, 4570, 8873, 2571], "ntqueryinformationfile": [23961, 4226, 2854, 2378, 14192, 3370, 8873, 2571], "ntreadfile": [23961, 16416, 20952, 9463], "ntsetinformationfile": [23961, 13462, 2378, 14192, 3370, 8873, 2571], "ntwritefile": [23961, 26373, 8873, 2571], "colescript_compile": [5624, 22483, 1035, 4012, 22090], "cdocument_write": [3729, 10085, 27417, 2102, 1035, 4339], "celement_put_innerhtml": [8292, 16930, 4765, 1035, 2404, 1035, 5110, 11039, 19968], "chyperlink_seturlcomponent": [10381, 18863, 19403, 2243, 1035, 2275, 3126, 22499, 8737, 5643, 3372], "ciframeelement_createelement": [25022, 15643, 12260, 3672, 1035, 3443, 12260, 3672], "cscriptelement_put_src": [20116, 23235, 12260, 3672, 1035, 2404, 1035, 5034, 2278], "cwindow_addtimeoutcode": [19296, 22254, 5004, 1035, 5587, 7292, 5833, 16044], "getusernamea": [2131, 20330, 18442, 2050], "getusernamew": [2131, 20330, 18442, 2860], "lookupaccountsidw": [2298, 6279, 6305, 3597, 16671, 5332, 2094, 2860], "getcomputernamea": [2131, 9006, 18780, 11795, 14074, 2050], "getcomputernamew": [2131, 9006, 18780, 11795, 14074, 2860], "getdiskfreespaceexw": [2131, 10521, 2243, 23301, 23058, 10288, 2860], "getdiskfreespacew": [2131, 10521, 2243, 23301, 23058, 2860], "gettimezoneinformation": [2131, 7292, 15975, 2378, 14192, 3370], "writeconsolea": [4339, 8663, 19454, 5243], "writeconsolew": [4339, 8663, 19454, 7974], "coinitializesecurity": [9226, 29050, 3669, 11254, 8586, 25137], "uuidcreate": [1057, 21272, 16748, 3686], "getusernameexa": [2131, 20330, 18442, 10288, 2050], "getusernameexw": [2131, 20330, 18442, 10288, 2860], "readcabinetstate": [3191, 3540, 16765, 3215, 12259], "shgetfolderpathw": [14021, 18150, 10371, 2121, 15069, 2860], "shgetspecialfolderlocation": [14021, 18150, 13102, 8586, 4818, 10371, 2121, 4135, 10719], "enumwindows": [4372, 2819, 11101, 15568], "getcursorpos": [2131, 10841, 25301, 14536, 2891], "getsystemmetrics": [4152, 27268, 6633, 12589, 2015], "netgetjoininformation": [5658, 18150, 5558, 5498, 2078, 14192, 3370], "netusergetinfo": [5658, 20330, 18150, 2378, 14876], "netusergetlocalgroups": [5658, 20330, 18150, 4135, 9289, 17058, 2015], "netshareenum": [16996, 8167, 12129, 2819], "dnsquery_a": [1040, 3619, 4226, 2854, 1035, 1037], "dnsquery_utf8": [1040, 3619, 4226, 2854, 1035, 21183, 2546, 2620], "dnsquery_w": [1040, 3619, 4226, 2854, 1035, 1059], "getadaptersaddresses": [2131, 8447, 13876, 2545, 4215, 16200, 11393, 2015], "getadaptersinfo": [2131, 8447, 13876, 2545, 2378, 14876], "getbestinterfaceex": [2131, 12681, 7629, 3334, 12172, 10288], "getinterfaceinfo": [2131, 18447, 2121, 12172, 2378, 14876], "obtainuseragentstring": [6855, 20330, 4270, 7666, 18886, 3070], "urldownloadtofilew": [24471, 6392, 12384, 11066, 3406, 8873, 2571, 2860], "deleteurlcacheentrya": [3972, 12870, 3126, 15472, 15395, 4765, 20444], "deleteurlcacheentryw": [3972, 12870, 3126, 15472, 15395, 4765, 2854, 2860], "httpopenrequesta": [8299, 26915, 2890, 15500, 2050], "httpopenrequestw": [8299, 26915, 2890, 15500, 2860], "httpqueryinfoa": [8299, 4226, 2854, 2378, 14876, 2050], "httpsendrequesta": [16770, 10497, 2890, 15500, 2050], "httpsendrequestw": [16770, 10497, 2890, 15500, 2860], "internetclosehandle": [4274, 20464, 9232, 11774, 2571], "internetconnecta": [4274, 8663, 2638, 25572], "internetconnectw": [4274, 8663, 2638, 6593, 2860], "internetcrackurla": [4274, 26775, 8684, 3126, 2721], "internetcrackurlw": [4274, 26775, 8684, 3126, 2140, 2860], "internetgetconnectedstate": [4274, 18150, 24230, 9153, 2618], "internetgetconnectedstateexa": [4274, 18150, 24230, 9153, 17389, 18684], "internetgetconnectedstateexw": [4274, 18150, 24230, 9153, 17389, 2595, 2860], "internetopena": [4274, 26915, 2050], "internetopenurla": [4274, 26915, 3126, 2721], "internetopenurlw": [4274, 26915, 3126, 2140, 2860], "internetopenw": [4274, 26915, 2860], "internetqueryoptiona": [4274, 4226, 2854, 7361, 3508, 2050], "internetreadfile": [4274, 16416, 20952, 9463], "internetsetoptiona": [4274, 13462, 7361, 3508, 2050], "internetsetstatuscallback": [4274, 13462, 9153, 5809, 9289, 20850, 8684], "internetwritefile": [4274, 26373, 8873, 2571], "connectex": [7532, 10288], "getaddrinfow": [2131, 4215, 13626, 2378, 14876, 2860], "transmitfile": [19818, 8873, 2571], "wsaaccept": [1059, 3736, 6305, 3401, 13876], "wsaconnect": [1059, 3736, 8663, 2638, 6593], "wsarecv": [1059, 10286, 8586, 2615], "wsarecvfrom": [1059, 10286, 8586, 2615, 19699, 5358], "wsasend": [1059, 20939, 10497], "wsasendto": [1059, 20939, 10497, 3406], "wsasocketa": [1059, 20939, 7432, 12928], "wsasocketw": [1059, 20939, 7432, 3388, 2860], "wsastartup": [1059, 20939, 7559, 8525, 2361], "accept": [5138], "bind": [14187], "closesocket": [14572, 7432, 3388], "connect": [7532], "getaddrinfo": [2131, 4215, 13626, 2378, 14876], "gethostbyname": [2131, 15006, 2102, 3762, 18442], "getsockname": [4152, 7432, 18442], "ioctlsocket": [25941, 19646, 6499, 19869, 2102], "listen": [4952], "recv": [28667, 2615], "recvfrom": [28667, 2615, 19699, 5358], "select": [7276], "send": [4604], "sendto": [4604, 3406], "setsockopt": [4520, 7432, 7361, 2102], "shutdown": [3844, 7698], "socket": [22278], "cocreateinstance": [2522, 16748, 3686, 7076, 26897], "coinitializeex": [9226, 29050, 3669, 23940, 2595], "oleinitialize": [15589, 5498, 20925, 4697], "createprocessinternalw": [3443, 21572, 9623, 11493, 16451, 2389, 2860], "createremotethread": [3443, 28578, 12184, 2705, 16416, 2094], "createthread": [3443, 2705, 16416, 2094], "createtoolhelp32snapshot": [3443, 3406, 4747, 16001, 2361, 16703, 2015, 2532, 4523, 12326], "module32firstw": [11336, 16703, 8873, 12096, 2860], "module32nextw": [11336, 16703, 2638, 18413, 2860], "process32firstw": [2832, 16703, 8873, 12096, 2860], "process32nextw": [2832, 16703, 2638, 18413, 2860], "readprocessmemory": [3191, 21572, 9623, 6491, 6633, 10253], "thread32first": [11689, 16703, 8873, 12096], "thread32next": [11689, 16703, 2638, 18413], "writeprocessmemory": [4339, 21572, 9623, 6491, 6633, 10253], "system": [2291], "ntallocatevirtualmemory": [23961, 8095, 24755, 2618, 21663, 26302, 13728, 6633, 10253], "ntcreateprocess": [23961, 16748, 3686, 21572, 9623, 2015], "ntcreateprocessex": [23961, 16748, 3686, 21572, 9623, 3366, 2595], "ntcreatesection": [23961, 16748, 8520, 18491], "ntcreatethread": [23961, 16748, 3686, 2705, 16416, 2094], "ntcreatethreadex": [23961, 16748, 3686, 2705, 16416, 3207, 2595], "ntcreateuserprocess": [23961, 16748, 3686, 20330, 21572, 9623, 2015], "ntfreevirtualmemory": [23961, 23301, 21663, 26302, 13728, 6633, 10253], "ntgetcontextthread": [23961, 18150, 8663, 18209, 2705, 16416, 2094], "ntmakepermanentobject": [23961, 2863, 3489, 4842, 2386, 4765, 16429, 20614], "ntmaketemporaryobject": [23961, 2863, 3489, 18532, 17822, 5649, 16429, 20614], "ntmapviewofsection": [23961, 2863, 2361, 8584, 11253, 29015], "ntopenprocess": [23961, 26915, 21572, 9623, 2015], "ntopensection": [23961, 26915, 29015], "ntopenthread": [23961, 26915, 2705, 16416, 2094], "ntprotectvirtualmemory": [23961, 21572, 26557, 9189, 4313, 26302, 13728, 6633, 10253], "ntqueueapcthread": [23961, 4226, 5657, 9331, 6593, 28362, 4215], "ntreadvirtualmemory": [23961, 16416, 2094, 21663, 26302, 13728, 6633, 10253], "ntresumethread": [23961, 6072, 17897, 2705, 16416, 2094], "ntsetcontextthread": [23961, 13462, 8663, 18209, 2705, 16416, 2094], "ntsuspendthread": [23961, 13203, 11837, 11927, 28362, 4215], "ntterminateprocess": [23961, 3334, 19269, 21572, 9623, 2015], "ntterminatethread": [23961, 3334, 19269, 2705, 16416, 2094], "ntunmapviewofsection": [23961, 4609, 2863, 2361, 8584, 11253, 29015], "ntwritevirtualmemory": [23961, 26373, 21663, 26302, 13728, 6633, 10253], "rtlcreateuserprocess": [19387, 15472, 29313, 20330, 21572, 9623, 2015], "rtlcreateuserthread": [19387, 15472, 29313, 20330, 2705, 16416, 2094], "shellexecuteexw": [5806, 10288, 8586, 10421, 10288, 2860], "regclosekey": [19723, 20464, 9232, 14839], "regcreatekeyexa": [19723, 16748, 3686, 14839, 10288, 2050], "regcreatekeyexw": [19723, 16748, 3686, 14839, 10288, 2860], "regdeletekeya": [19723, 9247, 12870, 14839, 2050], "regdeletekeyw": [19723, 9247, 12870, 14839, 2860], "regdeletevaluea": [19723, 9247, 12870, 10175, 5657, 2050], "regdeletevaluew": [19723, 9247, 12870, 10175, 5657, 2860], "regenumkeyexa": [19723, 2368, 2819, 14839, 10288, 2050], "regenumkeyexw": [19723, 2368, 2819, 14839, 10288, 2860], "regenumkeyw": [19723, 2368, 2819, 14839, 2860], "regenumvaluea": [19723, 2368, 2819, 10175, 5657, 2050], "regenumvaluew": [19723, 2368, 2819, 10175, 5657, 2860], "regopenkeyexa": [19723, 26915, 14839, 10288, 2050], "regopenkeyexw": [19723, 26915, 14839, 10288, 2860], "regqueryinfokeya": [19723, 4226, 2854, 2378, 14876, 14839, 2050], "regqueryinfokeyw": [19723, 4226, 2854, 2378, 14876, 14839, 2860], "regqueryvalueexa": [19723, 4226, 2854, 10175, 5657, 10288, 2050], "regqueryvalueexw": [19723, 4226, 2854, 10175, 5657, 10288, 2860], "regsetvalueexa": [19723, 13462, 10175, 5657, 10288, 2050], "regsetvalueexw": [19723, 13462, 10175, 5657, 10288, 2860], "ntcreatekey": [23961, 16748, 3686, 14839], "ntdeletekey": [23961, 9247, 12870, 14839], "ntdeletevaluekey": [23961, 9247, 12870, 10175, 5657, 14839], "ntenumeratekey": [23961, 2368, 17897, 11657, 14839], "ntenumeratevaluekey": [23961, 2368, 17897, 11657, 10175, 5657, 14839], "ntloadkey": [23961, 11066, 14839], "ntloadkey2": [23961, 11066, 14839, 2475], "ntloadkeyex": [23961, 11066, 14839, 10288], "ntopenkey": [23961, 26915, 14839], "ntopenkeyex": [23961, 26915, 14839, 10288], "ntquerykey": [23961, 4226, 2854, 14839], "ntquerymultiplevaluekey": [23961, 4226, 2854, 12274, 7096, 11514, 20414, 2389, 5657, 14839], "ntqueryvaluekey": [23961, 4226, 2854, 10175, 5657, 14839], "ntrenamekey": [23961, 7389, 14074, 14839], "ntreplacekey": [23961, 2890, 24759, 10732, 14839], "ntsavekey": [23961, 3736, 3726, 14839], "ntsavekeyex": [23961, 3736, 3726, 14839, 10288], "ntsetvaluekey": [23961, 13462, 10175, 5657, 14839], "findresourcea": [2424, 6072, 8162, 21456], "findresourceexa": [2424, 6072, 8162, 3401, 10288, 2050], "findresourceexw": [2424, 6072, 8162, 3401, 10288, 2860], "findresourcew": [2424, 6072, 8162, 3401, 2860], "loadresource": [7170, 6072, 8162, 3401], "sizeofresource": [2946, 11253, 6072, 8162, 3401], "controlservice": [7711, 2121, 7903, 2063], "createservicea": [9005, 2121, 7903, 5243], "createservicew": [9005, 2121, 7903, 7974], "deleteservice": [3972, 12870, 8043, 7903, 2063], "enumservicesstatusa": [4372, 18163, 2121, 7903, 7971, 29336, 10383], "enumservicesstatusw": [4372, 18163, 2121, 7903, 7971, 29336, 2271, 2860], "openscmanagera": [7480, 27487, 5162, 4590, 2050], "openscmanagerw": [7480, 27487, 5162, 4590, 2860], "openservicea": [7480, 2121, 7903, 5243], "openservicew": [7480, 2121, 7903, 7974], "startservicea": [4627, 2121, 7903, 5243], "startservicew": [4627, 2121, 7903, 7974], "getlocaltime": [2131, 4135, 9289, 7292], "getsystemtime": [4152, 27268, 6633, 7292], "getsystemtimeasfiletime": [4152, 27268, 6633, 7292, 3022, 8873, 7485, 14428], "gettickcount": [2131, 26348, 3597, 16671], "ntcreatemutant": [23961, 16748, 3686, 28120, 4630], "ntdelayexecution": [23961, 9247, 4710, 10288, 8586, 13700], "ntquerysystemtime": [23961, 4226, 24769, 27268, 6633, 7292], "timegettime": [2051, 18150, 7292], "lookupprivilegevaluew": [2298, 6279, 18098, 12848, 9463, 3351, 10175, 5657, 2860], "getnativesysteminfo": [2131, 19833, 24653, 27268, 23238, 2078, 14876], "getsysteminfo": [4152, 27268, 23238, 2078, 14876], "isdebuggerpresent": [2003, 3207, 8569, 13327, 28994, 4765], "outputdebugstringa": [6434, 3207, 8569, 5620, 18886, 13807], "seterrormode": [2275, 2121, 29165, 5302, 3207], "ldrgetdllhandle": [25510, 20800, 2102, 19422, 2140, 11774, 2571], "ldrgetprocedureaddress": [25510, 20800, 25856, 3217, 11788, 5397, 4215, 16200, 4757], "ldrloaddll": [25510, 12190, 10441, 14141, 3363], "ldrunloaddll": [25510, 15532, 11066, 19422, 2140], "ntclose": [23961, 20464, 9232], "ntduplicateobject": [23961, 8566, 24759, 24695, 16429, 20614], "ntloaddriver": [23961, 11066, 23663, 2099], "ntunloaddriver": [23961, 4609, 11066, 23663, 2099], "rtlcompressbuffer": [19387, 22499, 8737, 8303, 8569, 12494], "rtldecompressbuffer": [19387, 17920, 9006, 20110, 8569, 12494], "rtldecompressfragment": [19387, 17920, 9006, 20110, 27843, 21693, 4765], "exitwindowsex": [6164, 11101, 15568, 10288], "getasynckeystate": [2131, 3022, 6038, 29183, 9153, 2618], "getkeystate": [2131, 14839, 9153, 2618], "getkeyboardstate": [2131, 14839, 15271, 12259], "sendnotifymessagea": [4604, 17048, 8757, 7834, 3736, 3351, 2050], "sendnotifymessagew": [4604, 17048, 8757, 7834, 3736, 3351, 2860], "setwindowshookexa": [2275, 11101, 15568, 6806, 11045, 18684], "setwindowshookexw": [2275, 11101, 15568, 6806, 11045, 2595, 2860], "unhookwindowshookex": [4895, 6806, 6559, 11101, 15568, 6806, 11045, 2595], "drawtextexa": [4009, 18209, 10288, 2050], "drawtextexw": [4009, 18209, 10288, 2860], "findwindowa": [2424, 11101, 21293], "findwindowexa": [2424, 11101, 29385, 18684], "findwindowexw": [2424, 11101, 29385, 2595, 2860], "findwindoww": [2424, 11101, 5004, 2860], "getforegroundwindow": [2131, 29278, 13910, 22494, 4859, 11101, 5004], "loadstringa": [15665, 18886, 13807], "loadstringw": [15665, 18886, 3070, 2860], "messageboxtimeouta": [4471, 8758, 7292, 5833, 2050], "messageboxtimeoutw": [4471, 8758, 7292, 5833, 2860], "couninitialize": [2522, 19496, 3490, 20925, 4697], "ntopenmutant": [23961, 26915, 28120, 4630], "ntquerysysteminformation": [23961, 4226, 24769, 27268, 23238, 2078, 14192, 3370], "globalmemorystatus": [3795, 4168, 5302, 24769, 29336, 2271], "globalmemorystatusex": [3795, 4168, 5302, 24769, 29336, 8557, 2595], "setfiletime": [2275, 8873, 7485, 14428], "getfileversioninfosizew": [2131, 8873, 20414, 2545, 3258, 2378, 14876, 5332, 4371, 2860], "getfileversioninfow": [2131, 8873, 20414, 2545, 3258, 2378, 14876, 2860], "createactctxw": [3443, 18908, 6593, 2595, 2860], "cogetclassobject": [2522, 18150, 26266, 16429, 20614], "cocreateinstanceex": [2522, 16748, 3686, 7076, 26897, 10288], "iwbemservices_execquery": [1045, 2860, 4783, 5244, 2121, 7903, 2229, 1035, 4654, 8586, 4226, 2854], "setstdhandle": [4520, 2102, 17516, 4859, 2571], "registerhotkey": [4236, 12326, 14839], "createjobobjectw": [3443, 5558, 5092, 2497, 20614, 2860], "setinformationjobobject": [2275, 2378, 14192, 3370, 5558, 5092, 2497, 20614], "assignprocesstojobobject": [23911, 21572, 9623, 16033, 5558, 5092, 2497, 20614], "createremotethreadex": [3443, 28578, 12184, 2705, 16416, 3207, 2595], "iwbemservices_execmethod": [1045, 2860, 4783, 5244, 2121, 7903, 2229, 1035, 4654, 8586, 11368, 6806, 2094], "wnetgetprovidernamew": [1059, 7159, 18150, 21572, 17258, 11795, 14074, 2860], "ntshutdownsystem": [23961, 14235, 2102, 7698, 6508, 13473, 2213]}
if __name__ == '__main__':
data_file = './all_analysis_data.txt'
label_file ='./labels.csv'
config = Config()
train_input_ids,train_input_mak,train_input_lab,test_input_ids,test_input_mak,test_input_lab = load_data(config.max_sequnce, data_file, label_file)
train_set = Data.TensorDataset(train_input_ids,train_input_mak,train_input_lab)
test_set = Data.TensorDataset(test_input_ids,test_input_mak,test_input_lab)
torch.save(train_set, './save_tensor/train.pkl')
torch.save(test_set, './save_tensor/test.pkl')
train_set = torch.load('./save_tensor/train.pkl')
test_set = torch.load('./save_tensor/test.pkl')
train_iter = Data.DataLoader(train_set, config.batch_size, shuffle=True)
test_iter = Data.DataLoader(test_set, config.batch_size, shuffle=True)
model = Bert_HBiLSTM(config)
optimizer = torch.optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=config.learning_rate)
train(config, model, train_iter, test_iter)
# print(predict_sentiment(model, vocab, ['a', 'delectable', 'and', 'intriguing', 'thriller','filled','with','surprises','read','my','lips','is','an','original']))#1
# print(predict_sentiment(model, vocab, ['this', 'is', 'a', 'sometimes', 'tedious', 'film']))#0