From acc1ba470628eea675e3d492031bf56e2444a03f Mon Sep 17 00:00:00 2001 From: mollerhoj Date: Thu, 4 Oct 2018 14:10:37 +0200 Subject: [PATCH] fix main script --- main.py | 53 +++++++++++++++++++++++++++-------------------------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/main.py b/main.py index fd78a28..50ff2cb 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,7 @@ from data_utils import Data from models.char_cnn_zhang import CharCNNZhang from models.char_cnn_kim import CharCNNKim +from models.char_tcn import CharTCN tf.flags.DEFINE_string("model", "char_cnn_zhang", "Specifies which model to use: char_cnn_zhang or char_cnn_kim") FLAGS = tf.flags.FLAGS @@ -13,17 +14,17 @@ # Load configurations config = json.load(open("config.json")) # Load training data - training_data = Data(data_source=config["data"]["data_source"], + training_data = Data(data_source=config["data"]["training_data_source"], alphabet=config["data"]["alphabet"], input_size=config["data"]["input_size"], - num_of_classes=config["data"]["no_of_classes"]) + num_of_classes=config["data"]["num_of_classes"]) training_data.load_data() training_inputs, training_labels = training_data.get_all_data() # Load validation data - validation_data = Data(data_source=config["data"]["data_source"], + validation_data = Data(data_source=config["data"]["validation_data_source"], alphabet=config["data"]["alphabet"], input_size=config["data"]["input_size"], - num_of_classes=config["data"]["no_of_classes"]) + num_of_classes=config["data"]["num_of_classes"]) validation_data.load_data() validation_inputs, validation_labels = validation_data.get_all_data() @@ -31,34 +32,34 @@ if FLAGS.model == "kim": model = CharCNNKim(input_size=config["data"]["input_size"], alphabet_size=config["data"]["alphabet_size"], - embedding_size=config["model"]["embedding_size"], - conv_layers=config["model"]["conv_layers"], - fully_connected_layers=config["model"]["fully_connected_layers"], + embedding_size=config["char_cnn_kim"]["embedding_size"], + conv_layers=config["char_cnn_kim"]["conv_layers"], + fully_connected_layers=config["char_cnn_kim"]["fully_connected_layers"], num_of_classes=config["data"]["num_of_classes"], - dropout_p=config["model"]["dropout_p"], - optimizer=config["model"]["optimizer"], - loss=config["model"]["loss"]) + dropout_p=config["char_cnn_kim"]["dropout_p"], + optimizer=config["char_cnn_kim"]["optimizer"], + loss=config["char_cnn_kim"]["loss"]) elif FLAGS.model == 'tcn': - model = CharTCN(input_size=config["data"]["input_size"], - alphabet_size=config["data"]["alphabet_size"], - embedding_size=config["model"]["embedding_size"], - conv_layers=config["model"]["conv_layers"], - fully_connected_layers=config["model"]["fully_connected_layers"], - num_of_classes=config["data"]["num_of_classes"], - dropout_p=config["model"]["dropout_p"], - optimizer=config["model"]["optimizer"], - loss=config["model"]["loss"]) + model = CharTCN(input_size=config["data"]["input_size"], + alphabet_size=config["data"]["alphabet_size"], + embedding_size=config["char_tcn"]["embedding_size"], + conv_layers=config["char_tcn"]["conv_layers"], + fully_connected_layers=config["char_tcn"]["fully_connected_layers"], + num_of_classes=config["data"]["num_of_classes"], + dropout_p=config["char_tcn"]["dropout_p"], + optimizer=config["char_tcn"]["optimizer"], + loss=config["char_tcn"]["loss"]) else: model = CharCNNZhang(input_size=config["data"]["input_size"], alphabet_size=config["data"]["alphabet_size"], - embedding_size=config["model"]["embedding_size"], - conv_layers=config["model"]["conv_layers"], - fully_connected_layers=config["model"]["fully_connected_layers"], + embedding_size=config["char_cnn_zhang"]["embedding_size"], + conv_layers=config["char_cnn_zhang"]["conv_layers"], + fully_connected_layers=config["char_cnn_zhang"]["fully_connected_layers"], num_of_classes=config["data"]["num_of_classes"], - threshold=config["model"]["threshold"], - dropout_p=config["model"]["dropout_p"], - optimizer=config["model"]["optimizer"], - loss=config["model"]["loss"]) + threshold=config["char_cnn_zhang"]["threshold"], + dropout_p=config["char_cnn_zhang"]["dropout_p"], + optimizer=config["char_cnn_zhang"]["optimizer"], + loss=config["char_cnn_zhang"]["loss"]) # Train model model.train(training_inputs=training_inputs, training_labels=training_labels,