Skip to content

Commit

Permalink
fix main script
Browse files Browse the repository at this point in the history
  • Loading branch information
mollerhoj committed Oct 4, 2018
1 parent 966513b commit acc1ba4
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from data_utils import Data
from models.char_cnn_zhang import CharCNNZhang
from models.char_cnn_kim import CharCNNKim
from models.char_tcn import CharTCN

tf.flags.DEFINE_string("model", "char_cnn_zhang", "Specifies which model to use: char_cnn_zhang or char_cnn_kim")
FLAGS = tf.flags.FLAGS
Expand All @@ -13,52 +14,52 @@
# Load configurations
config = json.load(open("config.json"))
# Load training data
training_data = Data(data_source=config["data"]["data_source"],
training_data = Data(data_source=config["data"]["training_data_source"],
alphabet=config["data"]["alphabet"],
input_size=config["data"]["input_size"],
num_of_classes=config["data"]["no_of_classes"])
num_of_classes=config["data"]["num_of_classes"])
training_data.load_data()
training_inputs, training_labels = training_data.get_all_data()
# Load validation data
validation_data = Data(data_source=config["data"]["data_source"],
validation_data = Data(data_source=config["data"]["validation_data_source"],
alphabet=config["data"]["alphabet"],
input_size=config["data"]["input_size"],
num_of_classes=config["data"]["no_of_classes"])
num_of_classes=config["data"]["num_of_classes"])
validation_data.load_data()
validation_inputs, validation_labels = validation_data.get_all_data()

# Load model configurations and build model
if FLAGS.model == "kim":
model = CharCNNKim(input_size=config["data"]["input_size"],
alphabet_size=config["data"]["alphabet_size"],
embedding_size=config["model"]["embedding_size"],
conv_layers=config["model"]["conv_layers"],
fully_connected_layers=config["model"]["fully_connected_layers"],
embedding_size=config["char_cnn_kim"]["embedding_size"],
conv_layers=config["char_cnn_kim"]["conv_layers"],
fully_connected_layers=config["char_cnn_kim"]["fully_connected_layers"],
num_of_classes=config["data"]["num_of_classes"],
dropout_p=config["model"]["dropout_p"],
optimizer=config["model"]["optimizer"],
loss=config["model"]["loss"])
dropout_p=config["char_cnn_kim"]["dropout_p"],
optimizer=config["char_cnn_kim"]["optimizer"],
loss=config["char_cnn_kim"]["loss"])
elif FLAGS.model == 'tcn':
model = CharTCN(input_size=config["data"]["input_size"],
alphabet_size=config["data"]["alphabet_size"],
embedding_size=config["model"]["embedding_size"],
conv_layers=config["model"]["conv_layers"],
fully_connected_layers=config["model"]["fully_connected_layers"],
num_of_classes=config["data"]["num_of_classes"],
dropout_p=config["model"]["dropout_p"],
optimizer=config["model"]["optimizer"],
loss=config["model"]["loss"])
model = CharTCN(input_size=config["data"]["input_size"],
alphabet_size=config["data"]["alphabet_size"],
embedding_size=config["char_tcn"]["embedding_size"],
conv_layers=config["char_tcn"]["conv_layers"],
fully_connected_layers=config["char_tcn"]["fully_connected_layers"],
num_of_classes=config["data"]["num_of_classes"],
dropout_p=config["char_tcn"]["dropout_p"],
optimizer=config["char_tcn"]["optimizer"],
loss=config["char_tcn"]["loss"])
else:
model = CharCNNZhang(input_size=config["data"]["input_size"],
alphabet_size=config["data"]["alphabet_size"],
embedding_size=config["model"]["embedding_size"],
conv_layers=config["model"]["conv_layers"],
fully_connected_layers=config["model"]["fully_connected_layers"],
embedding_size=config["char_cnn_zhang"]["embedding_size"],
conv_layers=config["char_cnn_zhang"]["conv_layers"],
fully_connected_layers=config["char_cnn_zhang"]["fully_connected_layers"],
num_of_classes=config["data"]["num_of_classes"],
threshold=config["model"]["threshold"],
dropout_p=config["model"]["dropout_p"],
optimizer=config["model"]["optimizer"],
loss=config["model"]["loss"])
threshold=config["char_cnn_zhang"]["threshold"],
dropout_p=config["char_cnn_zhang"]["dropout_p"],
optimizer=config["char_cnn_zhang"]["optimizer"],
loss=config["char_cnn_zhang"]["loss"])
# Train model
model.train(training_inputs=training_inputs,
training_labels=training_labels,
Expand Down

0 comments on commit acc1ba4

Please sign in to comment.