Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
TJ-Solergibert committed Aug 22, 2024
1 parent d5228bb commit 71122d3
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 24 deletions.
1 change: 1 addition & 0 deletions run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def get_dataloader_from_data_stage(
train_dataloader = build_nanoset_dataloader(
train_dataset,
trainer.sequence_length,
remove_document_xattention=data.dataset.remove_document_xattention,
parallel_context=trainer.parallel_context,
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
Expand Down
5 changes: 3 additions & 2 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,13 @@ def __post_init__(self):
@dataclass
class NanosetDatasetsArgs:
dataset_folder: Union[str, dict, List[str]]
remove_document_xattention: bool = False

def __post_init__(self):
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file
if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset folder
self.dataset_folder = [self.dataset_folder]
self.dataset_weights = [1]
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file
elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset folder
self.dataset_weights = None # Set to None so we consume all the samples randomly
elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights
tmp_dataset_folder = self.dataset_folder.copy()
Expand Down
9 changes: 5 additions & 4 deletions src/nanotron/data/chat_dataset.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
from typing import List

import numpy as np
from datasets import load_dataset
from datasets.distributed import split_dataset_by_node
from nanotron.data.chat_tokenizer import ChatTokenizer
from nanotron.data.collator import (
build_labels,
Expand All @@ -13,6 +11,9 @@
from torch.utils.data import IterableDataset
from transformers import AutoTokenizer

from datasets import load_dataset
from datasets.distributed import split_dataset_by_node


class ChatDataset(IterableDataset):
"""
Expand Down Expand Up @@ -116,10 +117,10 @@ def __iter__(self):
buffer_lengths = [len(tokens)]

# TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting
sample_completitions = self.create_labels(sample_tokens, sample_completitions)
sample_completitions = self.create_labels(sample_completitions)

# TODO(tj.solergibert) Delete (debug), just 4 switching the remove cross-attention setting
position_ids = self.create_position_ids(sample_lengths, self.sequence_length)
position_ids = self.create_position_ids(sample_lengths)

# TODO(tj.solergibert) Delete (debug)
# assert len(sample_tokens) <= max_buffer_token_len
Expand Down
84 changes: 68 additions & 16 deletions src/nanotron/data/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,34 @@
from nanotron.parallel.context import ParallelContext
from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer

LLAMA3_EOS_TOKEN = 128001 # NOTE(tj.solergibert) Currently, we hardcode this value as we only support Llama3 for removing the document cross attention


def build_position_ids_and_label_mask(input_ids, sequence_length):
"""
For each sample in the batch, create:
1. Position ids for each document
2. Mask eos token. Both the previous token generating the eos token and the token generated from the eos token
"""
position_ids_list = []
label_mask_list = []

for sample in input_ids:
# Position ids
document_ends = (sample == LLAMA3_EOS_TOKEN).nonzero().flatten().tolist()
document_ends.append(sequence_length)
lengths = [end - start for start, end in zip([0] + document_ends[:-1], document_ends)]
position_ids_list.append(build_position_ids(lengths))

# Label ids
label_mask = torch.ones(sequence_length, dtype=torch.bool)
for eos_token in document_ends[:-1]:
label_mask[eos_token - 1] = False
label_mask[eos_token] = False

label_mask_list.append(label_mask)
return torch.tensor(np.stack((position_ids_list))), torch.stack(label_mask_list)


@dataclass
class NanosetDataCollatorForCLM:
Expand All @@ -22,6 +50,7 @@ class NanosetDataCollatorForCLM:
input_pp_rank: int
output_pp_rank: int
parallel_context: ParallelContext
remove_document_xattention: bool

def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data.
Expand All @@ -31,13 +60,19 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni
self.output_pp_rank,
]:
assert all(len(example) == 0 for example in examples)
return {
result = {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"input_mask": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}

if self.remove_document_xattention:
result["position_ids"] = TensorPointer(group_rank=self.input_pp_rank)
else:
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)

return result

# Make sure we load only what's necessary, ie we only load a `input_ids` column.
assert all(list(example.keys()) == ["input_ids"] for example in examples)

Expand All @@ -48,23 +83,40 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni
result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {}

result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)

assert (
expanded_input_length == self.sequence_length + 1
), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}"

# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)

# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)
if self.remove_document_xattention:
# LlamaForTraining requires input_mask while LlamaForSFT requires position_ids
result["position_ids"] = TensorPointer(group_rank=self.input_pp_rank)
position_ids, label_mask = build_position_ids_and_label_mask(input_ids, self.sequence_length)
# TODO(tj.solergibert) assert shape of this 2 new tensors
# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["position_ids"] = position_ids

# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = label_mask

else:
# LlamaForTraining requires input_mask while LlamaForSFT requires position_ids
result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank)
# Process inputs: last token is the label
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)

# Process labels: shift them to the left
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool)

if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length:
raise ValueError(
Expand All @@ -81,23 +133,23 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni


# TODO(tj.solergibert) After "Beta", delete all the functs except `build_position_ids` and move `build_position_ids` to chat_dataset.py
def build_position_ids(lengths, sequence_length) -> np.array:
def build_position_ids(lengths) -> np.array:
position_ids = [list(range(length)) for length in lengths] # Create position ids list
return np.array([x for xs in position_ids for x in xs], dtype=np.int32) # Flatten list of position ids


# TODO(tj.solergibert) Delete (debug), just 4 switching the remove cross-attention setting
def build_position_ids_dummy(lengths, sequence_length) -> np.array:
def build_position_ids_dummy(lengths) -> np.array:
return np.array(list(range(sum(lengths))), dtype=np.int32) # TODO numpy arange


# TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting.
def build_labels_completions_only(input_ids, is_completitions):
def build_labels_completions_only(is_completitions):
return is_completitions


# TODO(tj.solergibert) Delete (debug), just 4 switching the training only on completitions setting
def build_labels(input_ids, is_completitions):
def build_labels(is_completitions):
return [True for _ in range(len(is_completitions))]


Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/data/dataloader_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
def build_nanoset_dataloader(
dataset,
sequence_length: int,
remove_document_xattention: bool,
parallel_context: ParallelContext,
input_pp_rank: int,
output_pp_rank: int,
Expand All @@ -37,6 +38,7 @@ def build_nanoset_dataloader(
input_pp_rank=input_pp_rank,
output_pp_rank=output_pp_rank,
parallel_context=parallel_context,
remove_document_xattention=remove_document_xattention,
)

# Compute size and rank of dataloader workers
Expand Down
5 changes: 4 additions & 1 deletion src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,10 @@ def init_model(self) -> Union[NanotronModel, DistributedDataParallel]:
def _init_model_instance(self) -> NanotronModel:
model_config_cls = self.model_config.__class__.__name__

if model_config_cls == "LlamaConfig" and isinstance(self.config.data_stages[0].data.dataset, ChatDatasetsArgs):
if model_config_cls == "LlamaConfig" and (
isinstance(self.config.data_stages[0].data.dataset, ChatDatasetsArgs)
or self.config.data_stages[0].data.dataset.remove_document_xattention
):
model_config_cls = "LlamaConfigForSFT"

assert (
Expand Down
5 changes: 4 additions & 1 deletion tools/preprocess_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def main(args):
dataset_options={"split": args.split},
)
elif args.readers == "parquet":
datatrove_reader = ParquetReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern)
datatrove_reader = ParquetReader(
data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern
)
else:
datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern)

Expand All @@ -109,6 +111,7 @@ def main(args):
output_folder=args.output_folder,
tokenizer_name_or_path=args.tokenizer_name_or_path,
eos_token=args.eos_token,
shuffle=False,
max_tokens_per_file=1e9,
),
],
Expand Down

0 comments on commit 71122d3

Please sign in to comment.