Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

32 b #121

Draft
wants to merge 140 commits into
base: main
Choose a base branch
from
Draft

32 b #121

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
140 commits
Select commit Hold shift + click to select a range
b94e702
Save more often
dirkgr Dec 8, 2024
368abb8
Don't check for cancelation all the time
dirkgr Dec 8, 2024
c277d54
Make sure we use the same CE loss that we used for the 13B
dirkgr Dec 8, 2024
7c74d8b
We're going to 5T!
dirkgr Dec 8, 2024
53d61fe
We can live with a bigger eval batch size.
dirkgr Dec 8, 2024
514abb8
Add MMLU downstream eval
dirkgr Dec 9, 2024
011113e
Module isn't callable
dirkgr Dec 9, 2024
2577397
Qwen-ish
dirkgr Dec 9, 2024
93637a1
Make model bigger
dirkgr Dec 9, 2024
784377d
It's now a 32B.
dirkgr Dec 10, 2024
eec7e10
6T tokens
dirkgr Dec 10, 2024
bd5edee
Official save folder
dirkgr Dec 10, 2024
f516f09
6.5T tokens
dirkgr Dec 10, 2024
49264f5
Merge remote-tracking branch 'origin/main' into 32B
dirkgr Dec 10, 2024
4bb5d5c
Merged
dirkgr Dec 10, 2024
1ff1371
Change project name and location
dirkgr Dec 10, 2024
4375612
Revert "Merged"
dirkgr Dec 10, 2024
20b9b08
Revert "Module isn't callable"
dirkgr Dec 10, 2024
7736198
Revert "Make sure we use the same CE loss that we used for the 13B"
dirkgr Dec 10, 2024
8e0613f
We still want it fused!
dirkgr Dec 10, 2024
5652953
One-in-two activation checkpointing
dirkgr Dec 10, 2024
323c786
Merge remote-tracking branch 'origin/main' into 32B
dirkgr Dec 10, 2024
4f676e2
Smaller microbatch
dirkgr Dec 10, 2024
d4e63fa
Wrap 3 in 4 blocks
dirkgr Dec 10, 2024
7c22386
Don't compile the loss.
dirkgr Dec 10, 2024
f38bff4
Turn off broken eval
dirkgr Dec 11, 2024
3bf2440
Go back to mbsz of 4
dirkgr Dec 11, 2024
ab5afcf
Set drop_last for DownstreamEvaluator to False
2015aroras Dec 11, 2024
47f9545
Bring back Copa now that we have Shane's fix
dirkgr Dec 11, 2024
ee6aa90
Merge remote-tracking branch 'origin/32B' into 32B
dirkgr Dec 11, 2024
c656a41
Check if beaker loading issues are due to beaker changes by updating …
2015aroras Dec 11, 2024
7852e1e
Try hsdp with 2 nodes per replica
2015aroras Dec 11, 2024
b19e76d
Revert "Try hsdp with 2 nodes per replica"
2015aroras Dec 11, 2024
a02dd95
Try activation checkpointing 3 in 4
2015aroras Dec 12, 2024
6eaa5a3
Try activation checkpointing 3 in 4 + all feedforwards checkpointed
2015aroras Dec 12, 2024
b2a07de
Decrease microbatch size
2015aroras Dec 13, 2024
9985d31
Try activation checkpointing on just feed forwards
2015aroras Dec 13, 2024
4cc6a62
Fix name
dirkgr Dec 16, 2024
1060499
Try to run with hybrid sharding.
dirkgr Dec 16, 2024
fb2a274
More batch
dirkgr Dec 16, 2024
1073613
Revert "More batch"
dirkgr Dec 16, 2024
c553b98
There is something wrong with how the `common` object is set up.
dirkgr Dec 16, 2024
e49d4b7
We need a less sharded checkpoint and I guess this is the only way we…
dirkgr Dec 16, 2024
9608482
Revert "We need a less sharded checkpoint and I guess this is the onl…
dirkgr Dec 16, 2024
4804004
Async checkpointer may have problems with large checkpoints?
dirkgr Dec 16, 2024
fd4edb8
For loading checkpoints, it seems we need a longer timeout
dirkgr Dec 16, 2024
1f79446
Revert "Async checkpointer may have problems with large checkpoints?"
dirkgr Dec 16, 2024
072c616
Flight to safety
dirkgr Dec 16, 2024
6ba3e23
Increase microbatch size up to 2 * 4096
2015aroras Dec 17, 2024
07cc66c
Watching the 32B in a notebook
dirkgr Dec 18, 2024
18e9a32
Merge branch '32B' of https://github.com/allenai/OLMo-core into 32B
dirkgr Dec 18, 2024
2150b36
Merge branch 'main' into 32B
2015aroras Dec 19, 2024
c8cf403
Enable HSDP with pre-downloading
2015aroras Dec 19, 2024
d9cb6cf
Turn off hsdp
2015aroras Dec 19, 2024
5f2cf19
Revert "Turn off hsdp"
2015aroras Dec 19, 2024
19c8758
Add option to set thread_count
2015aroras Dec 19, 2024
9a12202
Run formatter
2015aroras Dec 19, 2024
d5e6e2b
Limit thread count
2015aroras Dec 19, 2024
ea0acce
Decrease microbatch size
2015aroras Dec 19, 2024
d2a00a7
Increase microbatch size, increase activation checkpointing
2015aroras Dec 19, 2024
016e426
Decrease microbatch size
2015aroras Dec 20, 2024
a28ca37
Decrease thread_count
2015aroras Dec 20, 2024
1c33794
Thread count 1
2015aroras Dec 20, 2024
484d01c
Back to FSDP
2015aroras Dec 20, 2024
275364c
Back to HSDP, but with less replicas
2015aroras Dec 20, 2024
54d5623
Merge branch 'main' into 32B
2015aroras Dec 20, 2024
4644e6e
Microbatch size back to 1
2015aroras Dec 20, 2024
d7ed30e
Revert "Microbatch size back to 1"
2015aroras Dec 20, 2024
0c47992
Back to FSDP
2015aroras Dec 20, 2024
246eff6
Revert "Back to FSDP"
2015aroras Dec 20, 2024
b956e3f
Enable NCCL debug
2015aroras Dec 20, 2024
f877907
More debug info
2015aroras Dec 20, 2024
58bef95
Merge branch 'main' into 32B
2015aroras Dec 20, 2024
c84708f
Disable pre_download, set higher thread count
2015aroras Dec 20, 2024
56c4ab3
FSDP with AC of selected ops
2015aroras Dec 20, 2024
b5f3a86
Back to AC of just feedforward layers
2015aroras Dec 21, 2024
3fbdeb0
Add new inloop evals
2015aroras Dec 21, 2024
b335cdf
Turn off NCCL debug
2015aroras Dec 21, 2024
30f8f59
Merge branch 'main' into 32B
2015aroras Dec 21, 2024
e17e4b8
Make checkpoint writing respect thread count config
2015aroras Dec 22, 2024
ba49cc4
Add skip step optimizer changes
2015aroras Dec 22, 2024
25ede33
Update 32B config with skip step adamw
2015aroras Dec 22, 2024
ac01e83
Try fix skip step optimizer
2015aroras Dec 22, 2024
ddd61ac
Try manual _std_mean impl
2015aroras Dec 22, 2024
973a26c
Add skip step fixes
2015aroras Dec 22, 2024
baf5700
Have separate save and load thread counts
2015aroras Dec 22, 2024
b6762d8
Decrease threads used for saving
2015aroras Dec 22, 2024
d98f06d
Skipped steps and automatic spike analysis
dirkgr Dec 22, 2024
4a68e9e
Use compile=True for optimizer
2015aroras Dec 22, 2024
d81cd12
Make gcs upload pass generation
2015aroras Dec 23, 2024
0a04034
Update CHANGELOG
2015aroras Dec 23, 2024
5acc7eb
Run formatter
2015aroras Dec 23, 2024
213b03e
Make generation 0 when object does not exist
2015aroras Dec 23, 2024
b4994b0
Merge branch 'shanea/fix-upload-retries' into 32B
2015aroras Dec 23, 2024
3b84351
Run formatting
2015aroras Dec 23, 2024
178d9ad
Remove unneeded import
2015aroras Dec 23, 2024
0b737aa
Add missing reload
2015aroras Dec 23, 2024
3e6f9f1
Updated notebook
dirkgr Dec 23, 2024
663d63a
Updated dashboard
dirkgr Dec 24, 2024
496919b
Update the notebook
dirkgr Dec 24, 2024
a1854bd
Updated notebook
dirkgr Dec 27, 2024
f2de5f4
Retry on bad request
dirkgr Dec 28, 2024
33c0f58
Add some more retries
dirkgr Dec 28, 2024
86afc43
Updated the notebook
dirkgr Dec 29, 2024
2e45a79
Update the dashboard
dirkgr Dec 30, 2024
e4e8fbb
Fix the way we use the step in the optimizer
dirkgr Dec 31, 2024
146caaf
Dashboard update
dirkgr Dec 31, 2024
393a462
Update dashboard
dirkgr Jan 3, 2025
d39c59d
New report
dirkgr Jan 6, 2025
16983c4
Dashboard update
dirkgr Jan 7, 2025
5e4d04f
No more ephemeral checkpoints
dirkgr Jan 8, 2025
eba0418
Don't eval so much
dirkgr Jan 8, 2025
5605001
When you wait on someone, you bring them water.
dirkgr Jan 8, 2025
7ce7efa
Updating the dashboard
dirkgr Jan 8, 2025
05aa94f
Reorder ranks in GCP
dirkgr Jan 9, 2025
9c86bf9
Rank 0 needs to remain rank 0
dirkgr Jan 9, 2025
e27b91d
Slightly less checkpointing
dirkgr Jan 9, 2025
52b9b77
Revert "Slightly less checkpointing"
dirkgr Jan 9, 2025
f045eee
Turn off failure propagation to make slack notifier work better
2015aroras Jan 13, 2025
ddb3084
New dashboard
dirkgr Jan 14, 2025
72e0ed1
Merge branch '32B' of https://github.com/allenai/OLMo-core into 32B
dirkgr Jan 14, 2025
d1d8dcb
hopefully make GCS client calls more robust
epwalsh Jan 14, 2025
a0700e8
Catch user exceptions as well as system exceptions when training fails
2015aroras Jan 15, 2025
0595cf8
Revert "Catch user exceptions as well as system exceptions when train…
2015aroras Jan 15, 2025
74c6960
Dashboard
dirkgr Jan 16, 2025
df46d5c
Merge remote-tracking branch 'origin/32B' into 32B
dirkgr Jan 16, 2025
985785c
Suppress Google checksum warnings
2015aroras Jan 16, 2025
6cc9e99
Setup kernel cache for PyTorch
2015aroras Jan 16, 2025
6c31495
Dashboard
dirkgr Jan 17, 2025
f47f6f5
minor clean up
epwalsh Jan 17, 2025
db0df12
Add profiler
dirkgr Jan 20, 2025
7f98496
Dashboard
dirkgr Jan 20, 2025
be4e788
clean up rank reordering
epwalsh Jan 21, 2025
bfa6a8d
move script to launch module so it's available in package
epwalsh Jan 21, 2025
b1ad693
remove old
epwalsh Jan 21, 2025
fde5f68
fix merge conflicts
epwalsh Jan 21, 2025
4264050
clean up
epwalsh Jan 21, 2025
a7b4507
Merge branch 'main' into 32B
epwalsh Jan 21, 2025
5fbc50e
throttle uploads
epwalsh Jan 21, 2025
7f6a6d0
Add annealing config
epwalsh Jan 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
348 changes: 348 additions & 0 deletions src/scripts/train/OLMo2-32B-anneal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,348 @@
import json
import os
import sys
from dataclasses import dataclass
from typing import List, cast

