Skip to content

Commit

Permalink
No more NaN losses
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Sep 26, 2024
1 parent 06553df commit 3969aa2
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions src/nanotron/data/chat_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ def __iter__(self):
position_ids = self.create_position_ids(sample_lengths)

# TODO(tj.solergibert) Delete (debug)
# assert len(sample_tokens) <= max_buffer_token_len

yield {
"input_ids": np.array(sample_tokens, dtype=np.int32),
"is_completitions": np.array(sample_completitions, dtype=np.bool_),
"position_ids": position_ids,
}
# Don't yield samples without ANY completitions tokens as this produces NaN losses
if True in sample_completitions:
yield {
"input_ids": np.array(sample_tokens, dtype=np.int32),
"is_completitions": np.array(sample_completitions, dtype=np.bool_),
"position_ids": position_ids,
}

# TODO(tj.solergibert) Change for log_rank (log_rank is problematic with JupyterNB)
print("Consumed all samples, dataset is being re-looped.")

0 comments on commit 3969aa2

Please sign in to comment.