diff --git a/persephone/corpus.py b/persephone/corpus.py index 84f8ee1..1874e3f 100644 --- a/persephone/corpus.py +++ b/persephone/corpus.py @@ -222,6 +222,10 @@ def test_prefix_fn(self) -> Path: return self.tgt_dir / "test_prefixes.txt" def set_and_check_directories(self, tgt_dir: Path) -> None: + """ + Make sure that the required directories exist in the target directory. + set variables accordingly. + """ logger.info("Setting up directories for corpus in %s", tgt_dir) # Set the directory names @@ -243,6 +247,8 @@ def set_and_check_directories(self, tgt_dir: Path) -> None: "The supplied path requires a 'label' subdirectory.") def initialize_labels(self, labels): + + logger.debug("Creating mappings for labels") self.labels = labels self.vocab_size = len(self.labels) self.LABEL_TO_INDEX = {label: index for index, label in enumerate( @@ -253,11 +259,13 @@ def initialize_labels(self, labels): def prepare_feats(self): """ Prepares input features""" + logger.debug("Preparing input features") self.feat_dir.mkdir(parents=True, exist_ok=True) should_extract_feats = False for path in self.wav_dir.iterdir(): if not path.suffix == ".wav": + logger.info("Non wav file found in wav directory: %s", path) continue prefix = os.path.basename(os.path.splitext(str(path))[0]) mono16k_wav_path = self.feat_dir / "{}.wav".format(prefix) @@ -282,6 +290,7 @@ def make_data_splits(self, max_samples): test_f_exists = self.test_prefix_fn.is_file() if train_f_exists and valid_f_exists and test_f_exists: + logger.debug("Split for training, validation and tests specficied by files") self.train_prefixes = self.read_prefixes(self.train_prefix_fn) self.valid_prefixes = self.read_prefixes(self.valid_prefix_fn) self.test_prefixes = self.read_prefixes(self.test_prefix_fn) @@ -294,6 +303,8 @@ def make_data_splits(self, max_samples): self.feat_dir, prefixes, self.feat_type, max_samples) if not train_f_exists and not valid_f_exists and not test_f_exists: + logger.debug("No files supplied to define the split for training, validation" + " and tests. Using default.") train_prefixes, valid_prefixes, test_prefixes = self.divide_prefixes(prefixes) self.train_prefixes = train_prefixes self.valid_prefixes = valid_prefixes @@ -319,10 +330,13 @@ def make_data_splits(self, max_samples): @staticmethod def read_prefixes(prefix_fn: Path) -> List[str]: - assert prefix_fn.is_file() + if not prefix_fn.is_file(): + logger.critical("Expected a prefix file at path {}, but this path is" + " not a file".format(prefix_fn)) + assert prefix_fn.is_file(), "Path {} was not a file".format(prefix_fn) with prefix_fn.open() as prefix_f: prefixes = [line.strip() for line in prefix_f] - if prefixes == []: + if not prefixes: raise PersephoneException( "Empty prefix file {}. Either delete it\ or put something in it".format(prefix_fn)) @@ -330,7 +344,7 @@ def read_prefixes(prefix_fn: Path) -> List[str]: @staticmethod def write_prefixes(prefixes: List[str], prefix_fn: Path) -> None: - if prefixes == []: + if not prefixes: raise PersephoneException( "No prefixes. Will not write {}".format(prefix_fn)) with prefix_fn.open("w") as prefix_f: @@ -339,6 +353,7 @@ def write_prefixes(prefixes: List[str], prefix_fn: Path) -> None: @staticmethod def divide_prefixes(prefixes, seed=0): + """Divide data into training, validation and test subsets""" Ratios = namedtuple("Ratios", ["train", "valid", "test"]) ratios=Ratios(.90, .05, .05) train_end = int(ratios.train*len(prefixes)) @@ -352,9 +367,9 @@ def divide_prefixes(prefixes, seed=0): # TODO Adjust code to cope properly with toy datasets where these # subsets might actually be empty. - assert train_prefixes - assert valid_prefixes - assert test_prefixes + assert train_prefixes, "Got empty set for training data" + assert valid_prefixes, "Got empty set for validation data" + assert test_prefixes, "Got empty set for testing data" return train_prefixes, valid_prefixes, test_prefixes @@ -404,9 +419,11 @@ def get_train_fns(self): return self.prefixes_to_fns(self.train_prefixes) def get_valid_fns(self): + """ Fetches the validation set of the corpus.""" return self.prefixes_to_fns(self.valid_prefixes) def get_test_fns(self): + """ Fetches the test set of the corpus.""" return self.prefixes_to_fns(self.test_prefixes) def get_untranscribed_prefixes(self): @@ -418,6 +435,9 @@ def get_untranscribed_prefixes(self): prefixes = f.readlines() return [prefix.strip() for prefix in prefixes] + else: + logger.warning("Attempting to get untranscribed prefixes but the file ({})" + " that should specify these does not exist".format(untranscribed_prefix_fn)) return None @@ -437,7 +457,7 @@ def determine_prefixes(self) -> List[str]: # Take the intersection; sort for determinism. prefixes = sorted(list(set(label_prefixes) & set(wav_prefixes))) - if prefixes == []: + if not prefixes: raise PersephoneException("""WARNING: Corpus object has no data. Are you sure it's in the correct directories? WAVs should be in {} and transcriptions in {} with the extension .{}""".format( diff --git a/persephone/corpus_reader.py b/persephone/corpus_reader.py index 47cf0a6..3e9cb6f 100644 --- a/persephone/corpus_reader.py +++ b/persephone/corpus_reader.py @@ -71,7 +71,7 @@ def __init__(self, corpus, num_train=None, batch_size=None, max_samples=None, ra self.batch_size = 64 # For now we hope that training numbers are powers of two or # something... If not, crash before anything else happens. - assert num_train % self.batch_size == 0 + if num_train % self.batch_size != 0: logger.error("Number of training examples {} not divisible" " by batch size {}.".format(num_train, self.batch_size))