Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optional validation loop #6

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
7aa3940
XGLM work in progress: Causal Attention and Positional Embeddings work
AleHD Jun 26, 2024
78dd53c
WIP: GPT arch almost done, hf->nt converters working perfectly for no…
AleHD Jun 26, 2024
a74c71a
Added hf2nt frontend + tested training
AleHD Jul 9, 2024
04eaef9
Added nt2hf conversion + tests :)
AleHD Jul 11, 2024
138da5f
precommit
AleHD Jul 11, 2024
35c43f7
Merge pull request #1 from swiss-ai/gpt
negar-foroutan Jul 15, 2024
0485fd6
Added MultilingualNanoset Config
TJ-Solergibert Jul 16, 2024
539832a
Added MultilingualNanoset
TJ-Solergibert Jul 16, 2024
d9f0670
Added Language token
TJ-Solergibert Jul 16, 2024
efe8720
Forgot the trainer ups
TJ-Solergibert Jul 16, 2024
25ad39b
Fix minor errors. Everything works
TJ-Solergibert Jul 16, 2024
d91f9e1
Updated config file with GPT2 tokenized datasets in RCP
TJ-Solergibert Jul 16, 2024
d0c14e3
Before lunch
TJ-Solergibert Jul 17, 2024
9cfc5ea
After lunch
TJ-Solergibert Jul 17, 2024
eed7bce
Ready
TJ-Solergibert Jul 18, 2024
27133e1
Just in case
TJ-Solergibert Jul 22, 2024
5c09e11
just in case
TJ-Solergibert Jul 23, 2024
94d6c2a
This looks good
TJ-Solergibert Jul 24, 2024
5cccf16
This looks better
TJ-Solergibert Jul 24, 2024
d75038d
last fixes
TJ-Solergibert Jul 24, 2024
ab1dd83
Fixed tokenizer config
TJ-Solergibert Jul 24, 2024
2d91154
deleted comments
TJ-Solergibert Jul 24, 2024
ce068fd
Last fixes
TJ-Solergibert Aug 7, 2024
8e6f8ab
Optional validation
TJ-Solergibert Aug 27, 2024
1969526
Fix eval check
TJ-Solergibert Sep 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 135 additions & 0 deletions examples/config_multilingual_nanoset.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
checkpoints:
checkpoint_interval: 1000000
checkpoints_path: checkpoints/
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
save_initial_state: false
data_stages:
- data:
dataset:
training_folder:
- datasets/c4-es/train
- datasets/c4-en/train
- datasets/c4-fr/train
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
languages:
- es
- en
- fr
num_loading_workers: 1
seed: 42
name: General purpose training (Blended dataset)
start_training_step: 1
- data:
dataset:
training_folder:
- datasets/c4-es/train
validation_folder:
- datasets/c4-es/validation
languages:
- es
num_loading_workers: 1
seed: 42
name: Second purpose training (Single dataset)
start_training_step: 1000
- data:
dataset:
training_folder:
- datasets/c4-es/train
- datasets/c4-en/train
- datasets/c4-fr/train
validation_folder:
- datasets/c4-es/validation
- datasets/c4-en/validation
- datasets/c4-fr/validation
languages:
- es
- en
- fr
num_loading_workers: 1
seed: 42
name: Third purpose training (>1 dataset)
start_training_step: 2000
general:
benchmark_csv_path: null
consumed_train_samples: null
ignore_sanity_checks: true
project: MultilingualV2
run: llama
seed: 42
step: null
lighteval: null
logging:
iteration_step_info_interval: 1
log_level: info
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 4096
initializer_range: 0.02
intermediate_size: 14336
is_llama_config: true
max_position_embeddings: 4096
num_hidden_layers: 32
num_attention_heads: 32
num_key_value_heads: 8
pad_token_id: null
pretraining_tp: 1
rope_interleaved: false
rope_theta: 500000.0
rms_norm_eps: 1.0e-06
rope_scaling: null
tie_word_embeddings: false
use_cache: true
vocab_size: 128256
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
learning_rate_scheduler:
learning_rate: 0.0003
lr_decay_starting_step: null
lr_decay_steps: 98
lr_decay_style: cosine
lr_warmup_steps: 2
lr_warmup_style: linear
min_decay_lr: 1.0e-05
optimizer_factory:
adam_beta1: 0.9
adam_beta2: 0.95
adam_eps: 1.0e-08
name: adamW
torch_adam_is_fused: true
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
tp: 4
tp_linear_async_communication: false
tp_mode: REDUCE_SCATTER
profiler: null
tokenizer:
tokenizer_max_length: null
tokenizer_name_or_path: meta-llama/Meta-Llama-3-8B
tokenizer_revision: null
tokens:
batch_accumulation_per_replica: 1
limit_test_batches: 0
limit_val_batches: 10
micro_batch_size: 3
sequence_length: 4096
train_steps: 500
val_check_interval: 100
18 changes: 18 additions & 0 deletions examples/xglm/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# How to use XGLM?

