forked from hjptriplebee/Chinese_poem_generator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataPretreatment.py
71 lines (66 loc) · 2.29 KB
/
dataPretreatment.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# coding: UTF-8
'''''''''''''''''''''''''''''''''''''''''''''''''''''
file name: dataPretreatment.py
create time: 2017年06月23日 星期五 17时17分36秒
author: Jipeng Huang
e-mail: [email protected]
github: https://github.com/hjptriplebee
'''''''''''''''''''''''''''''''''''''''''''''''''''''
from config import *
import numpy as np
def pretreatment(filename):
"""pretreatment"""
poems = []
file = open(filename, "r")
for line in file: #every line is a poem
#print(line)
title, poem = line.strip().split(":") #get title and poem
poem = poem.replace(' ','')
if '_' in poem or '《' in poem or '[' in poem or '(' in poem or '(' in poem:
continue
if len(poem) < 10 or len(poem) > 128: #filter poem
continue
poem = '[' + poem + ']' #add start and end signs
poems.append(poem)
print("唐诗总数: %d"%len(poems))
#counting words
allWords = {}
for poem in poems:
for word in poem:
if word not in allWords:
allWords[word] = 1
else:
allWords[word] += 1
#'''
# erase words which are not common
erase = []
for key in allWords:
if allWords[key] < 2:
erase.append(key)
for key in erase:
del allWords[key]
#'''
wordPairs = sorted(allWords.items(), key = lambda x: -x[1])
words, a= zip(*wordPairs)
#print(words)
words += (" ", )
wordToID = dict(zip(words, range(len(words)))) #word to ID
wordTOIDFun = lambda A: wordToID.get(A, len(words))
poemsVector = [([wordTOIDFun(word) for word in poem]) for poem in poems] # poem to vector
#print(poemsVector)
#padding length to batchMaxLength
batchNum = (len(poemsVector) - 1) // batchSize
X = []
Y = []
#create batch
for i in range(batchNum):
batch = poemsVector[i * batchSize: (i + 1) * batchSize]
maxLength = max([len(vector) for vector in batch])
temp = np.full((batchSize, maxLength), wordTOIDFun(" "), np.int32)
for j in range(batchSize):
temp[j, :len(batch[j])] = batch[j]
X.append(temp)
temp2 = np.copy(temp) #copy!!!!!!
temp2[:, :-1] = temp[:, 1:]
Y.append(temp2)
return X, Y, len(words) + 1, wordToID, words