-
Notifications
You must be signed in to change notification settings - Fork 0
/
eval.py
62 lines (53 loc) · 2.12 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# -*- coding: utf-8 -*-
from __future__ import print_function
import codecs
import os
from train import Graph
from params import Params as pm
from data_loader import load_data, load_vocab, getword_idx
import tensorflow as tf
import numpy as np
from nltk.translate.bleu_score import corpus_bleu
def eval():
en2idx, idx2en, de2idx, idx2de = getword_idx()
g = Graph(en2idx, de2idx, is_training = False)
print("MSG : Graph loaded!")
X, Sources, Targets = load_data('test',en2idx, de2idx)
with g.graph.as_default():
sv = tf.train.Supervisor()
with sv.managed_session(config = tf.ConfigProto(allow_soft_placement = True)) as sess:
# load pre-train model
sv.saver.restore(sess, tf.train.latest_checkpoint(pm.checkpoint))
print("MSG : Restore Model!")
mname = open(pm.checkpoint + '/checkpoint', 'r').read().split('"')[1]
if not os.path.exists('Results'):
os.mkdir('Results')
with codecs.open("Results/" + mname, 'w', 'utf-8') as f:
list_of_refs, predict = [], []
# Get a batch
for i in range(len(X) // pm.batch_size):
x = X[i * pm.batch_size: (i + 1) * pm.batch_size]
sources = Sources[i * pm.batch_size: (i + 1) * pm.batch_size]
targets = Targets[i * pm.batch_size: (i + 1) * pm.batch_size]
# Autoregressive inference
preds = np.zeros((pm.batch_size, pm.maxlen), dtype = np.int32)
for j in range(pm.maxlen):
_preds = sess.run(g.preds, feed_dict = {g.inpt: x, g.outpt: preds})
preds[:, j] = _preds[:, j]
for source, target, pred in zip(sources, targets, preds):
got = " ".join(idx2de[idx] for idx in pred).split("<EOS>")[0].strip()
f.write("- Source: {}\n".format(source))
f.write("- Ground Truth: {}\n".format(target))
f.write("- Predict: {}\n\n".format(got))
f.flush()
# Bleu Score
ref = target.split()
prediction = got.split()
if len(ref) > pm.word_limit_lower and len(prediction) > pm.word_limit_lower:
list_of_refs.append([ref])
predict.append(prediction)
score = corpus_bleu(list_of_refs, predict)
f.write("Bleu Score = " + str(100 * score))
if __name__ == '__main__':
eval()
print("MSG : Done!")