Skip to content

Commit

Permalink
Merge branch 'master' into logging-corpus
Browse files Browse the repository at this point in the history
  • Loading branch information
oadams authored May 22, 2018
2 parents 919ea05 + 81dfaab commit a79a535
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
3 changes: 3 additions & 0 deletions persephone/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,10 +493,13 @@ def ensure_no_set_overlap(self) -> None:

if train & valid:
logger.warning("train and valid have overlapping items: {}".format(train & valid))
raise PersephoneException("train and valid have overlapping items: {}".format(train & valid))
if train & test:
logger.warning("train and test have overlapping items: {}".format(train & test))
raise PersephoneException("train and test have overlapping items: {}".format(train & test))
if valid & test:
logger.warning("valid and test have overlapping items: {}".format(valid & test))
raise PersephoneException("valid and test have overlapping items: {}".format(valid & test))

def pickle(self):
""" Pickles the Corpus object in a file in tgt_dir. """
Expand Down
13 changes: 9 additions & 4 deletions persephone/corpus_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def __init__(self, corpus, num_train=None, batch_size=None, max_samples=None, ra
if batch_size:
self.batch_size = batch_size
if num_train % batch_size != 0:
raise PersephoneException("""Number of training examples %d not divisible
by batch size %d.""" % (num_train, batch_size))
logger.error("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, batch_size))
raise PersephoneException("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, batch_size))
else:
# Dynamically change batch size based on number of training
# examples.
Expand All @@ -69,9 +71,12 @@ def __init__(self, corpus, num_train=None, batch_size=None, max_samples=None, ra
self.batch_size = 64
# For now we hope that training numbers are powers of two or
# something... If not, crash before anything else happens.

if num_train % self.batch_size != 0:
logger.critical("Invalid batch size {} provided".format(self.batch_size))
assert num_train % self.batch_size == 0, "Invalid batch size"
logger.error("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, self.batch_size))
raise PersephoneException("Number of training examples {} not divisible"
" by batch size {}.".format(num_train, batch_size))

random.seed(rand_seed)

Expand Down
4 changes: 4 additions & 0 deletions persephone/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def zero_pad(matrix, to_length):
x is of length to_length."""

assert matrix.shape[0] <= to_length
if not matrix.shape[0] <= to_length:
logger.error("zero_pad cannot be performed on matrix with shape {}"
" to length {}".format(matrix.shape[0], to_length))
raise ValueError
result = np.zeros((to_length,) + matrix.shape[1:])
result[:matrix.shape[0]] = matrix
return result
Expand Down

0 comments on commit a79a535

Please sign in to comment.