import rich
import torch
from rich import print

from olmo_core.config import Config, DType
from olmo_core.data import (
DataMix,
NumpyDataLoaderConfig,
NumpyDatasetConfig,
TokenizerConfig,
)
from olmo_core.distributed.parallel import DataParallelType
from olmo_core.distributed.utils import get_local_rank
from olmo_core.internal.common import build_launch_config, get_root_dir, get_work_dir
from olmo_core.io import resource_path
from olmo_core.launch.beaker import BeakerLaunchConfig
from olmo_core.nn.transformer import (
TransformerActivationCheckpointingConfig,
TransformerActivationCheckpointingMode,
TransformerConfig,
TransformerDataParallelConfig,
)
from olmo_core.optim import (
CosWithWarmup,
LinearWithWarmup,
OptimConfig,
OptimGroupOverride,
SkipStepAdamWConfig,
)
from olmo_core.train import (
Duration,
TrainerConfig,
prepare_training_environment,
teardown_training_environment,
)
from olmo_core.train.callbacks import (
CheckpointerCallback,
CometCallback,
ConfigSaverCallback,
DownstreamEvaluatorCallbackConfig,
GarbageCollectorCallback,
GPUMemoryMonitorCallback,
GradClipperCallback,
SchedulerCallback,
)
from olmo_core.train.checkpoint import CheckpointerConfig
from olmo_core.utils import get_default_device, prepare_cli_environment, seed_all

