Skip to content

Commit

Permalink
Optional validation
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Aug 27, 2024
1 parent ce068fd commit 8e6f8ab
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 11 deletions.
6 changes: 5 additions & 1 deletion run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,11 @@ def get_args():
# Load trainer and data
trainer = DistributedTrainer(config_file)
train_dataloader = get_dataloader(trainer)
valid_dataloader = get_valid_dataloader(trainer)

# NOTE(tj.solergibert) Build validation dataloaders only if necessary
valid_dataloader = None
if trainer.config.tokens.val_check_interval != -1:
valid_dataloader = get_valid_dataloader(trainer)

# Train
trainer.train(train_dataloader, valid_dataloader)
25 changes: 17 additions & 8 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ def __post_init__(self):
@dataclass
class MultilingualNanosetDatasetsArgs:
training_folder: Union[str, dict, List[str]]
validation_folder: Union[str, List[str]]
languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB
validation_folder: Optional[Union[str, List[str]]]
languages: Optional[List[str]] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB

def __post_init__(self):
if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder
self.training_folder = [self.training_folder]
self.validation_folder = [self.validation_folder]
self.validation_folder = [self.validation_folder] if self.validation_folder is not None else None
self.dataset_weights = [1]
elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder
self.dataset_weights = None # Set to None so we consume all the samples randomly
Expand All @@ -125,20 +125,23 @@ def __post_init__(self):
self.training_folder = list(tmp_training_folder.keys())
self.dataset_weights = list(tmp_training_folder.values())

assert len(self.training_folder) == len(
self.languages
assert (
len(self.training_folder) == len(self.languages) if self.languages else True
), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})"

assert len(self.training_folder) == len(
self.validation_folder
assert (
len(self.training_folder) == len(self.validation_folder) if self.validation_folder else True
), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})"

if not self.languages and self.validation_folder:
raise ValueError(f"You must specify languages to perform the validation step w/ {self.validation_folder}")


@dataclass
class DataArgs:
"""Arguments related to the data and data files processing"""

dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, MultilingualNanosetDatasetsArgs]
dataset: Union[MultilingualNanosetDatasetsArgs]
seed: Optional[int]
num_loading_workers: Optional[int] = 1

Expand Down Expand Up @@ -416,6 +419,12 @@ def __post_init__(self):
# if self.checkpoints.lighteval is not None:
# assert self.tokenizer.tokenizer_name_or_path is not None

if not self.data_stages[0].data.dataset.validation_folder:
# NOTE(tj.solergibert) We use print NOT log_rank because at this moment the process group is not
# initialized
print("No validation data provided, skipping validation step")
self.tokens.val_check_interval = -1

@property
def global_batch_size(self):
return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp
Expand Down
7 changes: 5 additions & 2 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -507,7 +507,7 @@ def train(
],
valid_dataloader_or_dls: Dict[
str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]]
],
] = None,
**kwargs,
) -> None:
self.pre_training(**kwargs)
Expand Down Expand Up @@ -543,7 +543,10 @@ def train(
self.training_step_time = time.time()

# Validation stage
if self.iteration_step % self.config.tokens.val_check_interval == 0:
if (
self.iteration_step % self.config.tokens.val_check_interval == 0
and self.config.tokens.val_check_interval != -1
):
self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls)
val_global_loss, val_lang_losses = self.validation_step(
dataloader=self.current_validation_dataloader
Expand Down

0 comments on commit 8e6f8ab

Please sign in to comment.