-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtest.py
159 lines (137 loc) · 5.87 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
"""Test a model and generate submission CSV.
Usage:
> python test.py --split SPLIT --load_path PATH --name NAME
where
> SPLIT is either "dev" or "test"
> PATH is a path to a checkpoint (e.g., save/train/model-01/best.pth.tar)
> NAME is a name to identify the test run
"""
import csv
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
import util
from args import get_test_args
from collections import OrderedDict
from json import dumps
from models import BiDAF
from os.path import join
from tensorboardX import SummaryWriter
from tqdm import tqdm
from ujson import load as json_load
from util import collate_fn, SQuAD
from tfidf import TFIDF
def main(args):
# Load TF-IDF from pickle
scorer = TFIDF([])
scorer.get_from_pickle()
# Set up logging
args.save_dir = util.get_save_dir(args.save_dir, args.name, training=False)
log = util.get_logger(args.save_dir, args.name)
log.info('Args: {}'.format(dumps(vars(args), indent=4, sort_keys=True)))
device, gpu_ids = util.get_available_devices()
args.batch_size *= max(1, len(gpu_ids))
# Get embeddings
log.info('Loading embeddings...')
word_vectors = util.torch_from_json(args.word_emb_file)
# Get data loader
log.info('Building dataset...')
record_file = vars(args)['{}_record_file'.format(args.split)]
dataset = SQuAD(record_file, args.use_squad_v2)
data_loader = data.DataLoader(dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
collate_fn=collate_fn)
# Get model
log.info('Building model...')
model = BiDAF(word_vectors=word_vectors,
char_vocab_size= 1376,
hidden_size=args.hidden_size)
model = nn.DataParallel(model, gpu_ids)
log.info('Loading checkpoint from {}...'.format(args.load_path))
model = util.load_model(model, args.load_path, gpu_ids, return_step=False)
model = model.to(device)
model.eval()
# Evaluate
log.info('Evaluating on {} split...'.format(args.split))
nll_meter = util.AverageMeter()
pred_dict = {} # Predictions for TensorBoard
sub_dict = {} # Predictions for submission
eval_file = vars(args)['{}_eval_file'.format(args.split)]
with open(eval_file, 'r') as fh:
gold_dict = json_load(fh)
with torch.no_grad(), \
tqdm(total=len(dataset)) as progress_bar:
for cw_idxs, cc_idxs, qw_idxs, qc_idxs, y1, y2, ids in data_loader:
# Setup for forward
cw_idxs = cw_idxs.to(device)
qw_idxs = qw_idxs.to(device)
batch_size = cw_idxs.size(0)
# Forward
log_p1, log_p2 = model(cw_idxs, qw_idxs, cc_idxs,qc_idxs)
y1, y2 = y1.to(device), y2.to(device)
loss = F.nll_loss(log_p1, y1) + F.nll_loss(log_p2, y2)
nll_meter.update(loss.item(), batch_size)
# Get F1 and EM scores
p1, p2 = log_p1.exp(), log_p2.exp()
starts, ends = util.discretize(p1, p2, args.max_ans_len, args.use_squad_v2)
# Log info
progress_bar.update(batch_size)
if args.split != 'test':
# No labels for the test set, so NLL would be invalid
progress_bar.set_postfix(NLL=nll_meter.avg)
idx2pred, uuid2pred = util.convert_tokens(gold_dict,
ids.tolist(),
starts.tolist(),
ends.tolist(),
args.use_squad_v2)
pred_dict.update(idx2pred)
sub_dict.update(uuid2pred)
if (args.use_tfidf):
# Apply TF-IDF filtering to pred_dict
tf_idf_threshold = 2
tf_idf_common_threshold = 1
for key, value in pred_dict.items():
if value != "":
tf_idf_score = scorer.normalized_additive_idf_ignore_common_words(
value, threshold_frequency=tf_idf_common_threshold)
if tf_idf_score < tf_idf_threshold:
pred_dict[key] = ''
pass
# print ("pred_dict: {}, pruned".format(tf_idf_score))
else:
pass
# print ("pred_dict: {}, kept".format(tf_idf_score))
# Log results (except for test set, since it does not come with labels)
if args.split != 'test':
results = util.eval_dicts(gold_dict, pred_dict, args.use_squad_v2)
results_list = [('NLL', nll_meter.avg),
('F1', results['F1']),
('EM', results['EM'])]
if args.use_squad_v2:
results_list.append(('AvNA', results['AvNA']))
results = OrderedDict(results_list)
# Log to console
results_str = ', '.join('{}: {:05.2f}'.format(k, v)
for k, v in results.items())
log.info('{} {}'.format(args.split.title(), results_str))
# Log to TensorBoard
tbx = SummaryWriter(args.save_dir)
util.visualize(tbx,
pred_dict=pred_dict,
eval_path=eval_file,
step=0,
split=args.split,
num_visuals=args.num_visuals)
# Write submission file
sub_path = join(args.save_dir, args.split + '_' + args.sub_file)
log.info('Writing submission file to {}...'.format(sub_path))
with open(sub_path, 'w') as csv_fh:
csv_writer = csv.writer(csv_fh, delimiter=',')
csv_writer.writerow(['Id', 'Predicted'])
for uuid in sorted(sub_dict):
csv_writer.writerow([uuid, sub_dict[uuid]])
if __name__ == '__main__':
main(get_test_args())