Skip to content

Commit

Permalink
Merge pull request #5 from fprost/dataset_input_speeding
Browse files Browse the repository at this point in the history
simple improvements to current pipeline
  • Loading branch information
fprost authored Jul 25, 2018
2 parents b7fce6a + d744065 commit b14e6f3
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions experiments/tf_trainer/common/tfrecord_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from __future__ import division
from __future__ import print_function

import multiprocessing

import tensorflow as tf
from tf_trainer.common import dataset_input
from tf_trainer.common import types
Expand All @@ -25,7 +27,8 @@ def __init__(
feature_preprocessor_init: Callable[[], Callable[[str], List[str]]],
batch_size: int = 64,
max_seq_length: int = 300,
round_labels: bool = True) -> None:
round_labels: bool = True,
num_prefetch: int = 3) -> None:
self._train_path = train_path
self._validate_path = validate_path
self._text_feature = text_feature
Expand All @@ -34,6 +37,7 @@ def __init__(
self._max_seq_length = max_seq_length
self.feature_preprocessor_init = feature_preprocessor_init
self._round_labels = round_labels
self._num_prefetch = num_prefetch

def train_input_fn(self) -> types.FeatureAndLabelTensors:
"""input_fn for TF Estimators for training set."""
Expand All @@ -50,7 +54,8 @@ def _input_fn_from_file(self, filepath: str) -> types.FeatureAndLabelTensors:
# but inside the inpout_fn function.
feature_preprocessor = self.feature_preprocessor_init()
parsed_dataset = dataset.map(
lambda x: self._read_tf_example(x, feature_preprocessor))
lambda x: self._read_tf_example(x, feature_preprocessor),
num_parallel_calls=multiprocessing.cpu_count())
batched_dataset = parsed_dataset.padded_batch(
self._batch_size,
padded_shapes=(
Expand All @@ -59,6 +64,7 @@ def _input_fn_from_file(self, filepath: str) -> types.FeatureAndLabelTensors:
self._text_feature: [None]
},
{label: [] for label in self._labels}))
batched_dataset = batched_dataset.prefetch(self._num_prefetch)

# TODO: think about what happens when we run out of examples; should we be
# using something that repeats over the dataset many time to allow
Expand Down

0 comments on commit b14e6f3

Please sign in to comment.