Skip to content

Commit

Permalink
update train
Browse files Browse the repository at this point in the history
  • Loading branch information
koth committed Sep 3, 2017
1 parent 94404b5 commit 303b00f
Showing 1 changed file with 20 additions and 3 deletions.
23 changes: 20 additions & 3 deletions kcws/train/train_cws.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
tf.app.flags.DEFINE_integer("train_steps", 150000, "trainning steps")
tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate")
tf.app.flags.DEFINE_bool("use_idcnn", True, "whether use the idcnn")
tf.app.flags.DEFINE_integer("track_history", 6, "track max history accuracy")


def do_load_data(path):
Expand Down Expand Up @@ -200,6 +201,7 @@ def test_evaluate(sess, unary_score, test_sequence_length, transMatrix, inp,
total_labels += sequence_length_
accuracy = 100.0 * correct_labels / float(total_labels)
print("Accuracy: %.3f%%" % accuracy)
return accuracy


def inputs(path):
Expand Down Expand Up @@ -232,6 +234,8 @@ def main(unused_argv):
with sv.managed_session(master='') as sess:
# actual training loop
training_steps = FLAGS.train_steps
trackHist = 0
bestAcc = 0
for step in range(training_steps):
if sv.should_stop():
break
Expand All @@ -244,9 +248,22 @@ def main(unused_argv):
print("[%d] loss: [%r]" %
(step + 1, sess.run(total_loss)))
if (step + 1) % 1000 == 0 or step == 0:
test_evaluate(sess, test_unary_score,
test_sequence_length, trainsMatrix,
model.inp, tX, tY)
acc = test_evaluate(sess, test_unary_score,
test_sequence_length, trainsMatrix,
model.inp, tX, tY)
if acc > bestAcc:
if step:
sv.saver.save(
sess, FLAGS.log_dir + '/best_model')
bestAcc = acc
trackHist = 0
elif trackHist > FLAGS.track_history:
print(
"always not good enough in last %d histories, best accuracy:%.3f"
% (trackHist, bestAcc))
break
else:
trackHist += 1
except KeyboardInterrupt, e:
sv.saver.save(sess,
FLAGS.log_dir + '/model',
Expand Down

0 comments on commit 303b00f

Please sign in to comment.