# The max number of pretraining steps configured for the purpose of setting the learning rate
# schedule. I'm hard-coding this here based on the number found in the logs. It only changes
# if batch size changes, which we're not planning on changing over the course of the run.
MAX_PRETRAIN_STEPS = 774861


@dataclass
class AnnealingConfig(Config):
run_name: str
launch: BeakerLaunchConfig
model: TransformerConfig
optim: OptimConfig
dataset: NumpyDatasetConfig
data_loader: NumpyDataLoaderConfig
trainer: TrainerConfig
init_seed: int = 12536


def build_config(
*, script: str, cmd: str, run_name: str, checkpoint: str, cluster: str, overrides: List[str]
) -> AnnealingConfig:
root_dir = get_root_dir(cluster)

tokenizer_config = TokenizerConfig.dolma2()

# Try to guess step number to infer where the learning rate left off.
last_pretrain_step: int
if (basename := os.path.basename(checkpoint)).startswith("step"):
last_pretrain_step = int(basename.replace("step", ""))
else:
last_pretrain_step = torch.load(
resource_path(f"{checkpoint}/train", "rank0.pt"), weights_only=False
)["global_step"]

# Now infer the learning rate.
with resource_path(checkpoint, "config.json").open() as f:
config = json.load(f)
base_lr = config["optim"]["lr"]
scheduler_config = config["trainer"]["callbacks"]["lr_scheduler"]["scheduler"]
assert scheduler_config.pop("_CLASS_") == CosWithWarmup.__name__
scheduler = CosWithWarmup(**scheduler_config)
starting_lr = float(scheduler.get_lr(base_lr, last_pretrain_step, MAX_PRETRAIN_STEPS))

