diff --git a/persephone/model.py b/persephone/model.py index caa96f8..b7b1e01 100644 --- a/persephone/model.py +++ b/persephone/model.py @@ -11,7 +11,7 @@ import tensorflow as tf -from . import preprocess +from .preprocess import labels from . import utils from . import config from .exceptions import PersephoneException @@ -47,7 +47,7 @@ def dense_to_human_readable(dense_repr, index_to_label): def decode(model_path_prefix: Union[str, Path], input_paths: Sequence[Path], - labels: Set[str]) -> List[List[str]]: + label_set: Set[str]) -> List[List[str]]: model_path_prefix = str(model_path_prefix) @@ -75,7 +75,7 @@ def decode(model_path_prefix: Union[str, Path], dense_decoded = sess.run("SparseToDense:0", feed_dict=feed_dict) # Create a human-readable representation of the decoded. - indices_to_labels = preprocess.labels.make_indices_to_labels(labels) + indices_to_labels = labels.make_indices_to_labels(label_set) human_readable = dense_to_human_readable(dense_decoded, indices_to_labels) return human_readable diff --git a/persephone/tests/experiments/test_na.py b/persephone/tests/experiments/test_na.py index 452ad9b..7ab7732 100644 --- a/persephone/tests/experiments/test_na.py +++ b/persephone/tests/experiments/test_na.py @@ -12,7 +12,6 @@ from persephone import config from persephone import model from persephone import results -from persephone import run from persephone import corpus_reader from persephone import rnn_ctc from persephone import experiment @@ -27,7 +26,7 @@ # it should have a txt extension. TEST_PER_FN = "test/test_per" -def set_up_base_testing_dir(): +def set_up_base_testing_dir(data_base_dir=DATA_BASE_DIR): """ Creates a directory to store corpora and experimental directories used in testing. """