From 7aa3940716084dd29f5c1cf3f68397db2c5c20e8 Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 26 Jun 2024 12:38:55 +0000 Subject: [PATCH 01/24] XGLM work in progress: Causal Attention and Positional Embeddings work --- examples/xglm/__init__.py | 0 examples/xglm/convert_hf2nt.py | 28 ++ examples/xglm/tests/test_attn.py | 74 +++++ examples/xglm/tests/test_implementation.py | 90 ++++++ src/nanotron/config/models_config.py | 36 +++ src/nanotron/models/gpt3.py | 358 +++++++++++++++++++++ 6 files changed, 586 insertions(+) create mode 100644 examples/xglm/__init__.py create mode 100644 examples/xglm/convert_hf2nt.py create mode 100644 examples/xglm/tests/test_attn.py create mode 100644 examples/xglm/tests/test_implementation.py create mode 100644 src/nanotron/models/gpt3.py diff --git a/examples/xglm/__init__.py b/examples/xglm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py new file mode 100644 index 00000000..e008f859 --- /dev/null +++ b/examples/xglm/convert_hf2nt.py @@ -0,0 +1,28 @@ +import torch + +from transformers.models.xglm.modeling_xglm import XGLMAttention +from nanotron.models.gpt3 import CausalSelfAttention + + +def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): + q_ws = torch.chunk(attn_hf.q_proj.weight, attn_hf.num_heads) + k_ws = torch.chunk(attn_hf.k_proj.weight, attn_hf.num_heads) + v_ws = torch.chunk(attn_hf.v_proj.weight, attn_hf.num_heads) + + q_bs = torch.chunk(attn_hf.q_proj.bias, attn_hf.num_heads) + k_bs = torch.chunk(attn_hf.k_proj.bias, attn_hf.num_heads) + v_bs = torch.chunk(attn_hf.v_proj.bias, attn_hf.num_heads) + + qkv_w = [] + qkv_b = [] + for q_w, k_w, v_w, q_b, k_b, v_b in zip(q_ws, k_ws, v_ws, q_bs, k_bs, v_bs): + qkv_w += [q_w, k_w, v_w] + qkv_b += [q_b, k_b, v_b] + qkv_w = torch.cat(qkv_w) + qkv_b = torch.cat(qkv_b) + + with torch.no_grad(): + attn_nt.query_key_value.weight.data = qkv_w.clone() + attn_nt.query_key_value.bias.data = qkv_b.clone() + attn_nt.dense.weight.data = attn_hf.out_proj.weight.clone() + attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() diff --git a/examples/xglm/tests/test_attn.py b/examples/xglm/tests/test_attn.py new file mode 100644 index 00000000..2fcdb3a8 --- /dev/null +++ b/examples/xglm/tests/test_attn.py @@ -0,0 +1,74 @@ +import torch +from torch.nn import functional as F +#torch.Size([4, 2048, 16, 64]), torch.Size([2048, 4, 1024]) + +# inputs = (batchsize * qlen, heads, head_dim) +# outputs = (batchsize*qlen, heads, head_dim) +def sdpa(query, key, value, batchsize: int): + def reshape(tensor): # output = (batchsize, heads, qlen, head_dim) + return tensor.view(batchsize, qlen, heads, head_dim).permute(0, 2, 1, 3) + + batchsize_x_qlen, heads, head_dim = query.size() + qlen = batchsize_x_qlen//batchsize + out = F.scaled_dot_product_attention(reshape(query), reshape(key), reshape(value), is_causal=True) # (b,h,q,d) + return out.permute(0, 2, 1, 3).reshape(batchsize*qlen, heads, head_dim) + + +# inputs = (batchsize * qlen, heads, head_dim) +# outputs = (batchsize*qlen, heads, head_dim) +def fa(query_states, key_states, value_states, batchsize: int): + from flash_attn.flash_attn_interface import flash_attn_varlen_func + + batchsize_x_qlen, heads, head_dim = query_states.size() + qlen = batchsize_x_qlen//batchsize + + q_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") + kv_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") + + # TODO @thomasw21: Compute once, instead of computing for each layers. + cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) + cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) + torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) + torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) + + # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not + # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. + causal = False if q_sequence_mask.shape[1] == 1 else True + attn_output = flash_attn_varlen_func( + q=query_states, + k=key_states, + v=value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=q_sequence_mask.shape[1], + max_seqlen_k=kv_sequence_mask.shape[1], + dropout_p=0.0, + softmax_scale=None, # defaults to 1/sqrt(d_qk) + causal=causal, + window_size=(-1, -1), + return_attn_probs=False, + ) + return attn_output + + +def main(): + batchsize = 5 + qlen = 6 + heads = 2 + head_dim = 16 + + query = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + key = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + value = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) + + out_pt = sdpa(query, key, value, batchsize) + out_fa = fa(query, key, value, batchsize) + + assert out_pt.size() == out_fa.size() + + torch.testing.assert_close(out_pt, out_fa) + + + +if __name__ == "__main__": + main() diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py new file mode 100644 index 00000000..10f0302a --- /dev/null +++ b/examples/xglm/tests/test_implementation.py @@ -0,0 +1,90 @@ +import numpy as np +import torch +import pytest + +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMSinusoidalPositionalEmbedding + +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import CausalSelfAttention, PositionEmbedding +from nanotron.parallel import ParallelContext + +from tests.helpers.utils import init_distributed + +from examples.xglm.convert_hf2nt import convert_attention + + +SEQUENCE_LENGTH = 2048 +BATCH_SIZE = 4 +HIDDEN_SIZE = 1024 +DTYPE = torch.float64 + +CONFIG = GPT3Config( + attn_pdrop=0.0, + embd_pdrop=0.0, + resid_pdrop=0.0, + eos_token_id=2, + hidden_size=HIDDEN_SIZE, + intermediate_size=4096, + layer_norm_epsilon=1e-05, + max_position_embeddings=SEQUENCE_LENGTH, + num_attention_heads=16, + num_hidden_layers=24, + scale_attn_weights=True, + vocab_size=256008, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=True +) + + +@pytest.fixture +def hidden_states() -> torch.Tensor: + return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, + dtype=DTYPE) + + +@pytest.fixture +def input_mask() -> torch.Tensor: + return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) + + +def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + attn_nt = CausalSelfAttention(CONFIG, None, parallel_context.tp_pg, 0).cuda().eval().to(DTYPE) + attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) + assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) + + # Build xglm mask. + mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) + mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) + + convert_attention(attn_nt, attn_hf) + out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=mask)[0].permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_attention)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_position_embeddings(parallel_context: ParallelContext): + position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, SEQUENCE_LENGTH) + + emb_nt = PositionEmbedding(parallel_context.tp_pg, CONFIG, None).cuda() + emb_hf = XGLMSinusoidalPositionalEmbedding(SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() + + assert emb_nt.position_embedding.weight.size() == emb_hf.weights.size() + torch.testing.assert_close(emb_nt.position_embedding.weight, emb_hf.weights) + + out_nt = emb_nt(position_ids)["position_embeds"] + out_hf = emb_hf(position_ids).permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + +def test_position_embeddings(): + init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 57225243..5b8ac999 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -133,4 +133,40 @@ def n_inner(self): return self.intermediate_size +@dataclass +class GPT3Config: + """Configuration for a GPT3 model""" + + activation_function: str = "gelu" + attn_pdrop: float = 0.1 + embd_pdrop: float = 0.1 + eos_token_id: int = 49152 + hidden_size: int = 2048 + intermediate_size: Optional[int] = None + layer_norm_epsilon: float = 1e-05 + max_position_embeddings: int = 4096 + num_attention_heads: int = 16 + num_hidden_layers: int = 24 + resid_pdrop: float = 0.1 + scale_attention_softmax_in_fp32: bool = True + scale_attn_weights: bool = True + vocab_size: int = 49280 + sinusoidal_position_embedding: bool = True + position_embedding_offset: int = 2 + use_spda: bool = False + + def as_starcoder2(self) -> Starcoder2Config: + config = dict(**vars(self)) + del config["sinusoidal_position_embedding"] + del config["use_spda"] + del config["position_embedding_offset"] + return Starcoder2Config( + grouped_query=True, + num_kv_heads=self.num_attention_heads, + use_rotary_embeddings=False, + **config + ) + + NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any] + diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py new file mode 100644 index 00000000..8cea58c4 --- /dev/null +++ b/src/nanotron/models/gpt3.py @@ -0,0 +1,358 @@ +"""PyTorch GPT-3 model.""" + +import math +from typing import Optional + +import torch +from torch import nn +from torch.nn import functional as F + +from nanotron import distributed as dist +from nanotron.parallel import ParallelContext +from nanotron.config import Config, GPT3Config, ParallelismArgs +from nanotron.generation.generate_store import AttachableStore +from nanotron.models.starcoder2 import MLP as Starcoder2MLP +from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention +from nanotron.models.starcoder2 import CausalSelfGQA +from nanotron.random import RandomStates, branch_random_state +from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode +from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding +from nanotron.parallel.tied_parameters import tie_parameters + +# NOTES: +# - tie head_weight with embeddings I think. + +# TODO: +# - class GPT3Config: config lol +# - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. +# - from starcoder import Embedding +# - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding +# - class GPTBLock: very similar to starcoder2 but make it so it support non-GQA or MQA +# - from starcoder import Loss + + +class CoreAttention(Starcoder2CoreAttention): + def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + super().__init__(config.as_starcoder2(), parallel_config, layer_idx) + self.gpt3config = config + + def forward(self, + query_states: torch.Tensor, # [batch_size * q_length, q_heads, inner_dim] + key_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] + value_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] + q_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, q_length] (can be broadcasted to that size) + kv_sequence_mask: torch.Tensor, # torch.BoolTensor [batch_size, kv_length] (can be broadcasted to that size) + ): + + if self.gpt3config.use_spda: + assert torch.all(q_sequence_mask) + assert torch.all(kv_sequence_mask) + + batch_size, q_length = q_sequence_mask.size() + kv_length = kv_sequence_mask.size(1) + _, q_heads, head_dim = query_states.size() + kv_heads = key_states.size(1) + + attention_output = F.scaled_dot_product_attention( + query_states.view(batch_size, q_length, q_heads, head_dim).permute(0, 2, 1, 3), + key_states.view(batch_size, kv_length, kv_heads, head_dim).permute(0, 2, 1, 3), + value_states.view(batch_size, kv_length, kv_heads, head_dim).permute(0, 2, 1, 3), + dropout_p=self.dropout if self.training else 0.0, + is_causal=True, + ) # [batch, q_length, q_heads, head_dim] + attention_output = attention_output.permute(0, 2, 1, 3) + attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) + return attention_output + + assert query_states.dtype in {torch.bfloat16, torch.float16} + return super().forward(query_states, key_states, value_states, q_sequence_mask, kv_sequence_mask) + + +class CausalSelfAttention(CausalSelfGQA): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + layer_idx: int, + ): + super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) + self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. + self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. + + +class MLP(Starcoder2MLP): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + ): + # TODO: GPT3Config -> Starcoder2Config. + super().__init__(config, parallel_config, tp_pg) + self.dropout = nn.Dropout(p=config.dropout) # TODO: correct config.dropout name + + def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] + hidden_states = self.c_fc(hidden_states) + hidden_states = self.act(hidden_states) + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = self.dropout(input=hidden_states) + hidden_states = self.c_proj(hidden_states) + return {"hidden_states": hidden_states} + + +class GPTBlock(nn.Module): + def __init__( + self, + config: GPT3Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + random_states: RandomStates, + layer_idx: int, + ): + super(GPTBlock, self).__init__() + self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.attn = CausalSelfAttention( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx + ) + self.attn_dropout = config.attn_pdrop + + self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.ff_dropout = config.resid_pdrop + + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + def forward( + self, + hidden_states: torch.Tensor | TensorPointer, + sequence_mask: torch.Tensor | TensorPointer, + ) -> dict[str, torch.Tensor | TensorPointer]: + + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) + hidden_states = output["hidden_states"] + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) + else: + # No need for random state context manager + # TODO: add dropout scaling? + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.ln_2(hidden_states) + hidden_states = self.ff(hidden_states=hidden_states)["hidden_states"] + + if self.training: + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode is TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) + else: + # No need for random state context manager + # TODO: add dropout scaling? + hidden_states = hidden_states + residual + + return { + "hidden_states": hidden_states, + "sequence_mask": output["sequence_mask"], + } + + +class PositionEmbedding(nn.Module, AttachableStore): + def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config: Optional[ParallelismArgs]): + super().__init__() + + self.config = config + if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: + dummy_pos = 0 + else: + dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % k) + true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos + + if config.sinusoidal_position_embedding: + weight = self._make_weights(tp_pg, true_max_size, config.hidden_size) + else: + weight = None + + position_embedding = TensorParallelEmbedding( + num_embeddings=true_max_size, + embedding_dim=config.hidden_size, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + _weight=weight + ) + self.pg = tp_pg + + # Sinusoidal position embeddings are usually not trainable. + # We adjust that by setting the module self.position_embedding without gradient. + if config.sinusoidal_position_embedding: + with torch.no_grad(): + self.position_embedding = position_embedding.requires_grad_(False) + else: + self.position_embedding = position_embedding + + def forward(self, position_ids: torch.Tensor): # [batch_size, seq_length] + position_ids = position_ids.transpose(0, 1) + position_embeds = self.position_embedding(position_ids + self.config.position_embedding_offset) + return {"position_embeds": position_embeds} + + def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, + embedding_dim: int) -> torch.Tensor: + rank = dist.get_rank(group=tp_pg) + tp_size = tp_pg.size() + + assert 0 <= rank < tp_size + assert num_embeddings % tp_size == 0 + assert embedding_dim % 2 == 0 + block_size = num_embeddings//tp_size + + half_dim = embedding_dim//2 + emb = math.log(10_000)/(half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) + emb = (rank*block_size + torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) + return emb + + +class GPT3Model(nn.Module): + def __init__( + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + + self.token_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"input_ids"}, + module_output_keys={"input_embeds"}, + ) + self.position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=PositionEmbedding, + module_kwargs={ + "tp_pg": parallel_context.tp_pg, + "config": config, + "parallel_config": parallel_config, + }, + module_input_keys={"position_ids"}, + module_output_keys={"position_embeds"}, + ) + + self.embeds_dropout = PipelineBlock( + p2p=self.p2p, + module_builder=nn.Dropout, + module_kwargs={"p": config.embd_pdrop}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=GPTBlock, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "random_states": random_states, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "sequence_mask"}, + module_output_keys={"hidden_states", "sequence_mask"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=TritonLayerNorm, + module_kwargs={"normalized_shape": config.hidden_size, "eps": config.layer_norm_epsilon}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Understand that this means that we return sharded logits that are going to need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.hidden_size, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + # TODO: refactor so that we store that default in a single place. + "mode": self.tp_mode, + "async_communication": parallel_config.tp_linear_async_communication + if parallel_config is not None + else False, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + self.cast_to_fp32 = PipelineBlock( + p2p=self.p2p, + module_builder=lambda: lambda x: x.float(), + module_kwargs={}, + module_input_keys={"x"}, + module_output_keys={"output"}, + ) + + + def forward( + self, + input_ids: torch.Tensor | TensorPointer, # [batch_size, seq_length] + input_mask: torch.Tensor | TensorPointer, # [batch_size, seq_length] + ): + # all tensors are optional as most ranks don't need anything from the dataloader. + + position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) + input_embeds = self.token_embeddings(input_ids=input_ids)["input_embeds"] + position_embeds = self.position_embeds(position_ids=position_ids)["position_embeds"] + hidden_states = input_embeds + position_embeds + + with branch_random_state( + self.random_states, "tp_synced", enabled=self.tp_mode == TensorParallelLinearMode.ALL_REDUCE + ): + hidden_states = self.embeds_dropout(input=hidden_states)["hidden_states"] + + hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} + for encoder_block in self.decoder: + hidden_encoder_states = encoder_block(**hidden_encoder_states) + + hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] + + return fp32_sharded_logits From 78dd53cdfdb467961edd1a56b04d8426fd2819df Mon Sep 17 00:00:00 2001 From: AleHD Date: Wed, 26 Jun 2024 17:24:53 +0000 Subject: [PATCH 02/24] WIP: GPT arch almost done, hf->nt converters working perfectly for non-distributed inference --- examples/xglm/convert_hf2nt.py | 70 +++++++- examples/xglm/tests/test_attn.py | 74 --------- examples/xglm/tests/test_implementation.py | 135 +++++++++++++-- src/nanotron/config/models_config.py | 4 + src/nanotron/models/gpt3.py | 184 ++++++++++----------- 5 files changed, 287 insertions(+), 180 deletions(-) delete mode 100644 examples/xglm/tests/test_attn.py diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index e008f859..6e6ddff1 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -1,7 +1,44 @@ import torch +from torch import nn -from transformers.models.xglm.modeling_xglm import XGLMAttention -from nanotron.models.gpt3 import CausalSelfAttention +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM +from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining +from nanotron.config.models_config import GPT3Config + + +def convert_config(config: XGLMConfig) -> GPT3Config: + # TODOs: + # dropout=0.1, + # layerdrop=0.0, + # init_std=0.02, + # use_cache=True, + # decoder_start_token_id=2, + # pad_token_id=1, + # bos_token_id=0, + + # TODO: when going gpt3->xglm: + # - assert layernorm is 1e-05 + return GPT3Config( + activation_function=config.activation_function, + attn_pdrop=config.attention_dropout, + embd_pdrop=0.0, # TODO + eos_token_id=config.eos_token_id, + hidden_size=config.d_model, + intermediate_size=config.ffn_dim, + layer_norm_epsilon=1e-05, + max_position_embeddings=config.max_position_embeddings, + num_attention_heads=config.attention_heads, + num_hidden_layers=config.num_layers, + resid_pdrop=0.0, # TODO + scale_attention_softmax_in_fp32=True, + scale_attn_weights=True, + vocab_size=config.vocab_size, + sinusoidal_position_embedding=True, + position_embedding_offset=2, + use_spda=False, + act_pdrop=config.activation_dropout, + scale_embedding=config.scale_embedding, + ) def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): @@ -26,3 +63,32 @@ def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): attn_nt.query_key_value.bias.data = qkv_b.clone() attn_nt.dense.weight.data = attn_hf.out_proj.weight.clone() attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() + + +def convert_generic(module1: nn.Module, module2: nn.Module): + names1 = {name for name, _ in module1.named_parameters()} + names2 = {name for name, _ in module2.named_parameters()} + assert names1 == names2, f"{names1} != {names2}" + params2 = dict(module2.named_parameters()) + for name, param in module1.named_parameters(): + param.data = params2[name].clone() + + +def convert_mlp(mlp_nt: MLP, block_hf: XGLMDecoderLayer): + convert_generic(mlp_nt.c_fc, block_hf.fc1) + convert_generic(mlp_nt.c_proj, block_hf.fc2) + + +def convert_decoder(block_nt: GPTBlock, block_hf: XGLMDecoderLayer): + convert_generic(block_nt.ln_1, block_hf.self_attn_layer_norm) + convert_attention(block_nt.attn, block_hf.self_attn) + convert_generic(block_nt.ln_2, block_hf.final_layer_norm) + convert_mlp(block_nt.ff, block_hf) + + +def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): + convert_generic(model_nt.model.token_embeddings.pp_block.token_embedding, model_hf.model.embed_tokens) + for layer_nt, layer_hf in zip(model_nt.model.decoder, model_hf.model.layers): + convert_decoder(layer_nt.pp_block, layer_hf) + convert_generic(model_nt.model.final_layer_norm.pp_block, model_hf.model.layer_norm) + convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) diff --git a/examples/xglm/tests/test_attn.py b/examples/xglm/tests/test_attn.py deleted file mode 100644 index 2fcdb3a8..00000000 --- a/examples/xglm/tests/test_attn.py +++ /dev/null @@ -1,74 +0,0 @@ -import torch -from torch.nn import functional as F -#torch.Size([4, 2048, 16, 64]), torch.Size([2048, 4, 1024]) - -# inputs = (batchsize * qlen, heads, head_dim) -# outputs = (batchsize*qlen, heads, head_dim) -def sdpa(query, key, value, batchsize: int): - def reshape(tensor): # output = (batchsize, heads, qlen, head_dim) - return tensor.view(batchsize, qlen, heads, head_dim).permute(0, 2, 1, 3) - - batchsize_x_qlen, heads, head_dim = query.size() - qlen = batchsize_x_qlen//batchsize - out = F.scaled_dot_product_attention(reshape(query), reshape(key), reshape(value), is_causal=True) # (b,h,q,d) - return out.permute(0, 2, 1, 3).reshape(batchsize*qlen, heads, head_dim) - - -# inputs = (batchsize * qlen, heads, head_dim) -# outputs = (batchsize*qlen, heads, head_dim) -def fa(query_states, key_states, value_states, batchsize: int): - from flash_attn.flash_attn_interface import flash_attn_varlen_func - - batchsize_x_qlen, heads, head_dim = query_states.size() - qlen = batchsize_x_qlen//batchsize - - q_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") - kv_sequence_mask = torch.ones(batchsize, qlen, dtype=torch.bool, device="cuda") - - # TODO @thomasw21: Compute once, instead of computing for each layers. - cu_seqlens_q = torch.zeros((q_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - cu_seqlens_k = torch.zeros((kv_sequence_mask.shape[0] + 1), dtype=torch.int32, device=query_states.device) - torch.cumsum(q_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_q[1:]) - torch.cumsum(kv_sequence_mask.sum(-1, dtype=torch.int32), dim=0, dtype=torch.int32, out=cu_seqlens_k[1:]) - - # TODO(kunhao): flash attn's causal means that the query can only attend to the keys before it. This is not - # what we want if we are using kv cache. This is a hack as we always have q_length == 1 when using kv cache. - causal = False if q_sequence_mask.shape[1] == 1 else True - attn_output = flash_attn_varlen_func( - q=query_states, - k=key_states, - v=value_states, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_sequence_mask.shape[1], - max_seqlen_k=kv_sequence_mask.shape[1], - dropout_p=0.0, - softmax_scale=None, # defaults to 1/sqrt(d_qk) - causal=causal, - window_size=(-1, -1), - return_attn_probs=False, - ) - return attn_output - - -def main(): - batchsize = 5 - qlen = 6 - heads = 2 - head_dim = 16 - - query = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - key = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - value = torch.randn(batchsize*qlen, heads, head_dim, device="cuda", dtype=torch.bfloat16) - - out_pt = sdpa(query, key, value, batchsize) - out_fa = fa(query, key, value, batchsize) - - assert out_pt.size() == out_fa.size() - - torch.testing.assert_close(out_pt, out_fa) - - - -if __name__ == "__main__": - main() diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index 10f0302a..3636415b 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -1,27 +1,33 @@ +from typing import Optional + import numpy as np import torch import pytest -from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMSinusoidalPositionalEmbedding +from transformers import XGLMTokenizer +from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM +import nanotron from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import CausalSelfAttention, PositionEmbedding +from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock from nanotron.parallel import ParallelContext from tests.helpers.utils import init_distributed -from examples.xglm.convert_hf2nt import convert_attention +from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert SEQUENCE_LENGTH = 2048 BATCH_SIZE = 4 HIDDEN_SIZE = 1024 -DTYPE = torch.float64 +DTYPE = torch.bfloat16 +TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" CONFIG = GPT3Config( attn_pdrop=0.0, embd_pdrop=0.0, resid_pdrop=0.0, + act_pdrop=0.0, eos_token_id=2, hidden_size=HIDDEN_SIZE, intermediate_size=4096, @@ -42,11 +48,22 @@ def hidden_states() -> torch.Tensor: return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) - @pytest.fixture def input_mask() -> torch.Tensor: return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) +@pytest.fixture +def input_ids() -> torch.Tensor: + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH)) + + +def attention_mask() -> torch.Tensor: + # XGLM causal attention mask. + mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) + mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) + return mask + def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() @@ -56,14 +73,9 @@ def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tens attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) - # Build xglm mask. - mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) - mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) - mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) - convert_attention(attn_nt, attn_hf) out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] - out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=mask)[0].permute(1, 0, 2) + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" torch.testing.assert_close(out_nt, out_hf) @@ -88,3 +100,104 @@ def _test_position_embeddings(parallel_context: ParallelContext): def test_position_embeddings(): init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() + + +def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + config_hf = XGLMConfig() + decoder_hf = XGLMDecoderLayer(config_hf).cuda().to(DTYPE).eval() + config_nt = convert_config(config_hf) + if DTYPE not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + decoder_nt = GPTBlock(config_nt, None, parallel_context.tp_pg, random_states, 0).cuda().to(DTYPE).eval() + + convert_decoder(decoder_nt, decoder_hf) + + out_nt = decoder_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, + input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # Get hf model. + if model_hf is None: + config_hf = XGLMConfig() + model_hf = XGLMForCausalLM(config_hf).cuda().to(DTYPE).eval() + else: + model_hf = model_hf.cuda().to(DTYPE).eval() + config_hf = model_hf.config + + # Get nanotron model and make the conversion. + config_nt = convert_config(config_hf) + if DTYPE not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=DTYPE, + device="cuda", + ).eval() + convert(model_nt, model_hf) + + print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters()))/1000/1000) + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask).to(DTYPE) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt.cpu(), out_hf.cpu()) + +def _test_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + _test_model(None, parallel_context, input_ids, input_mask) + + +def test_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) + + +def _test_xglm7B(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") + _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + + +def test_xglm7B(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm7B)() + + +def _test_xglm500M(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") + _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + + +def test_xglm500M(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 5b8ac999..12bac0fb 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -154,12 +154,16 @@ class GPT3Config: sinusoidal_position_embedding: bool = True position_embedding_offset: int = 2 use_spda: bool = False + act_pdrop: float = 0.0 + scale_embedding: bool = True def as_starcoder2(self) -> Starcoder2Config: config = dict(**vars(self)) del config["sinusoidal_position_embedding"] del config["use_spda"] del config["position_embedding_offset"] + del config["act_pdrop"] + del config["scale_embedding"] return Starcoder2Config( grouped_query=True, num_kv_heads=self.num_attention_heads, diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 8cea58c4..99f6ea85 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -2,6 +2,7 @@ import math from typing import Optional +from contextlib import contextmanager import torch from torch import nn @@ -9,11 +10,15 @@ from nanotron import distributed as dist from nanotron.parallel import ParallelContext -from nanotron.config import Config, GPT3Config, ParallelismArgs +from nanotron.config import Config, GPT3Config, ParallelismArgs, Starcoder2Config from nanotron.generation.generate_store import AttachableStore +from nanotron.models import starcoder2 +from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.models.starcoder2 import MLP as Starcoder2MLP +from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention -from nanotron.models.starcoder2 import CausalSelfGQA +from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock +from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel from nanotron.random import RandomStates, branch_random_state from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode @@ -28,10 +33,55 @@ # - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. # - from starcoder import Embedding # - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding -# - class GPTBLock: very similar to starcoder2 but make it so it support non-GQA or MQA +# - class GPTBlock: very similar to starcoder2 but make it so it support non-GQA or MQA # - from starcoder import Loss +@contextmanager +def replace_coreattention(gpt3config: GPT3Config): + orig = starcoder2.CoreAttention + try: + def create_core_attention(config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + return CoreAttention(gpt3config, parallel_config, layer_idx) + starcoder2.CoreAttention = create_core_attention + yield + finally: + starcoder2.CoreAttention = orig + + +@contextmanager +def replace_decoder(gpt3config: GPT3Config): + orig = starcoder2.PipelineBlock + try: + def create_pp_block(module_builder, module_kwargs, **kwargs): + if module_builder is Starcoder2GPTBlock: + # Starcoder2's GPT module is trying to instantiate a Starcoder2 GPTBlock. + # Let's return a PipelineBlock with a GPT3Block instead. + # This also requires to replace starcoders2's config with gpt3's config. + module_kwargs["config"] = gpt3config + return orig(module_builder=GPTBlock, module_kwargs=module_kwargs, **kwargs) + # Else, they are setting up other modules, which we also want unchanged. + return orig(module_builder=module_builder, module_kwargs=module_kwargs, **kwargs) + + starcoder2.PipelineBlock = create_pp_block + yield + finally: + starcoder2.PipelineBlock = orig + + +@contextmanager +def replace_gpt3model(gpt3config: GPT3Config): + orig = starcoder2.GPTModel + try: + def create_gptmodel(config: Starcoder2Config, parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], random_states: RandomStates): + return GPT3Model(gpt3config, parallel_context, parallel_config, random_states) + starcoder2.GPTModel = create_gptmodel + yield + finally: + starcoder2.GPTModel = orig + + class CoreAttention(Starcoder2CoreAttention): def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): super().__init__(config.as_starcoder2(), parallel_config, layer_idx) @@ -63,7 +113,7 @@ def forward(self, ) # [batch, q_length, q_heads, head_dim] attention_output = attention_output.permute(0, 2, 1, 3) attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) - return attention_output + return attention_output.contiguous() assert query_states.dtype in {torch.bfloat16, torch.float16} return super().forward(query_states, key_states, value_states, q_sequence_mask, kv_sequence_mask) @@ -77,9 +127,10 @@ def __init__( tp_pg: dist.ProcessGroup, layer_idx: int, ): - super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) + with replace_coreattention(config): + super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. - self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. + #self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. class MLP(Starcoder2MLP): @@ -88,10 +139,12 @@ def __init__( config: GPT3Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + random_states: RandomStates ): - # TODO: GPT3Config -> Starcoder2Config. - super().__init__(config, parallel_config, tp_pg) - self.dropout = nn.Dropout(p=config.dropout) # TODO: correct config.dropout name + super().__init__(config.as_starcoder2(), parallel_config, tp_pg) + self.dropout = nn.Dropout(p=config.act_pdrop) + self.random_states = random_states + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE def forward(self, hidden_states): # [seq_length, batch_size, hidden_dim] hidden_states = self.c_fc(hidden_states) @@ -113,6 +166,7 @@ def __init__( random_states: RandomStates, layer_idx: int, ): + #print("New gpt block created :D") super(GPTBlock, self).__init__() self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention( @@ -124,7 +178,7 @@ def __init__( self.attn_dropout = config.attn_pdrop self.ln_2 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) - self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg) + self.ff = MLP(config=config, parallel_config=parallel_config, tp_pg=tp_pg, random_states=random_states) self.ff_dropout = config.resid_pdrop self.random_states = random_states @@ -138,8 +192,10 @@ def forward( residual = hidden_states hidden_states = self.ln_1(hidden_states) + #hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] + #return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} if self.training: with branch_random_state( @@ -227,7 +283,7 @@ def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, return emb -class GPT3Model(nn.Module): +class GPT3Model(GPTModel): def __init__( self, config: GPT3Config, @@ -235,24 +291,9 @@ def __init__( parallel_config: Optional[ParallelismArgs], random_states: RandomStates, ): - super().__init__() + with replace_decoder(config): + super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) - # Declare all the nodes - self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) - self.random_states = random_states - self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - - self.token_embeddings = PipelineBlock( - p2p=self.p2p, - module_builder=Embedding, - module_kwargs={ - "tp_pg": parallel_context.tp_pg, - "config": config, - "parallel_config": parallel_config, - }, - module_input_keys={"input_ids"}, - module_output_keys={"input_embeds"}, - ) self.position_embeddings = PipelineBlock( p2p=self.p2p, module_builder=PositionEmbedding, @@ -264,69 +305,7 @@ def __init__( module_input_keys={"position_ids"}, module_output_keys={"position_embeds"}, ) - - self.embeds_dropout = PipelineBlock( - p2p=self.p2p, - module_builder=nn.Dropout, - module_kwargs={"p": config.embd_pdrop}, - module_input_keys={"input"}, - module_output_keys={"hidden_states"}, - ) - - self.decoder = nn.ModuleList( - [ - PipelineBlock( - p2p=self.p2p, - module_builder=GPTBlock, - module_kwargs={ - "config": config, - "parallel_config": parallel_config, - "tp_pg": parallel_context.tp_pg, - "random_states": random_states, - "layer_idx": layer_idx, - }, - module_input_keys={"hidden_states", "sequence_mask"}, - module_output_keys={"hidden_states", "sequence_mask"}, - ) - for layer_idx in range(config.num_hidden_layers) - ] - ) - - self.final_layer_norm = PipelineBlock( - p2p=self.p2p, - module_builder=TritonLayerNorm, - module_kwargs={"normalized_shape": config.hidden_size, "eps": config.layer_norm_epsilon}, - module_input_keys={"input"}, - module_output_keys={"hidden_states"}, - ) - - self.lm_head = PipelineBlock( - p2p=self.p2p, - # Understand that this means that we return sharded logits that are going to need to be gathered - module_builder=TensorParallelColumnLinear, - module_kwargs={ - "in_features": config.hidden_size, - "out_features": config.vocab_size, - "pg": parallel_context.tp_pg, - "bias": False, - # TODO: refactor so that we store that default in a single place. - "mode": self.tp_mode, - "async_communication": parallel_config.tp_linear_async_communication - if parallel_config is not None - else False, - }, - module_input_keys={"x"}, - module_output_keys={"logits"}, - ) - - self.cast_to_fp32 = PipelineBlock( - p2p=self.p2p, - module_builder=lambda: lambda x: x.float(), - module_kwargs={}, - module_input_keys={"x"}, - module_output_keys={"output"}, - ) - + self.embed_scale = config.hidden_size**0.5 if config.scale_embedding else 1.0 def forward( self, @@ -335,9 +314,9 @@ def forward( ): # all tensors are optional as most ranks don't need anything from the dataloader. + input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) - input_embeds = self.token_embeddings(input_ids=input_ids)["input_embeds"] - position_embeds = self.position_embeds(position_ids=position_ids)["position_embeds"] + position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] hidden_states = input_embeds + position_embeds with branch_random_state( @@ -348,6 +327,7 @@ def forward( hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) + #return hidden_encoder_states["hidden_states"] hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] @@ -356,3 +336,21 @@ def forward( fp32_sharded_logits = self.cast_to_fp32(x=sharded_logits)["output"] return fp32_sharded_logits + + +# TODO: maybe reimplement: +# - tie_custom_params +# - get_embeddings_lm_head_tied_names +# - get_block_compute_costs +# - get_flops_per_sec +class GPT3ForTraining(Starcoder2ForTraining): + def __init__( + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): + with replace_gpt3model(config): + super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) + From a74c71ad3a56a33501153c7bd00f4418d4ef1cb6 Mon Sep 17 00:00:00 2001 From: AleHD Date: Tue, 9 Jul 2024 16:46:55 +0200 Subject: [PATCH 03/24] Added hf2nt frontend + tested training --- examples/xglm/README.md | 13 +++ examples/xglm/convert_hf2nt.py | 86 ++++++++++++++-- examples/xglm/example_config.yaml | 98 +++++++++++++++++++ src/nanotron/config/models_config.py | 6 +- src/nanotron/models/gpt3.py | 23 +---- .../optimizer_from_gradient_accumulator.py | 3 +- src/nanotron/trainer.py | 2 + 7 files changed, 199 insertions(+), 32 deletions(-) create mode 100644 examples/xglm/README.md create mode 100644 examples/xglm/example_config.yaml diff --git a/examples/xglm/README.md b/examples/xglm/README.md new file mode 100644 index 00000000..abc50f95 --- /dev/null +++ b/examples/xglm/README.md @@ -0,0 +1,13 @@ +# How to use XGLM? + +1. First, make sure to convert the weights from huggingface, for instance: + ``` + torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M + ``` + +1. Now you are ready to use XGLM. + Make sure you use a .yaml configuration with proper GPT3 config and then run for instance: + ``` + torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml + ``` + If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 6e6ddff1..9db5ed93 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -1,27 +1,42 @@ +""" +Converts a HF model to nanotron format +Command: + torchrun --nproc-per-node=1 convert_hf2nt.py --checkpoint-path=hf_weights --save-path=nanotron_weights +""" + +import json +import warnings +import dataclasses +from argparse import ArgumentParser +from pathlib import Path + import torch from torch import nn - from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM + +import nanotron from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from nanotron.config.models_config import GPT3Config +from nanotron.trainer import mark_tied_parameters + def convert_config(config: XGLMConfig) -> GPT3Config: # TODOs: - # dropout=0.1, # layerdrop=0.0, # init_std=0.02, # use_cache=True, - # decoder_start_token_id=2, # pad_token_id=1, # bos_token_id=0, - - # TODO: when going gpt3->xglm: - # - assert layernorm is 1e-05 + if config.dropout != config.attention_dropout: + warnings.warn(f"huggingface.dropout = {config.dropout} does not match with " + f"huggingface.attention_dropout = {config.attention_dropout}. " + "Nanotron implementation needs these two values to be equal " + "for correct conversion.") return GPT3Config( activation_function=config.activation_function, attn_pdrop=config.attention_dropout, - embd_pdrop=0.0, # TODO + embd_pdrop=config.dropout, eos_token_id=config.eos_token_id, hidden_size=config.d_model, intermediate_size=config.ffn_dim, @@ -29,12 +44,12 @@ def convert_config(config: XGLMConfig) -> GPT3Config: max_position_embeddings=config.max_position_embeddings, num_attention_heads=config.attention_heads, num_hidden_layers=config.num_layers, - resid_pdrop=0.0, # TODO + resid_pdrop=config.dropout, scale_attention_softmax_in_fp32=True, scale_attn_weights=True, vocab_size=config.vocab_size, sinusoidal_position_embedding=True, - position_embedding_offset=2, + position_embedding_offset=config.decoder_start_token_id, use_spda=False, act_pdrop=config.activation_dropout, scale_embedding=config.scale_embedding, @@ -92,3 +107,56 @@ def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): convert_decoder(layer_nt.pp_block, layer_hf) convert_generic(model_nt.model.final_layer_norm.pp_block, model_hf.model.layer_norm) convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) + + +def create_nt_model(model_config: GPT3Config, device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16) -> GPT3ForTraining: + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + #random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + return model_nt + + +def main(hf_path: str, save_path: Path): + # Load hf. + print("Loading hf...") + model_hf = XGLMForCausalLM.from_pretrained(hf_path) + + # Init nanotron. + print("Initializing nt...") + config_nt = convert_config(model_hf.config) + model_nt = create_nt_model(config_nt) + + # Copy weights and save model. + print("Copying weights...") + convert(model_nt, model_hf) + nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, + root_folder=save_path) + with open(save_path/"model_config.json", "w+") as f: + json.dump(dataclasses.asdict(config_nt), f) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument("--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint") + parser.add_argument("--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model") + args = parser.parse_args() + main(args.checkpoint_path, args.save_path) diff --git a/examples/xglm/example_config.yaml b/examples/xglm/example_config.yaml new file mode 100644 index 00000000..2d7e9926 --- /dev/null +++ b/examples/xglm/example_config.yaml @@ -0,0 +1,98 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints/xglm + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_overwrite_cache: false + dataset_processing_num_proc_per_process: 64 + hf_dataset_config_name: null + hf_dataset_or_datasets: DKYoon/SlimPajama-6B + hf_dataset_splits: train + text_column_name: text + num_loading_workers: 1 + seed: 42 + name: Finetuning + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: xglm-test + run: xglm-dp4tp1pp1 + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + path: /capstor/scratch/cscs/ahernnde/checkpoints/xglm-564M + make_vocab_size_divisible_by: 1 + model_config: + activation_function: gelu + attn_pdrop: 0.1 + embd_pdrop: 0.1 + scale_embedding: true + eos_token_id: 2 + hidden_size: 1024 + intermediate_size: 4096 + layer_norm_epsilon: 0.00001 + max_position_embeddings: 2048 + num_attention_heads: 16 + num_hidden_layers: 24 + resid_pdrop: 0.1 + scale_attention_softmax_in_fp32: true + scale_attn_weights: true + vocab_size: 256008 + sinusoidal_position_embedding: true + position_embedding_offset: 2 + use_spda: false + act_pdrop: 0.0 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 900 + lr_decay_style: cosine + lr_warmup_steps: 100 + lr_warmup_style: linear + min_decay_lr: 1.0e-04 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 4 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: facebook/xglm-564M + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 4 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 8 + sequence_length: 2048 + train_steps: 1000 + val_check_interval: -1 diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 12bac0fb..6c568e80 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -164,6 +164,8 @@ def as_starcoder2(self) -> Starcoder2Config: del config["position_embedding_offset"] del config["act_pdrop"] del config["scale_embedding"] + if "_is_using_mup" in config: + del config["_is_using_mup"] return Starcoder2Config( grouped_query=True, num_kv_heads=self.num_attention_heads, @@ -171,6 +173,4 @@ def as_starcoder2(self) -> Starcoder2Config: **config ) - -NanotronConfigs = Union[LlamaConfig, Starcoder2Config, Any] - +NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 99f6ea85..33661c8b 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -18,24 +18,13 @@ from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock -from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel +from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel, dropout_add_fused_train from nanotron.random import RandomStates, branch_random_state from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding from nanotron.parallel.tied_parameters import tie_parameters -# NOTES: -# - tie head_weight with embeddings I think. - -# TODO: -# - class GPT3Config: config lol -# - check that attention (i.e. nanotron.attn vs xglm.self_attn) is the same. -# - from starcoder import Embedding -# - class PositionEmbedding: my sinusoidal embedding extends from TensorParallelEmbedding -# - class GPTBlock: very similar to starcoder2 but make it so it support non-GQA or MQA -# - from starcoder import Loss - @contextmanager def replace_coreattention(gpt3config: GPT3Config): @@ -130,7 +119,6 @@ def __init__( with replace_coreattention(config): super().__init__(config.as_starcoder2(), parallel_config, tp_pg, layer_idx) self.maybe_rotary = lambda q, k, **_: (q, k) # Overwrite possible rotary with identity. - #self.attention = CoreAttention(config, parallel_config=parallel_config, layer_idx=layer_idx) # Use our custom CoreAttention. class MLP(Starcoder2MLP): @@ -204,7 +192,6 @@ def forward( hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.attn_dropout) else: # No need for random state context manager - # TODO: add dropout scaling? hidden_states = hidden_states + residual residual = hidden_states @@ -218,7 +205,6 @@ def forward( hidden_states = dropout_add_fused_train(hidden_states, residual=residual, prob=self.ff_dropout) else: # No need for random state context manager - # TODO: add dropout scaling? hidden_states = hidden_states + residual return { @@ -235,7 +221,7 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: dummy_pos = 0 else: - dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % k) + dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size()) true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos if config.sinusoidal_position_embedding: @@ -278,7 +264,7 @@ def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, half_dim = embedding_dim//2 emb = math.log(10_000)/(half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) - emb = (rank*block_size + torch.arange(num_embeddings, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = (rank*block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) return emb @@ -315,6 +301,7 @@ def forward( # all tensors are optional as most ranks don't need anything from the dataloader. input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale + # TODO: position_ids could be cached. position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] hidden_states = input_embeds + position_embeds @@ -339,8 +326,6 @@ def forward( # TODO: maybe reimplement: -# - tie_custom_params -# - get_embeddings_lm_head_tied_names # - get_block_compute_costs # - get_flops_per_sec class GPT3ForTraining(Starcoder2ForTraining): diff --git a/src/nanotron/optim/optimizer_from_gradient_accumulator.py b/src/nanotron/optim/optimizer_from_gradient_accumulator.py index 01be7cb5..9883c720 100644 --- a/src/nanotron/optim/optimizer_from_gradient_accumulator.py +++ b/src/nanotron/optim/optimizer_from_gradient_accumulator.py @@ -38,7 +38,8 @@ def __init__( **{k: v for k, v in named_param_group.items() if k != "named_params"}, "named_params": [ (name, gradient_accumulator.get_parameter_for_optimizer(name)) - for name, _ in named_param_group["named_params"] + for name, param in named_param_group["named_params"] + if param.requires_grad ], } for named_param_group in named_param_groups diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index b6752f38..f01caa3e 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -58,6 +58,7 @@ from nanotron.models.base import check_model_has_grad from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp @@ -103,6 +104,7 @@ CONFIG_TO_MODEL_CLASS = { "LlamaConfig": LlamaForTraining, "Starcoder2Config": Starcoder2ForTraining, + "GPT3Config": GPT3ForTraining, } try: From 04eaef956a091bbcc40e5c2ef140aad7b577f003 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 11 Jul 2024 13:38:52 +0200 Subject: [PATCH 04/24] Added nt2hf conversion + tests :) --- examples/xglm/README.md | 5 + examples/xglm/convert_hf2nt.py | 38 +---- examples/xglm/convert_nt2hf.py | 126 +++++++++++++++ examples/xglm/convert_utils.py | 59 +++++++ examples/xglm/tests/test_implementation.py | 177 +++++++++++++++++---- src/nanotron/config/models_config.py | 4 + src/nanotron/models/gpt3.py | 2 +- 7 files changed, 347 insertions(+), 64 deletions(-) create mode 100644 examples/xglm/convert_nt2hf.py create mode 100644 examples/xglm/convert_utils.py diff --git a/examples/xglm/README.md b/examples/xglm/README.md index abc50f95..22765f52 100644 --- a/examples/xglm/README.md +++ b/examples/xglm/README.md @@ -11,3 +11,8 @@ torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml ``` If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`. + +1. If you want to convert your finetuned checkpoint back to huggingface use: + ``` + torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M + ``` diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 9db5ed93..0efcceca 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -18,11 +18,11 @@ from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from nanotron.config.models_config import GPT3Config from nanotron.trainer import mark_tied_parameters - +from examples.xglm.convert_utils import convert_generic, create_nt_model def convert_config(config: XGLMConfig) -> GPT3Config: - # TODOs: + # These settings seem to be unused: # layerdrop=0.0, # init_std=0.02, # use_cache=True, @@ -80,15 +80,6 @@ def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention): attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone() -def convert_generic(module1: nn.Module, module2: nn.Module): - names1 = {name for name, _ in module1.named_parameters()} - names2 = {name for name, _ in module2.named_parameters()} - assert names1 == names2, f"{names1} != {names2}" - params2 = dict(module2.named_parameters()) - for name, param in module1.named_parameters(): - param.data = params2[name].clone() - - def convert_mlp(mlp_nt: MLP, block_hf: XGLMDecoderLayer): convert_generic(mlp_nt.c_fc, block_hf.fc1) convert_generic(mlp_nt.c_proj, block_hf.fc2) @@ -109,31 +100,6 @@ def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM): convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head) -def create_nt_model(model_config: GPT3Config, device: torch.device = torch.device("cuda"), - dtype: torch.dtype = torch.bfloat16) -> GPT3ForTraining: - - parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) - parallel_context = nanotron.parallel.ParallelContext( - data_parallel_size=parallel_config.dp, - pipeline_parallel_size=parallel_config.pp, - tensor_parallel_size=parallel_config.tp, - ) - #random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) - model_nt = nanotron.models.build_model( - model_builder=lambda: GPT3ForTraining( - config=model_config, - parallel_context=parallel_context, - parallel_config=parallel_config, - random_states=None, - ), - parallel_context=parallel_context, - dtype=dtype, - device=device, - ) - mark_tied_parameters(model=model_nt, parallel_context=parallel_context) - return model_nt - - def main(hf_path: str, save_path: Path): # Load hf. print("Loading hf...") diff --git a/examples/xglm/convert_nt2hf.py b/examples/xglm/convert_nt2hf.py new file mode 100644 index 00000000..422695a1 --- /dev/null +++ b/examples/xglm/convert_nt2hf.py @@ -0,0 +1,126 @@ +""" +Converts a nanotron model to HF format +Command: + torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights +""" + +from argparse import ArgumentParser +from typing import Optional +from pathlib import Path + +import torch +from transformers import AutoTokenizer +from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM + +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining +from examples.xglm.convert_utils import convert_generic, create_nt_model + + +def convert_config(config: GPT3Config) -> XGLMConfig: + if config.embd_pdrop != config.resid_pdrop: + warnings.warn(f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion.") + if config.layer_norm_epsilon != 1e-5: + warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") + return XGLMConfig( + activation_function=config.activation_function, + attention_dropout=config.attn_pdrop, + dropout=config.embd_pdrop, + eos_token_id=config.eos_token_id, + d_model=config.hidden_size, + ffn_dim=config.intermediate_size, + max_position_embeddings=config.max_position_embeddings, + attention_heads=config.num_attention_heads, + num_layers=config.num_hidden_layers, + vocab_size=config.vocab_size, + decoder_start_token_id=config.position_embedding_offset, + activation_dropout=config.act_pdrop, + scale_embedding=config.scale_embedding, + ) + + +def convert_attention(attn_hf: XGLMAttention, attn_nt: XGLMAttention): + qs_w = [] + ks_w = [] + vs_w = [] + qs_b = [] + ks_b = [] + vs_b = [] + + head_dim = attn_hf.head_dim + qkv_ws = list(attn_nt.query_key_value.weight.split(head_dim)) + qkv_bs = list(attn_nt.query_key_value.bias.split(head_dim)) + for i, (w, b) in enumerate(zip(qkv_ws, qkv_bs)): + if i % 3 == 0: + qs_w.append(w) + qs_b.append(b) + elif i % 3 == 1: + ks_w.append(w) + ks_b.append(b) + else: + vs_w.append(w) + vs_b.append(b) + + q_w = torch.cat(qs_w) + k_w = torch.cat(ks_w) + v_w = torch.cat(vs_w) + q_b = torch.cat(qs_b) + k_b = torch.cat(ks_b) + v_b = torch.cat(vs_b) + + with torch.no_grad(): + attn_hf.q_proj.weight.data = q_w.clone() + attn_hf.k_proj.weight.data = k_w.clone() + attn_hf.v_proj.weight.data = v_w.clone() + attn_hf.q_proj.bias.data = q_b.clone() + attn_hf.k_proj.bias.data = k_b.clone() + attn_hf.v_proj.bias.data = v_b.clone() + + attn_hf.out_proj.weight.data = attn_nt.dense.weight.data.clone() + attn_hf.out_proj.bias.data = attn_nt.dense.bias.data.clone() + + +def convert_decoder(block_hf: XGLMDecoderLayer, block_nt: GPTBlock): + convert_generic(block_hf.self_attn_layer_norm, block_nt.ln_1) + convert_attention(block_hf.self_attn, block_nt.attn) + convert_generic(block_hf.final_layer_norm, block_nt.ln_2) + convert_generic(block_hf.fc1, block_nt.ff.c_fc) + convert_generic(block_hf.fc2, block_nt.ff.c_proj) + + +def convert(model_hf: XGLMForCausalLM, model_nt: GPT3ForTraining): + convert_generic(model_hf.model.embed_tokens, model_nt.model.token_embeddings.pp_block.token_embedding) + for layer_hf, layer_nt in zip(model_hf.model.layers, model_nt.model.decoder): + convert_decoder(layer_hf, layer_nt.pp_block) + convert_generic(model_hf.model.layer_norm, model_nt.model.final_layer_norm.pp_block) + convert_generic(model_hf.lm_head, model_nt.model.lm_head.pp_block) + + +def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): + # Load nanotron model. + model_nt = create_nt_model(checkpoint_path=checkpoint_path) + + # Init huggingface model. + model_config_hf = convert_config(model_nt.config) + model_hf = XGLMForCausalLM._from_config(model_config_hf) + + # Copy weights, initialize tokenizer and save model. + if tokenizer_name is not None: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) + tokenizer.save_pretrained(save_path) + convert(model_hf, model_nt) + model_hf.save_pretrained(save_path) + print(f"Model saved to {save_path}") + + +if __name__ == "__main__": + parser = ArgumentParser(description="Convert HF weights to nanotron format") + parser.add_argument("--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint") + parser.add_argument("--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model") + parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") + args = parser.parse_args() + main(args.checkpoint_path, args.save_path, args.tokenizer_name) + diff --git a/examples/xglm/convert_utils.py b/examples/xglm/convert_utils.py new file mode 100644 index 00000000..88a731a1 --- /dev/null +++ b/examples/xglm/convert_utils.py @@ -0,0 +1,59 @@ +import json +from pathlib import Path +from typing import Optional + +import torch +from torch import nn + +import nanotron +from nanotron.models.gpt3 import GPT3ForTraining +from nanotron.config.models_config import GPT3Config +from nanotron.trainer import mark_tied_parameters + + +def convert_generic(module1: nn.Module, module2: nn.Module): + names1 = {name for name, _ in module1.named_parameters()} + names2 = {name for name, _ in module2.named_parameters()} + assert names1 == names2, f"{names1} != {names2}" + params2 = dict(module2.named_parameters()) + for name, param in module1.named_parameters(): + param.data = params2[name].clone() + + +def create_nt_model( + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None + ): + + if model_config is None: + assert checkpoint_path is not None + with open(checkpoint_path / "model_config.json") as f: + model_config = GPT3Config(**json.load(f)) + + parallel_config = nanotron.config.ParallelismArgs(dp=1, pp=1, tp=1) + parallel_context = nanotron.parallel.ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=model_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + random_states=None, + ), + parallel_context=parallel_context, + dtype=dtype, + device=device, + ) + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + if checkpoint_path is not None: + nanotron.serialize.load_weights( + model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path + ) + + return model_nt diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index 3636415b..d9dc0f85 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -8,6 +8,7 @@ from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM import nanotron +from nanotron.trainer import mark_tied_parameters from nanotron.config.models_config import GPT3Config from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock from nanotron.parallel import ParallelContext @@ -15,12 +16,17 @@ from tests.helpers.utils import init_distributed from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert +from examples.xglm.convert_nt2hf import convert_attention as convert_attention_nt2hf +from examples.xglm.convert_nt2hf import convert_config as convert_config_nt2hf +from examples.xglm.convert_nt2hf import convert_decoder as convert_decoder_nt2hf +from examples.xglm.convert_nt2hf import convert as convert_nt2hf -SEQUENCE_LENGTH = 2048 +MAX_SEQUENCE_LENGTH = 2048 +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. BATCH_SIZE = 4 HIDDEN_SIZE = 1024 -DTYPE = torch.bfloat16 +DTYPE = torch.float64 TEXT = "Hello. This is a relatively long text. I will use this text to test the conversion scripts. Let's finish this text soon because I don't have much more to say. Final note:" CONFIG = GPT3Config( @@ -32,7 +38,7 @@ hidden_size=HIDDEN_SIZE, intermediate_size=4096, layer_norm_epsilon=1e-05, - max_position_embeddings=SEQUENCE_LENGTH, + max_position_embeddings=MAX_SEQUENCE_LENGTH, num_attention_heads=16, num_hidden_layers=24, scale_attn_weights=True, @@ -45,25 +51,39 @@ @pytest.fixture def hidden_states() -> torch.Tensor: - return torch.randn(SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) @pytest.fixture def input_mask() -> torch.Tensor: - return torch.ones(BATCH_SIZE, SEQUENCE_LENGTH, dtype=torch.bool) + return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) @pytest.fixture def input_ids() -> torch.Tensor: - return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, SEQUENCE_LENGTH)) + return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) + + +def almost_close(t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-5, rtol: float = 0.016, + max_far: float = 0.0, far_atol: float = 0.01): + very_close = torch.abs(t1 - t2) <= atol + rtol*torch.abs(t2) + not_very_close = ~very_close + + if torch.all(very_close): + return + assert torch.mean(not_very_close.float()) <= max_far, f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" + assert torch.all(torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" def attention_mask() -> torch.Tensor: # XGLM causal attention mask. - mask = torch.ones(SEQUENCE_LENGTH, SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) + mask = torch.ones(TEST_SEQUENCE_LENGTH, TEST_SEQUENCE_LENGTH, dtype=torch.bool, device="cuda").tril(diagonal=0) mask = torch.where(mask, 0.0, -np.inf).to(DTYPE) mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) return mask +## +# FROM HERE DOWN (until next comment), all tests are hf->nt +## def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() @@ -85,10 +105,10 @@ def test_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): def _test_position_embeddings(parallel_context: ParallelContext): - position_ids = torch.arange(SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, SEQUENCE_LENGTH) + position_ids = torch.arange(TEST_SEQUENCE_LENGTH, device="cuda").unsqueeze(0) # shape = (1, TEST_SEQUENCE_LENGTH) emb_nt = PositionEmbedding(parallel_context.tp_pg, CONFIG, None).cuda() - emb_hf = XGLMSinusoidalPositionalEmbedding(SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() + emb_hf = XGLMSinusoidalPositionalEmbedding(MAX_SEQUENCE_LENGTH, HIDDEN_SIZE).cuda() assert emb_nt.position_embedding.weight.size() == emb_hf.weights.size() torch.testing.assert_close(emb_nt.position_embedding.weight, emb_hf.weights) @@ -120,7 +140,7 @@ def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt, out_hf) + torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): @@ -129,21 +149,25 @@ def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) input_ids = input_ids.cuda() input_mask = input_mask.cuda() + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + # Get hf model. if model_hf is None: config_hf = XGLMConfig() - model_hf = XGLMForCausalLM(config_hf).cuda().to(DTYPE).eval() + model_hf = XGLMForCausalLM(config_hf).cuda().to(new_dtype).eval() else: - model_hf = model_hf.cuda().to(DTYPE).eval() + model_hf = model_hf.cuda().to(new_dtype).eval() config_hf = model_hf.config # Get nanotron model and make the conversion. config_nt = convert_config(config_hf) - if DTYPE not in {torch.bfloat16, torch.float16}: + if new_dtype not in {torch.bfloat16, torch.float16}: config_nt.use_spda = True model_nt = nanotron.models.build_model( model_builder=lambda: GPT3ForTraining( @@ -153,7 +177,7 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC random_states=random_states, ), parallel_context=parallel_context, - dtype=DTYPE, + dtype=new_dtype, device="cuda", ).eval() convert(model_nt, model_hf) @@ -162,42 +186,141 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC # Get outputs and assert. with torch.no_grad(): - out_nt = model_nt.model(input_ids, input_mask).to(DTYPE) + out_nt = model_nt.model(input_ids, input_mask).to(new_dtype) del model_nt torch.cuda.empty_cache() out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) del model_hf torch.cuda.empty_cache() assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.cpu(), out_hf.cpu()) + return out_nt.cpu(), out_hf.cpu() + def _test_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): - _test_model(None, parallel_context, input_ids, input_mask) + out_nt, out_hf = _test_model(None, parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.05) def test_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): init_distributed(tp=1, dp=1, pp=1)(_test_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) +def _test_xglm500M(parallel_context: ParallelContext): + tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") + tokenized = tok(TEXT) + model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") + out_nt, out_hf = _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.05) + + +def test_xglm500M(): + init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() + + def _test_xglm7B(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") - _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model(model_hf, parallel_context, + torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + almost_close(out_nt, out_hf, max_far=0.15, far_atol=0.1) def test_xglm7B(): init_distributed(tp=1, dp=1, pp=1)(_test_xglm7B)() -def _test_xglm500M(parallel_context: ParallelContext): - tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") - tokenized = tok(TEXT) - model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") - _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) +## +# From here down we test nt->hf converters +## +def _test_nt2hf_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() -def test_xglm500M(): - init_distributed(tp=1, dp=1, pp=1)(_test_xglm500M)() + attn_nt = CausalSelfAttention(CONFIG, None, parallel_context.tp_pg, 0).cuda().eval().to(DTYPE) + attn_hf = XGLMAttention(CONFIG.hidden_size, CONFIG.num_attention_heads, CONFIG.attn_pdrop).cuda().eval().to(DTYPE) + assert sum(map(torch.numel, attn_nt.parameters())) == sum(map(torch.numel, attn_hf.parameters())) + + convert_attention_nt2hf(attn_hf, attn_nt) + out_nt = attn_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = attn_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt, out_hf) + + +def test_nt2hf_attention(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_attention)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_nt2hf_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + hidden_states = hidden_states.cuda() + sequence_mask = sequence_mask.cuda() + + config_hf = convert_config_nt2hf(CONFIG) + decoder_nt = GPTBlock(CONFIG, None, parallel_context.tp_pg, random_states, 0).cuda().to(DTYPE).eval() + decoder_hf = XGLMDecoderLayer(config_hf).cuda().to(DTYPE).eval() + + convert_decoder_nt2hf(decoder_hf, decoder_nt) + + out_nt = decoder_nt(hidden_states, sequence_mask)["hidden_states"] + out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) + + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + + +def test_nt2hf_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) + + +def _test_nt2hf_model(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) + input_ids = input_ids.cuda() + input_mask = input_mask.cuda() + + # unfortunately, we can't use float64 with huggingface xglm. + new_dtype = torch.float32 if DTYPE == torch.float64 else DTYPE + + # Get nanotron model. + config_nt = GPT3Config(**vars(CONFIG)) + if new_dtype not in {torch.bfloat16, torch.float16}: + config_nt.use_spda = True + model_nt = nanotron.models.build_model( + model_builder=lambda: GPT3ForTraining( + config=config_nt, + parallel_context=parallel_context, + parallel_config=None, + random_states=random_states, + ), + parallel_context=parallel_context, + dtype=new_dtype, + device="cuda", + ).eval() + mark_tied_parameters(model=model_nt, parallel_context=parallel_context) + + # Create empty model_hf and make conversion. + model_hf = XGLMForCausalLM(convert_config_nt2hf(config_nt)).cuda().to(new_dtype).eval() + convert_nt2hf(model_hf, model_nt) + + # Get outputs and assert. + with torch.no_grad(): + out_nt = model_nt.model(input_ids, input_mask).to(new_dtype) + del model_nt + torch.cuda.empty_cache() + out_hf = model_hf(input_ids=input_ids, attention_mask=input_mask).logits.permute(1, 0, 2) + del model_hf + torch.cuda.empty_cache() + assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" + return out_nt.cpu(), out_hf.cpu() + + +def _test_nt2hf_dummy_xglm(parallel_context: ParallelContext, input_ids: torch.Tensor, input_mask: torch.Tensor): + out_nt, out_hf = _test_nt2hf_model(parallel_context, input_ids, input_mask) + almost_close(out_nt, out_hf, max_far=0.01, far_atol=0.02) + + +def test_nt2hf_dummy_xglm(input_ids: torch.Tensor, input_mask: torch.Tensor): + init_distributed(tp=1, dp=1, pp=1)(_test_nt2hf_dummy_xglm)(input_ids=input_ids, input_mask=input_mask) diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 6c568e80..80f956d1 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -173,4 +173,8 @@ def as_starcoder2(self) -> Starcoder2Config: **config ) + @property + def n_inner(self): + return self.intermediate_size + NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 33661c8b..7d4e6f82 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -338,4 +338,4 @@ def __init__( ): with replace_gpt3model(config): super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) - + self.config = config From 138da5ff5a5a9c34ac6191149dfdc83603b08e20 Mon Sep 17 00:00:00 2001 From: AleHD Date: Thu, 11 Jul 2024 14:32:44 +0200 Subject: [PATCH 05/24] precommit --- examples/xglm/convert_hf2nt.py | 33 ++++---- examples/xglm/convert_nt2hf.py | 28 ++++--- examples/xglm/convert_utils.py | 21 +++-- examples/xglm/tests/test_implementation.py | 89 ++++++++++++++-------- src/nanotron/config/models_config.py | 8 +- src/nanotron/models/gpt3.py | 85 ++++++++++++--------- src/nanotron/trainer.py | 2 +- 7 files changed, 154 insertions(+), 112 deletions(-) diff --git a/examples/xglm/convert_hf2nt.py b/examples/xglm/convert_hf2nt.py index 0efcceca..c18a1ab8 100644 --- a/examples/xglm/convert_hf2nt.py +++ b/examples/xglm/convert_hf2nt.py @@ -4,20 +4,18 @@ torchrun --nproc-per-node=1 convert_hf2nt.py --checkpoint-path=hf_weights --save-path=nanotron_weights """ +import dataclasses import json import warnings -import dataclasses from argparse import ArgumentParser from pathlib import Path +import nanotron import torch -from torch import nn +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import MLP, CausalSelfAttention, GPT3ForTraining, GPTBlock from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM -import nanotron -from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining -from nanotron.config.models_config import GPT3Config -from nanotron.trainer import mark_tied_parameters from examples.xglm.convert_utils import convert_generic, create_nt_model @@ -29,10 +27,12 @@ def convert_config(config: XGLMConfig) -> GPT3Config: # pad_token_id=1, # bos_token_id=0, if config.dropout != config.attention_dropout: - warnings.warn(f"huggingface.dropout = {config.dropout} does not match with " - f"huggingface.attention_dropout = {config.attention_dropout}. " - "Nanotron implementation needs these two values to be equal " - "for correct conversion.") + warnings.warn( + f"huggingface.dropout = {config.dropout} does not match with " + f"huggingface.attention_dropout = {config.attention_dropout}. " + "Nanotron implementation needs these two values to be equal " + "for correct conversion." + ) return GPT3Config( activation_function=config.activation_function, attn_pdrop=config.attention_dropout, @@ -113,16 +113,19 @@ def main(hf_path: str, save_path: Path): # Copy weights and save model. print("Copying weights...") convert(model_nt, model_hf) - nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, - root_folder=save_path) - with open(save_path/"model_config.json", "w+") as f: + nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, root_folder=save_path) + with open(save_path / "model_config.json", "w+") as f: json.dump(dataclasses.asdict(config_nt), f) print(f"Model saved to {save_path}") if __name__ == "__main__": parser = ArgumentParser(description="Convert HF weights to nanotron format") - parser.add_argument("--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint") - parser.add_argument("--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model") + parser.add_argument( + "--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model" + ) args = parser.parse_args() main(args.checkpoint_path, args.save_path) diff --git a/examples/xglm/convert_nt2hf.py b/examples/xglm/convert_nt2hf.py index 422695a1..81816aa9 100644 --- a/examples/xglm/convert_nt2hf.py +++ b/examples/xglm/convert_nt2hf.py @@ -4,25 +4,28 @@ torchrun --nproc-per-node=1 convert_nt2hf.py --checkpoint-path=nanotron_weights --save-path=hf_weights """ +import warnings from argparse import ArgumentParser -from typing import Optional from pathlib import Path +from typing import Optional import torch +from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import GPT3ForTraining, GPTBlock from transformers import AutoTokenizer from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM -from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import CausalSelfAttention, GPTBlock, MLP, GPT3ForTraining from examples.xglm.convert_utils import convert_generic, create_nt_model def convert_config(config: GPT3Config) -> XGLMConfig: if config.embd_pdrop != config.resid_pdrop: - warnings.warn(f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " - f"nanotron.resid_pdrop = {config.resid_pdrop}. " - "XGLM implementation needs these two values to be equal " - "for correct conversion.") + warnings.warn( + f"nanotron.embd_pdrop = {config.embd_pdrop} does not match with " + f"nanotron.resid_pdrop = {config.resid_pdrop}. " + "XGLM implementation needs these two values to be equal " + "for correct conversion." + ) if config.layer_norm_epsilon != 1e-5: warnings.warn(f"nanotron.layer_norm_epsilon must be 1e-5, not {config.layer_norm_epsilon}") return XGLMConfig( @@ -70,7 +73,7 @@ def convert_attention(attn_hf: XGLMAttention, attn_nt: XGLMAttention): q_b = torch.cat(qs_b) k_b = torch.cat(ks_b) v_b = torch.cat(vs_b) - + with torch.no_grad(): attn_hf.q_proj.weight.data = q_w.clone() attn_hf.k_proj.weight.data = k_w.clone() @@ -118,9 +121,12 @@ def main(checkpoint_path: Path, save_path: Path, tokenizer_name: Optional[str]): if __name__ == "__main__": parser = ArgumentParser(description="Convert HF weights to nanotron format") - parser.add_argument("--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint") - parser.add_argument("--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model") + parser.add_argument( + "--checkpoint-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to the nanotron checkpoint" + ) + parser.add_argument( + "--save-path", type=Path, default="facebook/xglm-7.5B", help="Path to save the huggingface model" + ) parser.add_argument("--tokenizer-name", type=str, default="facebook/xglm-7.5B") args = parser.parse_args() main(args.checkpoint_path, args.save_path, args.tokenizer_name) - diff --git a/examples/xglm/convert_utils.py b/examples/xglm/convert_utils.py index 88a731a1..75d67782 100644 --- a/examples/xglm/convert_utils.py +++ b/examples/xglm/convert_utils.py @@ -2,13 +2,12 @@ from pathlib import Path from typing import Optional -import torch -from torch import nn - import nanotron -from nanotron.models.gpt3 import GPT3ForTraining +import torch from nanotron.config.models_config import GPT3Config +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.trainer import mark_tied_parameters +from torch import nn def convert_generic(module1: nn.Module, module2: nn.Module): @@ -21,11 +20,11 @@ def convert_generic(module1: nn.Module, module2: nn.Module): def create_nt_model( - model_config: Optional[GPT3Config] = None, - device: torch.device = torch.device("cuda"), - dtype: torch.dtype = torch.bfloat16, - checkpoint_path: Optional[Path] = None - ): + model_config: Optional[GPT3Config] = None, + device: torch.device = torch.device("cuda"), + dtype: torch.dtype = torch.bfloat16, + checkpoint_path: Optional[Path] = None, +): if model_config is None: assert checkpoint_path is not None @@ -52,8 +51,6 @@ def create_nt_model( mark_tied_parameters(model=model_nt, parallel_context=parallel_context) if checkpoint_path is not None: - nanotron.serialize.load_weights( - model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path - ) + nanotron.serialize.load_weights(model=model_nt, parallel_context=parallel_context, root_folder=checkpoint_path) return model_nt diff --git a/examples/xglm/tests/test_implementation.py b/examples/xglm/tests/test_implementation.py index d9dc0f85..a25d7881 100644 --- a/examples/xglm/tests/test_implementation.py +++ b/examples/xglm/tests/test_implementation.py @@ -1,29 +1,31 @@ from typing import Optional +import nanotron import numpy as np -import torch import pytest - -from transformers import XGLMTokenizer -from transformers.models.xglm.modeling_xglm import XGLMConfig, XGLMAttention, XGLMSinusoidalPositionalEmbedding, XGLMDecoderLayer, XGLMForCausalLM - -import nanotron -from nanotron.trainer import mark_tied_parameters +import torch from nanotron.config.models_config import GPT3Config -from nanotron.models.gpt3 import GPT3ForTraining, CausalSelfAttention, PositionEmbedding, GPTBlock +from nanotron.models.gpt3 import CausalSelfAttention, GPT3ForTraining, GPTBlock, PositionEmbedding from nanotron.parallel import ParallelContext +from nanotron.trainer import mark_tied_parameters +from transformers import XGLMTokenizer +from transformers.models.xglm.modeling_xglm import ( + XGLMAttention, + XGLMConfig, + XGLMDecoderLayer, + XGLMForCausalLM, + XGLMSinusoidalPositionalEmbedding, +) -from tests.helpers.utils import init_distributed - -from examples.xglm.convert_hf2nt import convert_attention, convert_config, convert_decoder, convert +from examples.xglm.convert_hf2nt import convert, convert_attention, convert_config, convert_decoder +from examples.xglm.convert_nt2hf import convert as convert_nt2hf from examples.xglm.convert_nt2hf import convert_attention as convert_attention_nt2hf from examples.xglm.convert_nt2hf import convert_config as convert_config_nt2hf from examples.xglm.convert_nt2hf import convert_decoder as convert_decoder_nt2hf -from examples.xglm.convert_nt2hf import convert as convert_nt2hf - +from tests.helpers.utils import init_distributed MAX_SEQUENCE_LENGTH = 2048 -TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. +TEST_SEQUENCE_LENGTH = 128 # If we test with a very large sequence length, precision errors get more significant independent of the correct implementation. BATCH_SIZE = 4 HIDDEN_SIZE = 1024 DTYPE = torch.float64 @@ -45,33 +47,44 @@ vocab_size=256008, sinusoidal_position_embedding=True, position_embedding_offset=2, - use_spda=True + use_spda=True, ) @pytest.fixture def hidden_states() -> torch.Tensor: - return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, - dtype=DTYPE) + return torch.randn(TEST_SEQUENCE_LENGTH, BATCH_SIZE, HIDDEN_SIZE, dtype=DTYPE) + @pytest.fixture def input_mask() -> torch.Tensor: return torch.ones(BATCH_SIZE, TEST_SEQUENCE_LENGTH, dtype=torch.bool) + @pytest.fixture def input_ids() -> torch.Tensor: return torch.randint(0, CONFIG.vocab_size, (BATCH_SIZE, TEST_SEQUENCE_LENGTH)) -def almost_close(t1: torch.Tensor, t2: torch.Tensor, atol: float = 1e-5, rtol: float = 0.016, - max_far: float = 0.0, far_atol: float = 0.01): - very_close = torch.abs(t1 - t2) <= atol + rtol*torch.abs(t2) +def almost_close( + t1: torch.Tensor, + t2: torch.Tensor, + atol: float = 1e-5, + rtol: float = 0.016, + max_far: float = 0.0, + far_atol: float = 0.01, +): + very_close = torch.abs(t1 - t2) <= atol + rtol * torch.abs(t2) not_very_close = ~very_close if torch.all(very_close): return - assert torch.mean(not_very_close.float()) <= max_far, f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" - assert torch.all(torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" + assert ( + torch.mean(not_very_close.float()) <= max_far + ), f"not very close found: {100*torch.mean(not_very_close.float()):.1f}%" + assert torch.all( + torch.abs(t1[not_very_close] - t2[not_very_close]) <= far_atol + ), f"Worse deviation found: {torch.max(torch.abs(t1 - t2)):.4f}" def attention_mask() -> torch.Tensor: @@ -81,10 +94,12 @@ def attention_mask() -> torch.Tensor: mask = mask.repeat(BATCH_SIZE, 1, 1).unsqueeze(1) return mask + ## # FROM HERE DOWN (until next comment), all tests are hf->nt ## + def _test_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() sequence_mask = sequence_mask.cuda() @@ -118,6 +133,7 @@ def _test_position_embeddings(parallel_context: ParallelContext): assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" torch.testing.assert_close(out_nt, out_hf) + def test_position_embeddings(): init_distributed(tp=1, dp=1, pp=1)(_test_position_embeddings)() @@ -140,15 +156,21 @@ def _test_decoder(parallel_context: ParallelContext, hidden_states: torch.Tensor out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + torch.testing.assert_close( + out_nt.bfloat16(), out_hf.bfloat16() + ) # We cast to bf16 to get more relaxed constraints. def test_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): init_distributed(tp=1, dp=1, pp=1)(_test_decoder)(hidden_states=hidden_states, sequence_mask=input_mask) -def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelContext, - input_ids: torch.Tensor, input_mask: torch.Tensor): +def _test_model( + model_hf: Optional[XGLMForCausalLM], + parallel_context: ParallelContext, + input_ids: torch.Tensor, + input_mask: torch.Tensor, +): random_states = nanotron.random.RandomStates({"tp_synced": nanotron.random.get_current_random_state()}) input_ids = input_ids.cuda() @@ -182,7 +204,7 @@ def _test_model(model_hf: Optional[XGLMForCausalLM], parallel_context: ParallelC ).eval() convert(model_nt, model_hf) - print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters()))/1000/1000) + print("Parameter count (M):", sum(map(torch.numel, model_hf.parameters())) / 1000 / 1000) # Get outputs and assert. with torch.no_grad(): @@ -209,8 +231,9 @@ def _test_xglm500M(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-564M") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-564M") - out_nt, out_hf = _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model( + model_hf, parallel_context, torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]]) + ) almost_close(out_nt, out_hf, max_far=0.1, far_atol=0.05) @@ -222,8 +245,9 @@ def _test_xglm7B(parallel_context: ParallelContext): tok = XGLMTokenizer.from_pretrained("facebook/xglm-7.5B") tokenized = tok(TEXT) model_hf = XGLMForCausalLM.from_pretrained("facebook/xglm-7.5B") - out_nt, out_hf = _test_model(model_hf, parallel_context, - torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]])) + out_nt, out_hf = _test_model( + model_hf, parallel_context, torch.tensor([tokenized["input_ids"]]), torch.tensor([tokenized["attention_mask"]]) + ) almost_close(out_nt, out_hf, max_far=0.15, far_atol=0.1) @@ -235,6 +259,7 @@ def test_xglm7B(): # From here down we test nt->hf converters ## + def _test_nt2hf_attention(parallel_context: ParallelContext, hidden_states: torch.Tensor, sequence_mask: torch.Tensor): hidden_states = hidden_states.cuda() sequence_mask = sequence_mask.cuda() @@ -269,7 +294,9 @@ def _test_nt2hf_decoder(parallel_context: ParallelContext, hidden_states: torch. out_hf = decoder_hf(hidden_states.permute(1, 0, 2), attention_mask=attention_mask())[0].permute(1, 0, 2) assert out_nt.size() == out_hf.size(), f"{out_nt.size()}, {out_hf.size()}" - torch.testing.assert_close(out_nt.bfloat16(), out_hf.bfloat16()) # We cast to bf16 to get more relaxed constraints. + torch.testing.assert_close( + out_nt.bfloat16(), out_hf.bfloat16() + ) # We cast to bf16 to get more relaxed constraints. def test_nt2hf_decoder(hidden_states: torch.Tensor, input_mask: torch.Tensor): diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index 80f956d1..37593a54 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, List, Optional, Union +from typing import List, Optional @dataclass @@ -167,14 +167,12 @@ def as_starcoder2(self) -> Starcoder2Config: if "_is_using_mup" in config: del config["_is_using_mup"] return Starcoder2Config( - grouped_query=True, - num_kv_heads=self.num_attention_heads, - use_rotary_embeddings=False, - **config + grouped_query=True, num_kv_heads=self.num_attention_heads, use_rotary_embeddings=False, **config ) @property def n_inner(self): return self.intermediate_size + NanotronConfigs = LlamaConfig | Starcoder2Config | GPT3Config diff --git a/src/nanotron/models/gpt3.py b/src/nanotron/models/gpt3.py index 7d4e6f82..25e5f78b 100644 --- a/src/nanotron/models/gpt3.py +++ b/src/nanotron/models/gpt3.py @@ -1,37 +1,40 @@ """PyTorch GPT-3 model.""" import math -from typing import Optional from contextlib import contextmanager +from typing import Optional import torch from torch import nn from torch.nn import functional as F from nanotron import distributed as dist -from nanotron.parallel import ParallelContext -from nanotron.config import Config, GPT3Config, ParallelismArgs, Starcoder2Config +from nanotron.config import GPT3Config, ParallelismArgs, Starcoder2Config from nanotron.generation.generate_store import AttachableStore from nanotron.models import starcoder2 -from nanotron.nn.layer_norm import TritonLayerNorm from nanotron.models.starcoder2 import MLP as Starcoder2MLP -from nanotron.parallel.pipeline_parallel.block import PipelineBlock +from nanotron.models.starcoder2 import CausalSelfGQA, GPTModel, Starcoder2ForTraining, dropout_add_fused_train from nanotron.models.starcoder2 import CoreAttention as Starcoder2CoreAttention from nanotron.models.starcoder2 import GPTBlock as Starcoder2GPTBlock -from nanotron.models.starcoder2 import CausalSelfGQA, Starcoder2ForTraining, GPTModel, dropout_add_fused_train -from nanotron.random import RandomStates, branch_random_state +from nanotron.nn.layer_norm import TritonLayerNorm +from nanotron.parallel import ParallelContext +from nanotron.parallel.pipeline_parallel.block import PipelineBlock from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode from nanotron.parallel.tensor_parallel.nn import TensorParallelEmbedding -from nanotron.parallel.tied_parameters import tie_parameters +from nanotron.random import RandomStates, branch_random_state @contextmanager def replace_coreattention(gpt3config: GPT3Config): orig = starcoder2.CoreAttention try: - def create_core_attention(config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + + def create_core_attention( + config: Starcoder2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int + ): return CoreAttention(gpt3config, parallel_config, layer_idx) + starcoder2.CoreAttention = create_core_attention yield finally: @@ -42,6 +45,7 @@ def create_core_attention(config: Starcoder2Config, parallel_config: Optional[Pa def replace_decoder(gpt3config: GPT3Config): orig = starcoder2.PipelineBlock try: + def create_pp_block(module_builder, module_kwargs, **kwargs): if module_builder is Starcoder2GPTBlock: # Starcoder2's GPT module is trying to instantiate a Starcoder2 GPTBlock. @@ -62,9 +66,15 @@ def create_pp_block(module_builder, module_kwargs, **kwargs): def replace_gpt3model(gpt3config: GPT3Config): orig = starcoder2.GPTModel try: - def create_gptmodel(config: Starcoder2Config, parallel_context: ParallelContext, - parallel_config: Optional[ParallelismArgs], random_states: RandomStates): + + def create_gptmodel( + config: Starcoder2Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): return GPT3Model(gpt3config, parallel_context, parallel_config, random_states) + starcoder2.GPTModel = create_gptmodel yield finally: @@ -76,7 +86,8 @@ def __init__(self, config: GPT3Config, parallel_config: Optional[ParallelismArgs super().__init__(config.as_starcoder2(), parallel_config, layer_idx) self.gpt3config = config - def forward(self, + def forward( + self, query_states: torch.Tensor, # [batch_size * q_length, q_heads, inner_dim] key_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] value_states: torch.Tensor, # [batch_size * kv_length, kv_heads, inner_dim] @@ -101,7 +112,7 @@ def forward(self, is_causal=True, ) # [batch, q_length, q_heads, head_dim] attention_output = attention_output.permute(0, 2, 1, 3) - attention_output = attention_output.reshape(batch_size*q_length, q_heads, head_dim) + attention_output = attention_output.reshape(batch_size * q_length, q_heads, head_dim) return attention_output.contiguous() assert query_states.dtype in {torch.bfloat16, torch.float16} @@ -127,7 +138,7 @@ def __init__( config: GPT3Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, - random_states: RandomStates + random_states: RandomStates, ): super().__init__(config.as_starcoder2(), parallel_config, tp_pg) self.dropout = nn.Dropout(p=config.act_pdrop) @@ -154,14 +165,11 @@ def __init__( random_states: RandomStates, layer_idx: int, ): - #print("New gpt block created :D") + # print("New gpt block created :D") super(GPTBlock, self).__init__() self.ln_1 = TritonLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.attn = CausalSelfAttention( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - layer_idx=layer_idx + config=config, parallel_config=parallel_config, tp_pg=tp_pg, layer_idx=layer_idx ) self.attn_dropout = config.attn_pdrop @@ -180,10 +188,10 @@ def forward( residual = hidden_states hidden_states = self.ln_1(hidden_states) - #hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) + # hidden_states = torch.arange(hidden_states.numel()).to(hidden_states.device).to(hidden_states.dtype).view(hidden_states.size()) output = self.attn(hidden_states=hidden_states, sequence_mask=sequence_mask) hidden_states = output["hidden_states"] - #return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} + # return {"hidden_states": hidden_states, "sequence_mask": sequence_mask} if self.training: with branch_random_state( @@ -221,7 +229,9 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config if (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() == 0: dummy_pos = 0 else: - dummy_pos = tp_pg.size() - ((config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size()) + dummy_pos = tp_pg.size() - ( + (config.max_position_embeddings + config.position_embedding_offset) % tp_pg.size() + ) true_max_size = config.max_position_embeddings + config.position_embedding_offset + dummy_pos if config.sinusoidal_position_embedding: @@ -234,7 +244,7 @@ def __init__(self, tp_pg: dist.ProcessGroup, config: GPT3Config, parallel_config embedding_dim=config.hidden_size, pg=tp_pg, mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, - _weight=weight + _weight=weight, ) self.pg = tp_pg @@ -251,32 +261,31 @@ def forward(self, position_ids: torch.Tensor): # [batch_size, seq_length] position_embeds = self.position_embedding(position_ids + self.config.position_embedding_offset) return {"position_embeds": position_embeds} - def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, - embedding_dim: int) -> torch.Tensor: + def _make_weights(self, tp_pg: dist.ProcessGroup, num_embeddings: int, embedding_dim: int) -> torch.Tensor: rank = dist.get_rank(group=tp_pg) tp_size = tp_pg.size() assert 0 <= rank < tp_size assert num_embeddings % tp_size == 0 assert embedding_dim % 2 == 0 - block_size = num_embeddings//tp_size + block_size = num_embeddings // tp_size - half_dim = embedding_dim//2 - emb = math.log(10_000)/(half_dim - 1) + half_dim = embedding_dim // 2 + emb = math.log(10_000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.int64).float() * -emb) - emb = (rank*block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) + emb = (rank * block_size + torch.arange(block_size, dtype=torch.int64).float().unsqueeze(1)) * emb.unsqueeze(0) emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(block_size, embedding_dim) return emb class GPT3Model(GPTModel): def __init__( - self, - config: GPT3Config, - parallel_context: ParallelContext, - parallel_config: Optional[ParallelismArgs], - random_states: RandomStates, - ): + self, + config: GPT3Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: RandomStates, + ): with replace_decoder(config): super().__init__(config.as_starcoder2(), parallel_context, parallel_config, random_states) @@ -300,7 +309,9 @@ def forward( ): # all tensors are optional as most ranks don't need anything from the dataloader. - input_embeds = self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"]*self.embed_scale + input_embeds = ( + self.token_embeddings(input_ids=input_ids, input_mask=input_mask)["input_embeds"] * self.embed_scale + ) # TODO: position_ids could be cached. position_ids = torch.arange(input_ids.size(1), device="cuda").repeat(input_ids.size(0)).view(*input_ids.size()) position_embeds = self.position_embeddings(position_ids=position_ids)["position_embeds"] @@ -314,7 +325,7 @@ def forward( hidden_encoder_states = {"hidden_states": hidden_states, "sequence_mask": input_mask} for encoder_block in self.decoder: hidden_encoder_states = encoder_block(**hidden_encoder_states) - #return hidden_encoder_states["hidden_states"] + # return hidden_encoder_states["hidden_states"] hidden_states = self.final_layer_norm(input=hidden_encoder_states["hidden_states"])["hidden_states"] diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index f01caa3e..bc81e326 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -56,9 +56,9 @@ ) from nanotron.models import NanotronModel, build_model from nanotron.models.base import check_model_has_grad +from nanotron.models.gpt3 import GPT3ForTraining from nanotron.models.llama import LlamaForTraining, RotaryEmbedding from nanotron.models.starcoder2 import Starcoder2ForTraining -from nanotron.models.gpt3 import GPT3ForTraining from nanotron.optim.clip_grads import clip_grad_norm from nanotron.parallel import ParallelContext from nanotron.parallel.data_parallel.utils import sync_gradients_across_dp From 0485fd64dc4cf8b68eaa963bfd586de4e6c4ac67 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 11:45:28 +0000 Subject: [PATCH 06/24] Added MultilingualNanoset Config --- src/nanotron/config/config.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 05b49955..bfd20227 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -107,6 +107,27 @@ def __post_init__(self): self.dataset_weights = list(tmp_dataset_folder.values()) +@dataclass +class MultilingualNanosetDatasetsArgs: + dataset_folder: Union[str, dict, List[str]] + dataset_tokens: List[ + int + ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) + + def __post_init__(self): + if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file + self.dataset_folder = [self.dataset_folder] + self.dataset_weights = [1] + elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + self.dataset_weights = None # Set to None so we consume all the samples randomly + elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights + tmp_dataset_folder = self.dataset_folder.copy() + self.dataset_folder = list(tmp_dataset_folder.keys()) + self.dataset_weights = list(tmp_dataset_folder.values()) + + assert len(self.dataset_folder) == len(self.dataset_tokens) + + @dataclass class DataArgs: """Arguments related to the data and data files processing""" From 539832ade4914ac92bc8c66b55dbf031f3195ec6 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 11:48:51 +0000 Subject: [PATCH 07/24] Added MultilingualNanoset --- run_train.py | 125 +++++++++++- src/nanotron/data/multilingual_nanoset.py | 221 ++++++++++++++++++++++ 2 files changed, 343 insertions(+), 3 deletions(-) create mode 100644 src/nanotron/data/multilingual_nanoset.py diff --git a/run_train.py b/run_train.py index 021d955d..649784ca 100644 --- a/run_train.py +++ b/run_train.py @@ -12,7 +12,13 @@ import numpy as np from nanotron import logging -from nanotron.config import DataArgs, DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs +from nanotron.config import ( + DataArgs, + DatasetStageArgs, + MultilingualNanosetDatasetsArgs, + NanosetDatasetsArgs, + PretrainDatasetsArgs, +) from nanotron.data.dataloader_builder import build_nanoset_dataloader from nanotron.dataloader import ( clm_process, @@ -171,6 +177,40 @@ def get_dataloader_from_data_stage( dataloader_drop_last=True, ) + return train_dataloader + # Case 4: MultilingualNanosets + elif isinstance(data.dataset, MultilingualNanosetDatasetsArgs): + # Get tokenizer cardinality + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + # Create Nanoset + from nanotron.data.multilingual_nanoset import MultilingualNanoset + + with main_rank_first(trainer.parallel_context.world_pg): + train_dataset = MultilingualNanoset( + dataset_folders=data.dataset.dataset_folder, + dataset_weights=data.dataset.dataset_weights, + sequence_length=trainer.sequence_length, + token_size=token_size, + train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, + random_seed=data.seed, + ) + + # Prepare dataloader + train_dataloader = build_nanoset_dataloader( + train_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=consumed_train_samples, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + return train_dataloader else: raise ValueError(f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}") @@ -178,6 +218,57 @@ def get_dataloader_from_data_stage( return dataloader +def get_valid_dataloader_from_data_stage( + trainer: DistributedTrainer, + data: DataArgs, + valid_split_num_samples: int, + # consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples +): + + # First, we need to know which ranks to feed the dataloader to + input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) + + # Only support Validation with MultilingualNanosets + if isinstance(data.dataset, NanosetDatasetsArgs): + # Get tokenizer cardinality + tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) + token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 + del tokenizer + # Create Multilingual Nanoset + from nanotron.data.multilingual_nanoset import MultilingualNanoset + + with main_rank_first(trainer.parallel_context.world_pg): + valid_dataset = MultilingualNanoset( + dataset_folders=data.dataset.dataset_folder, + dataset_weights=data.dataset.dataset_weights, + sequence_length=trainer.sequence_length, + token_size=token_size, + train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + valid_split_num_samples=valid_split_num_samples, + is_valid=True, + random_seed=data.seed, + ) + + # Prepare dataloader + valid_dataloader = build_nanoset_dataloader( + valid_dataset, + trainer.sequence_length, + parallel_context=trainer.parallel_context, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + micro_batch_size=trainer.micro_batch_size, + consumed_train_samples=0, + dataloader_num_workers=data.num_loading_workers, + dataloader_drop_last=True, + ) + + return valid_dataloader + else: + raise ValueError( + f"Unhandled case of `self.config.data.dataset`. Got: {data.dataset}. Validation is currently just supported for MultilingualNanoset" + ) + + def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloaders = {} @@ -219,6 +310,33 @@ def get_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: return dataloaders +def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: + dataloaders = {} + + for stage_idx, stage in enumerate(trainer.config.data_stages): + # NOTE: we only create the dataloader for the first stage, + # then we lazy initialize the dataloader for the other stages + stage = cast(DatasetStageArgs, stage) + valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size + + log_rank( + f"[Training Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = ( + get_valid_dataloader_from_data_stage(trainer, stage.data, valid_split_num_samples=valid_split_num_samples) + if stage_idx == 0 + else lambda stage=stage: get_dataloader_from_data_stage( + trainer, stage.data, valid_split_num_samples=valid_split_num_samples + ) + ) + dataloaders[stage.name] = dataloader + return dataloaders + + def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--config-file", type=str, required=True, help="Path to the YAML or python config file") @@ -231,7 +349,8 @@ def get_args(): # Load trainer and data trainer = DistributedTrainer(config_file) - dataloader = get_dataloader(trainer) + train_dataloader = get_dataloader(trainer) + valid_dataloader = get_valid_dataloader(trainer) # Train - trainer.train(dataloader) + trainer.train(train_dataloader, valid_dataloader) diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py new file mode 100644 index 00000000..40e06b87 --- /dev/null +++ b/src/nanotron/data/multilingual_nanoset.py @@ -0,0 +1,221 @@ +import os +import warnings +from typing import Dict, List, Tuple, Union + +import numpy as np +import torch +from datatrove.utils.dataset import DatatroveFolderDataset +from nanotron import logging +from nanotron.data.utils import count_dataset_indexes, normalize +from nanotron.logging import log_rank +from numba import jit + +logger = logging.get_logger(__name__) + + +class MultilingualNanoset(torch.utils.data.Dataset): + """ + The Nanoset dataset + + Args: + dataset_folders (List[str]): List of folders with tokenized datasets + dataset_weights (Union[List[float], None]): List with the weights for weighted datasets. If None, consume all samples from all datasets without weighting. Weights are normalized in __init__ + sequence_length (int): Sequence length of the built samples + token_size (int): Number of bytes for the tokens stored in the processed dataset files. 2 for vocab sizes < 65535, 4 otherwise + train_split_num_samples (int): Number of samples the dataset needs. It's the training steps * global batch size + """ + + def __init__( + self, + dataset_folders: List[str], + sequence_length: int, + token_size: int, + train_split_num_samples: int, + valid_split_num_samples: int, + is_valid: bool = False, + dataset_weights: Union[List[float], None] = None, + random_seed: int = 1234, + ) -> None: + + # Checks + if isinstance(dataset_folders, str): + warnings.warn("dataset_folders should be of type List[str] but str was provided. Converting to List[str]") + dataset_folders = [dataset_folders] + + # Init + self.dataset_folders = dataset_folders + self.sequence_length = sequence_length + self.token_size = token_size + self.train_split_num_samples = train_split_num_samples + self.valid_split_num_samples = valid_split_num_samples + self.is_valid = is_valid + self.random_seed = random_seed + self.datatrove_datasets = [] + for dataset_folder in self.dataset_folders: + self.datatrove_datasets.append( + DatatroveFolderDataset( + folder_path=dataset_folder, + filename_pattern=os.path.join(dataset_folder, "*.ds"), + seq_len=sequence_length, + recursive=False, + token_size=token_size, + shuffle=True, + ) + ) + + # Build Nanoset Index + ## To build the index we need the length of each dataset + self.dataset_lengths = [len(datatrove_dataset) for datatrove_dataset in self.datatrove_datasets] + ## Set dataset weights + if ( + dataset_weights is None + ): # Case of training with > 1 datasets without weighting them: Consume both datasets entirely on each epoch + self.dataset_weights = normalize(self.dataset_lengths) + else: + self.dataset_weights = normalize(dataset_weights) + assert len(dataset_folders) == len( + self.dataset_weights + ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." + ## Build dataset index and dataset sample index + ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts + self.valid_dataset_lenghts = self.dataset_weights * valid_split_num_samples + # Assert that we have sufficient samples to build the valid split + for ds_index in range(len(self.dataset_lengths)): + assert ( + self.valid_dataset_lenghts[ds_index] > self.dataset_lengths[ds_index] + ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." + self.train_dataset_lenghts = [ + a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) + ] # Subtract the valid samples from the training dataset + + if is_valid: # Valid MultilingualNanoset + self.split_num_samples = valid_split_num_samples + self.split_samples_per_epoch = valid_split_num_samples + self.num_epochs = 1 + self.split_dataset_lenghts = self.valid_dataset_lenghts + self.split_dataset_offsets = self.train_dataset_lenghts + + else: # Train MultilingualNanoset + self.split_num_samples = train_split_num_samples + self.split_samples_per_epoch = sum(self.train_dataset_lenghts) + self.num_epochs = int(self.split_num_samples / self.split_samples_per_epoch) + 1 + self.split_dataset_lenghts = self.train_dataset_lenghts + self.split_dataset_offsets = [ + 0 for _ in range(len(self.dataset_lengths)) + ] # For training there is NO offset + + self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() + + self.print_nanoset_info() + + def __len__(self) -> int: + """ + Returns: + int: The number of samples of the Nanoset + """ + + return len(self.dataset_index) + + def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: + """ + Returns sequence_length + 1 tokens from the memmap dataset + + Args: + idx (int): The index into the dataset + + Returns: + Dict[str, torch.LongTensor]: The input ids wrapped in a dictionary + """ + dataset = self.dataset_index[idx] + dataset_sample = self.dataset_sample_index[idx] + + return self.datatrove_datasets[dataset][dataset_sample] + + def build_nanoset_index(self) -> np.ndarray: + """ + Build dataset index and dataset sample index + """ + # Build the dataset indexes for 1 epoch + dataset_index, dataset_sample_index = build_nanoset_index_helper( + n_samples=self.split_samples_per_epoch, + weights=self.dataset_weights, + dataset_sizes=self.split_dataset_lengths, + offsets=self.split_dataset_offsets, + ) + # Shuffle the indexes the same way + numpy_random_state = np.random.RandomState(self.random_seed) + numpy_random_state.shuffle(dataset_index) + numpy_random_state = np.random.RandomState(self.random_seed) + numpy_random_state.shuffle(dataset_sample_index) + # Concatenate num_epochs the shuffled indexes + dataset_index = np.concatenate([dataset_index for _ in range(self.num_epochs)]) + dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(self.num_epochs)]) + # Just keep the necessary samples + dataset_index = dataset_index[: self.split_num_samples] + dataset_sample_index = dataset_sample_index[: self.split_num_samples] + + return dataset_index, dataset_sample_index + + def print_nanoset_info(self): + + log_rank( + f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of samples: {len(self)}", + logger=logger, + level=logging.INFO, + rank=0, + ) + log_rank( + f"> [{'Validation' if self.is_valid else 'Training'} dataset] Total number of tokens: {len(self) * self.sequence_length}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + # Print samples from each dataset + weight + dataset_sample_count = count_dataset_indexes(self.dataset_index, len(self.dataset_folders)) + for index, sample_count in enumerate(dataset_sample_count): + log_rank( + f"> Total number of {'validation' if self.is_valid else 'training'} samples from the {self.dataset_folders[index]} dataset: {sample_count} ({round(normalize(dataset_sample_count).tolist()[index], 2)})", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +@jit(nopython=True, cache=True) +def build_nanoset_index_helper( + n_samples: int, weights: np.ndarray, dataset_sizes: List[int], offsets: List[int] +) -> Tuple[np.ndarray, np.ndarray]: + """ + Given multiple datasets and a weighting array, build samples indexes + such that it follows those weights. + For train and valid splits we split each dataset_folder in train (first part) and valid splits. We set the offsets to the train lengths + for generating the valid split + """ + # Create empty arrays for dataset indices and dataset sample indices + dataset_index = np.empty((n_samples,), dtype="uint") + dataset_sample_index = np.empty((n_samples,), dtype="long") # Supports dataset with up to 2**64 samples + + # Initialize buffer for number of samples used for each dataset + current_samples = np.zeros((len(weights),), dtype="long") + + # Iterate over all samples + for sample_idx in range(n_samples): + + # Convert sample index to float for comparison against weights + sample_idx_float = max(sample_idx, 1.0) + + # Find the dataset with the highest error + errors = weights * sample_idx_float - current_samples + max_error_index = np.argmax(errors) + + # Assign the dataset index and update the sample index + dataset_index[sample_idx] = max_error_index + dataset_sample_index[sample_idx] = ( + current_samples[max_error_index] % dataset_sizes[max_error_index] + ) + offsets[max_error_index] + + # Update the total samples for the selected dataset + current_samples[max_error_index] += 1 + + return dataset_index, dataset_sample_index From d9f06703d49762b261075467981a066bc01f9249 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 12:25:17 +0000 Subject: [PATCH 08/24] Added Language token --- examples/config_multilingual_nanoset.yaml | 120 ++++++++++++++++++++++ src/nanotron/data/multilingual_nanoset.py | 7 +- 2 files changed, 126 insertions(+), 1 deletion(-) create mode 100644 examples/config_multilingual_nanoset.yaml diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml new file mode 100644 index 00000000..00ae6570 --- /dev/null +++ b/examples/config_multilingual_nanoset.yaml @@ -0,0 +1,120 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: checkpoints/ + checkpoints_path_is_shared_file_system: false + resume_checkpoint_path: null + save_initial_state: false +data_stages: +- data: + dataset: + dataset_folder: datasets/c4-es/tokenized + dataset_tokens: + - 15 + num_loading_workers: 1 + seed: 42 + name: General purpose training (Single dataset) + start_training_step: 1 +- data: + dataset: + dataset_folder: + - datasets/SlimPajama-6B/tokenized + - datasets/c4-es/tokenized + dataset_tokens: + - 16 + - 15 + num_loading_workers: 1 + seed: 42 + name: Second purpose training (> 1 dataset) + start_training_step: 15 +- data: + dataset: + dataset_folder: + datasets/SlimPajama-6B/tokenized: 0.8 + datasets/c4-es/tokenized: 0.2 + dataset_tokens: + - 16 + - 15 + num_loading_workers: 1 + seed: 42 + name: Third purpose training (Blended dataset) + start_training_step: 25 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: Nanoset + run: llama + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + bos_token_id: 1 + eos_token_id: 2 + hidden_act: silu + hidden_size: 4096 + initializer_range: 0.02 + intermediate_size: 11008 + is_llama_config: true + max_position_embeddings: 4096 + num_hidden_layers: 32 + num_attention_heads: 32 + num_key_value_heads: 8 + pad_token_id: null + pretraining_tp: 1 + rope_interleaved: false + rope_theta: 500000.0 + rms_norm_eps: 1.0e-06 + rope_scaling: null + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 98 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + zero_stage: 0 +parallelism: + dp: 1 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + tp: 1 + tp_linear_async_communication: false + tp_mode: REDUCE_SCATTER +profiler: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 10 + micro_batch_size: 2 + sequence_length: 1024 + train_steps: 200 + val_check_interval: -1 diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 40e06b87..6526659d 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -32,6 +32,7 @@ def __init__( token_size: int, train_split_num_samples: int, valid_split_num_samples: int, + dataset_tokens: List[int], is_valid: bool = False, dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, @@ -48,6 +49,7 @@ def __init__( self.token_size = token_size self.train_split_num_samples = train_split_num_samples self.valid_split_num_samples = valid_split_num_samples + self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed self.datatrove_datasets = [] @@ -129,7 +131,10 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset = self.dataset_index[idx] dataset_sample = self.dataset_sample_index[idx] - return self.datatrove_datasets[dataset][dataset_sample] + tokens = self.datatrove_datasets[dataset][dataset_sample] + tokens[0] = self.dataset_tokens[dataset] # Prepend language token + + return tokens def build_nanoset_index(self) -> np.ndarray: """ From efe87209103382f004a33ddfd940df75c0deef89 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 12:51:42 +0000 Subject: [PATCH 09/24] Forgot the trainer ups --- src/nanotron/trainer.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index bc81e326..3f4c5189 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -393,7 +393,10 @@ def find_stage_idx_to_resume(): def train( self, - dataloader_or_dls: Dict[ + train_dataloader_or_dls: Dict[ + str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] + ], + valid_dataloader_or_dls: Dict[ str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] ], **kwargs, @@ -424,7 +427,7 @@ def train( prof.step() self.iteration_start_time = time.time() - self._update_dataloader_based_on_training_stages(dataloader_or_dls) + self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) From 25ad39b2b25fe80c380065dba9e211dba31ed11e Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 14:12:57 +0000 Subject: [PATCH 10/24] Fix minor errors. Everything works --- run_train.py | 6 ++++-- src/nanotron/config/config.py | 2 +- src/nanotron/data/multilingual_nanoset.py | 11 +++++++---- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/run_train.py b/run_train.py index 649784ca..9b77da77 100644 --- a/run_train.py +++ b/run_train.py @@ -195,6 +195,7 @@ def get_dataloader_from_data_stage( token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, + dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -229,7 +230,7 @@ def get_valid_dataloader_from_data_stage( input_pp_rank, output_pp_rank = get_input_output_pp_ranks(model=trainer.model) # Only support Validation with MultilingualNanosets - if isinstance(data.dataset, NanosetDatasetsArgs): + if isinstance(data.dataset, MultilingualNanosetDatasetsArgs): # Get tokenizer cardinality tokenizer = AutoTokenizer.from_pretrained(trainer.config.tokenizer.tokenizer_name_or_path) token_size = 4 if len(tokenizer) > np.iinfo(np.uint16).max + 1 else 2 @@ -245,6 +246,7 @@ def get_valid_dataloader_from_data_stage( token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, valid_split_num_samples=valid_split_num_samples, + dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -320,7 +322,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size log_rank( - f"[Training Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index bfd20227..924a2cdf 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -132,7 +132,7 @@ def __post_init__(self): class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs] + dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, MultilingualNanosetDatasetsArgs] seed: Optional[int] num_loading_workers: Optional[int] = 1 diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 6526659d..cd8be195 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -1,5 +1,6 @@ import os import warnings +from math import ceil from typing import Dict, List, Tuple, Union import numpy as np @@ -80,11 +81,13 @@ def __init__( ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts - self.valid_dataset_lenghts = self.dataset_weights * valid_split_num_samples + self.valid_dataset_lenghts = [ + ceil(weight * valid_split_num_samples) for weight in self.dataset_weights + ] # Better not tu use numpy so we don't get overflow issues # Assert that we have sufficient samples to build the valid split for ds_index in range(len(self.dataset_lengths)): assert ( - self.valid_dataset_lenghts[ds_index] > self.dataset_lengths[ds_index] + self.dataset_lengths[ds_index] > self.valid_dataset_lenghts[ds_index] ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." self.train_dataset_lenghts = [ a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) @@ -132,7 +135,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset_sample = self.dataset_sample_index[idx] tokens = self.datatrove_datasets[dataset][dataset_sample] - tokens[0] = self.dataset_tokens[dataset] # Prepend language token + tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token return tokens @@ -144,7 +147,7 @@ def build_nanoset_index(self) -> np.ndarray: dataset_index, dataset_sample_index = build_nanoset_index_helper( n_samples=self.split_samples_per_epoch, weights=self.dataset_weights, - dataset_sizes=self.split_dataset_lengths, + dataset_sizes=self.split_dataset_lenghts, offsets=self.split_dataset_offsets, ) # Shuffle the indexes the same way From d91f9e1e8b67ffa51a14fff9bb0e408c02920631 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 16 Jul 2024 14:13:33 +0000 Subject: [PATCH 11/24] Updated config file with GPT2 tokenized datasets in RCP --- examples/config_multilingual_nanoset.yaml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 00ae6570..3c4476a0 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -7,7 +7,7 @@ checkpoints: data_stages: - data: dataset: - dataset_folder: datasets/c4-es/tokenized + dataset_folder: /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized dataset_tokens: - 15 num_loading_workers: 1 @@ -17,8 +17,8 @@ data_stages: - data: dataset: dataset_folder: - - datasets/SlimPajama-6B/tokenized - - datasets/c4-es/tokenized + - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized + - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized dataset_tokens: - 16 - 15 @@ -29,8 +29,8 @@ data_stages: - data: dataset: dataset_folder: - datasets/SlimPajama-6B/tokenized: 0.8 - datasets/c4-es/tokenized: 0.2 + /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized: 0.8 + /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized: 0.2 dataset_tokens: - 16 - 15 @@ -65,7 +65,7 @@ model: initializer_range: 0.02 intermediate_size: 11008 is_llama_config: true - max_position_embeddings: 4096 + max_position_embeddings: 1024 num_hidden_layers: 32 num_attention_heads: 32 num_key_value_heads: 8 @@ -108,7 +108,7 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_name_or_path: gpt2 tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 From d0c14e38054cb9bef16d75940e2ee076cde26bea Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 10:13:57 +0000 Subject: [PATCH 12/24] Before lunch --- run_train.py | 13 +--- src/nanotron/config/config.py | 6 +- src/nanotron/data/multilingual_nanoset.py | 76 +++++++++-------------- 3 files changed, 37 insertions(+), 58 deletions(-) diff --git a/run_train.py b/run_train.py index 9b77da77..57e0ec25 100644 --- a/run_train.py +++ b/run_train.py @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - valid_split_num_samples=trainer.config.tokens.limit_val_batches * trainer.global_batch_size, dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -222,7 +221,6 @@ def get_dataloader_from_data_stage( def get_valid_dataloader_from_data_stage( trainer: DistributedTrainer, data: DataArgs, - valid_split_num_samples: int, # consumed_train_samples: int, We will never use this because in each valid iteration we consume all the samples ): @@ -245,7 +243,6 @@ def get_valid_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - valid_split_num_samples=valid_split_num_samples, dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, @@ -259,7 +256,6 @@ def get_valid_dataloader_from_data_stage( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, micro_batch_size=trainer.micro_batch_size, - consumed_train_samples=0, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, ) @@ -319,21 +315,18 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: # NOTE: we only create the dataloader for the first stage, # then we lazy initialize the dataloader for the other stages stage = cast(DatasetStageArgs, stage) - valid_split_num_samples = trainer.config.tokens.limit_val_batches * trainer.global_batch_size log_rank( - f"[Validation Plan] Stage {stage.name} has {valid_split_num_samples} samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", logger=logger, level=logging.INFO, rank=0, ) dataloader = ( - get_valid_dataloader_from_data_stage(trainer, stage.data, valid_split_num_samples=valid_split_num_samples) + get_valid_dataloader_from_data_stage(trainer, stage.data) if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage( - trainer, stage.data, valid_split_num_samples=valid_split_num_samples - ) + else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) ) dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 924a2cdf..fb3e49dd 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -109,7 +109,8 @@ def __post_init__(self): @dataclass class MultilingualNanosetDatasetsArgs: - dataset_folder: Union[str, dict, List[str]] + training_folder: Union[str, dict, List[str]] + validation_folder: Union[str, dict, List[str]] dataset_tokens: List[ int ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) @@ -125,7 +126,8 @@ def __post_init__(self): self.dataset_folder = list(tmp_dataset_folder.keys()) self.dataset_weights = list(tmp_dataset_folder.values()) - assert len(self.dataset_folder) == len(self.dataset_tokens) + assert len(self.training_folder) == len(self.validation_folder) + assert len(self.training_folder) == len(self.dataset_tokens) @dataclass diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index cd8be195..f634fd98 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -1,6 +1,5 @@ import os import warnings -from math import ceil from typing import Dict, List, Tuple, Union import numpy as np @@ -32,7 +31,6 @@ def __init__( sequence_length: int, token_size: int, train_split_num_samples: int, - valid_split_num_samples: int, dataset_tokens: List[int], is_valid: bool = False, dataset_weights: Union[List[float], None] = None, @@ -49,7 +47,6 @@ def __init__( self.sequence_length = sequence_length self.token_size = token_size self.train_split_num_samples = train_split_num_samples - self.valid_split_num_samples = valid_split_num_samples self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed @@ -80,36 +77,11 @@ def __init__( self.dataset_weights ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index - ### Split dataset_lengths into train_dataset_lenghts & valid_dataset_lenghts - self.valid_dataset_lenghts = [ - ceil(weight * valid_split_num_samples) for weight in self.dataset_weights - ] # Better not tu use numpy so we don't get overflow issues - # Assert that we have sufficient samples to build the valid split - for ds_index in range(len(self.dataset_lengths)): - assert ( - self.dataset_lengths[ds_index] > self.valid_dataset_lenghts[ds_index] - ), f"Trying to build validation dataset with {self.valid_dataset_lenghts[ds_index]} samples but {dataset_folders[ds_index]} just have {self.dataset_lengths[ds_index]} samples." - self.train_dataset_lenghts = [ - a - b for a, b in zip(self.dataset_lengths, self.valid_dataset_lenghts) - ] # Subtract the valid samples from the training dataset - if is_valid: # Valid MultilingualNanoset - self.split_num_samples = valid_split_num_samples - self.split_samples_per_epoch = valid_split_num_samples - self.num_epochs = 1 - self.split_dataset_lenghts = self.valid_dataset_lenghts - self.split_dataset_offsets = self.train_dataset_lenghts + self.dataset_index, self.dataset_sample_index = self.build_valid_nanoset_index(self.dataset_lengths) else: # Train MultilingualNanoset - self.split_num_samples = train_split_num_samples - self.split_samples_per_epoch = sum(self.train_dataset_lenghts) - self.num_epochs = int(self.split_num_samples / self.split_samples_per_epoch) + 1 - self.split_dataset_lenghts = self.train_dataset_lenghts - self.split_dataset_offsets = [ - 0 for _ in range(len(self.dataset_lengths)) - ] # For training there is NO offset - - self.dataset_index, self.dataset_sample_index = self.build_nanoset_index() + self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() self.print_nanoset_info() @@ -139,16 +111,16 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: return tokens - def build_nanoset_index(self) -> np.ndarray: + def build_train_nanoset_index(self) -> np.ndarray: """ - Build dataset index and dataset sample index + Build train dataset index and dataset sample index """ + # Compute samples per epoch and number of epochs + samples_per_epoch = sum(self.dataset_lengths) + num_epochs = int(self.train_split_num_samples / samples_per_epoch) + 1 # Build the dataset indexes for 1 epoch - dataset_index, dataset_sample_index = build_nanoset_index_helper( - n_samples=self.split_samples_per_epoch, - weights=self.dataset_weights, - dataset_sizes=self.split_dataset_lenghts, - offsets=self.split_dataset_offsets, + dataset_index, dataset_sample_index = build_train_nanoset_index_helper( + n_samples=samples_per_epoch, weights=self.dataset_weights, dataset_sizes=self.dataset_lengths ) # Shuffle the indexes the same way numpy_random_state = np.random.RandomState(self.random_seed) @@ -156,14 +128,28 @@ def build_nanoset_index(self) -> np.ndarray: numpy_random_state = np.random.RandomState(self.random_seed) numpy_random_state.shuffle(dataset_sample_index) # Concatenate num_epochs the shuffled indexes - dataset_index = np.concatenate([dataset_index for _ in range(self.num_epochs)]) - dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(self.num_epochs)]) + dataset_index = np.concatenate([dataset_index for _ in range(num_epochs)]) + dataset_sample_index = np.concatenate([dataset_sample_index for _ in range(num_epochs)]) # Just keep the necessary samples - dataset_index = dataset_index[: self.split_num_samples] - dataset_sample_index = dataset_sample_index[: self.split_num_samples] + dataset_index = dataset_index[: self.train_split_num_samples] + dataset_sample_index = dataset_sample_index[: self.train_split_num_samples] return dataset_index, dataset_sample_index + @jit(nopython=True, cache=True) + def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: + """ + Build valid dataset index and dataset sample index + """ + dataset_index = [] + dataset_sample_index = [] + + for i, length in enumerate(dataset_lengths): + dataset_index.extend([i] * length) + dataset_sample_index.extend(range(length)) + + return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") + def print_nanoset_info(self): log_rank( @@ -191,8 +177,8 @@ def print_nanoset_info(self): @jit(nopython=True, cache=True) -def build_nanoset_index_helper( - n_samples: int, weights: np.ndarray, dataset_sizes: List[int], offsets: List[int] +def build_train_nanoset_index_helper( + n_samples: int, weights: np.ndarray, dataset_sizes: List[int] ) -> Tuple[np.ndarray, np.ndarray]: """ Given multiple datasets and a weighting array, build samples indexes @@ -219,9 +205,7 @@ def build_nanoset_index_helper( # Assign the dataset index and update the sample index dataset_index[sample_idx] = max_error_index - dataset_sample_index[sample_idx] = ( - current_samples[max_error_index] % dataset_sizes[max_error_index] - ) + offsets[max_error_index] + dataset_sample_index[sample_idx] = current_samples[max_error_index] % dataset_sizes[max_error_index] # Update the total samples for the selected dataset current_samples[max_error_index] += 1 From 9cfc5ea954505d880ffe19580ef8e60b4c8acd70 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 17 Jul 2024 14:10:03 +0000 Subject: [PATCH 13/24] After lunch --- examples/config_multilingual_nanoset.yaml | 42 +++++++++++++++-------- run_train.py | 6 ++-- src/nanotron/config/config.py | 21 ++++++------ src/nanotron/data/multilingual_nanoset.py | 33 +++++++++--------- tools/preprocess_data.py | 5 ++- 5 files changed, 61 insertions(+), 46 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 3c4476a0..238f8269 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -7,7 +7,8 @@ checkpoints: data_stages: - data: dataset: - dataset_folder: /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized + training_folder: datasets/c4-es/train + validation_folder: datasets/c4-es/validation dataset_tokens: - 15 num_loading_workers: 1 @@ -16,24 +17,37 @@ data_stages: start_training_step: 1 - data: dataset: - dataset_folder: - - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized - - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized + training_folder: + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation dataset_tokens: - - 16 - 15 + - 16 + - 17 num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) start_training_step: 15 - data: dataset: - dataset_folder: - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/SlimPajama-6B/tokenized: 0.8 - /mloscratch/homes/solergib/nanotrove/nanotron/datasets/c4-es/tokenized: 0.2 + training_folder: + datasets/c4-es/train: 0.6 + datasets/c4-en/train: 0.3 + datasets/c4-fr/train: 0.1 + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation dataset_tokens: - - 16 - 15 + - 16 + - 17 + num_loading_workers: 1 seed: 42 name: Third purpose training (Blended dataset) @@ -61,12 +75,12 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 4096 + hidden_size: 512 initializer_range: 0.02 - intermediate_size: 11008 + intermediate_size: 512 is_llama_config: true max_position_embeddings: 1024 - num_hidden_layers: 32 + num_hidden_layers: 2 num_attention_heads: 32 num_key_value_heads: 8 pad_token_id: null @@ -108,13 +122,13 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: gpt2 + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 10 - micro_batch_size: 2 + micro_batch_size: 4 sequence_length: 1024 train_steps: 200 val_check_interval: -1 diff --git a/run_train.py b/run_train.py index 57e0ec25..39cda23b 100644 --- a/run_train.py +++ b/run_train.py @@ -189,7 +189,7 @@ def get_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): train_dataset = MultilingualNanoset( - dataset_folders=data.dataset.dataset_folder, + dataset_folders=data.dataset.training_folder, dataset_weights=data.dataset.dataset_weights, sequence_length=trainer.sequence_length, token_size=token_size, @@ -238,11 +238,9 @@ def get_valid_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): valid_dataset = MultilingualNanoset( - dataset_folders=data.dataset.dataset_folder, - dataset_weights=data.dataset.dataset_weights, + dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index fb3e49dd..ce61a249 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -110,21 +110,20 @@ def __post_init__(self): @dataclass class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] - validation_folder: Union[str, dict, List[str]] - dataset_tokens: List[ - int - ] # Set token for each language previously defined. We use a List and not a dict because this way we support specifyng weights (dict) or not (List[str]) + validation_folder: Union[str, List[str]] + dataset_tokens: List[int] # Set token for each language previously defined def __post_init__(self): - if isinstance(self.dataset_folder, str): # Case 1: 1 Dataset file - self.dataset_folder = [self.dataset_folder] + if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder + self.training_folder = [self.training_folder] + self.validation_folder = [self.validation_folder] self.dataset_weights = [1] - elif isinstance(self.dataset_folder, List): # Case 2: > 1 Dataset file + elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder self.dataset_weights = None # Set to None so we consume all the samples randomly - elif isinstance(self.dataset_folder, dict): # Case 3: dict with > 1 dataset_folder and weights - tmp_dataset_folder = self.dataset_folder.copy() - self.dataset_folder = list(tmp_dataset_folder.keys()) - self.dataset_weights = list(tmp_dataset_folder.values()) + elif isinstance(self.training_folder, dict): # Case 3: dict with > 1 training_folder and weights + tmp_training_folder = self.training_folder.copy() + self.training_folder = list(tmp_training_folder.keys()) + self.dataset_weights = list(tmp_training_folder.values()) assert len(self.training_folder) == len(self.validation_folder) assert len(self.training_folder) == len(self.dataset_tokens) diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index f634fd98..7af57448 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -30,8 +30,8 @@ def __init__( dataset_folders: List[str], sequence_length: int, token_size: int, - train_split_num_samples: int, dataset_tokens: List[int], + train_split_num_samples: int = None, is_valid: bool = False, dataset_weights: Union[List[float], None] = None, random_seed: int = 1234, @@ -78,7 +78,7 @@ def __init__( ), f"Specified {len(self.dataset_weights)} weights but {len(dataset_folders)} datasets were provided." ## Build dataset index and dataset sample index if is_valid: # Valid MultilingualNanoset - self.dataset_index, self.dataset_sample_index = self.build_valid_nanoset_index(self.dataset_lengths) + self.dataset_index, self.dataset_sample_index = build_valid_nanoset_index(self.dataset_lengths) else: # Train MultilingualNanoset self.dataset_index, self.dataset_sample_index = self.build_train_nanoset_index() @@ -136,20 +136,6 @@ def build_train_nanoset_index(self) -> np.ndarray: return dataset_index, dataset_sample_index - @jit(nopython=True, cache=True) - def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: - """ - Build valid dataset index and dataset sample index - """ - dataset_index = [] - dataset_sample_index = [] - - for i, length in enumerate(dataset_lengths): - dataset_index.extend([i] * length) - dataset_sample_index.extend(range(length)) - - return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") - def print_nanoset_info(self): log_rank( @@ -211,3 +197,18 @@ def build_train_nanoset_index_helper( current_samples[max_error_index] += 1 return dataset_index, dataset_sample_index + + +@jit(nopython=True, cache=True) +def build_valid_nanoset_index(dataset_lengths: List[int]) -> np.ndarray: + """ + Build valid dataset index and dataset sample index + """ + dataset_index = [] + dataset_sample_index = [] + + for i, length in enumerate(dataset_lengths): + dataset_index.extend([i] * length) + dataset_sample_index.extend(range(length)) + + return np.array(dataset_index, dtype="uint"), np.array(dataset_sample_index, dtype="long") diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index c668aa58..8383ba38 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -98,7 +98,9 @@ def main(args): dataset_options={"split": args.split}, ) elif args.readers == "parquet": - datatrove_reader = ParquetReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) + datatrove_reader = ParquetReader( + data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern + ) else: datatrove_reader = JsonlReader(data_folder=args.dataset, text_key=args.column, glob_pattern=args.glob_pattern) @@ -107,6 +109,7 @@ def main(args): datatrove_reader, DocumentTokenizer( output_folder=args.output_folder, + shuffle=False, tokenizer_name_or_path=args.tokenizer_name_or_path, eos_token=args.eos_token, max_tokens_per_file=1e9, From eed7bce10712a9137eec78ac9c3b6c609fcb28d5 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Thu, 18 Jul 2024 10:48:00 +0000 Subject: [PATCH 14/24] Ready --- examples/config_multilingual_nanoset.yaml | 20 ++++++++++---------- src/nanotron/config/config.py | 11 ++++++++--- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 238f8269..599bff6c 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -9,8 +9,8 @@ data_stages: dataset: training_folder: datasets/c4-es/train validation_folder: datasets/c4-es/validation - dataset_tokens: - - 15 + lang_to_ids: + es: 128002 num_loading_workers: 1 seed: 42 name: General purpose training (Single dataset) @@ -25,10 +25,10 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - dataset_tokens: - - 15 - - 16 - - 17 + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 name: Second purpose training (> 1 dataset) @@ -43,10 +43,10 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - dataset_tokens: - - 15 - - 16 - - 17 + lang_to_ids: + es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index ce61a249..dd2c157d 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - dataset_tokens: List[int] # Set token for each language previously defined + lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -125,8 +125,13 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - assert len(self.training_folder) == len(self.validation_folder) - assert len(self.training_folder) == len(self.dataset_tokens) + self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.validation_folder + ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" + assert len(self.training_folder) == len( + self.dataset_tokens + ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" @dataclass From 27133e1e4b433c07f7e423b66bd1eb9845dc948c Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Mon, 22 Jul 2024 16:39:47 +0000 Subject: [PATCH 15/24] Just in case --- examples/config_multilingual_nanoset.yaml | 2 +- run_train.py | 4 +- src/nanotron/config/config.py | 1 + src/nanotron/models/llama.py | 30 ++-- .../parallel/pipeline_parallel/engine.py | 27 +++- src/nanotron/trainer.py | 143 ++++++++++++++++-- 6 files changed, 173 insertions(+), 34 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 599bff6c..33f9db41 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -131,4 +131,4 @@ tokens: micro_batch_size: 4 sequence_length: 1024 train_steps: 200 - val_check_interval: -1 + val_check_interval: 3 diff --git a/run_train.py b/run_train.py index 39cda23b..ed9b5607 100644 --- a/run_train.py +++ b/run_train.py @@ -238,10 +238,10 @@ def get_valid_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): valid_dataset = MultilingualNanoset( - dataset_folders=data.dataset.validation_folder, + dataset_folders=data.dataset.validation_folder, # TODO Just 1 folder sequence_length=trainer.sequence_length, token_size=token_size, - dataset_tokens=data.dataset.dataset_tokens, + dataset_tokens=data.dataset.dataset_tokens, # TODO Just 1 lang is_valid=True, random_seed=data.seed, ) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index dd2c157d..e5ea3ec1 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -125,6 +125,7 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) + self.ids_to_lang = {v: k for k, v in self.lang_to_ids.items()} self.dataset_tokens = list(self.lang_to_ids.values()) assert len(self.training_folder) == len( self.validation_folder diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 2411e5fa..7ae34dd5 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -801,7 +801,14 @@ def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch @torch.jit.script def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor - return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( + dim=1 + ) # TODO esto de entrada da float/float = float + + +# TODO la loss de cada uno !!!! ((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1)) +# Y pasa el assert close!! +# assert_close(((loss * label_mask).sum(dtype=dtype) / label_mask.sum()), torch.mean((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1))) class Loss(nn.Module): @@ -818,14 +825,16 @@ def forward( # Megatron by defaults cast everything in fp32. `--f16-lm-cross-entropy` is an option you can use to keep current precision. # https://github.com/NVIDIA/Megatron-LM/blob/f267e6186eae1d6e2055b412b00e2e545a8e896a/megatron/model/gpt_model.py#L38 - loss = sharded_cross_entropy( + sample_loss = sharded_cross_entropy( sharded_logits, label_ids.transpose(0, 1).contiguous(), group=self.tp_pg, dtype=torch.float ).transpose(0, 1) # TODO @thomasw21: It's unclear what kind of normalization we want to do. - loss = masked_mean(loss, label_mask, dtype=torch.float) - # I think indexing causes a sync we don't actually want - # loss = loss[label_mask].sum() - return {"loss": loss} + sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float) + # NOTE @tj.solergibert: masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. + # We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort. + # TODO @thomasw21: I think indexing causes a sync we don't actually want + # TODO @thomasw21: loss = loss[label_mask].sum() + return {"sample_loss": sample_loss} class LlamaForTraining(NanotronModel): @@ -847,7 +856,7 @@ def __init__( "label_ids", "label_mask", }, - module_output_keys={"loss"}, + module_output_keys={"sample_loss"}, ) self.parallel_context = parallel_context self.config = config @@ -864,12 +873,13 @@ def forward( input_ids=input_ids, input_mask=input_mask, ) - loss = self.loss( + outputs = self.loss( sharded_logits=sharded_logits, label_ids=label_ids, label_mask=label_mask, - )["loss"] - return {"loss": loss} + ) + outputs["loss"] = torch.mean(outputs["sample_loss"]) + return outputs @torch.no_grad() def init_model_randomly(self, config: Config): diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index ca9df312..bf690bd0 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -2,6 +2,9 @@ from typing import Dict, Iterable, Optional, Union import torch +from torch import nn as torch_nn +from torch.nn.parallel import DistributedDataParallel + from nanotron import distributed as dist from nanotron import logging from nanotron.distributed import ProcessGroup @@ -12,8 +15,6 @@ from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -from torch import nn as torch_nn -from torch.nn.parallel import DistributedDataParallel logger = logging.get_logger(__name__) @@ -29,6 +30,7 @@ def forward( state: PipelineTrainBatchState, micro_batch: Dict[str, Union[torch.Tensor, TensorPointer]], model: torch_nn.Module, + is_validation: bool = False, ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: # Increment the number of backwards state.nb_forwards += 1 @@ -52,7 +54,7 @@ def forward( output["loss"] = output["loss"] / self.nb_microbatches # Add output as activations that require backward pass - if not isinstance(output["loss"], TensorPointer): + if not isinstance(output["loss"], TensorPointer) and not is_validation: assert output["loss"].requires_grad state.register_activation_requiring_backward(output["loss"]) return output @@ -138,12 +140,15 @@ def validate_batch_iter( self.nb_microbatches = nb_microbatches outputs = [] + lang_ids = [] with attach_pipeline_state_to_model(model=model, pipeline_state=state): # All forward for micro_batch in batch: context = self._get_fwd_context(model=model) - output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model) + output = self.forward( + context=context, state=state, micro_batch=micro_batch, model=model, is_validation=True + ) # TODO @thomasw21: Somehow this needs to be done somewhere else to support interleaving. Somewhere right after a "stage" for _ in range(len(state.microbatches_activations_to_send)): send_activation = state.microbatches_activations_to_send.popleft() @@ -151,15 +156,23 @@ def validate_batch_iter( send_activation() # We make `output` a dict + # TODO convert to dict other items returned by the model (MoE aux loss for example) + # But in next if statement be careful if we return other items in all of the pp processes + # This conversion to dicts is kind of useless as the model already returns a dict with loss key. Maybe the PP ranks return TensorPointer Objects? if not isinstance(output, dict): output = {"loss": output} # Store the loss for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} - outputs.append(output) - - return outputs + # TODO ver este output que es y tambien ver en outputs como se guarda. Donde se have la media? En el training step lol + # Aqui deberiamos segregar por languagues porque es el unico punto en el que tenemos la languague!! O al menos "etiquetarla" o acumularla por language + # 1. Hacemos dict con key para cada idioma 2. cada key tiene una lista donde append los tensors 3. en valid step hacemos lo del stack y allreduces + # Finalmente: Aqui metemos solo el lang ids, en trainer.py acumularemos los resultados y tal. + outputs.extend(list(output["sample_loss"])) # TODO flatten?????? o extend?????? + lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) # TODO esto deberia se un extend???? + + return outputs, lang_ids class AllForwardAllBackwardPipelineEngine(PipelineEngine): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 3f4c5189..583068cd 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -300,7 +300,9 @@ def _print_training_plan(self): ) log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) - def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): + def _update_dataloader_based_on_training_stages( + self, dataloaders: Union[List[DataLoader], DataLoader], is_validation: bool = False + ): from collections.abc import Generator if not hasattr(self.config, "data_stages") or self.config.data_stages is None: @@ -309,9 +311,16 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da dataloader = dataloaders[0] else: dataloader = dataloaders - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + + if is_validation: + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + else: + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) return elif isinstance(dataloaders, Generator): # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader @@ -328,7 +337,7 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): import gc log_rank( - f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", + f"[{'Validation' if is_validation else 'Training'} Stage: {stage_name}] Clearing the previous {'validation' if is_validation else 'training'} stage's dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -369,7 +378,7 @@ def find_stage_idx_to_resume(): self.metadata.last_stage_idx = stage_idx - if is_resume_from_training: + if is_resume_from_training and not is_validation: remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( stage, self.config, self.metadata ) @@ -387,9 +396,15 @@ def find_stage_idx_to_resume(): break if dataloader is not None: - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + if is_validation: + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + else: + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) def train( self, @@ -428,9 +443,23 @@ def train( self.iteration_start_time = time.time() self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) + self._update_dataloader_based_on_training_stages(valid_dataloader_or_dls, is_validation=True) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) + self.training_step_time = time.time() + + # Validation step + # TODO A ver, en este loop solo se lleva a cabo una training iteration pero claro hay un porron de validation iteration... mmmmm + # Tal vez deberiamos mover esto a otro lugar? Es decir, aqui se have un training step pero hacemos varios validation steps + # Lo podemos dejar aqui solamente que las metricas de throughput y tokens consumidos se tendrian que revisar + # Porque actualmente utilizan la global batch size, que es correcta ya que es la que tiene cada training step pero claro, + # Cada validation es mucho mas largo que un training step + # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas + if self.iteration_step % self.config.tokens.val_check_interval == 0: + global_loss, lang_losses = self.validation_step(dataloader=self.current_validation_dataloader) + self.validation_step_time = time.time() + self.validation_step_logs(global_loss, lang_losses) # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -546,12 +575,36 @@ def training_step( return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: - outputs = self.pipeline_engine.validate_batch_iter( + outputs, lang_ids = self.pipeline_engine.validate_batch_iter( model=self.model, - batch=(next(dataloader) for _ in range(self.limit_val_batches)), - nb_microbatches=self.limit_val_batches, + batch=(next(dataloader) for _ in range(self.current_validation_dataloader_lenght)), + nb_microbatches=self.current_validation_dataloader_lenght, ) - return outputs + + lang_losses = { + lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.lang_to_ids.keys() + } + # Compute losses + if isinstance(outputs[0], torch.Tensor): + # Multilingual losses + for loss, lang_id in zip(outputs, lang_ids): + lang_losses[ + self.config.data_stages[self.metadata.last_stage_idx].data.dataset.ids_to_lang[lang_id] + ].append(loss) + # Global loss + global_loss_avg = torch.stack(outputs).sum() + # Sync losses across DP + for lang in lang_losses.keys(): + lang_losses[lang] = torch.stack(lang_losses[lang]).sum() + dist.all_reduce( + lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG + ) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos... + dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) + else: + global_loss_avg = None + lang_losses = None + + return global_loss_avg, lang_losses def train_step_logs( self, @@ -561,7 +614,7 @@ def train_step_logs( # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() torch.cuda.synchronize() - elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 + elapsed_time_per_iteration_ms = (self.training_step_time - self.iteration_start_time) * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) ) # tokens_per_sec is calculated using sequence_length @@ -641,6 +694,68 @@ def train_step_logs( else: exit(0) + def validation_step_logs( + self, + global_loss: torch.Tensor, + lang_losses: torch.Tensor, + ) -> None: + # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 + dist.barrier() + torch.cuda.synchronize() + total_validation_samples = self.current_validation_dataloader_lenght * self.micro_batch_size + elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 + tokens_per_sec = ( + total_validation_samples * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) + ) # tokens_per_sec is calculated using sequence_length + # TODO para el valid ojo con cambiar global_batch_size = len dataloader * mbs + model_tflops, hardware_tflops = self.unwrapped_model.get_flops_per_sec( + iteration_time_in_sec=elapsed_time_per_iteration_ms / 1000, + sequence_length=self.sequence_length, + global_batch_size=total_validation_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches + ) + + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: + assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" + + log_entries = [ + # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), + LogItem( + "validation_consumed_tokens", + self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + "human_format", + ), # , "12d"), + LogItem( + "validation_elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format" + ), # , ".1f"), + LogItem("validation_tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), + LogItem( + "validation_tokens_per_sec_per_gpu", + tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", + ), # , "1.6E"), + LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), + LogItem("validation_model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), + LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), + ] + + # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. + # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 + # GitHub complains: https://github.com/wandb/wandb/issues/3035 + log_entries.extend( + [LogItem(f"{lang}_validation_loss", loss.item(), "human_format") for lang, loss in lang_losses.items()] + ) + + # NOTE: only one rank writes to wandb + if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: + wandb.log( + { + **{log_item.tag: log_item.scalar_value for log_item in log_entries}, + "iteration_step": self.iteration_step, + } + ) + + self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step) + def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: """Initialize the model and load weights from checkpoint if needed.""" # TODO: add max_position_embeddings From 5c09e11a1a1df814b5edccb1ba6ac0a026897d04 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 23 Jul 2024 16:01:09 +0000 Subject: [PATCH 16/24] just in case --- run_train.py | 3 +- src/nanotron/data/dataloader_builder.py | 3 +- .../parallel/pipeline_parallel/engine.py | 4 +- .../parallel/pipeline_parallel/state.py | 4 + src/nanotron/serialize/metadata.py | 2 + src/nanotron/trainer.py | 171 ++++++++++++++---- 6 files changed, 148 insertions(+), 39 deletions(-) diff --git a/run_train.py b/run_train.py index ed9b5607..80c7a426 100644 --- a/run_train.py +++ b/run_train.py @@ -256,6 +256,7 @@ def get_valid_dataloader_from_data_stage( micro_batch_size=trainer.micro_batch_size, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + shuffle=True, ) return valid_dataloader @@ -315,7 +316,7 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: stage = cast(DatasetStageArgs, stage) log_rank( - f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples in the validation set", + f"[Validation Plan] Stage {stage.name} has {len(stage.data.dataset.validation_folder)} folders with samples for the validation set", logger=logger, level=logging.INFO, rank=0, diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index 9d3285f6..b8bfb303 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -23,6 +23,7 @@ def build_nanoset_dataloader( consumed_train_samples: int = 0, dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, + shuffle: bool = False, ) -> DataLoader: # Case of ranks not requiring data. We give them a dummy dataset, then the collator will do his job @@ -49,7 +50,7 @@ def build_nanoset_dataloader( dl_rank=dp_rank, drop_last=dataloader_drop_last, consumed_train_samples=consumed_train_samples, - shuffle=False, + shuffle=shuffle, ) return DataLoader( diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index bf690bd0..bc6dc5b5 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -12,7 +12,7 @@ from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model -from nanotron.parallel.pipeline_parallel.state import PipelineTrainBatchState +from nanotron.parallel.pipeline_parallel.state import PipelineEvalBatchState, PipelineTrainBatchState from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers @@ -136,7 +136,7 @@ def validate_batch_iter( nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: # Assign a new state for the current batch - state = PipelineTrainBatchState() # TODO: do i need state? + state = PipelineEvalBatchState() # PipelineTrainBatchState() # TODO: do i need state? self.nb_microbatches = nb_microbatches outputs = [] diff --git a/src/nanotron/parallel/pipeline_parallel/state.py b/src/nanotron/parallel/pipeline_parallel/state.py index e07cc89a..f22d6571 100644 --- a/src/nanotron/parallel/pipeline_parallel/state.py +++ b/src/nanotron/parallel/pipeline_parallel/state.py @@ -4,6 +4,7 @@ from typing import List import torch + from nanotron import distributed as dist from nanotron import logging from nanotron.logging import log_rank @@ -203,6 +204,9 @@ class PipelineEvalBatchState(PipelineBatchState): microbatches_activations_to_recv = collections.deque() activations_buffer = collections.deque() + # Reinitialise counter + nb_forwards = 0 + def register_activation_requiring_backward(self, activation: torch.Tensor): pass diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 0d8708f9..4bd36c19 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -46,6 +46,8 @@ class TrainingMetadata: last_stage_idx: Optional[int] = None data_stages: Optional[List[DataStageMetadata]] = None + last_validation_stage_idx: Optional[int] = None + def __post_init__(self): # NOTE: this is a sanity check after loading a trained checkpoint total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 583068cd..b1cc36ad 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -231,7 +231,11 @@ def __init__( for stage in self.config.data_stages ] self.metadata: TrainingMetadata = TrainingMetadata( - consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages + consumed_train_samples=0, + last_train_step=0, + last_stage_idx=0, + data_stages=data_stages, + last_validation_stage_idx=0, ) # Setup tensorboard write and log writers on output rank @@ -253,6 +257,8 @@ def __init__( self.limit_val_batches = self.config.tokens.limit_val_batches # NOTE: the dataloader currently in use for the current training stage self.current_dataloader: Optional[DataLoader] = None + # NOTE: the dataloader currently in use for the current validation stage + self.current_validation_dataloader: Optional[DataLoader] = None self.post_init() @@ -300,9 +306,108 @@ def _print_training_plan(self): ) log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) - def _update_dataloader_based_on_training_stages( - self, dataloaders: Union[List[DataLoader], DataLoader], is_validation: bool = False - ): + def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataLoader], DataLoader]): + # NOTE(tj.solergibert) Similar to _update_dataloader_based_on_training_stages BUT: + # 1. We call this function EVERY TIME we run the validation loop + # 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset + # in the first iteration and subsequent validations will fail + # TODO(tj.solergibert) Delete previous DataLoaders from memory like we do with training DataLoaders + # TODO(tj.solergibert) Check the tuple case below + from collections.abc import Generator + + if not hasattr(self.config, "data_stages") or self.config.data_stages is None: + + if isinstance(dataloaders, tuple): # TODO(tj.solergibert) Check this tuple case + dataloader = dataloaders[0] + else: + dataloader = dataloaders + + self.current_validation_dataloader_lenght = len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + + return + elif isinstance(dataloaders, Generator): + # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader + # remove this in the next PR + self.current_validation_dataloader = dataloaders + return + + assert len(dataloaders) > 0, "No dataloaders provided" + assert len(dataloaders) == len( + self.config.data_stages + ), "Number of dataloaders should match the number of dataset stages" + + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + import gc + + log_rank( + f"[Validation Stage: {stage_name}] Clearing the previous validation stage's dataloader and dataset from memory", + logger=logger, + level=logging.INFO, + ) + + # NOTE: Clear dataloader from memory + del dataloader.dataset + del dataloader.sampler + del dataloader.batch_sampler + + gc.collect() + + dataloader = None + + for stage_idx, stage in enumerate(self.config.data_stages): + if stage_idx < self.metadata.last_stage_idx: + continue + + if ( + stage_idx is not self.metadata.last_validation_stage_idx + and self.metadata.last_validation_stage_idx is not None + ): + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + # Si cambiamos de stage borramo el antiguo + # En ambos casos recrear el que toca !!! + # TODO Aqui nos quedamos!!! Tenemos que borrar el anterior dataloader cuando sea necesario y hacer el sanity del current dataloader SIEMPRE + stage = cast(DatasetStageArgs, stage) + print( + stage.name + ) # TODO como actualizamos el last stage index en el training aqui estamos mirando el dataloader de la siguiente iteracion que mal por dios!!!!! + + log_rank( + f"Ese print bueno {stage.name}", + logger=logger, + level=logging.INFO, + rank=0, + ) + # self.metadata.last_stage_idx = stage_idx + """ + if self.current_validation_dataloader is not None: # TODO Si hay algun dataloader ya lo eliminamos. Igualmente creamos de nuevo. Bueno el dataloader como tal ya esta creado, solo hay que devolver el sanity check raro + prev_stage_name = self.config.data_stages[stage_idx - 1].name + prev_dataloader = dataloaders[prev_stage_name] + + if isinstance(prev_dataloader, DataLoader): + # NOTE: we don't need to clear dummy data generator from memory + clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + """ + log_rank( + f"Preparing validation DataLoader from stage {stage.name}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + dataloader = dataloaders[stage.name] + # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it + dataloader = dataloader() if callable(dataloader) else dataloader + break + + self.current_validation_dataloader_lenght = 200 # TODO len(dataloader) + self.current_validation_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) + + def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): from collections.abc import Generator if not hasattr(self.config, "data_stages") or self.config.data_stages is None: @@ -311,16 +416,9 @@ def _update_dataloader_based_on_training_stages( dataloader = dataloaders[0] else: dataloader = dataloaders - - if is_validation: - self.current_validation_dataloader_lenght = len(dataloader) - self.current_validation_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) - else: - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) return elif isinstance(dataloaders, Generator): # TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader @@ -337,7 +435,7 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): import gc log_rank( - f"[{'Validation' if is_validation else 'Training'} Stage: {stage_name}] Clearing the previous {'validation' if is_validation else 'training'} stage's dataloader and datasets from memory", + f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -360,7 +458,7 @@ def find_stage_idx_to_resume(): stage_idx_to_resume = find_stage_idx_to_resume() - for stage_idx, stage in enumerate(self.config.data_stages): + for stage_idx, stage in enumerate(self.config.data_stages): # TODO check metadatalaststageindex init if stage_idx < self.metadata.last_stage_idx: continue @@ -378,7 +476,7 @@ def find_stage_idx_to_resume(): self.metadata.last_stage_idx = stage_idx - if is_resume_from_training and not is_validation: + if is_resume_from_training: remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( stage, self.config, self.metadata ) @@ -396,15 +494,9 @@ def find_stage_idx_to_resume(): break if dataloader is not None: - if is_validation: - self.current_validation_dataloader_lenght = len(dataloader) - self.current_validation_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) - else: - self.current_dataloader = sanity_check_dataloader( - dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + self.current_dataloader = sanity_check_dataloader( + dataloader=dataloader, parallel_context=self.parallel_context, config=self.config + ) def train( self, @@ -443,7 +535,6 @@ def train( self.iteration_start_time = time.time() self._update_dataloader_based_on_training_stages(train_dataloader_or_dls) - self._update_dataloader_based_on_training_stages(valid_dataloader_or_dls, is_validation=True) # Training step outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) @@ -457,9 +548,18 @@ def train( # Cada validation es mucho mas largo que un training step # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas if self.iteration_step % self.config.tokens.val_check_interval == 0: - global_loss, lang_losses = self.validation_step(dataloader=self.current_validation_dataloader) + log_rank( + f"KOMO???? {self.iteration_step}", + logger=logger, + level=logging.INFO, + rank=0, + ) + self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) + val_global_loss, val_lang_losses = self.validation_step( + dataloader=self.current_validation_dataloader + ) self.validation_step_time = time.time() - self.validation_step_logs(global_loss, lang_losses) + self.validation_step_logs(val_global_loss, val_lang_losses) # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -592,10 +692,10 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten self.config.data_stages[self.metadata.last_stage_idx].data.dataset.ids_to_lang[lang_id] ].append(loss) # Global loss - global_loss_avg = torch.stack(outputs).sum() + global_loss_avg = torch.mean(torch.stack(outputs)) # Sync losses across DP for lang in lang_losses.keys(): - lang_losses[lang] = torch.stack(lang_losses[lang]).sum() + lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) dist.all_reduce( lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG ) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos... @@ -630,7 +730,6 @@ def train_step_logs( lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "consumed_tokens", self.metadata.consumed_train_samples * self.config.tokens.sequence_length, @@ -718,7 +817,6 @@ def validation_step_logs( assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" log_entries = [ - # LogItem("consumed_samples", self.consumed_train_samples, "human_format"), # , "12d"), LogItem( "validation_consumed_tokens", self.metadata.consumed_train_samples * self.config.tokens.sequence_length, @@ -734,7 +832,7 @@ def validation_step_logs( "human_format", ), # , "1.6E"), LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), - LogItem("validation_model_tflops_per_gpu", model_tflops, "human_format"), # , ".2f"), + LogItem("validation_model_tflops_per_gpu", model_tflops / 3, "human_format"), # , ".2f"), LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), ] @@ -746,12 +844,15 @@ def validation_step_logs( ) # NOTE: only one rank writes to wandb + # NOTE(tj.solergibert) By default wandb.log performs a step in the x-axis every time. + # Set commit=False to log values with the next wandb.log with the training logs if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: wandb.log( { **{log_item.tag: log_item.scalar_value for log_item in log_entries}, "iteration_step": self.iteration_step, - } + }, + commit=False, ) self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step) From 94d6c2a9931cf735366be9b356a01270e657a9a1 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 16:18:36 +0000 Subject: [PATCH 17/24] This looks good --- examples/config_multilingual_nanoset.yaml | 60 +++++----- run_train.py | 9 +- src/nanotron/models/llama.py | 7 +- .../parallel/pipeline_parallel/engine.py | 16 +-- src/nanotron/trainer.py | 103 +++++++++--------- 5 files changed, 99 insertions(+), 96 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 33f9db41..5573a224 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 1000 + checkpoint_interval: 1000000 checkpoints_path: checkpoints/ checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null @@ -7,38 +7,40 @@ checkpoints: data_stages: - data: dataset: - training_folder: datasets/c4-es/train - validation_folder: datasets/c4-es/validation + training_folder: + datasets/c4-es/train: 0.85 + datasets/c4-en/train: 0.05 + datasets/c4-fr/train: 0.1 + validation_folder: + - datasets/c4-es/validation + - datasets/c4-en/validation + - datasets/c4-fr/validation lang_to_ids: es: 128002 + en: 128003 + fr: 128004 num_loading_workers: 1 seed: 42 - name: General purpose training (Single dataset) + name: General purpose training (Blended dataset) start_training_step: 1 - data: dataset: training_folder: - datasets/c4-es/train - - datasets/c4-en/train - - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - - datasets/c4-en/validation - - datasets/c4-fr/validation lang_to_ids: es: 128002 - en: 128003 - fr: 128004 num_loading_workers: 1 seed: 42 - name: Second purpose training (> 1 dataset) - start_training_step: 15 + name: Second purpose training (Single dataset) + start_training_step: 100 - data: dataset: training_folder: - datasets/c4-es/train: 0.6 - datasets/c4-en/train: 0.3 - datasets/c4-fr/train: 0.1 + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - datasets/c4-en/validation @@ -50,13 +52,13 @@ data_stages: num_loading_workers: 1 seed: 42 - name: Third purpose training (Blended dataset) - start_training_step: 25 + name: Third purpose training (>1 dataset) + start_training_step: 200 general: benchmark_csv_path: null consumed_train_samples: null ignore_sanity_checks: true - project: Nanoset + project: Multilingual run: llama seed: 42 step: null @@ -75,12 +77,12 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 512 + hidden_size: 4096 initializer_range: 0.02 - intermediate_size: 512 + intermediate_size: 14336 is_llama_config: true - max_position_embeddings: 1024 - num_hidden_layers: 2 + max_position_embeddings: 4096 + num_hidden_layers: 32 num_attention_heads: 32 num_key_value_heads: 8 pad_token_id: null @@ -89,7 +91,7 @@ model: rope_theta: 500000.0 rms_norm_eps: 1.0e-06 rope_scaling: null - tie_word_embeddings: true + tie_word_embeddings: false use_cache: true vocab_size: 128256 optimizer: @@ -116,19 +118,19 @@ parallelism: expert_parallel_size: 1 pp: 1 pp_engine: 1f1b - tp: 1 + tp: 4 tp_linear_async_communication: false tp_mode: REDUCE_SCATTER profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B + tokenizer_name_or_path: /mloscratch/homes/solergib/models/Meta-Llama-3-8B tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 limit_test_batches: 0 limit_val_batches: 10 - micro_batch_size: 4 - sequence_length: 1024 - train_steps: 200 - val_check_interval: 3 + micro_batch_size: 3 + sequence_length: 4096 + train_steps: 800 + val_check_interval: 50 diff --git a/run_train.py b/run_train.py index 80c7a426..2ddff5ad 100644 --- a/run_train.py +++ b/run_train.py @@ -325,8 +325,15 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: dataloader = ( get_valid_dataloader_from_data_stage(trainer, stage.data) if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage(trainer, stage.data) + else lambda stage=stage: get_valid_dataloader_from_data_stage(trainer, stage.data) ) + # TODO(tj.solergibert) As we are creating again the valid dataloader in every validation stage, we print multiple times + # the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda + # funcs and directly create all dataloaders. + # + # This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead of creating multiple DataLoaders + # 2. Consume less memory as the lambda func is lighter that the DataLoader object with the Dataset, collator, etc. + # BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling from the Nanoset dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 7ae34dd5..133442af 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -803,12 +803,7 @@ def masked_mean(loss, label_mask, dtype): # type: (Tensor, Tensor, torch.dtype) -> Tensor return (loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum( dim=1 - ) # TODO esto de entrada da float/float = float - - -# TODO la loss de cada uno !!!! ((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1)) -# Y pasa el assert close!! -# assert_close(((loss * label_mask).sum(dtype=dtype) / label_mask.sum()), torch.mean((loss * label_mask).sum(dim=1, dtype=dtype) / label_mask.sum(dim=1))) + ) # NOTE(tj.solergibert) Added dim=1 to return a tensor with shape [Batch size, 1] instead of [1] class Loss(nn.Module): diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index bc6dc5b5..549ef5eb 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -136,7 +136,7 @@ def validate_batch_iter( nb_microbatches: int, ) -> Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]]: # Assign a new state for the current batch - state = PipelineEvalBatchState() # PipelineTrainBatchState() # TODO: do i need state? + state = PipelineEvalBatchState() self.nb_microbatches = nb_microbatches outputs = [] @@ -156,21 +156,17 @@ def validate_batch_iter( send_activation() # We make `output` a dict - # TODO convert to dict other items returned by the model (MoE aux loss for example) - # But in next if statement be careful if we return other items in all of the pp processes - # This conversion to dicts is kind of useless as the model already returns a dict with loss key. Maybe the PP ranks return TensorPointer Objects? if not isinstance(output, dict): output = {"loss": output} # Store the loss for each microbatch if not isinstance(output["loss"], TensorPointer): output = {k: v.detach() for k, v in output.items()} - # TODO ver este output que es y tambien ver en outputs como se guarda. Donde se have la media? En el training step lol - # Aqui deberiamos segregar por languagues porque es el unico punto en el que tenemos la languague!! O al menos "etiquetarla" o acumularla por language - # 1. Hacemos dict con key para cada idioma 2. cada key tiene una lista donde append los tensors 3. en valid step hacemos lo del stack y allreduces - # Finalmente: Aqui metemos solo el lang ids, en trainer.py acumularemos los resultados y tal. - outputs.extend(list(output["sample_loss"])) # TODO flatten?????? o extend?????? - lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) # TODO esto deberia se un extend???? + + outputs.extend( + list(output["sample_loss"]) + ) # NOTE(tj.solergibert) Yes, it might look useless to do list + extend but it's necessary to split the output["sample_loss"] tensor into multiple tensors + lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) return outputs, lang_ids diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index b1cc36ad..f720446a 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -311,7 +311,25 @@ def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataL # 1. We call this function EVERY TIME we run the validation loop # 2. Every time it returns a NEW validation iterator DataLoader. If you don't do this you'll consume the whole validation dataset # in the first iteration and subsequent validations will fail - # TODO(tj.solergibert) Delete previous DataLoaders from memory like we do with training DataLoaders + # `dataloaders` are either torch DataLoaders (the very first stage) OR functions that we call later that provide torch DataLoaders (subsequent stages) + # From this torch DataLoaders objects we then call `sanity_check_dataloader` that will return a iterator. + # In short, `sanity_check_dataloader` just places the input tensors in the GPU when necessary (TensorPointers stay in the CPU) + # + # TBH, the for loop below it's just for deleting the DataLoaders of previous stages, which is not so problematic. The important part is returning the + # DataLoader iterator every time we call this function from the current training stage, which is tracked during training + # + # Also, keep in mind that if val_check_interval = 5 & data.start_training_step = 10 we will already perform the evaluation with the SECOND data stage + # after just training for the current iteration, so it might not be a good idea to set evals during the stage in which we change of data stage + # + # NOTE(tj.solergibert) Further investigation should be done, but there is a extrange behaiviour when deleting the DataLoaders////lambda functs. As they + # are converted into Iterators with `sanity_check_dataloader` we can't access anymore the DataLoader object to del the dataset (After first stage, + # in this function we locally create the DataLoder from the lambda func --> Return Iterator) + # + # Also when the gc deletes the first stage dataloader, all the `DatatroveFileDataset._f` are already None AND the `del` thing are deleting a copy of the + # object, not the object itself + # + # FINAL NOTE(tj.solergibert) I will open a Issue in nanotron to check with them if they are aware of this useless deletitions + # # TODO(tj.solergibert) Check the tuple case below from collections.abc import Generator @@ -339,11 +357,11 @@ def _prepare_dataloader_for_validation_stage(self, dataloaders: Union[List[DataL self.config.data_stages ), "Number of dataloaders should match the number of dataset stages" - def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): import gc log_rank( - f"[Validation Stage: {stage_name}] Clearing the previous validation stage's dataloader and dataset from memory", + f"[Validation Stage: {stage_name}] Clearing the previous validation stage's ({prev_stage_name}) dataloader and dataset from memory", logger=logger, level=logging.INFO, ) @@ -355,57 +373,38 @@ def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): gc.collect() - dataloader = None - for stage_idx, stage in enumerate(self.config.data_stages): if stage_idx < self.metadata.last_stage_idx: continue + # NOTE(tj.solergibert) From this point stage_idx = self.metadata.last_stage_idx. We update self.metadata.last_stage_idx (which keeps track of the training stage) + # in each and every training step. if ( stage_idx is not self.metadata.last_validation_stage_idx - and self.metadata.last_validation_stage_idx is not None - ): + ): # When stage_idx (= self.metadata.last_stage_idx, the training stage index) is different than the last validation stage index self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index - # Si cambiamos de stage borramo el antiguo - # En ambos casos recrear el que toca !!! - # TODO Aqui nos quedamos!!! Tenemos que borrar el anterior dataloader cuando sea necesario y hacer el sanity del current dataloader SIEMPRE - stage = cast(DatasetStageArgs, stage) - print( - stage.name - ) # TODO como actualizamos el last stage index en el training aqui estamos mirando el dataloader de la siguiente iteracion que mal por dios!!!!! - - log_rank( - f"Ese print bueno {stage.name}", - logger=logger, - level=logging.INFO, - rank=0, - ) - # self.metadata.last_stage_idx = stage_idx - """ - if self.current_validation_dataloader is not None: # TODO Si hay algun dataloader ya lo eliminamos. Igualmente creamos de nuevo. Bueno el dataloader como tal ya esta creado, solo hay que devolver el sanity check raro + # Delete previous stage DataLoader prev_stage_name = self.config.data_stages[stage_idx - 1].name prev_dataloader = dataloaders[prev_stage_name] - if isinstance(prev_dataloader, DataLoader): - # NOTE: we don't need to clear dummy data generator from memory - clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) - """ - log_rank( - f"Preparing validation DataLoader from stage {stage.name}", - logger=logger, - level=logging.INFO, - rank=0, - ) + if isinstance(prev_dataloader, DataLoader): + # NOTE: we don't need to clear dummy data generator from memory + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) + + self.metadata.last_validation_stage_idx = stage_idx # Update validation stage index + # NOTE(tj.solergibert) Create AGAIN the DataLoader dataloader = dataloaders[stage.name] # NOTE: if a dataloader is lazy initialized, we need to call it to initialize it dataloader = dataloader() if callable(dataloader) else dataloader break - self.current_validation_dataloader_lenght = 200 # TODO len(dataloader) + self.current_validation_dataloader_lenght = len(dataloader) self.current_validation_dataloader = sanity_check_dataloader( dataloader=dataloader, parallel_context=self.parallel_context, config=self.config - ) + ) # NOTE(tj.solergibert) Create a Iterator from the DataLoader def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]): from collections.abc import Generator @@ -431,11 +430,11 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da self.config.data_stages ), "Number of dataloaders should match the number of dataset stages" - def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): + def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str, prev_stage_name: str): import gc log_rank( - f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory", + f"[Training Stage: {stage_name}] Clearing the previous training stage's ({prev_stage_name}) dataloader and datasets from memory", logger=logger, level=logging.INFO, ) @@ -472,7 +471,9 @@ def find_stage_idx_to_resume(): if isinstance(prev_dataloader, DataLoader): # NOTE: we don't need to clear dummy data generator from memory - clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name) + clear_dataloader_from_memory( + prev_dataloader, stage_name=stage.name, prev_stage_name=prev_stage_name + ) self.metadata.last_stage_idx = stage_idx @@ -548,18 +549,14 @@ def train( # Cada validation es mucho mas largo que un training step # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas if self.iteration_step % self.config.tokens.val_check_interval == 0: - log_rank( - f"KOMO???? {self.iteration_step}", - logger=logger, - level=logging.INFO, - rank=0, - ) self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) val_global_loss, val_lang_losses = self.validation_step( dataloader=self.current_validation_dataloader ) self.validation_step_time = time.time() - self.validation_step_logs(val_global_loss, val_lang_losses) + self.validation_step_logs( + val_global_loss, val_lang_losses + ) # TODO(tj.solergibert) Check what happens when val_check_interval % iteration_step_info_interval != 0 # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -684,6 +681,14 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten lang_losses = { lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.lang_to_ids.keys() } + # WARNING(tj.solergibert) This mechanism will fail in the following [corner] case: + # If the lang_losses dict for a given lang IS EMPTY aka in the validation step in a Data Parallel Group + # we have 0 SAMPLES of a given lang, lang_losses[lang] will be a empty python list so the toch.stack call + # will fail with "stack expects a non-empty TensorList". I've tried setting this lang_losses[lang] to torch.empty + # but of course it doesn't works as we then do the average across the DP group. + # We will fix this issue in the future if we encounter this problem again. + # A bit of inspo https://blog.speechmatics.com/Sparse-All-Reduce-Part-1 + # Compute losses if isinstance(outputs[0], torch.Tensor): # Multilingual losses @@ -696,9 +701,7 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten # Sync losses across DP for lang in lang_losses.keys(): lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) - dist.all_reduce( - lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG - ) # TODO Estas averages dan enormes porque debe de hacer el average con un solo valor!!!!!!!! Debe de set loss per batch o asi no? Sino meter en el outputs de arriba coger el "loss" y comparar a mano vamos... + dist.all_reduce(lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) else: global_loss_avg = None @@ -833,7 +836,7 @@ def validation_step_logs( ), # , "1.6E"), LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), LogItem("validation_model_tflops_per_gpu", model_tflops / 3, "human_format"), # , ".2f"), - LogItem("validation_hardware_tflops_per_gpu", hardware_tflops, "human_format"), # , ".2f"), + LogItem("validation_hardware_tflops_per_gpu", hardware_tflops / 3, "human_format"), # , ".2f"), ] # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. From 5cccf16711e296caa517fb6619ae0dbd5d7ede75 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 16:47:54 +0000 Subject: [PATCH 18/24] This looks better --- run_train.py | 13 ++-- src/nanotron/models/llama.py | 4 +- src/nanotron/trainer.py | 142 ++++++++++++++++------------------- 3 files changed, 76 insertions(+), 83 deletions(-) diff --git a/run_train.py b/run_train.py index 2ddff5ad..a51caf59 100644 --- a/run_train.py +++ b/run_train.py @@ -238,10 +238,10 @@ def get_valid_dataloader_from_data_stage( with main_rank_first(trainer.parallel_context.world_pg): valid_dataset = MultilingualNanoset( - dataset_folders=data.dataset.validation_folder, # TODO Just 1 folder + dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - dataset_tokens=data.dataset.dataset_tokens, # TODO Just 1 lang + dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -331,9 +331,12 @@ def get_valid_dataloader(trainer: DistributedTrainer) -> Dict[str, DataLoader]: # the validation MultilingualNanoset info (Number of samples, etc.) [UPDATE: ]. In order to solve that, we could get rid of this lambda # funcs and directly create all dataloaders. # - # This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead of creating multiple DataLoaders - # 2. Consume less memory as the lambda func is lighter that the DataLoader object with the Dataset, collator, etc. - # BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling from the Nanoset + # This lambda functs (Used in training too) are for creating the DataLoaders lazyly FOR 1. Start training faster instead + # of creating multiple DataLoaders 2. Consume less memory as the lambda func is lighter that the DataLoader object with + # the Dataset, collator, etc. + # BUT 1. The Nanoset creation process is very fast and 2. Nanosets doesn't consume any memory at all till we start sampling + # from the Nanoset. Also they later transform the DataLoader into a Iterator object so it's impossible to retrieve + # the DataLoader object again to delete it (More comments in trainer.py) dataloaders[stage.name] = dataloader return dataloaders diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 133442af..8c4125b7 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -825,8 +825,10 @@ def forward( ).transpose(0, 1) # TODO @thomasw21: It's unclear what kind of normalization we want to do. sample_loss = masked_mean(sample_loss, label_mask, dtype=torch.float) - # NOTE @tj.solergibert: masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. + # NOTE(tj.solergibert) masked_mean returns a single scalar with the batch loss. We've changed it to compute the SAMPLE loss. # We will continue using "loss" as the batch loss but we add "sample_loss" for the multilingual effort. + # WARN(tj.solergibert) Don't panic, the batch loss used to update the parameters is computed in `LlamaForTraining` + # TODO @thomasw21: I think indexing causes a sync we don't actually want # TODO @thomasw21: loss = loss[label_mask].sum() return {"sample_loss": sample_loss} diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index f720446a..c327f508 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -457,7 +457,7 @@ def find_stage_idx_to_resume(): stage_idx_to_resume = find_stage_idx_to_resume() - for stage_idx, stage in enumerate(self.config.data_stages): # TODO check metadatalaststageindex init + for stage_idx, stage in enumerate(self.config.data_stages): if stage_idx < self.metadata.last_stage_idx: continue @@ -541,22 +541,17 @@ def train( outputs, loss_avg = self.training_step(dataloader=self.current_dataloader) self.training_step_time = time.time() - # Validation step - # TODO A ver, en este loop solo se lleva a cabo una training iteration pero claro hay un porron de validation iteration... mmmmm - # Tal vez deberiamos mover esto a otro lugar? Es decir, aqui se have un training step pero hacemos varios validation steps - # Lo podemos dejar aqui solamente que las metricas de throughput y tokens consumidos se tendrian que revisar - # Porque actualmente utilizan la global batch size, que es correcta ya que es la que tiene cada training step pero claro, - # Cada validation es mucho mas largo que un training step - # Puede que el len valid dataloader de el numero de valid batches por lo que con eso y la batch size podemos tirar millas + # Validation stage if self.iteration_step % self.config.tokens.val_check_interval == 0: self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) val_global_loss, val_lang_losses = self.validation_step( dataloader=self.current_validation_dataloader ) self.validation_step_time = time.time() - self.validation_step_logs( - val_global_loss, val_lang_losses - ) # TODO(tj.solergibert) Check what happens when val_check_interval % iteration_step_info_interval != 0 + else: + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + val_global_loss, val_lang_losses = None, None # Training Logs # TODO(xrsrke): refactor using callbacks would be better @@ -567,7 +562,9 @@ def train( ].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs(outputs=outputs, loss_avg=loss_avg) + self.train_step_logs( + outputs=outputs, loss_avg=loss_avg, global_loss=val_global_loss, lang_losses=val_lang_losses + ) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -711,12 +708,14 @@ def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, Ten def train_step_logs( self, - outputs: Iterable[Dict[str, Union[torch.Tensor, TensorPointer]]], loss_avg: Optional[torch.Tensor], + global_loss: torch.Tensor, + lang_losses: torch.Tensor, ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 dist.barrier() torch.cuda.synchronize() + # Training metrics elapsed_time_per_iteration_ms = (self.training_step_time - self.iteration_start_time) * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) @@ -727,9 +726,24 @@ def train_step_logs( global_batch_size=self.global_batch_size, ) + # Validation metrics + if global_loss is not None: + validation_total_samples = self.current_validation_dataloader_lenght * self.micro_batch_size + validation_elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 + validation_tokens_per_sec = ( + validation_total_samples * self.sequence_length / (validation_elapsed_time_per_iteration_ms / 1000) + ) # tokens_per_sec is calculated using sequence_length + + validation_model_tflops, validation_hardware_tflops = self.unwrapped_model.get_flops_per_sec( + iteration_time_in_sec=validation_elapsed_time_per_iteration_ms / 1000, + sequence_length=self.sequence_length, + global_batch_size=validation_total_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches + ) + if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" + # Training metrics lr = self.lr_scheduler.get_last_lr()[0] log_entries = [ @@ -753,6 +767,44 @@ def train_step_logs( if self.config.optimizer.clip_grad is not None: log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f")) + # Validation metrics + if global_loss is not None: + log_entries = [ + LogItem( + "validation_consumed_tokens", + self.metadata.consumed_train_samples * self.config.tokens.sequence_length, + "human_format", + ), # , "12d"), + LogItem( + "validation_elapsed_time_per_iteration_ms", + validation_elapsed_time_per_iteration_ms, + "human_format", + ), # , ".1f"), + LogItem("validation_tokens_per_sec", validation_tokens_per_sec, "human_format"), # , "1.6E"), + LogItem( + "validation_tokens_per_sec_per_gpu", + validation_tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", + ), # , "1.6E"), + LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), + LogItem( + "validation_model_tflops_per_gpu", validation_model_tflops / 3, "human_format" + ), # , ".2f"), + LogItem( + "validation_hardware_tflops_per_gpu", validation_hardware_tflops / 3, "human_format" + ), # , ".2f"), + ] + + # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. + # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 + # GitHub complains: https://github.com/wandb/wandb/issues/3035 + log_entries.extend( + [ + LogItem(f"{lang}_validation_loss", loss.item(), "human_format") + for lang, loss in lang_losses.items() + ] + ) + # Log not too often the memory if self.iteration_step < 5 or (self.iteration_step - 1) % self.config.checkpoints.checkpoint_interval == 0: total, used, free = shutil.disk_usage("/") @@ -796,70 +848,6 @@ def train_step_logs( else: exit(0) - def validation_step_logs( - self, - global_loss: torch.Tensor, - lang_losses: torch.Tensor, - ) -> None: - # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 - dist.barrier() - torch.cuda.synchronize() - total_validation_samples = self.current_validation_dataloader_lenght * self.micro_batch_size - elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 - tokens_per_sec = ( - total_validation_samples * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) - ) # tokens_per_sec is calculated using sequence_length - # TODO para el valid ojo con cambiar global_batch_size = len dataloader * mbs - model_tflops, hardware_tflops = self.unwrapped_model.get_flops_per_sec( - iteration_time_in_sec=elapsed_time_per_iteration_ms / 1000, - sequence_length=self.sequence_length, - global_batch_size=total_validation_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches - ) - - if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: - assert self.loggerwriter is not None, "loggerwriter should be defined on logger ranks" - - log_entries = [ - LogItem( - "validation_consumed_tokens", - self.metadata.consumed_train_samples * self.config.tokens.sequence_length, - "human_format", - ), # , "12d"), - LogItem( - "validation_elapsed_time_per_iteration_ms", elapsed_time_per_iteration_ms, "human_format" - ), # , ".1f"), - LogItem("validation_tokens_per_sec", tokens_per_sec, "human_format"), # , "1.6E"), - LogItem( - "validation_tokens_per_sec_per_gpu", - tokens_per_sec / self.parallel_context.world_pg.size(), - "human_format", - ), # , "1.6E"), - LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), - LogItem("validation_model_tflops_per_gpu", model_tflops / 3, "human_format"), # , ".2f"), - LogItem("validation_hardware_tflops_per_gpu", hardware_tflops / 3, "human_format"), # , ".2f"), - ] - - # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. - # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 - # GitHub complains: https://github.com/wandb/wandb/issues/3035 - log_entries.extend( - [LogItem(f"{lang}_validation_loss", loss.item(), "human_format") for lang, loss in lang_losses.items()] - ) - - # NOTE: only one rank writes to wandb - # NOTE(tj.solergibert) By default wandb.log performs a step in the x-axis every time. - # Set commit=False to log values with the next wandb.log with the training logs - if dist.get_rank(self.parallel_context.world_pg) == self.logger_ranks[0] and wandb is not None: - wandb.log( - { - **{log_item.tag: log_item.scalar_value for log_item in log_entries}, - "iteration_step": self.iteration_step, - }, - commit=False, - ) - - self.loggerwriter.add_scalars_from_list(log_entries, self.iteration_step) - def init_model(self) -> Union[NanotronModel, DistributedDataParallel]: """Initialize the model and load weights from checkpoint if needed.""" # TODO: add max_position_embeddings From d75038dad4ba8786344f06725b24b213059f9b97 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 19:21:22 +0000 Subject: [PATCH 19/24] last fixes --- src/nanotron/config/config.py | 7 +++++ src/nanotron/trainer.py | 56 +++++++++++++++++------------------ 2 files changed, 35 insertions(+), 28 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index e5ea3ec1..80229ca2 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -406,6 +406,13 @@ def __post_init__(self): for i in range(len(self.data_stages) - 1) ), "The stages are not sorted by start_training_step in increasing order" + # NOTE(tj.solergibert) As we are reporting the training & validation metrics together, we + # must comply with val_check_interval % iteration_step_info_interval = 0 + if not self.tokens.val_check_interval % self.logging.iteration_step_info_interval == 0: + raise ValueError( + f"It is necessary to run the validation stage during a logging step. Validation interval: {self.tokens.val_check_interval}, Logging interval: {self.logging.iteration_step_info_interval}" + ) + # # if lighteval, we need tokenizer to be defined # if self.checkpoints.lighteval is not None: # assert self.tokenizer.tokenizer_name_or_path is not None diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index c327f508..c6ca734c 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -562,9 +562,7 @@ def train( ].consumed_train_samples += self.global_batch_size if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs( - outputs=outputs, loss_avg=loss_avg, global_loss=val_global_loss, lang_losses=val_lang_losses - ) + self.train_step_logs(loss_avg=loss_avg, global_loss=val_global_loss, lang_losses=val_lang_losses) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -769,31 +767,33 @@ def train_step_logs( # Validation metrics if global_loss is not None: - log_entries = [ - LogItem( - "validation_consumed_tokens", - self.metadata.consumed_train_samples * self.config.tokens.sequence_length, - "human_format", - ), # , "12d"), - LogItem( - "validation_elapsed_time_per_iteration_ms", - validation_elapsed_time_per_iteration_ms, - "human_format", - ), # , ".1f"), - LogItem("validation_tokens_per_sec", validation_tokens_per_sec, "human_format"), # , "1.6E"), - LogItem( - "validation_tokens_per_sec_per_gpu", - validation_tokens_per_sec / self.parallel_context.world_pg.size(), - "human_format", - ), # , "1.6E"), - LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), - LogItem( - "validation_model_tflops_per_gpu", validation_model_tflops / 3, "human_format" - ), # , ".2f"), - LogItem( - "validation_hardware_tflops_per_gpu", validation_hardware_tflops / 3, "human_format" - ), # , ".2f"), - ] + log_entries.extend( + [ + LogItem( + "validation_consumed_tokens", + validation_total_samples * self.sequence_length, + "human_format", + ), # , "12d"), + LogItem( + "validation_elapsed_time_per_iteration_ms", + validation_elapsed_time_per_iteration_ms, + "human_format", + ), # , ".1f"), + LogItem("validation_tokens_per_sec", validation_tokens_per_sec, "human_format"), # , "1.6E"), + LogItem( + "validation_tokens_per_sec_per_gpu", + validation_tokens_per_sec / self.parallel_context.world_pg.size(), + "human_format", + ), # , "1.6E"), + LogItem("validation_loss", global_loss.item(), "human_format"), # , "1.6E"), + LogItem( + "validation_model_tflops_per_gpu", validation_model_tflops / 3, "human_format" + ), # , ".2f"), # NOTE(tj.solergibert) Check llama.py --> def get_flops() --> model_flops for explanation of the / 3 factor + LogItem( + "validation_hardware_tflops_per_gpu", validation_hardware_tflops / 3, "human_format" + ), # , ".2f"), # NOTE(tj.solergibert) Check llama.py --> def get_flops() --> model_flops for explanation of the / 3 factor + ] + ) # NOTE Currently you have to log each lang metric one by one and then merge them manually in the same plot through the wandb UI. # Example: https://community.wandb.ai/t/log-multiple-variables-at-the-same-plot/2474 From ab1dd835ba34d2bc5651ff78ebc60bdf050164aa Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 19:26:42 +0000 Subject: [PATCH 20/24] Fixed tokenizer config --- examples/config_multilingual_nanoset.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 5573a224..596e5e32 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -124,7 +124,7 @@ parallelism: profiler: null tokenizer: tokenizer_max_length: null - tokenizer_name_or_path: /mloscratch/homes/solergib/models/Meta-Llama-3-8B + tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B tokenizer_revision: null tokens: batch_accumulation_per_replica: 1 From 2d911544cb915c2cccc264971e3f4ec285ebd27e Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 24 Jul 2024 19:29:51 +0000 Subject: [PATCH 21/24] deleted comments --- src/nanotron/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index c6ca734c..62fe6bcc 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -730,12 +730,12 @@ def train_step_logs( validation_elapsed_time_per_iteration_ms = (self.validation_step_time - self.training_step_time) * 1000 validation_tokens_per_sec = ( validation_total_samples * self.sequence_length / (validation_elapsed_time_per_iteration_ms / 1000) - ) # tokens_per_sec is calculated using sequence_length + ) validation_model_tflops, validation_hardware_tflops = self.unwrapped_model.get_flops_per_sec( iteration_time_in_sec=validation_elapsed_time_per_iteration_ms / 1000, sequence_length=self.sequence_length, - global_batch_size=validation_total_samples, # TODO con esto de la global batch size yo la pondria a 1 y multiplicaba por el numero de batches + global_batch_size=validation_total_samples, ) if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: From ce068fd5a9d1d29805dedd9e3493fafd883ab847 Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 7 Aug 2024 19:44:23 +0000 Subject: [PATCH 22/24] Last fixes --- examples/config_multilingual_nanoset.yaml | 39 +++++----- run_train.py | 4 +- src/nanotron/config/config.py | 11 ++- src/nanotron/data/collator.py | 73 +++++++++++++++++++ src/nanotron/data/dataloader_builder.py | 11 ++- src/nanotron/data/multilingual_nanoset.py | 4 +- src/nanotron/distributed.py | 4 - src/nanotron/models/llama.py | 10 ++- .../parallel/pipeline_parallel/engine.py | 6 +- src/nanotron/trainer.py | 50 +++++++++---- 10 files changed, 156 insertions(+), 56 deletions(-) diff --git a/examples/config_multilingual_nanoset.yaml b/examples/config_multilingual_nanoset.yaml index 596e5e32..cc66cd70 100644 --- a/examples/config_multilingual_nanoset.yaml +++ b/examples/config_multilingual_nanoset.yaml @@ -8,17 +8,17 @@ data_stages: - data: dataset: training_folder: - datasets/c4-es/train: 0.85 - datasets/c4-en/train: 0.05 - datasets/c4-fr/train: 0.1 + - datasets/c4-es/train + - datasets/c4-en/train + - datasets/c4-fr/train validation_folder: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - lang_to_ids: - es: 128002 - en: 128003 - fr: 128004 + languages: + - es + - en + - fr num_loading_workers: 1 seed: 42 name: General purpose training (Blended dataset) @@ -29,12 +29,12 @@ data_stages: - datasets/c4-es/train validation_folder: - datasets/c4-es/validation - lang_to_ids: - es: 128002 + languages: + - es num_loading_workers: 1 seed: 42 name: Second purpose training (Single dataset) - start_training_step: 100 + start_training_step: 1000 - data: dataset: training_folder: @@ -45,20 +45,19 @@ data_stages: - datasets/c4-es/validation - datasets/c4-en/validation - datasets/c4-fr/validation - lang_to_ids: - es: 128002 - en: 128003 - fr: 128004 - + languages: + - es + - en + - fr num_loading_workers: 1 seed: 42 name: Third purpose training (>1 dataset) - start_training_step: 200 + start_training_step: 2000 general: benchmark_csv_path: null consumed_train_samples: null ignore_sanity_checks: true - project: Multilingual + project: MultilingualV2 run: llama seed: 42 step: null @@ -114,7 +113,7 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 1 + dp: 2 expert_parallel_size: 1 pp: 1 pp_engine: 1f1b @@ -132,5 +131,5 @@ tokens: limit_val_batches: 10 micro_batch_size: 3 sequence_length: 4096 - train_steps: 800 - val_check_interval: 50 + train_steps: 500 + val_check_interval: 100 diff --git a/run_train.py b/run_train.py index a51caf59..809d8d41 100644 --- a/run_train.py +++ b/run_train.py @@ -194,7 +194,6 @@ def get_dataloader_from_data_stage( sequence_length=trainer.sequence_length, token_size=token_size, train_split_num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, - dataset_tokens=data.dataset.dataset_tokens, random_seed=data.seed, ) @@ -209,6 +208,7 @@ def get_dataloader_from_data_stage( consumed_train_samples=consumed_train_samples, dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, + is_multilingual=True, ) return train_dataloader @@ -241,7 +241,6 @@ def get_valid_dataloader_from_data_stage( dataset_folders=data.dataset.validation_folder, sequence_length=trainer.sequence_length, token_size=token_size, - dataset_tokens=data.dataset.dataset_tokens, is_valid=True, random_seed=data.seed, ) @@ -257,6 +256,7 @@ def get_valid_dataloader_from_data_stage( dataloader_num_workers=data.num_loading_workers, dataloader_drop_last=True, shuffle=True, + is_multilingual=True, ) return valid_dataloader diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 80229ca2..b3c755a5 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,7 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Union[str, List[str]] - lang_to_ids: dict # Mapping from the previously defined folders to tokens. Respect the order + languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -125,14 +125,13 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - self.ids_to_lang = {v: k for k, v in self.lang_to_ids.items()} - self.dataset_tokens = list(self.lang_to_ids.values()) + assert len(self.training_folder) == len( + self.languages + ), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})" + assert len(self.training_folder) == len( self.validation_folder ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" - assert len(self.training_folder) == len( - self.dataset_tokens - ), f"The sizes of training_folder and lang_to_ids mismatch ({len(self.training_folder)} vs {len(self.dataset_tokens)})" @dataclass diff --git a/src/nanotron/data/collator.py b/src/nanotron/data/collator.py index 199527e1..fd217b1a 100644 --- a/src/nanotron/data/collator.py +++ b/src/nanotron/data/collator.py @@ -78,3 +78,76 @@ def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Uni ) return result + + +@dataclasses.dataclass +class MultilingualNanosetDataCollatorForCLM: + """ + Data collator used for causal language modeling with Nanosets dataset. + + - input_pp_rank: Discards last input id token + - output_pp_rank: Discards first label id token + - other pp ranks: Don't have data. Instead, we use `TensorPointer` to point to the rank having the data. + """ + + sequence_length: int + input_pp_rank: int + output_pp_rank: int + parallel_context: ParallelContext + + def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # Process the case when current rank doesn't require data. We return `TensorPointer` that points to ranks having the data. + current_pp_rank = dist.get_rank(self.parallel_context.pp_pg) + if current_pp_rank not in [ + self.input_pp_rank, + self.output_pp_rank, + ]: + assert all(len(example) == 0 for example in examples) + return { + "input_ids": TensorPointer(group_rank=self.input_pp_rank), + "input_mask": TensorPointer(group_rank=self.input_pp_rank), + "lang_code": TensorPointer(group_rank=self.input_pp_rank), + "label_ids": TensorPointer(group_rank=self.output_pp_rank), + "label_mask": TensorPointer(group_rank=self.output_pp_rank), + } + + # TODO @nouamanetazi: Is it better to have examples as np.array or torch.Tensor? + input_ids = torch.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s) + lang_code = torch.vstack([examples[i]["lang_code"] for i in range(len(examples))]) # (b, 1) + batch_size, expanded_input_length = input_ids.shape + + result: Dict[str, Union[torch.LongTensor, TensorPointer]] = {} + + result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank) + result["input_mask"] = TensorPointer(group_rank=self.input_pp_rank) + result["lang_code"] = TensorPointer(group_rank=self.input_pp_rank) + result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank) + result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank) + + assert ( + expanded_input_length == self.sequence_length + 1 + ), f"Samples should be of length {self.sequence_length + 1} (seq_len+1), but got {expanded_input_length}" + + # Process inputs: last token is the label + if current_pp_rank == self.input_pp_rank: + result["input_ids"] = input_ids[:, :-1] + result["input_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + result["lang_code"] = lang_code + + # Process labels: shift them to the left + if current_pp_rank == self.output_pp_rank: + result["label_ids"] = input_ids[:, 1:] + result["label_mask"] = torch.ones((batch_size, self.sequence_length), dtype=torch.bool) + + if isinstance(result["input_ids"], torch.Tensor) and result["input_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['input_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + if isinstance(result["label_ids"], torch.Tensor) and result["label_ids"].shape[-1] != self.sequence_length: + raise ValueError( + f"`labels` are incorrectly preprocessed. `labels` length is {result['label_ids'].shape[-1]}, but should be" + f" {self.sequence_length}." + ) + + return result diff --git a/src/nanotron/data/dataloader_builder.py b/src/nanotron/data/dataloader_builder.py index b8bfb303..f9480029 100644 --- a/src/nanotron/data/dataloader_builder.py +++ b/src/nanotron/data/dataloader_builder.py @@ -1,6 +1,6 @@ import nanotron.distributed as dist from nanotron import logging -from nanotron.data.collator import NanosetDataCollatorForCLM +from nanotron.data.collator import MultilingualNanosetDataCollatorForCLM, NanosetDataCollatorForCLM from nanotron.dataloader import ( EmptyInfiniteDataset, get_dataloader_worker_init, @@ -20,6 +20,7 @@ def build_nanoset_dataloader( output_pp_rank: int, micro_batch_size: int, dataloader_num_workers: int, + is_multilingual: bool = False, consumed_train_samples: int = 0, dataloader_drop_last: bool = True, dataloader_pin_memory: bool = True, @@ -40,6 +41,14 @@ def build_nanoset_dataloader( parallel_context=parallel_context, ) + if is_multilingual: + data_collator = MultilingualNanosetDataCollatorForCLM( + sequence_length=sequence_length, + input_pp_rank=input_pp_rank, + output_pp_rank=output_pp_rank, + parallel_context=parallel_context, + ) + # Compute size and rank of dataloader workers dp_ranks_size = parallel_context.dp_pg.size() dp_rank = parallel_context.dp_pg.rank() diff --git a/src/nanotron/data/multilingual_nanoset.py b/src/nanotron/data/multilingual_nanoset.py index 7af57448..8eec5549 100644 --- a/src/nanotron/data/multilingual_nanoset.py +++ b/src/nanotron/data/multilingual_nanoset.py @@ -30,7 +30,6 @@ def __init__( dataset_folders: List[str], sequence_length: int, token_size: int, - dataset_tokens: List[int], train_split_num_samples: int = None, is_valid: bool = False, dataset_weights: Union[List[float], None] = None, @@ -47,7 +46,6 @@ def __init__( self.sequence_length = sequence_length self.token_size = token_size self.train_split_num_samples = train_split_num_samples - self.dataset_tokens = dataset_tokens self.is_valid = is_valid self.random_seed = random_seed self.datatrove_datasets = [] @@ -107,7 +105,7 @@ def __getitem__(self, idx: int) -> Dict[str, np.ndarray]: dataset_sample = self.dataset_sample_index[idx] tokens = self.datatrove_datasets[dataset][dataset_sample] - tokens["input_ids"][0] = self.dataset_tokens[dataset] # Prepend language token + tokens["lang_code"] = torch.tensor(dataset, dtype=torch.long) return tokens diff --git a/src/nanotron/distributed.py b/src/nanotron/distributed.py index 0156b1bb..0bc54f3e 100644 --- a/src/nanotron/distributed.py +++ b/src/nanotron/distributed.py @@ -52,10 +52,6 @@ def all_gather_into_tensor( # pylint: disable=function-redefined if group is None: group = dist.torch_dist.distributed_c10d._get_default_group() - assert ( - group.size() > 1 - ), "You should probably not call `all_gather_into_tensor` with a single rank, as it copies data over" - if torch_version_above_1_13: return dist.all_gather_into_tensor( output_tensor=output_tensor, input_tensor=input_tensor, group=group, async_op=async_op diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 8c4125b7..ec1b38c0 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -733,14 +733,20 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): - return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask)[0] + return self.forward_with_hidden_states(input_ids=input_ids, input_mask=input_mask, lang_code=lang_code)[0] def forward_with_hidden_states( self, input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] input_mask: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + lang_code: Union[torch.Tensor, TensorPointer], # [batch_size, 1] ): + # NOTE(tj.solergibert) I bring `lang_code` till the forward of LlamaModel. Remember that + # to use it in the different pipeline blocks you need to also set the module_input_keys & module_output_keys + # of the necessary `PipelineBlock`'s defined in the LlamaModel init! + # all tensors are optional as most ranks don't need anything from the dataloader. output = self.token_position_embeddings(input_ids=input_ids, input_mask=input_mask) @@ -863,12 +869,14 @@ def forward( self, input_ids: Union[torch.Tensor, TensorPointer], input_mask: Union[torch.Tensor, TensorPointer], + lang_code: Union[torch.Tensor, TensorPointer], label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: sharded_logits = self.model( input_ids=input_ids, input_mask=input_mask, + lang_code=lang_code, ) outputs = self.loss( sharded_logits=sharded_logits, diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 549ef5eb..9b548e35 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -140,7 +140,7 @@ def validate_batch_iter( self.nb_microbatches = nb_microbatches outputs = [] - lang_ids = [] + lang_codes = [] with attach_pipeline_state_to_model(model=model, pipeline_state=state): # All forward @@ -166,9 +166,9 @@ def validate_batch_iter( outputs.extend( list(output["sample_loss"]) ) # NOTE(tj.solergibert) Yes, it might look useless to do list + extend but it's necessary to split the output["sample_loss"] tensor into multiple tensors - lang_ids.extend(micro_batch["input_ids"][:, 0].tolist()) + lang_codes.extend(micro_batch["lang_code"].flatten().tolist()) - return outputs, lang_ids + return outputs, lang_codes class AllForwardAllBackwardPipelineEngine(PipelineEngine): diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 62fe6bcc..a17f9849 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -80,6 +80,7 @@ from nanotron.sanity_checks import ( after_optim_step_sanity_checks, after_tbi_sanity_checks, + assert_tensor_synced_across_pg, before_optim_step_sanity_checks, before_tbi_sanity_checks, ) @@ -667,37 +668,54 @@ def training_step( return outputs, loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: - outputs, lang_ids = self.pipeline_engine.validate_batch_iter( + outputs, lang_codes = self.pipeline_engine.validate_batch_iter( model=self.model, batch=(next(dataloader) for _ in range(self.current_validation_dataloader_lenght)), nb_microbatches=self.current_validation_dataloader_lenght, ) lang_losses = { - lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.lang_to_ids.keys() + lang: [] for lang in self.config.data_stages[self.metadata.last_stage_idx].data.dataset.languages } - # WARNING(tj.solergibert) This mechanism will fail in the following [corner] case: - # If the lang_losses dict for a given lang IS EMPTY aka in the validation step in a Data Parallel Group - # we have 0 SAMPLES of a given lang, lang_losses[lang] will be a empty python list so the toch.stack call - # will fail with "stack expects a non-empty TensorList". I've tried setting this lang_losses[lang] to torch.empty - # but of course it doesn't works as we then do the average across the DP group. - # We will fix this issue in the future if we encounter this problem again. - # A bit of inspo https://blog.speechmatics.com/Sparse-All-Reduce-Part-1 + lang_losses_list = list(lang_losses.keys()) # Compute losses if isinstance(outputs[0], torch.Tensor): # Multilingual losses - for loss, lang_id in zip(outputs, lang_ids): - lang_losses[ - self.config.data_stages[self.metadata.last_stage_idx].data.dataset.ids_to_lang[lang_id] - ].append(loss) + for loss, lang_code in zip(outputs, lang_codes): + lang_losses[lang_losses_list[lang_code]].append(loss) # Global loss global_loss_avg = torch.mean(torch.stack(outputs)) - # Sync losses across DP + # Sync multilingual losses across DP for lang in lang_losses.keys(): - lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) - dist.all_reduce(lang_losses[lang], group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) + if not lang_losses[ + lang + ]: # If the list is empty --> Set local language loss to -1 to exclude it from the global computation + lang_losses[lang] = torch.tensor(-1, dtype=torch.float32) + else: # If we have at least 1 loss from a given language --> compute local language loss mean + lang_losses[lang] = torch.mean(torch.stack(lang_losses[lang])) + + # NOTE(tj.solergibert) We create a (DP SIZE, LANGS) tensor to aggregate ALL local losses across DP groups. + # Then we compute the mean of each lang in each and every rank and finally copy back the result to the + # `lang_losses` dict for logging + lang_losses_tensor_out = torch.zeros( + (self.parallel_context.dp_pg.size(), len(lang_losses.keys())), dtype=torch.float, device="cuda" + ) # (DP SIZE, LANGS) + lang_losses_tensor_local = torch.stack(list(lang_losses.values())).unsqueeze(0) # (1, LANGS) + dist.all_gather_into_tensor(lang_losses_tensor_out, lang_losses_tensor_local, self.parallel_context.dp_pg) + mask = lang_losses_tensor_out != -1 + lang_losses_tensor_local = (lang_losses_tensor_out * mask).sum(dim=0) / mask.sum(dim=0) # (1, LANGS) + for idx, lang in enumerate(lang_losses.keys()): + lang_losses[lang] = lang_losses_tensor_local[idx] + + # Sync global losses across DP dist.all_reduce(global_loss_avg, group=self.parallel_context.dp_pg, op=dist.ReduceOp.AVG) + + # TODO(tj.solergibert) Delete this testing assertions + for lang in lang_losses.keys(): + assert_tensor_synced_across_pg(tensor=lang_losses[lang], pg=self.parallel_context.dp_pg) + assert_tensor_synced_across_pg(tensor=global_loss_avg, pg=self.parallel_context.dp_pg) + else: global_loss_avg = None lang_losses = None From 8e6f8ab2f2876d97e10aadbf4120c99cf1cade5f Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Tue, 27 Aug 2024 17:32:53 +0000 Subject: [PATCH 23/24] Optional validation --- run_train.py | 6 +++++- src/nanotron/config/config.py | 25 +++++++++++++++++-------- src/nanotron/trainer.py | 7 +++++-- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/run_train.py b/run_train.py index 809d8d41..f39d1d58 100644 --- a/run_train.py +++ b/run_train.py @@ -354,7 +354,11 @@ def get_args(): # Load trainer and data trainer = DistributedTrainer(config_file) train_dataloader = get_dataloader(trainer) - valid_dataloader = get_valid_dataloader(trainer) + + # NOTE(tj.solergibert) Build validation dataloaders only if necessary + valid_dataloader = None + if trainer.config.tokens.val_check_interval != -1: + valid_dataloader = get_valid_dataloader(trainer) # Train trainer.train(train_dataloader, valid_dataloader) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index b3c755a5..fd95d27a 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -110,13 +110,13 @@ def __post_init__(self): @dataclass class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] - validation_folder: Union[str, List[str]] - languages: List[str] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB + validation_folder: Optional[Union[str, List[str]]] + languages: Optional[List[str]] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder self.training_folder = [self.training_folder] - self.validation_folder = [self.validation_folder] + self.validation_folder = [self.validation_folder] if self.validation_folder is not None else None self.dataset_weights = [1] elif isinstance(self.training_folder, List): # Case 2: > 1 Dataset folder self.dataset_weights = None # Set to None so we consume all the samples randomly @@ -125,20 +125,23 @@ def __post_init__(self): self.training_folder = list(tmp_training_folder.keys()) self.dataset_weights = list(tmp_training_folder.values()) - assert len(self.training_folder) == len( - self.languages + assert ( + len(self.training_folder) == len(self.languages) if self.languages else True ), f"The sizes of training_folder and languages mismatch ({len(self.training_folder)} vs {len(self.languages)})" - assert len(self.training_folder) == len( - self.validation_folder + assert ( + len(self.training_folder) == len(self.validation_folder) if self.validation_folder else True ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" + if not self.languages and self.validation_folder: + raise ValueError(f"You must specify languages to perform the validation step w/ {self.validation_folder}") + @dataclass class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[PretrainDatasetsArgs, NanosetDatasetsArgs, MultilingualNanosetDatasetsArgs] + dataset: Union[MultilingualNanosetDatasetsArgs] seed: Optional[int] num_loading_workers: Optional[int] = 1 @@ -416,6 +419,12 @@ def __post_init__(self): # if self.checkpoints.lighteval is not None: # assert self.tokenizer.tokenizer_name_or_path is not None + if not self.data_stages[0].data.dataset.validation_folder: + # NOTE(tj.solergibert) We use print NOT log_rank because at this moment the process group is not + # initialized + print("No validation data provided, skipping validation step") + self.tokens.val_check_interval = -1 + @property def global_batch_size(self): return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index a17f9849..a5fb819c 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -507,7 +507,7 @@ def train( ], valid_dataloader_or_dls: Dict[ str, Union[Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]], Tuple[Iterator, ...]] - ], + ] = None, **kwargs, ) -> None: self.pre_training(**kwargs) @@ -543,7 +543,10 @@ def train( self.training_step_time = time.time() # Validation stage - if self.iteration_step % self.config.tokens.val_check_interval == 0: + if ( + self.iteration_step % self.config.tokens.val_check_interval == 0 + and self.config.tokens.val_check_interval != -1 + ): self._prepare_dataloader_for_validation_stage(valid_dataloader_or_dls) val_global_loss, val_lang_losses = self.validation_step( dataloader=self.current_validation_dataloader From 1969526c0e9ea569acce5059abcc2a9e83deaafe Mon Sep 17 00:00:00 2001 From: tj-solergibert Date: Wed, 4 Sep 2024 16:56:46 +0000 Subject: [PATCH 24/24] Fix eval check --- src/nanotron/config/config.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index fd95d27a..d359dabe 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -111,7 +111,9 @@ def __post_init__(self): class MultilingualNanosetDatasetsArgs: training_folder: Union[str, dict, List[str]] validation_folder: Optional[Union[str, List[str]]] - languages: Optional[List[str]] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Reporting to WANDB + languages: List[ + str + ] # NOTE(tj.solergibert) Required for 1. Aggregating the result 2. Embed lang information into the model 3. Reporting to WANDB def __post_init__(self): if isinstance(self.training_folder, str): # Case 1: 1 Dataset folder @@ -133,15 +135,15 @@ def __post_init__(self): len(self.training_folder) == len(self.validation_folder) if self.validation_folder else True ), f"The sizes of training_folder and validation_folder mismatch ({len(self.training_folder)} vs {len(self.validation_folder)})" - if not self.languages and self.validation_folder: - raise ValueError(f"You must specify languages to perform the validation step w/ {self.validation_folder}") + if not self.languages: + raise ValueError("You must specify the languages of each dataset") @dataclass class DataArgs: """Arguments related to the data and data files processing""" - dataset: Union[MultilingualNanosetDatasetsArgs] + dataset: Union[PretrainDatasetsArgs, MultilingualNanosetDatasetsArgs] seed: Optional[int] num_loading_workers: Optional[int] = 1