return AnnealingConfig(
run_name=run_name,
launch=build_launch_config(
name=run_name,
root_dir=root_dir,
cmd=[script, cmd, run_name, cluster, *overrides],
cluster=cluster,
nccl_debug=False,
),
model=TransformerConfig.olmo2_32B(
vocab_size=tokenizer_config.padded_vocab_size(),
compile=True,
fused_ops=False,
use_flash=False,
dp_config=TransformerDataParallelConfig(
name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32
),
ac_config=TransformerActivationCheckpointingConfig(
mode=TransformerActivationCheckpointingMode.selected_modules,
modules=["blocks.*.feed_forward"],
),
),
optim=SkipStepAdamWConfig(
lr=starting_lr,
weight_decay=0.1,
betas=(0.9, 0.95),
group_overrides=[
OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0))
],
compile=True,
),
dataset=NumpyDatasetConfig.from_data_mix(
DataMix.OLMoE_mix_0824, # TODO: change this to annealing mix
tokenizer=tokenizer_config,
mix_base_dir=root_dir,
sequence_length=4096,
work_dir=get_work_dir(root_dir),
),
data_loader=NumpyDataLoaderConfig(
global_batch_size=2048 * 4096, # NOTE: this is specified in TOKENS, not instances.
seed=34521, # Can update this to change data order.
num_workers=4,
),
trainer=TrainerConfig(
save_folder="gs://ai2-llm/checkpoints/peteish32-anneal",
rank_microbatch_size=2 * 4096, # NOTE: again this is specified in tokens.
checkpointer=CheckpointerConfig(
save_thread_count=1, load_thread_count=32, throttle_uploads=True
),
save_overwrite=True,
metrics_collect_interval=10,
cancel_check_interval=10,
z_loss_multiplier=1e-5,
compile_loss=False,
fused_loss=True,
max_duration=Duration.tokens(int(6.5e12)),
)
.with_callback(
"checkpointer",
CheckpointerCallback(
save_interval=1000,
save_async=True,
),
)
.with_callback(
"comet",
CometCallback(
name=run_name,
workspace="ai2",
project="peteish32",
enabled=True,
cancel_check_interval=10,
),
)
.with_callback(
"lr_scheduler",
SchedulerCallback(
scheduler=LinearWithWarmup(
warmup_steps=0,
alpha_f=0.1, # TODO: change this to 0.0 if you want to go down to 0
)
),
)
.with_callback(
"gpu_monitor",
GPUMemoryMonitorCallback(),
)
.with_callback("grad_clipper", GradClipperCallback(max_grad_norm=1.0))
.with_callback("config_saver", ConfigSaverCallback())
.with_callback("garbage_collector", GarbageCollectorCallback())
.with_callback(
"downstream_evaluator",
DownstreamEvaluatorCallbackConfig(
tasks=[
# MMLU for backwards compatibility
"mmlu_stem_mc_5shot",
"mmlu_humanities_mc_5shot",
"mmlu_social_sciences_mc_5shot",
"mmlu_other_mc_5shot",
# MMLU test
"mmlu_stem_mc_5shot_test",
"mmlu_humanities_mc_5shot_test",
"mmlu_social_sciences_mc_5shot_test",
"mmlu_other_mc_5shot_test",
## Core 12 tasks for backwards compatibility
# "arc_challenge",
# "arc_easy",
# "basic_arithmetic",
# "boolq",
# "commonsense_qa",
# "copa",
# "hellaswag",
# "openbook_qa",
# "piqa",
# "sciq",
# "social_iqa",
# "winogrande",
## Core 12 tasks 5-shot
# "arc_challenge_rc_5shot",
# "arc_easy_rc_5shot",
## "basic_arithmetic_rc_5shot", # doesn't exist
## "boolq_rc_5shot", # we don't like it
# "csqa_rc_5shot",
## "copa_rc_5shot", # doesn't exist
# "hellaswag_rc_5shot",
# "openbookqa_rc_5shot",
# "piqa_rc_5shot",
## "sciq_rc_5shot", # doesn't exist
# "socialiqa_rc_5shot",
# "winogrande_rc_5shot",
## New in-loop evals
# "arc_challenge_val_rc_5shot",
# "arc_challenge_val_mc_5shot",
"arc_challenge_test_rc_5shot",
# "arc_challenge_test_mc_5shot",
# "arc_easy_val_rc_5shot",
# "arc_easy_val_mc_5shot",
"arc_easy_test_rc_5shot",
# "arc_easy_test_mc_5shot",
# "boolq_val_rc_5shot",
# "boolq_val_mc_5shot",
"csqa_val_rc_5shot",
# "csqa_val_mc_5shot",
"hellaswag_val_rc_5shot",
# "hellaswag_val_mc_5shot",
# "openbookqa_val_rc_5shot",
# "openbookqa_val_mc_5shot",
"openbookqa_test_rc_5shot",
# "openbookqa_test_mc_5shot",
"piqa_val_rc_5shot",
# "piqa_val_mc_5shot",
"socialiqa_val_rc_5shot",
# "socialiqa_val_mc_5shot",
# "winogrande_val_rc_5shot",
# "winogrande_val_mc_5shot",
# "mmlu_stem_val_rc_5shot",
# "mmlu_stem_val_mc_5shot",
# "mmlu_humanities_val_rc_5shot",
# "mmlu_humanities_val_mc_5shot",
# "mmlu_social_sciences_val_rc_5shot",
# "mmlu_social_sciences_val_mc_5shot",
# "mmlu_other_val_rc_5shot",
# "mmlu_other_val_mc_5shot",
],
tokenizer=tokenizer_config,
eval_interval=1000,
),
),
).merge(overrides)


