diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 2c22d9e35..5efcfbd7d 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -127,7 +127,7 @@ eval_per_device_batch_size: 0 max_corpus_chars: 10_000_000 # dataset_type: c4 # must be c4, array_record or synthetic dataset_type: array_record -grain_worker_count: 0 +grain_worker_count: 1 # Training loop steps: 150_001 # If set to -1 then will inherit value from learning_rate_schedule_steps diff --git a/MaxText/input_pipeline.py b/MaxText/input_pipeline.py index 066b2d3c1..1b8211d67 100644 --- a/MaxText/input_pipeline.py +++ b/MaxText/input_pipeline.py @@ -167,24 +167,6 @@ def preprocessing_pipeline( # Return multi-host jax.Array prep iterator return multihost_gen -# def preprocessing_pipeline_lazydata( -# dataset, -# vocab_path, -# batch_size: int, -# global_mesh, -# shuffle: bool, -# num_epochs: Optional[int] = 1, -# pack_examples: bool = True, -# shuffle_buffer_size: int = 1024, -# max_length: int = 512, -# shift: bool = True, -# drop_remainder: bool = True, -# data_sharding = None, -# data_shuffle_seed = 0, -# ): -# dataset = normalize_features(dataset) -# dataset = dataset.filter(length_filter(max_length)) - def preprocessing_pipeline_pygrain( dataset, @@ -202,7 +184,7 @@ def preprocessing_pipeline_pygrain( data_sharding = None, data_shuffle_seed = 0, ): - + """Apply pygrain operations to preprocess the given dataset.""" operations = [] operations.append(pygrain_operations.ParseFeatures()) operations.append(pygrain_operations.NormalizeFeatures()) @@ -279,6 +261,7 @@ def get_datasets_pygrain( config: ml_collections.ConfigDict, read_config = None, ): + """Load dataset from array_record files for using with pygrain""" data_dir = os.path.join(config.dataset_path, config.dataset_name) train_files = [data_dir + '/' + f for f in os.listdir(data_dir) if re.match(r'.*train.*', f)] train_ds = pygrain.ArrayRecordDataSource(train_files)