diff --git a/src/scripts/train/OLMo2-32B-anneal.py b/src/scripts/train/OLMo2-32B-anneal.py new file mode 100644 index 00000000..cc08534d --- /dev/null +++ b/src/scripts/train/OLMo2-32B-anneal.py @@ -0,0 +1,348 @@ +import json +import os +import sys +from dataclasses import dataclass +from typing import List, cast + +import rich +import torch +from rich import print + +from olmo_core.config import Config, DType +from olmo_core.data import ( + DataMix, + NumpyDataLoaderConfig, + NumpyDatasetConfig, + TokenizerConfig, +) +from olmo_core.distributed.parallel import DataParallelType +from olmo_core.distributed.utils import get_local_rank +from olmo_core.internal.common import build_launch_config, get_root_dir, get_work_dir +from olmo_core.io import resource_path +from olmo_core.launch.beaker import BeakerLaunchConfig +from olmo_core.nn.transformer import ( + TransformerActivationCheckpointingConfig, + TransformerActivationCheckpointingMode, + TransformerConfig, + TransformerDataParallelConfig, +) +from olmo_core.optim import ( + CosWithWarmup, + LinearWithWarmup, + OptimConfig, + OptimGroupOverride, + SkipStepAdamWConfig, +) +from olmo_core.train import ( + Duration, + TrainerConfig, + prepare_training_environment, + teardown_training_environment, +) +from olmo_core.train.callbacks import ( + CheckpointerCallback, + CometCallback, + ConfigSaverCallback, + DownstreamEvaluatorCallbackConfig, + GarbageCollectorCallback, + GPUMemoryMonitorCallback, + GradClipperCallback, + SchedulerCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig +from olmo_core.utils import get_default_device, prepare_cli_environment, seed_all + +# The max number of pretraining steps configured for the purpose of setting the learning rate +# schedule. I'm hard-coding this here based on the number found in the logs. It only changes +# if batch size changes, which we're not planning on changing over the course of the run. +MAX_PRETRAIN_STEPS = 774861 + + +@dataclass +class AnnealingConfig(Config): + run_name: str + launch: BeakerLaunchConfig + model: TransformerConfig + optim: OptimConfig + dataset: NumpyDatasetConfig + data_loader: NumpyDataLoaderConfig + trainer: TrainerConfig + init_seed: int = 12536 + + +def build_config( + *, script: str, cmd: str, run_name: str, checkpoint: str, cluster: str, overrides: List[str] +) -> AnnealingConfig: + root_dir = get_root_dir(cluster) + + tokenizer_config = TokenizerConfig.dolma2() + + # Try to guess step number to infer where the learning rate left off. + last_pretrain_step: int + if (basename := os.path.basename(checkpoint)).startswith("step"): + last_pretrain_step = int(basename.replace("step", "")) + else: + last_pretrain_step = torch.load( + resource_path(f"{checkpoint}/train", "rank0.pt"), weights_only=False + )["global_step"] + + # Now infer the learning rate. + with resource_path(checkpoint, "config.json").open() as f: + config = json.load(f) + base_lr = config["optim"]["lr"] + scheduler_config = config["trainer"]["callbacks"]["lr_scheduler"]["scheduler"] + assert scheduler_config.pop("_CLASS_") == CosWithWarmup.__name__ + scheduler = CosWithWarmup(**scheduler_config) + starting_lr = float(scheduler.get_lr(base_lr, last_pretrain_step, MAX_PRETRAIN_STEPS)) + + return AnnealingConfig( + run_name=run_name, + launch=build_launch_config( + name=run_name, + root_dir=root_dir, + cmd=[script, cmd, run_name, cluster, *overrides], + cluster=cluster, + nccl_debug=False, + ), + model=TransformerConfig.olmo2_32B( + vocab_size=tokenizer_config.padded_vocab_size(), + compile=True, + fused_ops=False, + use_flash=False, + dp_config=TransformerDataParallelConfig( + name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 + ), + ac_config=TransformerActivationCheckpointingConfig( + mode=TransformerActivationCheckpointingMode.selected_modules, + modules=["blocks.*.feed_forward"], + ), + ), + optim=SkipStepAdamWConfig( + lr=starting_lr, + weight_decay=0.1, + betas=(0.9, 0.95), + group_overrides=[ + OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) + ], + compile=True, + ), + dataset=NumpyDatasetConfig.from_data_mix( + DataMix.OLMoE_mix_0824, # TODO: change this to annealing mix + tokenizer=tokenizer_config, + mix_base_dir=root_dir, + sequence_length=4096, + work_dir=get_work_dir(root_dir), + ), + data_loader=NumpyDataLoaderConfig( + global_batch_size=2048 * 4096, # NOTE: this is specified in TOKENS, not instances. + seed=34521, # Can update this to change data order. + num_workers=4, + ), + trainer=TrainerConfig( + save_folder="gs://ai2-llm/checkpoints/peteish32-anneal", + rank_microbatch_size=2 * 4096, # NOTE: again this is specified in tokens. + checkpointer=CheckpointerConfig( + save_thread_count=1, load_thread_count=32, throttle_uploads=True + ), + save_overwrite=True, + metrics_collect_interval=10, + cancel_check_interval=10, + z_loss_multiplier=1e-5, + compile_loss=False, + fused_loss=True, + max_duration=Duration.tokens(int(6.5e12)), + ) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=1000, + save_async=True, + ), + ) + .with_callback( + "comet", + CometCallback( + name=run_name, + workspace="ai2", + project="peteish32", + enabled=True, + cancel_check_interval=10, + ), + ) + .with_callback( + "lr_scheduler", + SchedulerCallback( + scheduler=LinearWithWarmup( + warmup_steps=0, + alpha_f=0.1, # TODO: change this to 0.0 if you want to go down to 0 + ) + ), + ) + .with_callback( + "gpu_monitor", + GPUMemoryMonitorCallback(), + ) + .with_callback("grad_clipper", GradClipperCallback(max_grad_norm=1.0)) + .with_callback("config_saver", ConfigSaverCallback()) + .with_callback("garbage_collector", GarbageCollectorCallback()) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=[ + # MMLU for backwards compatibility + "mmlu_stem_mc_5shot", + "mmlu_humanities_mc_5shot", + "mmlu_social_sciences_mc_5shot", + "mmlu_other_mc_5shot", + # MMLU test + "mmlu_stem_mc_5shot_test", + "mmlu_humanities_mc_5shot_test", + "mmlu_social_sciences_mc_5shot_test", + "mmlu_other_mc_5shot_test", + ## Core 12 tasks for backwards compatibility + # "arc_challenge", + # "arc_easy", + # "basic_arithmetic", + # "boolq", + # "commonsense_qa", + # "copa", + # "hellaswag", + # "openbook_qa", + # "piqa", + # "sciq", + # "social_iqa", + # "winogrande", + ## Core 12 tasks 5-shot + # "arc_challenge_rc_5shot", + # "arc_easy_rc_5shot", + ## "basic_arithmetic_rc_5shot", # doesn't exist + ## "boolq_rc_5shot", # we don't like it + # "csqa_rc_5shot", + ## "copa_rc_5shot", # doesn't exist + # "hellaswag_rc_5shot", + # "openbookqa_rc_5shot", + # "piqa_rc_5shot", + ## "sciq_rc_5shot", # doesn't exist + # "socialiqa_rc_5shot", + # "winogrande_rc_5shot", + ## New in-loop evals + # "arc_challenge_val_rc_5shot", + # "arc_challenge_val_mc_5shot", + "arc_challenge_test_rc_5shot", + # "arc_challenge_test_mc_5shot", + # "arc_easy_val_rc_5shot", + # "arc_easy_val_mc_5shot", + "arc_easy_test_rc_5shot", + # "arc_easy_test_mc_5shot", + # "boolq_val_rc_5shot", + # "boolq_val_mc_5shot", + "csqa_val_rc_5shot", + # "csqa_val_mc_5shot", + "hellaswag_val_rc_5shot", + # "hellaswag_val_mc_5shot", + # "openbookqa_val_rc_5shot", + # "openbookqa_val_mc_5shot", + "openbookqa_test_rc_5shot", + # "openbookqa_test_mc_5shot", + "piqa_val_rc_5shot", + # "piqa_val_mc_5shot", + "socialiqa_val_rc_5shot", + # "socialiqa_val_mc_5shot", + # "winogrande_val_rc_5shot", + # "winogrande_val_mc_5shot", + # "mmlu_stem_val_rc_5shot", + # "mmlu_stem_val_mc_5shot", + # "mmlu_humanities_val_rc_5shot", + # "mmlu_humanities_val_mc_5shot", + # "mmlu_social_sciences_val_rc_5shot", + # "mmlu_social_sciences_val_mc_5shot", + # "mmlu_other_val_rc_5shot", + # "mmlu_other_val_mc_5shot", + ], + tokenizer=tokenizer_config, + eval_interval=1000, + ), + ), + ).merge(overrides) + + +def train(config: AnnealingConfig): + # Set RNG states on all devices. + seed_all(config.init_seed) + + device = get_default_device() + + # Build mesh, if needed. + world_mesh = config.model.build_mesh(device=device) + + # Build components. + model = config.model.build( + init_device="meta", + device=device, + max_seq_len=config.dataset.sequence_length, + mesh=world_mesh, + ) + optim = config.optim.build(model) + dataset = config.dataset.build() + data_loader = config.data_loader.build(dataset, mesh=world_mesh) + trainer = config.trainer.build(model, optim, data_loader, mesh=world_mesh) + + # Record the config to W&B/Comet and each checkpoint dir. + config_dict = config.as_config_dict() + cast(CometCallback, trainer.callbacks["comet"]).config = config_dict + cast(ConfigSaverCallback, trainer.callbacks["config_saver"]).config = config_dict + + # Train. + trainer.fit() + + +if __name__ == "__main__": + USAGE = f""" +[yellow]Usage:[/] [i blue]python[/] [i cyan]{sys.argv[0]}[/] [i b magenta]launch|train|dry_run[/] [i b]RUN_NAME PRETRAIN_CHECKPOINT CLUSTER[/] [i][OVERRIDES...][/] + +[b]Subcommands[/] +[b magenta]launch:[/] Launch the script on Beaker with the [b magenta]train[/] subcommand. +[b magenta]train:[/] Run the trainer. You usually shouldn't invoke the script with this subcommand directly. + Instead use [b magenta]launch[/] or run it with torchrun. +[b magenta]dry_run:[/] Print the config for debugging. + +[b]Examples[/] +$ [i]python {sys.argv[0]} launch run01 gs://ai2-llm/checkpoints/peteish32/step419000 ai2/jupiter-cirrascale-2 --launch.num_nodes=2[/] +""".strip() + + if len(sys.argv) < 5 or sys.argv[1] not in ("launch", "train", "dry_run"): + rich.get_console().print(USAGE, highlight=False) + sys.exit(1) + + script, cmd, run_name, checkpoint, cluster, *overrides = sys.argv + + if cmd in ("launch", "dry_run"): + prepare_cli_environment() + elif cmd == "train": + prepare_training_environment() + else: + raise NotImplementedError(cmd) + + config = build_config( + script=script, + cmd="train", + run_name=run_name, + checkpoint=checkpoint, + cluster=cluster, + overrides=overrides, + ) + + if get_local_rank() == 0: + print(config) + + if cmd == "dry_run": + pass + elif cmd == "launch": + config.launch.launch(follow=True) + elif cmd == "train": + try: + train(config) + finally: + teardown_training_environment() + else: + raise NotImplementedError(cmd) diff --git a/src/scripts/train/OLMo2-32B.ipynb b/src/scripts/train/OLMo2-32B.ipynb new file mode 100644 index 00000000..50ebe4c5 --- /dev/null +++ b/src/scripts/train/OLMo2-32B.ipynb @@ -0,0 +1,1799 @@ +{ + "cells": [ + { + "cell_type": "code", + "id": "initial_id", + "metadata": { + "collapsed": true, + "ExecuteTime": { + "end_time": "2025-01-20T05:43:00.396563Z", + "start_time": "2025-01-20T05:42:59.655736Z" + } + }, + "source": [ + "import os\n", + "from comet_ml.api import API\n", + "\n", + "comet_api = API(os.environ[\"COMETML_API_KEY\"])\n" + ], + "outputs": [], + "execution_count": 1 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T05:43:01.198704Z", + "start_time": "2025-01-20T05:43:00.406534Z" + } + }, + "cell_type": "code", + "source": [ + "exps = {\n", + " \"peteish32\": comet_api.get_experiments(\"ai2\", \"peteish32\", \"peteish32\"),\n", + " \"peteish13\": comet_api.get_experiments(\"ai2\", \"olmo-2-1124-13b\", \"OLMo-2-1124-13B-stage-1\"),\n", + " \"peteish7\": comet_api.get_experiments(\"ai2\", \"olmo-core-7b\", \"peteish7\")\n", + "}\n", + "\n", + "print(repr({k: len(v) for k, v in exps.items()}))" + ], + "id": "2c17abe415dabf07", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'peteish32': 50, 'peteish13': 75, 'peteish7': 13}\n" + ] + } + ], + "execution_count": 2 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T05:43:16.577611Z", + "start_time": "2025-01-20T05:43:01.261999Z" + } + }, + "cell_type": "code", + "source": [ + "# print available metrics\n", + "\n", + "for name, es in exps.items():\n", + " metrics = set()\n", + " for exp in es:\n", + " for summary in exp.get_metrics_summary():\n", + " metrics.add(summary[\"name\"])\n", + " metrics = list(metrics)\n", + " metrics.sort()\n", + "\n", + " print(f\"{name}:\")\n", + " for metric in metrics:\n", + " print(\"\\t\", metric)" + ], + "id": "dc7c5e3c92741b89", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "peteish32:\n", + "\t data/sequence length\n", + "\t eval/downstream/arc_challenge (BPB)\n", + "\t eval/downstream/arc_challenge (CE loss)\n", + "\t eval/downstream/arc_challenge (length-normalized accuracy)\n", + "\t eval/downstream/arc_challenge (log soft loss)\n", + "\t eval/downstream/arc_challenge (soft loss)\n", + "\t eval/downstream/arc_challenge_rc_5shot (BPB)\n", + "\t eval/downstream/arc_challenge_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_challenge_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/arc_challenge_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_challenge_rc_5shot (soft loss)\n", + "\t eval/downstream/arc_challenge_test_mc_5shot (BPB)\n", + "\t eval/downstream/arc_challenge_test_mc_5shot (CE loss)\n", + "\t eval/downstream/arc_challenge_test_mc_5shot (accuracy)\n", + "\t eval/downstream/arc_challenge_test_mc_5shot (log soft loss)\n", + "\t eval/downstream/arc_challenge_test_mc_5shot (soft loss)\n", + "\t eval/downstream/arc_challenge_test_rc_5shot (BPB)\n", + "\t eval/downstream/arc_challenge_test_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_challenge_test_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/arc_challenge_test_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_challenge_test_rc_5shot (soft loss)\n", + "\t eval/downstream/arc_challenge_val_mc_5shot (BPB)\n", + "\t eval/downstream/arc_challenge_val_mc_5shot (CE loss)\n", + "\t eval/downstream/arc_challenge_val_mc_5shot (accuracy)\n", + "\t eval/downstream/arc_challenge_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/arc_challenge_val_mc_5shot (soft loss)\n", + "\t eval/downstream/arc_challenge_val_rc_5shot (BPB)\n", + "\t eval/downstream/arc_challenge_val_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_challenge_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/arc_challenge_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_challenge_val_rc_5shot (soft loss)\n", + "\t eval/downstream/arc_easy (BPB)\n", + "\t eval/downstream/arc_easy (CE loss)\n", + "\t eval/downstream/arc_easy (accuracy)\n", + "\t eval/downstream/arc_easy (log soft loss)\n", + "\t eval/downstream/arc_easy (soft loss)\n", + "\t eval/downstream/arc_easy_rc_5shot (BPB)\n", + "\t eval/downstream/arc_easy_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_easy_rc_5shot (accuracy)\n", + "\t eval/downstream/arc_easy_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_easy_rc_5shot (soft loss)\n", + "\t eval/downstream/arc_easy_test_mc_5shot (BPB)\n", + "\t eval/downstream/arc_easy_test_mc_5shot (CE loss)\n", + "\t eval/downstream/arc_easy_test_mc_5shot (accuracy)\n", + "\t eval/downstream/arc_easy_test_mc_5shot (log soft loss)\n", + "\t eval/downstream/arc_easy_test_mc_5shot (soft loss)\n", + "\t eval/downstream/arc_easy_test_rc_5shot (BPB)\n", + "\t eval/downstream/arc_easy_test_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_easy_test_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/arc_easy_test_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_easy_test_rc_5shot (soft loss)\n", + "\t eval/downstream/arc_easy_val_mc_5shot (BPB)\n", + "\t eval/downstream/arc_easy_val_mc_5shot (CE loss)\n", + "\t eval/downstream/arc_easy_val_mc_5shot (accuracy)\n", + "\t eval/downstream/arc_easy_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/arc_easy_val_mc_5shot (soft loss)\n", + "\t eval/downstream/arc_easy_val_rc_5shot (BPB)\n", + "\t eval/downstream/arc_easy_val_rc_5shot (CE loss)\n", + "\t eval/downstream/arc_easy_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/arc_easy_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/arc_easy_val_rc_5shot (soft loss)\n", + "\t eval/downstream/basic_arithmetic (BPB)\n", + "\t eval/downstream/basic_arithmetic (CE loss)\n", + "\t eval/downstream/basic_arithmetic (accuracy)\n", + "\t eval/downstream/basic_arithmetic (log soft loss)\n", + "\t eval/downstream/basic_arithmetic (soft loss)\n", + "\t eval/downstream/boolq (BPB)\n", + "\t eval/downstream/boolq (CE loss)\n", + "\t eval/downstream/boolq (accuracy)\n", + "\t eval/downstream/boolq (log soft loss)\n", + "\t eval/downstream/boolq (soft loss)\n", + "\t eval/downstream/boolq_val_mc_5shot (BPB)\n", + "\t eval/downstream/boolq_val_mc_5shot (CE loss)\n", + "\t eval/downstream/boolq_val_mc_5shot (accuracy)\n", + "\t eval/downstream/boolq_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/boolq_val_mc_5shot (soft loss)\n", + "\t eval/downstream/boolq_val_rc_5shot (BPB)\n", + "\t eval/downstream/boolq_val_rc_5shot (CE loss)\n", + "\t eval/downstream/boolq_val_rc_5shot (accuracy)\n", + "\t eval/downstream/boolq_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/boolq_val_rc_5shot (soft loss)\n", + "\t eval/downstream/commonsense_qa (BPB)\n", + "\t eval/downstream/commonsense_qa (CE loss)\n", + "\t eval/downstream/commonsense_qa (length-normalized accuracy)\n", + "\t eval/downstream/commonsense_qa (log soft loss)\n", + "\t eval/downstream/commonsense_qa (soft loss)\n", + "\t eval/downstream/copa (BPB)\n", + "\t eval/downstream/copa (CE loss)\n", + "\t eval/downstream/copa (accuracy)\n", + "\t eval/downstream/copa (log soft loss)\n", + "\t eval/downstream/copa (soft loss)\n", + "\t eval/downstream/csqa_rc_5shot (BPB)\n", + "\t eval/downstream/csqa_rc_5shot (CE loss)\n", + "\t eval/downstream/csqa_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/csqa_rc_5shot (log soft loss)\n", + "\t eval/downstream/csqa_rc_5shot (soft loss)\n", + "\t eval/downstream/csqa_val_mc_5shot (BPB)\n", + "\t eval/downstream/csqa_val_mc_5shot (CE loss)\n", + "\t eval/downstream/csqa_val_mc_5shot (accuracy)\n", + "\t eval/downstream/csqa_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/csqa_val_mc_5shot (soft loss)\n", + "\t eval/downstream/csqa_val_rc_5shot (BPB)\n", + "\t eval/downstream/csqa_val_rc_5shot (CE loss)\n", + "\t eval/downstream/csqa_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/csqa_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/csqa_val_rc_5shot (soft loss)\n", + "\t eval/downstream/hellaswag (BPB)\n", + "\t eval/downstream/hellaswag (CE loss)\n", + "\t eval/downstream/hellaswag (length-normalized accuracy)\n", + "\t eval/downstream/hellaswag (log soft loss)\n", + "\t eval/downstream/hellaswag (soft loss)\n", + "\t eval/downstream/hellaswag_rc_5shot (BPB)\n", + "\t eval/downstream/hellaswag_rc_5shot (CE loss)\n", + "\t eval/downstream/hellaswag_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/hellaswag_rc_5shot (log soft loss)\n", + "\t eval/downstream/hellaswag_rc_5shot (soft loss)\n", + "\t eval/downstream/hellaswag_val_mc_5shot (BPB)\n", + "\t eval/downstream/hellaswag_val_mc_5shot (CE loss)\n", + "\t eval/downstream/hellaswag_val_mc_5shot (accuracy)\n", + "\t eval/downstream/hellaswag_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/hellaswag_val_mc_5shot (soft loss)\n", + "\t eval/downstream/hellaswag_val_rc_5shot (BPB)\n", + "\t eval/downstream/hellaswag_val_rc_5shot (CE loss)\n", + "\t eval/downstream/hellaswag_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/hellaswag_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/hellaswag_val_rc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (BPB)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (CE loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (log soft loss)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (soft loss)\n", + "\t eval/downstream/mmlu_humanities_val_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_humanities_val_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_humanities_val_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_humanities_val_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_humanities_val_rc_5shot (BPB)\n", + "\t eval/downstream/mmlu_humanities_val_rc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_humanities_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_humanities_val_rc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_other_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (BPB)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (CE loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (log soft loss)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (soft loss)\n", + "\t eval/downstream/mmlu_other_val_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_other_val_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_other_val_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_other_val_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_other_val_rc_5shot (BPB)\n", + "\t eval/downstream/mmlu_other_val_rc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_other_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_other_val_rc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (BPB)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (CE loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (log soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (BPB)\n", + "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_social_sciences_val_rc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (BPB)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (CE loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (log soft loss)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (soft loss)\n", + "\t eval/downstream/mmlu_stem_val_mc_5shot (BPB)\n", + "\t eval/downstream/mmlu_stem_val_mc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_stem_val_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_stem_val_mc_5shot (soft loss)\n", + "\t eval/downstream/mmlu_stem_val_rc_5shot (BPB)\n", + "\t eval/downstream/mmlu_stem_val_rc_5shot (CE loss)\n", + "\t eval/downstream/mmlu_stem_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/mmlu_stem_val_rc_5shot (soft loss)\n", + "\t eval/downstream/openbook_qa (BPB)\n", + "\t eval/downstream/openbook_qa (CE loss)\n", + "\t eval/downstream/openbook_qa (length-normalized accuracy)\n", + "\t eval/downstream/openbook_qa (log soft loss)\n", + "\t eval/downstream/openbook_qa (soft loss)\n", + "\t eval/downstream/openbookqa_rc_5shot (BPB)\n", + "\t eval/downstream/openbookqa_rc_5shot (CE loss)\n", + "\t eval/downstream/openbookqa_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/openbookqa_rc_5shot (log soft loss)\n", + "\t eval/downstream/openbookqa_rc_5shot (soft loss)\n", + "\t eval/downstream/openbookqa_test_mc_5shot (BPB)\n", + "\t eval/downstream/openbookqa_test_mc_5shot (CE loss)\n", + "\t eval/downstream/openbookqa_test_mc_5shot (accuracy)\n", + "\t eval/downstream/openbookqa_test_mc_5shot (log soft loss)\n", + "\t eval/downstream/openbookqa_test_mc_5shot (soft loss)\n", + "\t eval/downstream/openbookqa_test_rc_5shot (BPB)\n", + "\t eval/downstream/openbookqa_test_rc_5shot (CE loss)\n", + "\t eval/downstream/openbookqa_test_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/openbookqa_test_rc_5shot (log soft loss)\n", + "\t eval/downstream/openbookqa_test_rc_5shot (soft loss)\n", + "\t eval/downstream/openbookqa_val_mc_5shot (BPB)\n", + "\t eval/downstream/openbookqa_val_mc_5shot (CE loss)\n", + "\t eval/downstream/openbookqa_val_mc_5shot (accuracy)\n", + "\t eval/downstream/openbookqa_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/openbookqa_val_mc_5shot (soft loss)\n", + "\t eval/downstream/openbookqa_val_rc_5shot (BPB)\n", + "\t eval/downstream/openbookqa_val_rc_5shot (CE loss)\n", + "\t eval/downstream/openbookqa_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/openbookqa_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/openbookqa_val_rc_5shot (soft loss)\n", + "\t eval/downstream/piqa (BPB)\n", + "\t eval/downstream/piqa (CE loss)\n", + "\t eval/downstream/piqa (length-normalized accuracy)\n", + "\t eval/downstream/piqa (log soft loss)\n", + "\t eval/downstream/piqa (soft loss)\n", + "\t eval/downstream/piqa_rc_5shot (BPB)\n", + "\t eval/downstream/piqa_rc_5shot (CE loss)\n", + "\t eval/downstream/piqa_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/piqa_rc_5shot (log soft loss)\n", + "\t eval/downstream/piqa_rc_5shot (soft loss)\n", + "\t eval/downstream/piqa_val_mc_5shot (BPB)\n", + "\t eval/downstream/piqa_val_mc_5shot (CE loss)\n", + "\t eval/downstream/piqa_val_mc_5shot (accuracy)\n", + "\t eval/downstream/piqa_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/piqa_val_mc_5shot (soft loss)\n", + "\t eval/downstream/piqa_val_rc_5shot (BPB)\n", + "\t eval/downstream/piqa_val_rc_5shot (CE loss)\n", + "\t eval/downstream/piqa_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/piqa_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/piqa_val_rc_5shot (soft loss)\n", + "\t eval/downstream/sciq (BPB)\n", + "\t eval/downstream/sciq (CE loss)\n", + "\t eval/downstream/sciq (accuracy)\n", + "\t eval/downstream/sciq (log soft loss)\n", + "\t eval/downstream/sciq (soft loss)\n", + "\t eval/downstream/social_iqa (BPB)\n", + "\t eval/downstream/social_iqa (CE loss)\n", + "\t eval/downstream/social_iqa (length-normalized accuracy)\n", + "\t eval/downstream/social_iqa (log soft loss)\n", + "\t eval/downstream/social_iqa (soft loss)\n", + "\t eval/downstream/socialiqa_rc_5shot (BPB)\n", + "\t eval/downstream/socialiqa_rc_5shot (CE loss)\n", + "\t eval/downstream/socialiqa_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/socialiqa_rc_5shot (log soft loss)\n", + "\t eval/downstream/socialiqa_rc_5shot (soft loss)\n", + "\t eval/downstream/socialiqa_val_mc_5shot (BPB)\n", + "\t eval/downstream/socialiqa_val_mc_5shot (CE loss)\n", + "\t eval/downstream/socialiqa_val_mc_5shot (accuracy)\n", + "\t eval/downstream/socialiqa_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/socialiqa_val_mc_5shot (soft loss)\n", + "\t eval/downstream/socialiqa_val_rc_5shot (BPB)\n", + "\t eval/downstream/socialiqa_val_rc_5shot (CE loss)\n", + "\t eval/downstream/socialiqa_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/socialiqa_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/socialiqa_val_rc_5shot (soft loss)\n", + "\t eval/downstream/winogrande (BPB)\n", + "\t eval/downstream/winogrande (CE loss)\n", + "\t eval/downstream/winogrande (accuracy)\n", + "\t eval/downstream/winogrande (log soft loss)\n", + "\t eval/downstream/winogrande (soft loss)\n", + "\t eval/downstream/winogrande_rc_5shot (BPB)\n", + "\t eval/downstream/winogrande_rc_5shot (CE loss)\n", + "\t eval/downstream/winogrande_rc_5shot (accuracy)\n", + "\t eval/downstream/winogrande_rc_5shot (log soft loss)\n", + "\t eval/downstream/winogrande_rc_5shot (soft loss)\n", + "\t eval/downstream/winogrande_val_mc_5shot (BPB)\n", + "\t eval/downstream/winogrande_val_mc_5shot (CE loss)\n", + "\t eval/downstream/winogrande_val_mc_5shot (accuracy)\n", + "\t eval/downstream/winogrande_val_mc_5shot (log soft loss)\n", + "\t eval/downstream/winogrande_val_mc_5shot (soft loss)\n", + "\t eval/downstream/winogrande_val_rc_5shot (BPB)\n", + "\t eval/downstream/winogrande_val_rc_5shot (CE loss)\n", + "\t eval/downstream/winogrande_val_rc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/winogrande_val_rc_5shot (log soft loss)\n", + "\t eval/downstream/winogrande_val_rc_5shot (soft loss)\n", + "\t eval/lm/c4_en-validation/CE loss\n", + "\t eval/lm/c4_en-validation/PPL\n", + "\t eval/lm/dolma_books-validation/CE loss\n", + "\t eval/lm/dolma_books-validation/PPL\n", + "\t eval/lm/dolma_common-crawl-validation/CE loss\n", + "\t eval/lm/dolma_common-crawl-validation/PPL\n", + "\t eval/lm/dolma_pes2o-validation/CE loss\n", + "\t eval/lm/dolma_pes2o-validation/PPL\n", + "\t eval/lm/dolma_reddit-validation/CE loss\n", + "\t eval/lm/dolma_reddit-validation/PPL\n", + "\t eval/lm/dolma_stack-validation/CE loss\n", + "\t eval/lm/dolma_stack-validation/PPL\n", + "\t eval/lm/dolma_wiki-validation/CE loss\n", + "\t eval/lm/dolma_wiki-validation/PPL\n", + "\t eval/lm/ice-validation/CE loss\n", + "\t eval/lm/ice-validation/PPL\n", + "\t eval/lm/m2d2_s2orc-validation/CE loss\n", + "\t eval/lm/m2d2_s2orc-validation/PPL\n", + "\t eval/lm/pile-validation/CE loss\n", + "\t eval/lm/pile-validation/PPL\n", + "\t eval/lm/wikitext_103-validation/CE loss\n", + "\t eval/lm/wikitext_103-validation/PPL\n", + "\t optim/LR (group 0)\n", + "\t optim/LR (group 1)\n", + "\t optim/step skipped\n", + "\t optim/total grad norm\n", + "\t sys.compute.overall\n", + "\t sys.compute.utilized\n", + "\t sys.cpu.percent.avg\n", + "\t sys.disk.read_bps\n", + "\t sys.disk.root.percent.used\n", + "\t sys.disk.root.used\n", + "\t sys.disk.write_bps\n", + "\t sys.gpu.0.free_memory\n", + "\t sys.gpu.0.gpu_utilization\n", + "\t sys.gpu.0.memory_utilization\n", + "\t sys.gpu.0.percent.used_memory\n", + "\t sys.gpu.0.power_usage\n", + "\t sys.gpu.0.temperature\n", + "\t sys.gpu.0.total_memory\n", + "\t sys.gpu.0.used_memory\n", + "\t sys.gpu.1.free_memory\n", + "\t sys.gpu.1.gpu_utilization\n", + "\t sys.gpu.1.memory_utilization\n", + "\t sys.gpu.1.percent.used_memory\n", + "\t sys.gpu.1.power_usage\n", + "\t sys.gpu.1.temperature\n", + "\t sys.gpu.1.total_memory\n", + "\t sys.gpu.1.used_memory\n", + "\t sys.gpu.2.free_memory\n", + "\t sys.gpu.2.gpu_utilization\n", + "\t sys.gpu.2.memory_utilization\n", + "\t sys.gpu.2.percent.used_memory\n", + "\t sys.gpu.2.power_usage\n", + "\t sys.gpu.2.temperature\n", + "\t sys.gpu.2.total_memory\n", + "\t sys.gpu.2.used_memory\n", + "\t sys.gpu.3.free_memory\n", + "\t sys.gpu.3.gpu_utilization\n", + "\t sys.gpu.3.memory_utilization\n", + "\t sys.gpu.3.percent.used_memory\n", + "\t sys.gpu.3.power_usage\n", + "\t sys.gpu.3.temperature\n", + "\t sys.gpu.3.total_memory\n", + "\t sys.gpu.3.used_memory\n", + "\t sys.gpu.4.free_memory\n", + "\t sys.gpu.4.gpu_utilization\n", + "\t sys.gpu.4.memory_utilization\n", + "\t sys.gpu.4.percent.used_memory\n", + "\t sys.gpu.4.power_usage\n", + "\t sys.gpu.4.temperature\n", + "\t sys.gpu.4.total_memory\n", + "\t sys.gpu.4.used_memory\n", + "\t sys.gpu.5.free_memory\n", + "\t sys.gpu.5.gpu_utilization\n", + "\t sys.gpu.5.memory_utilization\n", + "\t sys.gpu.5.percent.used_memory\n", + "\t sys.gpu.5.power_usage\n", + "\t sys.gpu.5.temperature\n", + "\t sys.gpu.5.total_memory\n", + "\t sys.gpu.5.used_memory\n", + "\t sys.gpu.6.free_memory\n", + "\t sys.gpu.6.gpu_utilization\n", + "\t sys.gpu.6.memory_utilization\n", + "\t sys.gpu.6.percent.used_memory\n", + "\t sys.gpu.6.power_usage\n", + "\t sys.gpu.6.temperature\n", + "\t sys.gpu.6.total_memory\n", + "\t sys.gpu.6.used_memory\n", + "\t sys.gpu.7.free_memory\n", + "\t sys.gpu.7.gpu_utilization\n", + "\t sys.gpu.7.memory_utilization\n", + "\t sys.gpu.7.percent.used_memory\n", + "\t sys.gpu.7.power_usage\n", + "\t sys.gpu.7.temperature\n", + "\t sys.gpu.7.total_memory\n", + "\t sys.gpu.7.used_memory\n", + "\t sys.load.avg\n", + "\t sys.network.receive_bps\n", + "\t sys.network.send_bps\n", + "\t sys.ram.available\n", + "\t sys.ram.percent.used\n", + "\t sys.ram.total\n", + "\t sys.ram.used\n", + "\t system/GPU active mem (%)\n", + "\t system/GPU active mem (GiB)\n", + "\t system/GPU reserved mem (%)\n", + "\t system/GPU reserved mem (GiB)\n", + "\t throughput/device/BPS\n", + "\t throughput/device/BPS (actual avg)\n", + "\t throughput/device/TPS\n", + "\t throughput/device/TPS (actual avg)\n", + "\t throughput/device/data loading (%)\n", + "\t throughput/device/data loading (s)\n", + "\t throughput/total tokens\n", + "\t train/CE loss\n", + "\t train/PPL\n", + "\t train/Z loss\n", + "peteish13:\n", + "\t eval/downstream/arc_challenge (length-normalized accuracy)\n", + "\t eval/downstream/arc_easy (accuracy)\n", + "\t eval/downstream/basic_arithmetic (accuracy)\n", + "\t eval/downstream/boolq (accuracy)\n", + "\t eval/downstream/commonsense_qa (length-normalized accuracy)\n", + "\t eval/downstream/copa (accuracy)\n", + "\t eval/downstream/hellaswag (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_humanities_var (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_other_var (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_social_sciences_var (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_mc_5shot_test (length-normalized accuracy)\n", + "\t eval/downstream/mmlu_stem_var (length-normalized accuracy)\n", + "\t eval/downstream/openbook_qa (length-normalized accuracy)\n", + "\t eval/downstream/piqa (length-normalized accuracy)\n", + "\t eval/downstream/sciq (accuracy)\n", + "\t eval/downstream/social_iqa (length-normalized accuracy)\n", + "\t eval/downstream/winogrande (accuracy)\n", + "\t optim/LR (group 0)\n", + "\t optim/LR (group 1)\n", + "\t optim/total grad norm\n", + "\t sys.compute.overall\n", + "\t sys.compute.utilized\n", + "\t sys.cpu.percent.avg\n", + "\t sys.disk.read_bps\n", + "\t sys.disk.root.percent.used\n", + "\t sys.disk.root.used\n", + "\t sys.disk.write_bps\n", + "\t sys.gpu.0.free_memory\n", + "\t sys.gpu.0.gpu_utilization\n", + "\t sys.gpu.0.memory_utilization\n", + "\t sys.gpu.0.percent.used_memory\n", + "\t sys.gpu.0.power_usage\n", + "\t sys.gpu.0.temperature\n", + "\t sys.gpu.0.total_memory\n", + "\t sys.gpu.0.used_memory\n", + "\t sys.gpu.1.free_memory\n", + "\t sys.gpu.1.gpu_utilization\n", + "\t sys.gpu.1.memory_utilization\n", + "\t sys.gpu.1.percent.used_memory\n", + "\t sys.gpu.1.power_usage\n", + "\t sys.gpu.1.temperature\n", + "\t sys.gpu.1.total_memory\n", + "\t sys.gpu.1.used_memory\n", + "\t sys.gpu.2.free_memory\n", + "\t sys.gpu.2.gpu_utilization\n", + "\t sys.gpu.2.memory_utilization\n", + "\t sys.gpu.2.percent.used_memory\n", + "\t sys.gpu.2.power_usage\n", + "\t sys.gpu.2.temperature\n", + "\t sys.gpu.2.total_memory\n", + "\t sys.gpu.2.used_memory\n", + "\t sys.gpu.3.free_memory\n", + "\t sys.gpu.3.gpu_utilization\n", + "\t sys.gpu.3.memory_utilization\n", + "\t sys.gpu.3.percent.used_memory\n", + "\t sys.gpu.3.power_usage\n", + "\t sys.gpu.3.temperature\n", + "\t sys.gpu.3.total_memory\n", + "\t sys.gpu.3.used_memory\n", + "\t sys.load.avg\n", + "\t sys.network.receive_bps\n", + "\t sys.network.send_bps\n", + "\t sys.ram.available\n", + "\t sys.ram.percent.used\n", + "\t sys.ram.total\n", + "\t sys.ram.used\n", + "\t throughput/device/BPS\n", + "\t throughput/device/TPS\n", + "\t train/CE loss\n", + "\t train/PPL\n", + "\t train/Z loss\n", + "peteish7:\n", + "\t optim/LR (group 0)\n", + "\t optim/LR (group 1)\n", + "\t optim/total grad norm\n", + "\t sys.compute.overall\n", + "\t sys.compute.utilized\n", + "\t sys.cpu.percent.avg\n", + "\t sys.disk.read_bps\n", + "\t sys.disk.root.percent.used\n", + "\t sys.disk.root.used\n", + "\t sys.disk.write_bps\n", + "\t sys.load.avg\n", + "\t sys.network.receive_bps\n", + "\t sys.network.send_bps\n", + "\t sys.ram.available\n", + "\t sys.ram.percent.used\n", + "\t sys.ram.total\n", + "\t sys.ram.used\n", + "\t throughput/device/BPS\n", + "\t throughput/device/TPS\n", + "\t train/CE loss\n", + "\t train/PPL\n", + "\t train/Z loss\n" + ] + } + ], + "execution_count": 3 + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T05:44:08.292588Z", + "start_time": "2025-01-20T05:43:16.581296Z" + } + }, + "cell_type": "code", + "source": [ + "from tqdm.notebook import tqdm\n", + "\n", + "def download_metric(exps, metric_name):\n", + " result = {}\n", + " for exp in tqdm(exps):\n", + " metrics = exp.get_metrics(metric_name)\n", + " for values in metrics:\n", + " result[values['step']] = float(values['metricValue'])\n", + " result = dict(sorted(result.items()))\n", + " return result\n", + "\n", + "loss = {\n", + " name: download_metric(es, \"train/CE loss\")\n", + " for name, es in exps.items()\n", + "}\n", + "\n", + "skipped_steps = {\n", + " name: download_metric(es, \"optim/step skipped\")\n", + " for name, es in exps.items()\n", + "}\n", + "\n", + "c4loss = {\n", + " name: download_metric(es, \"eval/lm/c4_en-validation/CE loss\")\n", + " for name, es in exps.items()\n", + "}" + ], + "id": "6aa86a5638253061", + "outputs": [ + { + "data": { + "text/plain": [ + " 0%| | 0/50 [00:00 0])" + ], + "id": "277e0e889edb7b16", + "outputs": [ + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-19T21:44:08.518243\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Steps skipped for the 32B: 86\n", + "[80788, 81072, 84048, 85129, 87386, 92844, 107316, 111491, 113030, 114230, 118668, 121925, 126863, 127493, 128136, 129747, 134843, 136385, 142362, 142815, 144303, 144548, 147139, 147455, 148216, 148703, 150206, 154267, 159678, 159881, 160407, 163682, 167141, 167784, 175621, 187888, 188783, 194308, 204820, 205830, 206617, 212691, 217589, 226667, 230116, 231534, 232070, 232547, 233702, 241716, 242968, 246500, 249425, 250814, 251497, 256147, 257243, 259138, 262122, 263662, 264763, 266531, 267012, 283650, 290117, 290727, 291531, 294977, 295352, 297826, 298508, 310444, 311075, 314537, 319684, 320934, 323994, 325632, 327859, 332049, 333763, 341230, 348251, 379042, 381547, 382692]\n" + ] + } + ], + "execution_count": 5 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Downstream", + "id": "83cbde8bd1160629" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T05:48:15.449968Z", + "start_time": "2025-01-20T05:44:08.686031Z" + } + }, + "cell_type": "code", + "source": [ + "aggregate_metric_definitions = {\n", + " \"MMLU 5-shot MC\": {\n", + " \"eval/downstream/mmlu_stem_mc_5shot (length-normalized accuracy)\": 0.215,\n", + " \"eval/downstream/mmlu_humanities_mc_5shot (length-normalized accuracy)\": 0.335,\n", + " \"eval/downstream/mmlu_social_sciences_mc_5shot (length-normalized accuracy)\": 0.219,\n", + " \"eval/downstream/mmlu_other_mc_5shot (length-normalized accuracy)\": 0.231\n", + " },\n", + " \"Average of core 12\": {\n", + " \"eval/downstream/arc_challenge (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/arc_easy (accuracy)\": 1 / 12,\n", + " \"eval/downstream/basic_arithmetic (accuracy)\": 1 / 12,\n", + " \"eval/downstream/boolq (accuracy)\": 1 / 12,\n", + " \"eval/downstream/commonsense_qa (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/copa (accuracy)\": 1 / 12,\n", + " \"eval/downstream/hellaswag (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/openbook_qa (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/piqa (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/sciq (accuracy)\": 1 / 12,\n", + " \"eval/downstream/social_iqa (length-normalized accuracy)\": 1 / 12,\n", + " \"eval/downstream/winogrande (accuracy)\": 1 / 12,\n", + " },\n", + " \"Hellswag\": {\n", + " \"eval/downstream/hellaswag (length-normalized accuracy)\": 1\n", + " }\n", + "}\n", + "\n", + "import matplotlib.pyplot as plt\n", + "%matplotlib inline\n", + "%config InlineBackend.figure_format = 'svg'\n", + "import numpy as np\n", + "\n", + "fig, axs = plt.subplots(nrows=len(aggregate_metric_definitions), sharex=True, figsize=(10, len(aggregate_metric_definitions)*3))\n", + "\n", + "for ax, agg_metric_name in zip(axs, aggregate_metric_definitions):\n", + " metric_to_weight = aggregate_metric_definitions[agg_metric_name]\n", + " for run_name, run_exps in exps.items():\n", + " metric_to_values = {}\n", + " for metric in metric_to_weight.keys():\n", + " metric_to_values[metric] = download_metric(run_exps, metric)\n", + "\n", + " all_steps = set.union(*[set(v.keys()) for v in metric_to_values.values()])\n", + " minimal_steps = set.intersection(*[set(v.keys()) for v in metric_to_values.values()])\n", + " if all_steps != minimal_steps:\n", + " print(f\"Missing steps for {run_name} / {agg_metric_name}: {all_steps - minimal_steps}\")\n", + "\n", + " aggregated_values = {}\n", + " for step in minimal_steps:\n", + " value = 0.0\n", + " for metric, weight in metric_to_weight.items():\n", + " value += metric_to_values[metric][step] * weight\n", + " aggregated_values[step] = value\n", + " if len(aggregated_values) == 0:\n", + " continue\n", + "\n", + " print(f\"{run_name} / {agg_metric_name} max: {max(aggregated_values.values())}\")\n", + "\n", + " xs = np.array(list(aggregated_values.keys()))\n", + " ys = np.array(list(aggregated_values.values()))\n", + " order = np.argsort(xs)\n", + " xs = xs[order]\n", + " ys = ys[order]\n", + " xs *= (2048 * 4096)\n", + " ax.plot(xs, ys, linewidth=0.5)\n", + " ax.set_ylabel(agg_metric_name)\n", + "\n", + "plt.xlabel(\"step\")\n", + "plt.show()" + ], + "id": "8b310d9cc68ad856", + "outputs": [ + { + "data": { + "text/plain": [ + " 0%| | 0/50 [00:00" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-19T21:48:15.428828\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 6 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "## Spike Analysis", + "id": "744574cd19bbe369" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-01-20T05:48:16.039797Z", + "start_time": "2025-01-20T05:48:15.479419Z" + } + }, + "cell_type": "code", + "source": [ + "window_size = 128\n", + "losses = np.array(list(loss[\"peteish32\"].values()))\n", + "steps = np.array(list(loss[\"peteish32\"].keys()))\n", + "\n", + "from numpy.lib.stride_tricks import sliding_window_view\n", + "windows = sliding_window_view(losses, window_size)\n", + "\n", + "stds = windows.std(axis=1)\n", + "means = windows.mean(axis=1)\n", + "losses = losses[window_size - 1 :]\n", + "steps = steps[window_size - 1 :]\n", + "spike_steps = steps[np.argwhere(losses > means + stds * 6)].flatten()\n", + "print(f\"Steps with spikes: {spike_steps}\")\n", + "\n", + "fig, axes = plt.subplots(\n", + " nrows=len(spike_steps),\n", + " figsize=(7, len(spike_steps)*3),\n", + " sharex=False\n", + ")\n", + "\n", + "for ax, spike in zip(axes, spike_steps):\n", + " for name, values in loss.items():\n", + " xs = np.array(list(values.keys()))\n", + " ys = np.array(list(values.values()))\n", + " ax.plot(xs, ys, linewidth=0.5)\n", + " ax.set_ylim(2.1, 2.5)\n", + " ax.set_xlim(spike-1000, spike+1000)\n", + " plt.yscale('log')\n", + " plt.xlabel(\"step\")\n", + " plt.ylabel(\"loss\")\n", + "\n", + "plt.tight_layout()\n", + "plt.show()\n" + ], + "id": "6eb5abfb647663a5", + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Steps with spikes: [ 29645 38677 49089 54503 66257 73019 144302]\n" + ] + }, + { + "data": { + "text/plain": [ + "
" + ], + "image/svg+xml": "\n\n\n \n \n \n \n 2025-01-19T21:48:15.977016\n image/svg+xml\n \n \n Matplotlib v3.9.2, https://matplotlib.org/\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n\n" + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "execution_count": 7 + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/src/scripts/train/OLMo2-32B.py b/src/scripts/train/OLMo2-32B.py new file mode 100644 index 00000000..ea847522 --- /dev/null +++ b/src/scripts/train/OLMo2-32B.py @@ -0,0 +1,204 @@ +""" +Train a 32B OLMo model. Run this script without any arguments to see usage info. +""" + +import logging + +from olmo_core.config import DType +from olmo_core.distributed.parallel import DataParallelType +from olmo_core.float8 import Float8Config +from olmo_core.internal.experiment import CommonComponents, main +from olmo_core.nn.transformer import ( + TransformerActivationCheckpointingConfig, + TransformerActivationCheckpointingMode, + TransformerConfig, + TransformerDataParallelConfig, +) +from olmo_core.optim import OptimGroupOverride, SkipStepAdamWConfig +from olmo_core.train import Duration, DurationUnit, TrainerConfig +from olmo_core.train.callbacks import ( + CheckpointerCallback, + CometCallback, + DownstreamEvaluatorCallbackConfig, + ProfilerCallback, + WandBCallback, +) +from olmo_core.train.checkpoint import CheckpointerConfig + +log = logging.getLogger(__name__) + + +def build_model_config(common: CommonComponents) -> TransformerConfig: + compile = True + return TransformerConfig.olmo2_32B( + vocab_size=common.tokenizer.padded_vocab_size(), + compile=compile, + fused_ops=False, + use_flash=not compile, + dp_config=TransformerDataParallelConfig( + name=DataParallelType.fsdp, param_dtype=DType.bfloat16, reduce_dtype=DType.float32 + ), + # dp_config=TransformerDataParallelConfig( + # name=DataParallelType.hsdp, + # param_dtype=DType.bfloat16, + # reduce_dtype=DType.float32, + # num_replicas=64 // 16, # common.launch.num_nodes // 2, + # ), + # ac_config=TransformerActivationCheckpointingConfig(TransformerActivationCheckpointingMode.full), + ac_config=TransformerActivationCheckpointingConfig( + mode=TransformerActivationCheckpointingMode.selected_modules, + modules=[f"blocks.{i}.feed_forward" for i in range(64)], + ), + float8_config=Float8Config(compile=compile, enabled=False), + ) + + +def build_optim_config(common: CommonComponents) -> SkipStepAdamWConfig: + del common + return SkipStepAdamWConfig( + lr=6e-4, + weight_decay=0.1, + betas=(0.9, 0.95), + group_overrides=[ + OptimGroupOverride(params=["embeddings.weight"], opts=dict(weight_decay=0.0)) + ], + # fused=True, + compile=True, + ) + + +def build_trainer_config(common: CommonComponents) -> TrainerConfig: + project_name = "peteish32" + return ( + TrainerConfig( + save_folder=f"gs://ai2-llm/checkpoints/{project_name}/", + rank_microbatch_size=2 * 4096, + checkpointer=CheckpointerConfig( + save_thread_count=1, load_thread_count=32, throttle_uploads=True + ), + save_overwrite=True, + metrics_collect_interval=10, + cancel_check_interval=10, + z_loss_multiplier=1e-5, + compile_loss=False, + fused_loss=True, + max_duration=Duration(int(6.5e12), DurationUnit.tokens), + ) + .with_callback( + "checkpointer", + CheckpointerCallback( + save_interval=1000, + save_async=True, + ), + ) + .with_callback( + "profiler", ProfilerCallback(skip_first=3, wait=10, warmup=2, active=3, repeat=1) + ) + .with_callback( + "comet", + CometCallback( + name=common.run_name, + workspace="ai2", + project=project_name, + enabled=True, + cancel_check_interval=10, + ), + ) + .with_callback( + "wandb", + WandBCallback( + name=common.run_name, + entity="ai2-llm", + project=project_name, + enabled=False, + cancel_check_interval=10, + ), + ) + .with_callback( + "downstream_evaluator", + DownstreamEvaluatorCallbackConfig( + tasks=[ + # MMLU for backwards compatibility + "mmlu_stem_mc_5shot", + "mmlu_humanities_mc_5shot", + "mmlu_social_sciences_mc_5shot", + "mmlu_other_mc_5shot", + # MMLU test + "mmlu_stem_mc_5shot_test", + "mmlu_humanities_mc_5shot_test", + "mmlu_social_sciences_mc_5shot_test", + "mmlu_other_mc_5shot_test", + ## Core 12 tasks for backwards compatibility + # "arc_challenge", + # "arc_easy", + # "basic_arithmetic", + # "boolq", + # "commonsense_qa", + # "copa", + # "hellaswag", + # "openbook_qa", + # "piqa", + # "sciq", + # "social_iqa", + # "winogrande", + ## Core 12 tasks 5-shot + # "arc_challenge_rc_5shot", + # "arc_easy_rc_5shot", + ## "basic_arithmetic_rc_5shot", # doesn't exist + ## "boolq_rc_5shot", # we don't like it + # "csqa_rc_5shot", + ## "copa_rc_5shot", # doesn't exist + # "hellaswag_rc_5shot", + # "openbookqa_rc_5shot", + # "piqa_rc_5shot", + ## "sciq_rc_5shot", # doesn't exist + # "socialiqa_rc_5shot", + # "winogrande_rc_5shot", + ## New in-loop evals + # "arc_challenge_val_rc_5shot", + # "arc_challenge_val_mc_5shot", + "arc_challenge_test_rc_5shot", + # "arc_challenge_test_mc_5shot", + # "arc_easy_val_rc_5shot", + # "arc_easy_val_mc_5shot", + "arc_easy_test_rc_5shot", + # "arc_easy_test_mc_5shot", + # "boolq_val_rc_5shot", + # "boolq_val_mc_5shot", + "csqa_val_rc_5shot", + # "csqa_val_mc_5shot", + "hellaswag_val_rc_5shot", + # "hellaswag_val_mc_5shot", + # "openbookqa_val_rc_5shot", + # "openbookqa_val_mc_5shot", + "openbookqa_test_rc_5shot", + # "openbookqa_test_mc_5shot", + "piqa_val_rc_5shot", + # "piqa_val_mc_5shot", + "socialiqa_val_rc_5shot", + # "socialiqa_val_mc_5shot", + # "winogrande_val_rc_5shot", + # "winogrande_val_mc_5shot", + # "mmlu_stem_val_rc_5shot", + # "mmlu_stem_val_mc_5shot", + # "mmlu_humanities_val_rc_5shot", + # "mmlu_humanities_val_mc_5shot", + # "mmlu_social_sciences_val_rc_5shot", + # "mmlu_social_sciences_val_mc_5shot", + # "mmlu_other_val_rc_5shot", + # "mmlu_other_val_mc_5shot", + ], + tokenizer=common.tokenizer, + eval_interval=1000, + ), + ) + ) + + +if __name__ == "__main__": + main( + global_batch_size=2048 * 4096, + model_config_builder=build_model_config, + optim_config_builder=build_optim_config, + trainer_config_builder=build_trainer_config, + )