diff --git a/kcws/train/train_cws.py b/kcws/train/train_cws.py index e68cdeb..1dad7a1 100644 --- a/kcws/train/train_cws.py +++ b/kcws/train/train_cws.py @@ -32,6 +32,7 @@ tf.app.flags.DEFINE_integer("train_steps", 150000, "trainning steps") tf.app.flags.DEFINE_float("learning_rate", 0.001, "learning rate") tf.app.flags.DEFINE_bool("use_idcnn", True, "whether use the idcnn") +tf.app.flags.DEFINE_integer("track_history", 6, "track max history accuracy") def do_load_data(path): @@ -200,6 +201,7 @@ def test_evaluate(sess, unary_score, test_sequence_length, transMatrix, inp, total_labels += sequence_length_ accuracy = 100.0 * correct_labels / float(total_labels) print("Accuracy: %.3f%%" % accuracy) + return accuracy def inputs(path): @@ -232,6 +234,8 @@ def main(unused_argv): with sv.managed_session(master='') as sess: # actual training loop training_steps = FLAGS.train_steps + trackHist = 0 + bestAcc = 0 for step in range(training_steps): if sv.should_stop(): break @@ -244,9 +248,22 @@ def main(unused_argv): print("[%d] loss: [%r]" % (step + 1, sess.run(total_loss))) if (step + 1) % 1000 == 0 or step == 0: - test_evaluate(sess, test_unary_score, - test_sequence_length, trainsMatrix, - model.inp, tX, tY) + acc = test_evaluate(sess, test_unary_score, + test_sequence_length, trainsMatrix, + model.inp, tX, tY) + if acc > bestAcc: + if step: + sv.saver.save( + sess, FLAGS.log_dir + '/best_model') + bestAcc = acc + trackHist = 0 + elif trackHist > FLAGS.track_history: + print( + "always not good enough in last %d histories, best accuracy:%.3f" + % (trackHist, bestAcc)) + break + else: + trackHist += 1 except KeyboardInterrupt, e: sv.saver.save(sess, FLAGS.log_dir + '/model',