-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_loader.py
125 lines (103 loc) · 3.76 KB
/
data_loader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
from __future__ import print_function
from params import Params as pm
import codecs
import sys
import numpy as np
import tensorflow as tf
def load_vocab(vocab):
'''
Load word token from encoding dictionary
# 建立word2idx和idx2word
Args:
vocab: [String], vocabulary files
'''
vocab = [line.split()[0] for line in codecs.open('./dictionary/{}'.format(vocab), 'r', 'utf-8').read().splitlines() if int(line.split()[1]) >= pm.word_limit_size]
word2idx_dic = {word: idx for idx, word in enumerate(vocab)}
idx2word_dic = {idx: word for idx, word in enumerate(vocab)}
return word2idx_dic, idx2word_dic
def getword_idx():
# 获得词索引的方法
en2idx, idx2en = load_vocab(pm.enc_vocab)
de2idx, idx2de = load_vocab(pm.dec_vocab)
return en2idx, idx2en, de2idx, idx2de
def generate_dataset(source_sents, target_sents, en2idx, de2idx):
'''
Parse source sentences and target sentences from corpus with some formats
Parse word token of each sentences
Args:
source_sents: [List], encoding sentences from src-train file
target_sents: [List], decoding sentences from tgt-train file
Padding for word token sentence list
'''
in_list, out_list, Sources, Targets = [], [], [], []
for source_sent, target_sent in zip(source_sents, target_sents):
# 1 means <UNK>
inpt = [en2idx.get(word, 1) for word in (source_sent + u" <EOS>").split()]
outpt = [de2idx.get(word, 1) for word in (target_sent + u" <EOS>").split()]
if max(len(inpt), len(outpt)) <= pm.maxlen:
# sentence token list
in_list.append(np.array(inpt))
out_list.append(np.array(outpt))
# sentence list
Sources.append(source_sent)
Targets.append(target_sent)
X = np.zeros([len(in_list), pm.maxlen], np.int32)
Y = np.zeros([len(out_list), pm.maxlen], np.int32)
for i, (x, y) in enumerate(zip(in_list, out_list)):
X[i] = np.lib.pad(x, (0, pm.maxlen - len(x)), 'constant', constant_values = (0, 0))
Y[i] = np.lib.pad(y, (0, pm.maxlen - len(y)), 'constant', constant_values = (0, 0))
return X, Y, Sources, Targets
def readlines(filepath):
with codecs.open(filepath, 'r', 'utf-8') as f:
lines = f.readlines()
newline = []
for line in lines:
line = line.strip()
newline.append(line)
return newline
def load_data(l_data, en2idx, de2idx):
'''
Read train-data from input datasets
Args:
l_data: [String], the file name of datasets which used to generate tokens
'''
if l_data == 'train':
en_sents = readlines(pm.src_train)
de_sents = readlines(pm.tgt_train)
if len(en_sents) == len(de_sents):
inpt, outpt, Sources, Targets = generate_dataset(en_sents, de_sents, en2idx, de2idx )
else:
print("MSG : Source length is different from Target length.")
sys.exit(0)
return inpt, outpt
elif l_data == 'test':
en_sents = readlines(pm.src_test)
de_sents = readlines(pm.tgt_test)
if len(en_sents) == len(de_sents):
inpt, outpt, Sources, Targets = generate_dataset(en_sents, de_sents, en2idx, de2idx )
else:
print("MSG : Source length is different from Target length.")
sys.exit(0)
return inpt, Sources, Targets
else:
print("MSG : Error when load data.")
sys.exit(0)
def get_batch_data(en2idx, de2idx):
'''
A batch dataset generator
'''
inpt, outpt = load_data("train", en2idx, de2idx)
batch_num = len(inpt) // pm.batch_size
inpt = tf.convert_to_tensor(inpt, tf.int32)
outpt = tf.convert_to_tensor(outpt, tf.int32)
# parsing data into queue used for pipeline operations as a generator.
input_queues = tf.train.slice_input_producer([inpt, outpt])
# multi-thread processing using batch
x, y = tf.train.shuffle_batch(input_queues,
num_threads = 12,
batch_size = pm.batch_size,
capacity = pm.batch_size * 64,
min_after_dequeue = pm.batch_size * 32,
allow_smaller_final_batch = False)
return x, y, batch_num