-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add MicroLlama training support (#1457)
Co-authored-by: rasbt <[email protected]>
- Loading branch information
Showing
10 changed files
with
249 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
|
||
# The name of the model to pretrain. Choose from names in ``litgpt.config``. Mutually exclusive with | ||
# ``model_config``. (type: Optional[str], default: null) | ||
model_name: micro-llama-300M | ||
|
||
# A ``litgpt.Config`` object to define the model architecture. Mutually exclusive with | ||
# ``model_config``. (type: Optional[Config], default: null) | ||
model_config: | ||
|
||
# Directory in which to save checkpoints and logs. If running in a Lightning Studio Job, look for it in | ||
# /teamspace/jobs/<job-name>/share. (type: <class 'Path'>, default: out/pretrain) | ||
out_dir: out/pretrain/micro-llama | ||
|
||
# The precision to use for pretraining. Possible choices: "bf16-true", "bf16-mixed", "32-true". (type: Optional[str], default: null) | ||
precision: bf16-mixed | ||
|
||
# Optional path to a checkpoint directory to initialize the model from. | ||
# Useful for continued pretraining. Mutually exclusive with ``resume``. (type: Optional[Path], default: null) | ||
initial_checkpoint_dir: | ||
|
||
# Path to a checkpoint directory to resume from in case training was interrupted, or ``True`` to resume | ||
# from the latest checkpoint in ``out_dir``. (type: Union[bool, Path], default: False) | ||
resume: false | ||
|
||
# Data-related arguments. If not provided, the default is ``litgpt.data.TinyLlama``. | ||
data: MicroLlama | ||
|
||
# Training-related arguments. See ``litgpt.args.TrainArgs`` for details | ||
train: | ||
|
||
# Number of optimizer steps between saving checkpoints (type: Optional[int], default: 1000) | ||
save_interval: 1000 | ||
|
||
# Number of iterations between logging calls (type: int, default: 1) | ||
log_interval: 1 | ||
|
||
# Number of samples between optimizer steps across data-parallel ranks (type: int, default: 48) | ||
# Scale this number according to the number of GPU and memory size per GPU | ||
# For example, we used 48 for 4 x 24G 4090 | ||
global_batch_size: 48 | ||
|
||
# Number of samples per data-parallel rank (type: int, default: 12) | ||
# Scale this number according to the memory size per GPU | ||
# For example, we used 12 for 24G 4090 | ||
micro_batch_size: 12 | ||
|
||
# Number of iterations with learning rate warmup active (type: int, default: 2000) | ||
lr_warmup_steps: 2000 | ||
|
||
# Number of epochs to train on (type: Optional[int], default: null) | ||
epochs: | ||
|
||
# Total number of tokens to train on (type: Optional[int], default: 3000000000000) | ||
max_tokens: 3000000000000 | ||
|
||
# Limits the number of optimizer steps to run. (type: Optional[int], default: null) | ||
max_steps: | ||
|
||
# Limits the length of samples. Off by default (type: Optional[int], default: null) | ||
max_seq_length: 2048 | ||
|
||
# Whether to tie the embedding weights with the language modeling head weights. (type: Optional[bool], default: False) | ||
tie_embeddings: | ||
|
||
# (type: Optional[float], default: 1.0) | ||
max_norm: 1.0 | ||
|
||
# (type: float, default: 4e-05) | ||
min_lr: 4.0e-05 | ||
|
||
# Evaluation-related arguments. See ``litgpt.args.EvalArgs`` for details | ||
eval: | ||
|
||
# Number of optimizer steps between evaluation calls (type: int, default: 1000) | ||
interval: 1000 | ||
|
||
# Number of tokens to generate (type: Optional[int], default: null) | ||
max_new_tokens: | ||
|
||
# Number of iterations (type: int, default: 100) | ||
max_iters: 100 | ||
|
||
# Whether to evaluate on the validation set at the beginning of the training | ||
initial_validation: false | ||
|
||
# Optimizer-related arguments | ||
optimizer: | ||
|
||
class_path: torch.optim.AdamW | ||
|
||
init_args: | ||
|
||
# (type: float, default: 0.001) | ||
lr: 4e-4 | ||
|
||
# (type: float, default: 0.01) | ||
weight_decay: 0.1 | ||
|
||
# (type: tuple, default: (0.9,0.999)) | ||
betas: | ||
- 0.9 | ||
- 0.95 | ||
|
||
# How many devices/GPUs to use. Uses all GPUs by default. (type: Union[int, str], default: auto) | ||
devices: auto | ||
|
||
# Optional path to the tokenizer dir that was used for preprocessing the dataset. Only some data | ||
# module require this. (type: Optional[Path], default: null) | ||
tokenizer_dir: checkpoints/meta-llama/Llama-2-7b-hf | ||
|
||
# The name of the logger to send metrics to. (type: Literal['wandb', 'tensorboard', 'csv'], default: tensorboard) | ||
logger_name: tensorboard | ||
|
||
# The random seed to use for reproducibility. (type: int, default: 42) | ||
seed: 42 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. | ||
from dataclasses import dataclass | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
from litgpt.data import TinyLlama | ||
|
||
@dataclass | ||
class MicroLlama(TinyLlama): | ||
"""The MicroLlama data module is composed of only SlimPajama data.""" | ||
|
||
def __init__(self, data_path: Union[str, Path] = Path("data/"), seed: int = 42, num_workers: int = 8): | ||
super().__init__(data_path=data_path, seed=seed, num_workers=num_workers, use_starcoder=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.