diff --git a/eval.py b/eval.py index 2f95addd6..941da86be 100755 --- a/eval.py +++ b/eval.py @@ -9,29 +9,31 @@ from text_cnn import TextCNN from tensorflow.contrib import learn import csv +import argparse # Parameters # ================================================== -# Data Parameters -tf.flags.DEFINE_string("positive_data_file", "./data/rt-polaritydata/rt-polarity.pos", "Data source for the positive data.") -tf.flags.DEFINE_string("negative_data_file", "./data/rt-polaritydata/rt-polarity.neg", "Data source for the negative data.") +parser = argparse.ArgumentParser() -# Eval Parameters -tf.flags.DEFINE_integer("batch_size", 64, "Batch Size (default: 64)") -tf.flags.DEFINE_string("checkpoint_dir", "", "Checkpoint directory from training run") -tf.flags.DEFINE_boolean("eval_train", False, "Evaluate on all training data") +#Data Parameters +parser.add_argument('--positive_data_file', type=str, default='./data/rt-polaritydata/rt-polarity.pos', help='Data source for the positive data.') +parser.add_argument('--negative_data_file', type=str, default='./data/rt-polaritydata/rt-polarity.neg', help='Data source for the positive data.') -# Misc Parameters -tf.flags.DEFINE_boolean("allow_soft_placement", True, "Allow device soft device placement") -tf.flags.DEFINE_boolean("log_device_placement", False, "Log placement of ops on devices") +#Eval Parameters +parser.add_argument('--batch_size', type=int, default=64, help='Batch Size (default: 64).') +parser.add_argument('--checkpoint_dir', type=str, default=None, help='Checkpoint directory from training run.') +parser.add_argument('--eval_train', type=bool, default=False, help='Evaluate on all training data.') +#Misc Parameters +parser.add_argument('--allow_soft_placement', type=bool, default=True, help='Allow device soft device placement.') +parser.add_argument('--log_device_placement', type=bool, default=False, help='Log placement of ops on devices.') + +FLAGS = parser.parse_args() -FLAGS = tf.flags.FLAGS -FLAGS._parse_flags() print("\nParameters:") -for attr, value in sorted(FLAGS.__flags.items()): - print("{}={}".format(attr.upper(), value)) +for attr in vars(FLAGS): + print("{}={}".format(attr.upper(), getattr(FLAGS, attr))) print("") # CHANGE THIS: Load data. Load your own data here