1. First, make sure to convert the weights from huggingface, for instance:
```
torchrun --nproc-per-node=1 examples/xglm/convert_hf2nt.py --checkpoint-path=facebook/xglm-564M --save-path=$SCRATCH/checkpoints/xglm-564M
```

1. Now you are ready to use XGLM.
Make sure you use a .yaml configuration with proper GPT3 config and then run for instance:
```
torchrun --nproc-per-node=4 run_train.py --config-file=examples/xglm/example_config.yaml
```
If you use this configuration file make sure to modify at least the loading path in `model.init_method.path`.

1. If you want to convert your finetuned checkpoint back to huggingface use:
```
torchrun --nproc-per-node=1 examples/xglm/convert_nt2hf.py --checkpoint-path=checpoints/xglm --save-path=$SCRATCH/checkpoints/huggingface/xglm-564M --tokenizer-name=facebook/xglm-564M
```
Empty file added examples/xglm/__init__.py
Empty file.
131 changes: 131 additions & 0 deletions examples/xglm/convert_hf2nt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""
Converts a HF model to nanotron format
Command:
torchrun --nproc-per-node=1 convert_hf2nt.py --checkpoint-path=hf_weights --save-path=nanotron_weights
"""

import dataclasses
import json
import warnings
from argparse import ArgumentParser
from pathlib import Path

import nanotron
import torch
from nanotron.config.models_config import GPT3Config
from nanotron.models.gpt3 import MLP, CausalSelfAttention, GPT3ForTraining, GPTBlock
from transformers.models.xglm.modeling_xglm import XGLMAttention, XGLMConfig, XGLMDecoderLayer, XGLMForCausalLM

from examples.xglm.convert_utils import convert_generic, create_nt_model


def convert_config(config: XGLMConfig) -> GPT3Config:
# These settings seem to be unused:
# layerdrop=0.0,
# init_std=0.02,
# use_cache=True,
# pad_token_id=1,
# bos_token_id=0,
if config.dropout != config.attention_dropout:
warnings.warn(
f"huggingface.dropout = {config.dropout} does not match with "
f"huggingface.attention_dropout = {config.attention_dropout}. "
"Nanotron implementation needs these two values to be equal "
"for correct conversion."
)
return GPT3Config(
activation_function=config.activation_function,
attn_pdrop=config.attention_dropout,
embd_pdrop=config.dropout,
eos_token_id=config.eos_token_id,
hidden_size=config.d_model,
intermediate_size=config.ffn_dim,
layer_norm_epsilon=1e-05,
max_position_embeddings=config.max_position_embeddings,
num_attention_heads=config.attention_heads,
num_hidden_layers=config.num_layers,
resid_pdrop=config.dropout,
scale_attention_softmax_in_fp32=True,
scale_attn_weights=True,
vocab_size=config.vocab_size,
sinusoidal_position_embedding=True,
position_embedding_offset=config.decoder_start_token_id,
use_spda=False,
act_pdrop=config.activation_dropout,
scale_embedding=config.scale_embedding,
)


def convert_attention(attn_nt: CausalSelfAttention, attn_hf: XGLMAttention):
q_ws = torch.chunk(attn_hf.q_proj.weight, attn_hf.num_heads)
k_ws = torch.chunk(attn_hf.k_proj.weight, attn_hf.num_heads)
v_ws = torch.chunk(attn_hf.v_proj.weight, attn_hf.num_heads)

q_bs = torch.chunk(attn_hf.q_proj.bias, attn_hf.num_heads)
k_bs = torch.chunk(attn_hf.k_proj.bias, attn_hf.num_heads)
v_bs = torch.chunk(attn_hf.v_proj.bias, attn_hf.num_heads)

qkv_w = []
qkv_b = []
for q_w, k_w, v_w, q_b, k_b, v_b in zip(q_ws, k_ws, v_ws, q_bs, k_bs, v_bs):
qkv_w += [q_w, k_w, v_w]
qkv_b += [q_b, k_b, v_b]
qkv_w = torch.cat(qkv_w)
qkv_b = torch.cat(qkv_b)

with torch.no_grad():
attn_nt.query_key_value.weight.data = qkv_w.clone()
attn_nt.query_key_value.bias.data = qkv_b.clone()
attn_nt.dense.weight.data = attn_hf.out_proj.weight.clone()
attn_nt.dense.bias.data = attn_hf.out_proj.bias.clone()


def convert_mlp(mlp_nt: MLP, block_hf: XGLMDecoderLayer):
convert_generic(mlp_nt.c_fc, block_hf.fc1)
convert_generic(mlp_nt.c_proj, block_hf.fc2)


def convert_decoder(block_nt: GPTBlock, block_hf: XGLMDecoderLayer):
convert_generic(block_nt.ln_1, block_hf.self_attn_layer_norm)
convert_attention(block_nt.attn, block_hf.self_attn)
convert_generic(block_nt.ln_2, block_hf.final_layer_norm)
convert_mlp(block_nt.ff, block_hf)


def convert(model_nt: GPT3ForTraining, model_hf: XGLMForCausalLM):
convert_generic(model_nt.model.token_embeddings.pp_block.token_embedding, model_hf.model.embed_tokens)
for layer_nt, layer_hf in zip(model_nt.model.decoder, model_hf.model.layers):
convert_decoder(layer_nt.pp_block, layer_hf)
convert_generic(model_nt.model.final_layer_norm.pp_block, model_hf.model.layer_norm)
convert_generic(model_nt.model.lm_head.pp_block, model_hf.lm_head)


def main(hf_path: str, save_path: Path):
# Load hf.
print("Loading hf...")
model_hf = XGLMForCausalLM.from_pretrained(hf_path)

# Init nanotron.
print("Initializing nt...")
config_nt = convert_config(model_hf.config)
model_nt = create_nt_model(config_nt)

# Copy weights and save model.
print("Copying weights...")
convert(model_nt, model_hf)
nanotron.serialize.save_weights(model=model_nt, parallel_context=model_nt.parallel_context, root_folder=save_path)
with open(save_path / "model_config.json", "w+") as f:
json.dump(dataclasses.asdict(config_nt), f)
print(f"Model saved to {save_path}")


if __name__ == "__main__":
parser = ArgumentParser(description="Convert HF weights to nanotron format")
parser.add_argument(
"--checkpoint-path", default="facebook/xglm-7.5B", help="Name or path to the huggingface checkpoint"
)
parser.add_argument(
"--save-path", type=Path, default="checkpoints/xglm-7.5B", help="Path to save the nanotron model"
)
args = parser.parse_args()
main(args.checkpoint_path, args.save_path)
Loading