diff --git a/examples_deepspeed/compile/README.md b/examples_deepspeed/compile/README.md new file mode 100644 index 0000000000..d811b7f272 --- /dev/null +++ b/examples_deepspeed/compile/README.md @@ -0,0 +1,38 @@ +# Run with torch.compile (Experimental) + +This folder contains an example that enables `torch.compile`. + + +## Enable `torch.compile` + +This example tested with the followings: + +- PyTorch [f614ed7](https://github.com/pytorch/pytorch/commit/f614ed78b8b8521900385542d515e60520915ca3) +-- The nightly revision as of Dec 8, 2023 +-- We recommend using the latest nightly revision +- DeepSpeed [tohtana/compile-zero](https://github.com/microsoft/DeepSpeed/tree/tohtana/compile-zero) + +You need to add `compile` section in your DeepSpeed config. The fields in the section are passed to `torch.compile`. + +```json + "compile": { + "disable": false, + "backend": "inductor" + }, +``` + +Only the small extension is on `backend`. To enable a custom backend, you can pass the fully qualified name of the backend funciton, . For example, if you have a backend class `my_backend` in `my_backend.py` in the current directory, you can enable it by `"backend": "my_backend.my_backend"`. + +`run.sh` in this folder shows how to run with `torch.compile`. +`run_validation.sh` runs different configurations to verifies the results. + +## Verification + +The following chart shows TFLOPS and loss curves resulting from No ZeRO and ZeRO 1/2/3. + +- Sequence length: 2048 +- Global batch size: 4 +- Model: GPT-1.3B +- GPUS: 4x A100 (80GB) + +![verification](loss_verification.png) diff --git a/examples_deepspeed/compile/ds_config_gpt_TEMPLATE.json b/examples_deepspeed/compile/ds_config_gpt_TEMPLATE.json new file mode 100644 index 0000000000..91a7a7376c --- /dev/null +++ b/examples_deepspeed/compile/ds_config_gpt_TEMPLATE.json @@ -0,0 +1,27 @@ +{ + "train_batch_size": GBSIZE, + "train_micro_batch_size_per_gpu": MBSIZE, + "steps_per_print": LOG_INTERVAL, + + "zero_optimization": { + "stage": ZERO_STAGE + }, + + "gradient_clipping": 1.0, + "prescale_gradients": PRESCALE_GRAD, + + "fp16": { + "enabled": true, + "loss_scale": 0, + "loss_scale_window": 500, + "hysteresis": 2, + "min_loss_scale": 1, + "initial_scale_power": 11 + }, + + "compile": { + "disable": DISABLE_COMPILE + }, + + "wall_clock_breakdown" : false +} diff --git a/examples_deepspeed/compile/loss_verification.png b/examples_deepspeed/compile/loss_verification.png new file mode 100644 index 0000000000..489445ec5f Binary files /dev/null and b/examples_deepspeed/compile/loss_verification.png differ diff --git a/examples_deepspeed/compile/run.sh b/examples_deepspeed/compile/run.sh new file mode 100644 index 0000000000..53337fae0f --- /dev/null +++ b/examples_deepspeed/compile/run.sh @@ -0,0 +1,355 @@ +#!/bin/bash +dir=`pwd` +############################################################################### +### Main configs +## GPT-3 models use 2K sequence length/context window +seq_len=2048 + +## The "GPT-3 XXX" below are configs from GPT-3 paper +## https://arxiv.org/abs/2005.14165, choose based on +## your desired model size or build your own configs + +## init_std is standard deviation for weight initialization. Usually larger +## model needs lower std. We used a heuristic equation of sqrt(1/3/hidden_size) +## from the MT-NLG 530B work (https://arxiv.org/pdf/2201.11990.pdf) + +## We changed min_lr to a lower number (1.0e-6), which we found is able to +## provide better zero-shot eval results. + +## GPT-3 Small 125M +# model_size=0.125 +# num_layers=12 +# hidden_size=768 +# num_attn_heads=12 +# global_batch_size=256 +# lr=6.0e-4 +# min_lr=1.0e-6 +# init_std=0.02 + +## GPT-3 Medium 350M +# model_size=0.35 +# num_layers=24 +# hidden_size=1024 +# num_attn_heads=16 +# global_batch_size=256 +# lr=3.0e-4 +# min_lr=1.0e-6 +# init_std=0.018 + +## GPT-3 Large 760M +# model_size=0.76 +# num_layers=24 +# hidden_size=1536 +# num_attn_heads=16 +# global_batch_size=256 +# lr=2.5e-4 +# min_lr=1.0e-6 +# init_std=0.015 + +## GPT-3 XL 1.3B +model_size=1.3 +num_layers=24 +hidden_size=2048 +num_attn_heads=16 +global_batch_size=16 +lr=2.0e-4 +min_lr=1.0e-6 +init_std=0.013 + +## GPT-3 2.7B +# model_size=2.7 +# num_layers=32 +# hidden_size=2560 +# num_attn_heads=32 +# global_batch_size=512 +# lr=1.6e-4 +# min_lr=1.0e-6 +# init_std=0.011 + +## GPT-3 6.7B +# model_size=6.7 +# num_layers=32 +# hidden_size=4096 +# num_attn_heads=32 +# global_batch_size=1024 +# lr=1.2e-4 +# min_lr=1.0e-6 +# init_std=0.009 + +## GPT-3 13B +# model_size=13 +# num_layers=40 +# hidden_size=5120 +# num_attn_heads=40 +# global_batch_size=1024 +# lr=1.0e-4 +# min_lr=1.0e-6 +# init_std=0.008 + +## GPT-3 175B +# model_size=175 +# num_layers=96 +# hidden_size=12288 +# num_attn_heads=96 +# global_batch_size=1536 +# lr=0.6e-4 +# min_lr=1.0e-6 +# init_std=0.005 +############################################################################### +### Training duration configs +## The main termination condition, original GPT-3 paper trains for 300B tokens. +train_tokens_in_billion=300 +train_tokens=$((${train_tokens_in_billion} * 1000000000)) + +## train_samples is another termination condition and also affect the number of +## data samples to be indexed. Since we want to reach the train_tokens +## above, and data efficiency techniques may change num tokens in some samples, +## so we just set this config large enough to make sure we have enough +## processed data and don't terminate by train_samples. +train_samples=$(( 300 * 1000000000 * 2 / ${seq_len} )) + +## Another wall-clock time termination condition in minutes. Set it large +## enough to avoid undesired early termination. +exit_duration=30000000 +############################################################################### +### lr configs +## lr warmup and decay duration. +## Original GPT-3 paper uses 375M warmup tokens and 260B cosine decay tokens. +## Here we increase the warmup tokens to 3B since when batch size warmup is not +## used, there are more tokens per step. Thus we need to increase warmup tokens +## to make sure there are enough warmup steps, which is important for training +## stability. +lr_warmup_tokens_in_million=3000 +lr_warmup_tokens=$((${lr_warmup_tokens_in_million} * 1000000)) +## Here we changed the LR decay tokens to align with total train tokens, since +## related works (e.g., https://arxiv.org/abs/2203.15556) find that setting the +## learning rate schedule to match the number of training tokens results in the +## best final model quality +lr_decay_tokens_in_billion=${train_tokens_in_billion} +lr_decay_tokens=$((${lr_decay_tokens_in_billion} * 1000000000)) +lr_decay_style="cosine" +############################################################################### +### Parallelism configs +## Model parallelism, 1 is no MP +## Currently we only support MP=1 with SP>1 +mp_size=1 + +## Sequence parallelism, 1 is no SP +sp_size=1 + +## Pipeline parallelism. To disable PP, set pp_size to 1 and no_pp to true. +## Note that currently both curriculum learning and random-LTD are NOT +## compatible with pipeline parallelism. +pp_size=1 +no_pp="true" + +## ZeRO-based data parallelism, stage=0 will disable ZeRO +zero_stage=${ZERO_STAGE:-1} + +COMPILE=${COMPILE:-"false"} +DEBUG_STEPS=${DEBUG_STEPS:-"0"} + +## Total number of GPUs. ds_ssh is from DeepSpeed library. +num_gpus=$(($(ds_ssh nvidia-smi --query-gpu=name --format=csv,noheader | wc -l)-2)) +num_gpus_pernode=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +num_node=$(( ${num_gpus} / ${num_gpus_pernode} )) + +## Data parallel size. +dp_size=$(( ${num_gpus} / ${pp_size} / ${mp_size} / ${sp_size} )) + +## Micro batch size per GPU +## Make sure that batch_size <= global_batch_size*pp_size*mp_size/num_gpus +## Reduce it manually if GPU OOM +# batch_size=$(( ${global_batch_size} / ${dp_size} )) +batch_size=4 + +############################################################################### +### Misc configs +log_interval=1 +eval_iters=10 +eval_interval=100000 +# num_save controls how frequent to save checkpoint. num_save=20 means that a +# checkpoint will be saved every 5% of training. For longer training you would +# want larger num_save to save more frequently, and vice versa. +num_save=100 +estimated_train_iter=$((${train_tokens} / ${seq_len} / ${global_batch_size})) +# save_interval=$((${estimated_train_iter} / ${num_save})) +save_interval=100000 + +## Activation checkpointing saves GPU memory, but reduces training speed +# activation_checkpoint="true" +activation_checkpoint="false" + +## Whether or not log optimizer states (norms, max abs values) to tensorboard. +## This is not required for training and might save GPU memory when turned off. +log_optimizer_state="true" +############################################################################### +### Output and data configs +current_time=$(date "+%Y.%m.%d_%H.%M.%S") +host="${HOSTNAME}" +seed=1234 +num_workers=0 + +# data_path="BookCorpusDataset_text_document" +data_path="/data/users/mtanaka/work/tcomp/bs_MDS/my-gpt2_text_document" +# if [ ! -f "BookCorpusDataset_text_document.bin" ]; then +# wget https://the-eye.eu/public/AI/pile_neox/data/BookCorpusDataset_text_document.bin +# fi +# if [ ! -f "BookCorpusDataset_text_document.idx" ]; then +# wget https://the-eye.eu/public/AI/pile_neox/data/BookCorpusDataset_text_document.idx +# fi + +vocab_path="/data/users/mtanaka/work/tcomp/bs_MDS/gpt2-vocab.json" +# if [ ! -f "$vocab_path" ]; then +# wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json +# fi +merge_path="/data/users/mtanaka/work/tcomp/bs_MDS/gpt2-merges.txt" +# if [ ! -f "$merge_path" ]; then +# wget https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt +# fi + +prescale_grad="true" +jobname="gpt_${model_size}B_tok${train_tokens_in_billion}B" +jobname="${jobname}_lr${lr}_min${min_lr}_w${lr_warmup_tokens_in_million}M_d${lr_decay_tokens_in_billion}B_${lr_decay_style}" +jobname="${jobname}_gbs${global_batch_size}_mbs${batch_size}_g${num_gpus}" +if [[ $zero_stage -gt 0 ]]; then + jobname="${jobname}_z${zero_stage}" + prescale_grad="false" +fi +if [[ $sp_size -gt 1 ]]; then + jobname="${jobname}_sp${sp_size}" +fi +if [[ $mp_size -gt 1 ]]; then + jobname="${jobname}_mp${mp_size}" +fi +if [ "${no_pp}" = "false" ]; then + jobname="${jobname}_pp${pp_size}" +fi + +disable_compile="true" +if [ "${COMPILE}" = "true" ]; then +disable_compile="false" +fi + +jobname="${jobname}_gbs${global_batch_size}_mbs${batch_size}_comp${COMPILE}" +jobname="${jobname}_seed${seed}_rebase" + +username=$(whoami) +output_home="output" +log_path="${output_home}/log/" +checkpoint_path="${output_home}/checkpoint/${jobname}" +tensorboard_dir="${output_home}/tensorboard/" +tensorboard_path="${tensorboard_dir}${jobname}_${host}_${current_time}" +mkdir -p ${log_path} +mkdir -p ${checkpoint_path} +mkdir -p ${tensorboard_path} +############################################################################### +data_options=" \ + --vocab-file ${vocab_path} \ + --merge-file ${merge_path} \ + --data-path ${data_path} \ + --data-impl mmap" + +## If CL is used, make sure to set "--split" the same as what you used during +## offline data analysis&indexing. +megatron_options=" \ + --override-opt_param-scheduler \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --tensor-model-parallel-size 1 \ + --ds-sequence-parallel-size 1 \ + --init-method-std ${init_std} \ + --lr-decay-tokens ${lr_decay_tokens} \ + --lr-warmup-tokens ${lr_warmup_tokens} \ + --micro-batch-size ${batch_size} \ + --exit-duration-in-mins ${exit_duration} \ + --global-batch-size ${global_batch_size} \ + --num-layers ${num_layers} \ + --hidden-size ${hidden_size} \ + --num-attention-heads ${num_attn_heads} \ + --seq-length ${seq_len} \ + --max-position-embeddings ${seq_len} \ + --train-tokens ${train_tokens} \ + --train-samples ${train_samples} \ + --lr ${lr} \ + --min-lr ${min_lr} \ + --lr-decay-style ${lr_decay_style} \ + --split 949,50,1 \ + --log-interval ${log_interval} \ + --eval-interval ${eval_interval} \ + --eval-iters ${eval_iters} \ + --save-interval ${save_interval} \ + --weight-decay 0.1 \ + --clip-grad 1.0 \ + --hysteresis 2 \ + --num-workers ${num_workers} \ + --fp16 \ + --seed ${seed} \ + --load ${checkpoint_path} \ + --save ${checkpoint_path} \ + --no-async-tensor-model-parallel-allreduce \ + --use-flash-attn-v2 \ + --tensorboard-queue-size 1 \ + --log-timers-to-tensorboard \ + --log-batch-size-to-tensorboard \ + --log-validation-ppl-to-tensorboard \ + --tensorboard-dir ${tensorboard_path}" + +if [ "${activation_checkpoint}" = "true" ]; then +megatron_options="${megatron_options} \ + --checkpoint-activations" +fi + +if [ "${log_optimizer_state}" = "true" ]; then +megatron_options="${megatron_options} \ + --log-optimizer-states-to-tensorboard" +fi + +config_json="ds_config_gbs${global_batch_size}_mbs${batch_size}_log${log_interval}_zero${zero_stage}.json" +template_json="ds_config_gpt_TEMPLATE.json" +sed "s/GBSIZE/${global_batch_size}/" ${template_json} \ + | sed "s/MBSIZE/${batch_size}/" \ + | sed "s/LOG_INTERVAL/${log_interval}/" \ + | sed "s/ZERO_STAGE/${zero_stage}/" \ + | sed "s/PRESCALE_GRAD/${prescale_grad}/" \ + | sed "s/DISABLE_COMPILE/${disable_compile}/" \ + > ${config_json} + +deepspeed_options=" \ + --deepspeed \ + --deepspeed_config ${config_json} \ + --zero-stage ${zero_stage} \ + --pipeline-model-parallel-size ${pp_size} \ + --debug-steps ${DEBUG_STEPS}" + +if [[ "${no_pp}" = "true" ]]; then +deepspeed_options="${deepspeed_options} \ + --no-pipeline-parallel" +fi + +if [ "${activation_checkpoint}" = "true" ]; then +deepspeed_options="${deepspeed_options} \ + --deepspeed-activation-checkpointing" +fi + + +## When saving checkpoint to a storage with cache, their could be consistency +## issue of the pointer to latest checkpoint. Here we find the correct pointer +## and broadcast it to all nodes. +iteration_file="$checkpoint_path/latest_checkpointed_iteration.txt" +iteration_file_2="$checkpoint_path/latest" +iteration=0 +for (( node = 0; node <= num_node-1; node++ )) +do + if $(ssh -q worker-"$node" "test -f \"$iteration_file\""); then + local_iteration=$(ssh -q worker-"$node" cat $iteration_file) + iteration=$(( ${local_iteration} > ${iteration} ? ${local_iteration} : ${iteration} )) + fi +done +if [[ $iteration -gt 0 ]]; then + iteration_2="global_step${iteration}" + ds_ssh "echo $iteration > $iteration_file" + ds_ssh "echo $iteration_2 > $iteration_file_2" +fi + +deepspeed ${dir}/../../pretrain_gpt.py ${megatron_options} ${data_options} ${deepspeed_options} 2>&1 | tee ${log_path}/${jobname}_${host}_${current_time}.log diff --git a/examples_deepspeed/compile/run_validation.sh b/examples_deepspeed/compile/run_validation.sh new file mode 100644 index 0000000000..81848ff841 --- /dev/null +++ b/examples_deepspeed/compile/run_validation.sh @@ -0,0 +1,9 @@ +export DEBUG_STEPS=1000 +ZERO_STAGE=0 COMPILE=false bash ./run.sh +ZERO_STAGE=0 COMPILE=true bash ./run.sh +ZERO_STAGE=1 COMPILE=false bash ./run.sh +ZERO_STAGE=1 COMPILE=true bash ./run.sh +ZERO_STAGE=2 COMPILE=false bash ./run.sh +ZERO_STAGE=2 COMPILE=true bash ./run.sh +ZERO_STAGE=3 COMPILE=false bash ./run.sh +ZERO_STAGE=3 COMPILE=true bash ./run.sh diff --git a/megatron/arguments.py b/megatron/arguments.py index d5e5970865..0115e5ecc1 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -888,6 +888,7 @@ def _add_training_args(parser): 'https://arxiv.org/abs/2307.08691') group.add_argument('--use-flash-attn-triton', action='store_true', help='use FlashAttention implementation of attention using Triton.') + group.add_argument('--debug-steps', type=int, default=0) group.add_argument('--disable-bias-linear', action='store_false', help='Disable bias in the linear layers', dest='add_bias_linear') diff --git a/megatron/core/tensor_parallel/mappings.py b/megatron/core/tensor_parallel/mappings.py index ae8d63ab2c..524c9d55b2 100644 --- a/megatron/core/tensor_parallel/mappings.py +++ b/megatron/core/tensor_parallel/mappings.py @@ -251,30 +251,37 @@ def backward(ctx, grad_output): # Helper functions. # ----------------- +@torch.compiler.disable def copy_to_tensor_model_parallel_region(input_): return _CopyToModelParallelRegion.apply(input_) +@torch.compiler.disable def reduce_from_tensor_model_parallel_region(input_): return _ReduceFromModelParallelRegion.apply(input_) +@torch.compiler.disable def scatter_to_tensor_model_parallel_region(input_): return _ScatterToModelParallelRegion.apply(input_) +@torch.compiler.disable def gather_from_tensor_model_parallel_region(input_): return _GatherFromModelParallelRegion.apply(input_) +@torch.compiler.disable def scatter_to_sequence_parallel_region(input_): return _ScatterToSequenceParallelRegion.apply(input_) +@torch.compiler.disable def gather_from_sequence_parallel_region(input_, tensor_parallel_output_grad=True): return _GatherFromSequenceParallelRegion.apply(input_, tensor_parallel_output_grad) +@torch.compiler.disable def reduce_scatter_to_sequence_parallel_region(input_): return _ReduceScatterToSequenceParallelRegion.apply(input_) diff --git a/megatron/training.py b/megatron/training.py index 0f05d7c7af..792b99e87a 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -40,7 +40,7 @@ from megatron.utils import calc_params_l2_norm from megatron.core.pipeline_parallel import get_forward_backward_func from megatron.utils import report_memory, throughput_calculator, checkpoint_throughput_calculator, update_rotary_pos_emb -from megatron.model.vision.knn_monitor import compute_feature_bank +# from megatron.model.vision.knn_monitor import compute_feature_bank from megatron.arguments import core_transformer_config_from_args import deepspeed @@ -1282,6 +1282,8 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, print_datetime('exiting program at iteration {}'.format(iteration)) sys.exit() + if args.debug_steps > 0 and iteration > args.debug_steps: + sys.exit() return iteration @@ -1295,8 +1297,8 @@ def evaluate(forward_step_func, """Evaluation.""" args = get_args() - if args.vision_pretraining and args.vision_pretraining_type == "dino": - compute_feature_bank(model) + # if args.vision_pretraining and args.vision_pretraining_type == "dino": + # compute_feature_bank(model) # Turn on evaluation mode which disables dropout. for model_module in model: