Skip to content

Commit

Permalink
Add MicroLlama training support (#1457)
Browse files Browse the repository at this point in the history
Co-authored-by: rasbt <[email protected]>
  • Loading branch information
keeeeenw and rasbt authored Jun 4, 2024
1 parent e567dbe commit fa88952
Show file tree
Hide file tree
Showing 10 changed files with 249 additions and 72 deletions.
47 changes: 24 additions & 23 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,30 +73,31 @@ LitGPT has 🤯 **custom, from-scratch implementations** of [20+ LLMs](tutorials

| Model | Model size | Author | Reference |
|----|----|----|----|
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) |
| Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) |
| Danube2 | 1.8B | H2O.ai | [H2O.ai](https://h2o.ai/platform/danube-1-8b/) |
| Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) |
| Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) |
| FreeWilly2 (Stable Beluga 2) | 70B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stable-beluga-large-instruction-fine-tuned-models) |
| Function Calling Llama 2 | 7B | Trelis | [Trelis et al. 2023](https://huggingface.co/Trelis/Llama-2-7b-chat-hf-function-calling-v2) |
| Gemma | 2B, 7B | Google | [Google Team, Google Deepmind](https://storage.googleapis.com/deepmind-media/gemma/gemma-report.pdf) |
| Llama 2 | 7B, 13B, 70B | Meta AI | [Touvron et al. 2023](https://arxiv.org/abs/2307.09288) |
| Llama 3 | 8B, 70B | Meta AI | [Meta AI 2024](https://github.com/meta-llama/llama3) |
| LongChat | 7B, 13B | LMSYS | [LongChat Team 2023](https://lmsys.org/blog/2023-06-29-longchat/) |
| MicroLlama | 300M | Ken Wang | [MicroLlama repo](https://github.com/keeeeenw/MicroLlama)
| Mixtral MoE | 8x7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/mixtral-of-experts/) |
| Mistral | 7B | Mistral AI | [Mistral AI 2023](https://mistral.ai/news/announcing-mistral-7b/) |
| Nous-Hermes | 7B, 13B, 70B | NousResearch | [Org page](https://huggingface.co/NousResearch) |
| OpenLLaMA | 3B, 7B, 13B | OpenLM Research | [Geng & Liu 2023](https://github.com/openlm-research/open_llama) |
| Phi | 1.3B, 2.7B | Microsoft Research | [Li et al. 2023](https://arxiv.org/abs/2309.05463) |
| Platypus | 7B, 13B, 70B | Lee et al. | [Lee, Hunter, and Ruiz 2023](https://arxiv.org/abs/2308.07317) |
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/)
| Pythia | {14,31,70,160,410}M, {1,1.4,2.8,6.9,12}B | EleutherAI | [Biderman et al. 2023](https://arxiv.org/abs/2304.01373) |
| RedPajama-INCITE | 3B, 7B | Together | [Together 2023](https://together.ai/blog/redpajama-models-v1) |
| StableCode | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| StableLM | 3B, 7B | Stability AI | [Stability AI 2023](https://github.com/Stability-AI/StableLM) |
| StableLM Zephyr | 3B | Stability AI | [Stability AI 2023](https://stability.ai/blog/stablecode-llm-generative-ai-coding) |
| TinyLlama | 1.1B | Zhang et al. | [Zhang et al. 2023](https://github.com/jzhang38/TinyLlama) |
| Vicuna | 7B, 13B, 33B | LMSYS | [Li et al. 2023](https://lmsys.org/blog/2023-03-30-vicuna/) |

**Tip**: You can list all available models by running the `litgpt download list` command.

Expand Down
115 changes: 115 additions & 0 deletions config_hub/pretrain/microllama.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@

# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with
# ``model_config``. (type: Optional[str], default: null)
model_name: micro-llama-300M

# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with
# ``model_config``. (type: Optional[Config], default: null)
model_config:

# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain)
out_dir: out/pretrain/micro-llama

# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null)
precision: bf16-mixed

# Optional path to a checkpoint directory to initialize the model from.
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null)
initial_checkpoint_dir:

# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume
# from the latest checkpoint in ``out_dir``. (type: Union[bool, Path], default: False)
resume: false

# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``.
data: MicroLlama

# Training-related arguments. See ``litgpt.args.TrainArgs`` for details
train:

# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000)
save_interval: 1000

# Number of iterations between logging calls (type: int, default: 1)
log_interval: 1

# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 48)
# Scale this number according to the number of GPU and memory size per GPU
# For example, we used 48 for 4 x 24G 4090
global_batch_size: 48

# Number of samples per data-parallel rank (type: int, default: 12)
# Scale this number according to the memory size per GPU
# For example, we used 12 for 24G 4090
micro_batch_size: 12

# Number of iterations with learning rate warmup active (type: int, default: 2000)
lr_warmup_steps: 2000

# Number of epochs to train on (type: Optional[int], default: null)
epochs:

# Total number of tokens to train on (type: Optional[int], default: 3000000000000)
max_tokens: 3000000000000

# Limits the number of optimizer steps to run. (type: Optional[int], default: null)
max_steps:

# Limits the length of samples. Off by default (type: Optional[int], default: null)
max_seq_length: 2048

# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False)
tie_embeddings:

# (type: Optional[float], default: 1.0)
max_norm: 1.0

# (type: float, default: 4e-05)
min_lr: 4.0e-05

# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details
eval:

# Number of optimizer steps between evaluation calls (type: int, default: 1000)
interval: 1000

# Number of tokens to generate (type: Optional[int], default: null)
max_new_tokens:

# Number of iterations (type: int, default: 100)
max_iters: 100

# Whether to evaluate on the validation set at the beginning of the training
initial_validation: false

# Optimizer-related arguments
optimizer:

class_path: torch.optim.AdamW

init_args:

# (type: float, default: 0.001)
lr: 4e-4

# (type: float, default: 0.01)
weight_decay: 0.1

# (type: tuple, default: (0.9,0.999))
betas:
- 0.9
- 0.95

# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto)
devices: auto

# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data
# module require this. (type: Optional[Path], default: null)
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf

# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: tensorboard)
logger_name: tensorboard

# The random seed to use for reproducibility. (type: int, default: 42)
seed: 42
28 changes: 27 additions & 1 deletion litgpt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1548,7 +1548,7 @@ def norm_class(self) -> Type:
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm", # original TinyLlama uses FusedRMSNorm
norm_class_name="RMSNorm", # original TinyLlama use FusedRMSNorm
norm_eps=1e-5,
mlp_class_name="LLaMAMLP",
intermediate_size=5632,
Expand All @@ -1563,6 +1563,32 @@ def norm_class(self) -> Type:
configs.append(copy)


############
# MicroLlama
############
micro_llama = [
dict(
name="micro-llama-300M",
hf_config=dict(org="keeeeenw", name="MicroLlama"),
block_size=2048,
vocab_size=32000,
padding_multiple=64,
n_layer=12,
n_head=16,
n_embd=1024,
rotary_percentage=1.0,
parallel_residual=False,
bias=False,
norm_class_name="RMSNorm", # original TinyLlama and MicroLlama use FusedRMSNorm
norm_eps=1e-5,
mlp_class_name="LLaMAMLP",
intermediate_size=5632,
n_query_groups=4,
)
]
configs.extend(micro_llama)


##########################
# Trelis Function Calling
##########################
Expand Down
2 changes: 2 additions & 0 deletions litgpt/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from litgpt.data.tinyllama import TinyLlama
from litgpt.data.tinystories import TinyStories
from litgpt.data.openwebtext import OpenWebText
from litgpt.data.microllama import MicroLlama


__all__ = [
Expand All @@ -34,5 +35,6 @@
"TextFiles",
"TinyLlama",
"TinyStories",
"MicroLlama"
"get_sft_collate_fn",
]
13 changes: 13 additions & 0 deletions litgpt/data/microllama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
from dataclasses import dataclass
from pathlib import Path
from typing import Union

from litgpt.data import TinyLlama

@dataclass
class MicroLlama(TinyLlama):
"""The MicroLlama data module is composed of only SlimPajama data."""

def __init__(self, data_path: Union[str, Path] = Path("data/"), seed: int = 42, num_workers: int = 8):
super().__init__(data_path=data_path, seed=seed, num_workers=num_workers, use_starcoder=False)
2 changes: 2 additions & 0 deletions litgpt/data/prepare_slimpajama.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@


class SlimPajamaDataRecipe(DataChunkRecipe):
is_generator = True

def __init__(self, tokenizer: Tokenizer, chunk_size: int):
super().__init__(chunk_size)
self.tokenizer = tokenizer
Expand Down
2 changes: 2 additions & 0 deletions litgpt/data/prepare_starcoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@


class StarcoderDataRecipe(DataChunkRecipe):
is_generator = True

def __init__(self, tokenizer: Tokenizer, chunk_size: int):
super().__init__(chunk_size)
self.tokenizer = tokenizer
Expand Down
55 changes: 33 additions & 22 deletions litgpt/data/tinyllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ class TinyLlama(DataModule):
"""The random seed for shuffling the dataset."""
num_workers: int = 8
"""How many DataLoader processes to use for loading."""
use_starcoder: bool = True
"""Toggle for using Starcoder data."""

batch_size: int = field(init=False, repr=False, default=1)
seq_length: int = field(init=False, repr=False, default=2048)
Expand All @@ -32,7 +34,11 @@ def __post_init__(self):
# Could be a remote path (s3://) or a local path
self.slimpajama_train = str(self.data_path).rstrip("/") + "/slimpajama/train"
self.slimpajama_val = str(self.data_path).rstrip("/") + "/slimpajama/val"
self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder"
self.required_paths = [self.slimpajama_train, self.slimpajama_val]

if self.use_starcoder:
self.starcoder_train = str(self.data_path).rstrip("/") + "/starcoder"
self.required_paths += [self.starcoder_train]

def connect(
self, tokenizer: Optional[Tokenizer] = None, batch_size: int = 1, max_seq_length: Optional[int] = None
Expand All @@ -41,7 +47,7 @@ def connect(
self.seq_length = max_seq_length + 1 # Increase by one because we need the next token as well

def prepare_data(self) -> None:
for path in (self.slimpajama_train, self.slimpajama_val, self.starcoder_train):
for path in self.required_paths:
if not path.startswith("s3://") and not Path(path).is_dir():
raise FileNotFoundError(
"The data path for TinyLlama is expected to be the directory containing these subdirectories:"
Expand All @@ -52,28 +58,33 @@ def prepare_data(self) -> None:
def train_dataloader(self) -> DataLoader:
from litdata.streaming import CombinedStreamingDataset, StreamingDataLoader, StreamingDataset, TokensLoader

train_datasets = [
StreamingDataset(
input_dir=self.slimpajama_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
),
StreamingDataset(
input_dir=self.starcoder_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
),
]

# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
combined_dataset = CombinedStreamingDataset(
datasets=train_datasets, seed=self.seed, weights=weights, iterate_over_all=False
slim_train_data = StreamingDataset(
input_dir=self.slimpajama_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
)
train_data = slim_train_data

if self.use_starcoder:
train_datasets = [
slim_train_data,
StreamingDataset(
input_dir=self.starcoder_train,
item_loader=TokensLoader(block_size=self.seq_length),
shuffle=True,
drop_last=True,
),
]

# Mix SlimPajama data and Starcoder data with these proportions:
weights = (0.693584, 0.306416)
train_data = CombinedStreamingDataset(
datasets=train_datasets, seed=self.seed, weights=weights, iterate_over_all=False
)

train_dataloader = StreamingDataLoader(
combined_dataset, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
train_data, batch_size=self.batch_size, pin_memory=True, num_workers=self.num_workers, drop_last=True
)
return train_dataloader

Expand Down
2 changes: 1 addition & 1 deletion litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from litgpt import Tokenizer
from litgpt.args import EvalArgs, TrainArgs
from litgpt.config import name_to_config
from litgpt.data import DataModule, TinyLlama
from litgpt.data import DataModule, TinyLlama, MicroLlama
from litgpt.model import GPT, Block, CausalSelfAttention, Config, LLaMAMLP
from litgpt.utils import (
CycleIterator,
Expand Down
Loading

0 comments on commit fa88952

Please sign in to comment.