-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathdimi-trainer.py
executable file
·124 lines (103 loc) · 3.91 KB
/
dimi-trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import sys
import itertools
if sys.version_info[0] != 3:
print("This script requires Python 3")
exit()
import scripts.dimi_io as io
import configparser
import scripts.dimi as dimi
import os
from random import randint, random
import time
import multiprocessing
def main(argv):
if len(argv) < 1:
sys.stderr.write("One required argument: <Config file|Resume directory>\n")
sys.exit(-1)
path = argv[0]
D, K, init_alpha = 0, 0, 0
if len(argv) == 3:
D, K = argv[1], argv[2]
elif len(argv) == 4:
D, K, init_alpha = argv[1], argv[2], argv[3]
if not os.path.exists(path):
sys.stderr.write("Input file/dir does not exist!\n")
sys.exit(-1)
config = configparser.ConfigParser()
input_seqs_file = None
time.sleep(random() * 10)
if os.path.isdir(path):
## Resume mode
config.read(path + "/config.ini")
out_dir = config.get('io', 'output_dir')
resume = True
else:
config.read(argv[0])
input_seqs_file = config.get('io', 'init_seqs', fallback=None)
if not input_seqs_file is None:
del config['io']['init_seqs']
out_dir = config.get('io', 'output_dir')
if not D and not K:
D = config.get('params', 'd')
K = config.get('params', 'k')
if not init_alpha:
init_alpha = config.get('params', 'init_alpha')
init_alpha = str(float(init_alpha))
config['params']['d'] = D
config['params']['k'] = K
if init_alpha:
config['params']['init_alpha'] = init_alpha
out_dir += '_D'+D+'K'+K+'A'+init_alpha
counter = itertools.count()
for i in counter:
new_out_dir = out_dir + '_{}'.format(i)
if not os.path.exists(new_out_dir):
os.makedirs(new_out_dir)
out_dir = new_out_dir
config['io']['output_dir'] = out_dir
sys.stderr.write("The output directory for this run is {}.\n".format(out_dir))
break
resume = False
with open(out_dir + "/config.ini", 'w') as configfile:
config.write(configfile)
## Write git hash of current branch to out directory
os.system('git rev-parse HEAD > %s/git-rev.txt' % (out_dir))
input_file = config.get('io', 'input_file')
working_dir = config.get('io', 'working_dir', fallback=out_dir)
dict_file = config.get('io', 'dict_file')
## Read in input file to get sequence for X
(pos_seq, word_seq) = io.read_input_file(input_file)
params = read_params(config)
params['output_dir'] = out_dir
## Store tag sequences of gold tagged sentences
gold_seq = dict()
if 'num_gold_sents' in params and params['num_gold_sents'] == 'all':
for i in range(0, len(pos_seq)):
gold_seq[i] = pos_seq[i]
else:
while len(gold_seq) < int(params.get('num_gold_sents', 0)) and len(gold_seq) < len(word_seq):
rand = randint(0, len(word_seq) - 1)
if rand not in gold_seq.keys():
gold_seq[rand] = pos_seq[rand]
word_vecs = None
if 'word_vecs_file' in params:
#sys.stderr.write("This functionality is at alpha stage and disabled in master.\n")
#sys.exit(-1)
word_vecs = io.read_word_vector_file(params.get('word_vecs_file'), io.read_dict_file(dict_file))
dimi.wrapped_sample_beam(word_seq, params, working_dir, gold_seqs=gold_seq,
word_vecs=word_vecs,
word_dict_file = dict_file, resume=resume)
def read_params(config):
params = {}
for (key, val) in config.items('io'):
params[key] = val
for (key, val) in config.items('params'):
params[key] = val
return params
if __name__ == "__main__":
try:
multiprocessing.set_start_method("fork")
except:
ctx = multiprocessing.get_start_method()
print(ctx)
main(sys.argv[1:])