forked from hjptriplebee/Chinese_poem_generator
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
34 lines (32 loc) · 1.21 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# coding: UTF-8
'''''''''''''''''''''''''''''''''''''''''''''''''''''
file name: main.py
create time: 2017年06月23日 星期五 16时41分54秒
author: Jipeng Huang
e-mail: [email protected]
github: https://github.com/hjptriplebee
'''''''''''''''''''''''''''''''''''''''''''''''''''''
import argparse
import dataPretreatment
import model
from config import *
def defineArgs():
"""define args"""
parser = argparse.ArgumentParser(description = "Chinese_poem_generator.")
parser.add_argument("-m", "--mode", help = "select mode by 'train' or test or head",
choices = ["train", "test", "head"], default = "test")
return parser.parse_args()
if __name__ == "__main__":
X, Y, wordNum, wordToID, words = dataPretreatment.pretreatment(trainPoems)
args = defineArgs()
if args.mode == "train":
print("training...")
model.train(X, Y, wordNum)
else:
if args.mode == "test":
print("genrating...")
poems = model.test(wordNum, wordToID, words)
else:
characters = input("please input chinese character:")
print("genrating...")
poems = model.testHead(wordNum, wordToID, words, characters)