def train(config: AnnealingConfig):
# Set RNG states on all devices.
seed_all(config.init_seed)

device = get_default_device()

# Build mesh, if needed.
world_mesh = config.model.build_mesh(device=device)

# Build components.
model = config.model.build(
init_device="meta",
device=device,
max_seq_len=config.dataset.sequence_length,
mesh=world_mesh,
)
optim = config.optim.build(model)
dataset = config.dataset.build()
data_loader = config.data_loader.build(dataset, mesh=world_mesh)
trainer = config.trainer.build(model, optim, data_loader, mesh=world_mesh)

# Record the config to W&B/Comet and each checkpoint dir.
config_dict = config.as_config_dict()
cast(CometCallback, trainer.callbacks["comet"]).config = config_dict
cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict

# Train.
trainer.fit()


if __name__ == "__main__":
USAGE = f"""
[yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]launch|train|dry_run[/] [i b]RUN_NAME PRETRAIN_CHECKPOINT CLUSTER[/] [i][OVERRIDES...][/]

[b]Subcommands[/]
[b magenta]launch:[/] Launch the script on Beaker with the [b magenta]train[/] subcommand.
[b magenta]train:[/] Run the trainer. You usually shouldn't invoke the script with this subcommand directly.
Instead use [b magenta]launch[/] or run it with torchrun.
[b magenta]dry_run:[/] Print the config for debugging.

[b]Examples[/]
$ [i]python {sys.argv[0]} launch run01 gs://ai2-llm/checkpoints/peteish32/step419000 ai2/jupiter-cirrascale-2 --launch.num_nodes=2[/]
""".strip()

if len(sys.argv) < 5 or sys.argv[1] not in ("launch", "train", "dry_run"):
rich.get_console().print(USAGE, highlight=False)
sys.exit(1)

script, cmd, run_name, checkpoint, cluster, *overrides = sys.argv

if cmd in ("launch", "dry_run"):
prepare_cli_environment()
elif cmd == "train":
prepare_training_environment()
else:
raise NotImplementedError(cmd)

config = build_config(
script=script,
cmd="train",
run_name=run_name,
checkpoint=checkpoint,
cluster=cluster,
overrides=overrides,
)

if get_local_rank() == 0:
print(config)

if cmd == "dry_run":
pass
elif cmd == "launch":
config.launch.launch(follow=True)
elif cmd == "train":
try:
train(config)
finally:
teardown_training_environment()
else:
raise NotImplementedError(cmd)
Loading
Loading