diff --git a/.github/workflows/dataset.yml b/.github/workflows/dataset.yml index 45f26bef..a138337e 100644 --- a/.github/workflows/dataset.yml +++ b/.github/workflows/dataset.yml @@ -16,6 +16,8 @@ on: - "**/*.py" - .github/workflows/dataset.yml + + jobs: ray: runs-on: [self-hosted, gpu] diff --git a/.github/workflows/e2e_digit_completion.yml b/.github/workflows/e2e_digit_completion.yml index bed26203..7b8678e6 100644 --- a/.github/workflows/e2e_digit_completion.yml +++ b/.github/workflows/e2e_digit_completion.yml @@ -17,6 +17,8 @@ on: - .github/workflows/e2e_digit_completion.yml - "tests/e2e/*.sh" + + jobs: e2e_digit_completion: runs-on: [self-hosted, l20-0] diff --git a/.github/workflows/e2e_gsm8k.yml b/.github/workflows/e2e_gsm8k.yml index 3d16d771..1295f689 100644 --- a/.github/workflows/e2e_gsm8k.yml +++ b/.github/workflows/e2e_gsm8k.yml @@ -17,6 +17,8 @@ on: - .github/workflows/e2e_gsm8k.yml - "tests/e2e/*.sh" + + jobs: e2e_gsm8k: runs-on: [self-hosted, l20-1] diff --git a/.github/workflows/e2e_lora.yml b/.github/workflows/e2e_lora.yml new file mode 100644 index 00000000..b2163b5f --- /dev/null +++ b/.github/workflows/e2e_lora.yml @@ -0,0 +1,48 @@ +name: e2e_lora + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/e2e_lora.yml + pull_request: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/e2e_lora.yml + - "tests/e2e/*.sh" + + + +jobs: + e2e_lora: + runs-on: [self-hosted, l20-1] + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + HF_HUB_ENABLE_HF_TRANSFER: 1 + container: + image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install hf_transfer peft + pip3 install -e .[test] + - name: Prepare gsm8k dataset + run: | + ray stop --force + python3 examples/data_preprocess/gsm8k.py + - name: Running gsm8k e2e training tests with LoRA + run: | + ray stop --force + bash tests/sft/run_sft_qwen05_peft.sh 8 $HOME/ckpts/ \ No newline at end of file diff --git a/.github/workflows/e2e_sft.yml b/.github/workflows/e2e_sft.yml new file mode 100644 index 00000000..4cd6fbe7 --- /dev/null +++ b/.github/workflows/e2e_sft.yml @@ -0,0 +1,56 @@ +name: e2e_sft + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + push: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/e2e_sft.yml + pull_request: + branches: + - main + paths: + - "**/*.py" + - .github/workflows/e2e_sft.yml + - "tests/e2e/*.sh" + + + +jobs: + e2e_sft: + runs-on: [self-hosted, l20-1] + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1" + HF_HUB_ENABLE_HF_TRANSFER: 1 + container: + image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install hf_transfer + pip3 install -e .[test] + - name: Prepare gsm8k dataset + run: | + ray stop --force + python3 examples/data_preprocess/gsm8k.py + - name: Running gsm8k e2e training tests on 8 L20 GPUs with rmpad using function rm + run: | + ray stop --force + bash tests/sft/run_sft.sh + - name: Running gsm8k e2e training tests on 8 L20 GPUs with sequence parallism + run: | + ray stop --force + bash examples/sft/gsm8k/run_qwen_05_sp2.sh 8 $HOME/ckpts/ + - name: Check loss difference between sequence parallel vs. default implementation + run: | + ray stop --force + bash tests/sft/run_sft_sp_loss_match.sh diff --git a/.github/workflows/model.yml b/.github/workflows/model.yml index d634c241..6ff7aacb 100644 --- a/.github/workflows/model.yml +++ b/.github/workflows/model.yml @@ -16,6 +16,8 @@ on: - "**/*.py" - .github/workflows/model.yml + + jobs: model_rmpad: runs-on: [self-hosted, l20-1] diff --git a/.github/workflows/ray_test.yml b/.github/workflows/ray_test.yml index 83ec8711..8c63f9d2 100644 --- a/.github/workflows/ray_test.yml +++ b/.github/workflows/ray_test.yml @@ -16,6 +16,8 @@ on: - "**/*.py" - .github/workflows/ray_test.yml + + jobs: ray: runs-on: [self-hosted, l20-0] diff --git a/README.md b/README.md index 74e43554..cca42503 100644 --- a/README.md +++ b/README.md @@ -39,12 +39,15 @@ veRL is fast with: - **vLLM** and **TGI** for rollout generation, **SGLang** support coming soon. - huggingface models support - Supervised fine-tuning -- Reward model training -- Reinforcement learning from human feedback with PPO -- flash-attention integration, sequence packing, and long context support +- Reinforcement learning from human feedback with [PPO](https://github.com/volcengine/verl/tree/main/examples/ppo_trainer) and [GRPO](https://github.com/volcengine/verl/tree/main/examples/grpo_trainer) + - Support model-based reward and function-based reward (verifiable reward) +- flash-attention integration, sequence packing, and long context support via DeepSpeed Ulysses - scales up to 70B models and hundreds of GPUs - experiment tracking with wandb and mlflow +## Upcoming Features +- Reward model training +- DPO training ## Getting Started @@ -54,7 +57,7 @@ Checkout this [Jupyter Notebook](https://github.com/volcengine/verl/tree/main/ex - [Installation](https://verl.readthedocs.io/en/latest/start/install.html) - [Quickstart](https://verl.readthedocs.io/en/latest/start/quickstart.html) -**Running an PPO example step-by-step:** +**Running a PPO example step-by-step:** - Data and Reward Preparation - [Prepare Data (Parquet) for Post-Training](https://verl.readthedocs.io/en/latest/preparation/prepare_data.html) - [Implement Reward Function for Dataset](https://verl.readthedocs.io/en/latest/preparation/reward_function.html) @@ -77,6 +80,8 @@ Checkout this [Jupyter Notebook](https://github.com/volcengine/verl/tree/main/ex - [Add models with the FSDP backend](https://verl.readthedocs.io/en/latest/advance/fsdp_extension.html) - [Add models with the Megatron-LM backend](https://verl.readthedocs.io/en/latest/advance/megatron_extension.html) +## Performance Tuning Guide +The performance is essential for on-policy RL algorithm. We write a detailed performance tuning guide to allow people tune the performance. See [here](https://verl.readthedocs.io/en/latest/perf/perf_tuning.html) for more details. ## Citation and acknowledgement @@ -95,9 +100,10 @@ If you find the project helpful, please cite: verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The project is adopted and supported by Anyscale, Bytedance, LMSys.org, Shanghai AI Lab, Tsinghua University, UC Berkeley, UCLA, UIUC, and University of Hong Kong. -## Publications Using veRL +## Awesome work using veRL - [Enhancing Multi-Step Reasoning Abilities of Language Models through Direct Q-Function Optimization](https://arxiv.org/abs/2410.09302) - [Flaming-hot Initiation with Regular Execution Sampling for Large Language Models](https://arxiv.org/abs/2410.21236) - [Process Reinforcement Through Implicit Rewards](https://github.com/PRIME-RL/PRIME/) +- [TinyZero](https://github.com/Jiayi-Pan/TinyZero): a reproduction of DeepSeek R1 Zero in countdown and multiplication tasks We are HIRING! Send us an [email](mailto:haibin.lin@bytedance.com) if you are interested in internship/FTE opportunities in MLSys/LLM reasoning/multimodal alignment. diff --git a/docs/examples/config.rst b/docs/examples/config.rst index 3fc1906b..b5ccd284 100644 --- a/docs/examples/config.rst +++ b/docs/examples/config.rst @@ -59,60 +59,79 @@ Actor/Rollout/Reference Policy .. code:: yaml actor_rollout_ref: - hybrid_engine: True - model: - path: ~/models/deepseek-llm-7b-chat - external_lib: null - override_config: {} - enable_gradient_checkpointing: False - actor: - strategy: fsdp # This is for backward-compatibility - ppo_mini_batch_size: 256 - ppo_micro_batch_size: 64 - grad_clip: 1.0 - clip_ratio: 0.2 - entropy_coeff: 0.001 - ppo_epochs: 1 - shuffle: True - optim: - lr: 1e-6 - lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime - min_lr_ratio: null # only useful for warmup with cosine - warmup_style: constant # select from constant/cosine - total_training_steps: -1 # must be override by program - fsdp_config: - wrap_policy: - # transformer_layer_cls_to_wrap: None - min_num_params: 0 - param_offload: False - grad_offload: False - optimizer_offload: False - ref: - fsdp_config: - param_offload: False - wrap_policy: - # transformer_layer_cls_to_wrap: None - min_num_params: 0 - log_prob_micro_batch_size: 128 - rollout: - name: vllm - temperature: 1.0 - top_k: -1 # 0 for hf rollout, -1 for vllm rollout - top_p: 1 - response_length: ${data.max_response_length} - # for vllm rollout - dtype: bfloat16 # should align with FSDP - gpu_memory_utilization: 0.5 - ignore_eos: False - enforce_eager: True - free_cache_engine: True - load_format: dummy_dtensor # or dummy_hf or dummy_megatron - tensor_model_parallel_size: 2 - max_num_batched_tokens: 8192 - max_num_seqs: 1024 - log_prob_micro_batch_size: 128 - # for vllm and hf rollout - do_sample: True + hybrid_engine: True + model: + path: ~/models/deepseek-llm-7b-chat + external_lib: null + override_config: { } + enable_gradient_checkpointing: False + use_remove_padding: False + actor: + strategy: fsdp # This is for backward-compatibility + ppo_mini_batch_size: 256 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 8 + use_dynamic_bsz: False + ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} + grad_clip: 1.0 + clip_ratio: 0.2 + entropy_coeff: 0.001 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo + ppo_epochs: 1 + shuffle: False + ulysses_sequence_parallel_size: 1 # sp size + optim: + lr: 1e-6 + lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime + min_lr_ratio: null # only useful for warmup with cosine + warmup_style: constant # select from constant/cosine + total_training_steps: -1 # must be override by program + fsdp_config: + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + param_offload: False + grad_offload: False + optimizer_offload: False + fsdp_size: -1 + ref: + fsdp_config: + param_offload: False + wrap_policy: + # transformer_layer_cls_to_wrap: None + min_num_params: 0 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size + rollout: + name: vllm + temperature: 1.0 + top_k: -1 # 0 for hf rollout, -1 for vllm rollout + top_p: 1 + prompt_length: ${data.max_prompt_length} # not use for opensource + response_length: ${data.max_response_length} + # for vllm rollout + dtype: bfloat16 # should align with FSDP + gpu_memory_utilization: 0.5 + ignore_eos: False + enforce_eager: True + free_cache_engine: True + load_format: dummy_dtensor + tensor_model_parallel_size: 2 + max_num_batched_tokens: 8192 + max_num_seqs: 1024 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 16 + log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} + log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + # for hf rollout + do_sample: True + # number of responses (i.e. num sample times) + n: 1 # > 1 for grpo **Common config for actor, rollout and reference model** @@ -136,11 +155,15 @@ Actor/Rollout/Reference Policy - ``actor_rollout_ref.actor.ppo_mini_batch_size``: One sample is split into multiple sub-batches with batch_size=ppo_mini_batch_size for PPO - updates + updates. The ppo_mini_batch_size is a global num across all workers/gpus + +- ``actor_rollout_ref.actor.ppo_micro_batch_size``: [Will be deprecated, use ppo_micro_batch_size_per_gpu] + Similar to gradient accumulation, the micro_batch_size_per_gpu for one forward pass, + trading speed for GPU memory. The value represent the global view. -- ``actor_rollout_ref.actor.ppo_micro_batch_size``: Similar to gradient - accumulation, the micro_batch_size for one forward pass, trading speed - for GPU memory +- ``actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``: Similar to gradient + accumulation, the micro_batch_size_per_gpu for one forward pass, trading speed + for GPU memory. The value represent the local num per gpu. - ``actor_rollout_ref.actor.grad_clip``: Gradient clipping for actor updates @@ -176,8 +199,12 @@ Actor/Rollout/Reference Policy - ``actor_rollout_ref.ref``: FSDP config same as actor. **For models larger than 7B, it's recommended to turn on offload for ref by default** -- ``actor_rollout_ref.ref.log_prob_micro_batch_size``: The batch size - for one forward pass in the computation of ``ref_log_prob``. + +- ``actor_rollout_ref.ref.log_prob_micro_batch_size``: [Will be deprecate, use log_prob_micro_batch_size_per_gpu] + The batch size for one forward pass in the computation of ``ref_log_prob``. The value represent the global num. + +- ``actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu``: The batch size + for one forward pass in the computation of ``ref_log_prob``. The value represent the local num per gpu. **Rollout Model** @@ -201,8 +228,11 @@ Actor/Rollout/Reference Policy - ``tensor_model_parallel_size``: TP size for rollout. Only effective for vllm. -- ``log_prob_micro_batch_size``: Micro_batch_size (The batch size for - one forward pass) for recalculating log_prob. +- ``actor_rollout_ref.ref.log_prob_micro_batch_size``: [Will be deprecate, use log_prob_micro_batch_size_per_gpu] + The batch size for one forward pass in the computation of ``log_prob``. The value represent the global num. + +- ``log_prob_micro_batch_size_per_gpu``: Micro batch size per gpu (The batch size for + one forward pass) for recalculating ``log_prob``. The value represent the local num per gpu. - ``do_sample``: Whether to sample. If set to False, the rollout model will perform greedy sampling. We disable ``do_sample`` during @@ -260,7 +290,7 @@ Reward Model fsdp_config: min_num_params: 0 param_offload: False - micro_batch_size: 64 + micro_batch_size_per_gpu: 16 max_length: null - ``reward_model.enable``: Whether to enable reward model. If False, we diff --git a/docs/examples/gsm8k_example.rst b/docs/examples/gsm8k_example.rst index de694cfd..ac4550df 100644 --- a/docs/examples/gsm8k_example.rst +++ b/docs/examples/gsm8k_example.rst @@ -85,7 +85,7 @@ We also provide various training scripts for SFT on GSM8K dataset in `gsm8k sft data.val_files=$HOME/data/gsm8k/test.parquet \ data.prompt_key=question \ data.response_key=answer \ - data.micro_batch_size=8 \ + data.micro_batch_size_per_gpu=8 \ model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \ trainer.default_hdfs_dir=hdfs://user/verl/experiments/gsm8k/deepseek-coder-6.7b-instruct/ \ trainer.project_name=gsm8k-sft \ @@ -136,21 +136,20 @@ The script of run_deepseek7b_llm.sh actor_rollout_ref.model.path=~/models/deepseek-llm-7b-chat \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.micro_batch_size=256 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.path=~/models/deepseek-llm-7b-chat \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=64 \ + critic.ppo_micro_batch_size_per_gpu=16 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ diff --git a/docs/faq/faq.rst b/docs/faq/faq.rst index 40bbb07d..56725834 100644 --- a/docs/faq/faq.rst +++ b/docs/faq/faq.rst @@ -17,3 +17,5 @@ How to run multi-node post-training with Ray? ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ You can start a ray cluster and submit a ray job, following the official guide from Ray: https://docs.ray.io/en/latest/ray-core/starting-ray.html + +If your cluster is managed by Slurm, please refer to the guide for deploying Ray on Slurm: https://docs.ray.io/en/latest/cluster/vms/user-guides/community/slurm.html diff --git a/docs/index.rst b/docs/index.rst index d44e2a71..4fa5f69c 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -62,6 +62,12 @@ veRL is fast with: workers/fsdp_workers workers/megatron_workers +.. toctree:: + :maxdepth: 1 + :caption: Performance Tuning Guide + + perf/perf_tuning + .. toctree:: :maxdepth: 1 :caption: Experimental Results diff --git a/docs/perf/perf_tuning.rst b/docs/perf/perf_tuning.rst new file mode 100644 index 00000000..238078bb --- /dev/null +++ b/docs/perf/perf_tuning.rst @@ -0,0 +1,121 @@ +Performance Tuning Guide +========================= + +In this ssection, we will discuss how to tune the performance of all the stages in veRL, including: + +1. Rollout generation throughput. + +2. Batch size tuning for forward and backward computation + +3. Enable ``use_dynamic_bsz=True`` for higher throughput. + +4. Utilize Ulysses Sequence Parallel for Long Context Training + +Rollout Generation Tuning +-------------------------- + +veRL currently supports two rollout backends: vLLM and TGI (with SGLang support coming soon). + +Below are key factors for tuning vLLM-based rollout. Before tuning, we recommend setting ``actor_rollout_ref.rollout.disable_log_stats=False`` so that rollout statistics are logged. + +- Increase ``gpu_memory_utilization``. The vLLM pre-allocates GPU KVCache by using gpu_memory_utilization% of the remaining memory. + However, if model parameters and optimizer states are not offloaded, using too high a fraction can lead to OOM. + A value between 0.5 and 0.7 often strikes a good balance between high throughput and avoiding OOM. + +- Adjust ``max_num_seqs`` or ``max_num_batched_tokens``. + If the GPU cache utilization is relatively low in the log, increase ``max_num_seqs`` or ``max_num_batched_tokens`` + can enlarge the effective batch size in the decoding stage, allowing more concurrent requests per batch. + We recommend setting ``max_num_batched_tokens > 2048`` for higher throughput. + +- Use a smaller ``tensor_parallel_size``. + When GPU resources allow, a smaller tensor parallel size spawns more vLLM replicas. + Data parallelism (DP) can yield higher throughput than tensor parallelism (TP), but also increases KVCache consumption. + Carefully balance the trade-off between more replicas and higher memory usage. + Our experient in Sec. 8.4 of `HybridFlow paper `_ evaluate this trade-off. + +More tuning details such as dealing with Preemption and Chunked-prefill +can be found in `vLLM official tuning guide `_ + + +Batch Size Tuning +----------------- + +To achieve higher throughput in experience preparation (i.e., model fwd) and model update (i.e., actor/critic fwd/bwd), +users may need to tune the ``*micro_batch_size_per_gpu`` for different computation. + +In veRL, the core principle for setting batch sizes is: + +- **Algorithmic metrics** (train batch size, PPO mini-batch size) are *global* (from a single-controller perspective), + normalized in each worker. See the `normalization code `_. + +- **Performance-related parameters** (micro batch size, max token length for dynamic batch size) are *local* parameters that define the per-GPU data allocations. + See the `normalization code `_. + +.. note:: In your training script, please use ``*micro_batch_size_per_gpu`` instead of ``*micro_batch_size``. + So that you don't need to consider the normalization of the ``micro_batch_size`` and ``micro_batch_size`` will be deprecated. + +Batch Size Tuning tips +"""""""""""""""""""""" + +Therefore, users may need to tune the ``*micro_batch_size_per_gpu`` to accelerate training. Here're some tips: + +1. **Enable gradient checkpointing**: + Set ``actor_rollout_ref.model.enable_gradient_checkpointing=True`` and ``critic.model.enable_gradient_checkpointing=True``. + This often allows for larger micro-batch sizes and will be beneficial for large mini-batch training. + +2. Increase the ``*micro_batch_size_per_gpu`` as much as possible till equals to normalized ``mini_batch_size``. + +3. **Use larger forward-only parameters**: + Forward only parameter, such as ``actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu``, + ``actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu``, ``critic.forward_micro_batch_size_per_gpu`` could be larger (e.g., 2x) than training related micro batch sizes, + such as ``actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu``, ``critic.ppo_micro_batch_size_per_gpu``. + +4. **Allow larger micro-batch sizes for Critic and Reward models**: + micro batch size of Critic and Reward model could be larger than Actor model. This is because the actor model has much larger vocab size in the final layer. + + +Tuning for Dynamic Batch Size +----------------------------- + +Dynamic batch size is a technique that allows the model to process similar number of tokens in a single forward pass (with different actual batch sizes). +This can significantly improve the training efficiency and reduce the memory usage. + +To utilize this technique, users can set ``use_dynamic_bsz=True`` in actor, ref, critic and reward models. +With ``use_dynamic_bsz=True``, users don't need to tune ``*micro_batch_size_per_gpu``. +Instead, users should tune the following parameters: + +- ``actor_rollout_ref.actor.ppo_max_token_len_per_gpu``, ``critic.ppo_max_token_len_per_gpu``: + The maximum number of tokens to be processed in fwd and bwd of ``update_policy`` and ``update_critic``. + +- ``actor_rollout_ref.ref.log_prob_max_token_len_per_gpu`` and ``actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu``: + The maximum number of tokens to be processed in a the fwd computation of ``compute_log_prob`` and ``comptue_ref_log_prob``. + +- ``critic.forward_micro_batch_size_per_gpu``, ``reward_model.forward_micro_batch_size_per_gpu``: + The maximum number of tokens to be processed in a the fwd computation of ``compute_values``, ``compute_rm_score``. + +Dynamic Batch Size Tuning tips +"""""""""""""""""""""""""""""" + +Here're some tips to tune the above parameters: + +1. **Increase** ``actor_rollout_ref.actor.ppo_max_token_len_per_gpu`` + Make it at least 2 x (max_prompt_length + max_response_length). We set it to 3x in `run_qwen2-7b_rm_seq_balance.sh `_. + Try to increase it to get higher throughput. + +2. **Forward-only parameters can be larger**: + Similar to the non-dynamic-batch scenario, forward-only token limits can exceed those used in forward/backward operations. + +3. **Use larger limits for Critic and Reward models**: + Critic and Reward parameters can be set at least 2× the Actor’s limits. For instance, we set them to 4× here: + `run_qwen2-7b_rm_seq_balance.sh `_ + +.. :math:`\text{critic.ppo_max_token_len_per_gpu} = 2 \times \text{actor.ppo_max_token_len_per_gpu})`. + +Ulysses Sequence Parallel for Long Context Training +---------------------------------------------------- + +To utilize this technique, users can set ``ulysses_sequence_parallel_size>1`` in actor, ref, critic and reward models. + +We support different model utilize different ulysses_sequence_parallel_size sizes. + +To train log sequence (>32k), users may need to decrease the ``*micro_batch_size_per_gpu`` and ``*max_token_len_per_gpu`` to avoid OOM. \ No newline at end of file diff --git a/docs/start/quickstart.rst b/docs/start/quickstart.rst index 881839af..5e0da4a7 100644 --- a/docs/start/quickstart.rst +++ b/docs/start/quickstart.rst @@ -92,14 +92,14 @@ Set the ``data.train_files`` ,\ ``data.val_files``, ``actor_rollout_ref.model.pa actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.ppo_mini_batch_size=64 \ - actor_rollout_ref.actor.ppo_micro_batch_size=4 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=8 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ critic.optim.lr=1e-5 \ critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ - critic.ppo_micro_batch_size=4 \ + critic.ppo_micro_batch_size_per_gpu=4 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.logger=['console'] \ +trainer.val_before_train=False \ @@ -133,8 +133,8 @@ If you encounter out of memory issues with HBM less than 32GB, enable the follow .. code-block:: bash - actor_rollout_ref.actor.ppo_micro_batch_size=1 \ - critic.ppo_micro_batch_size=1 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \ + critic.ppo_micro_batch_size_per_gpu=1 \ For the full set of configs, please refer to :ref:`config-explain-page` for detailed explaination and performance tuning. diff --git a/examples/grpo_trainer/run_deepseek7b_llm.sh b/examples/grpo_trainer/run_deepseek7b_llm.sh new file mode 100644 index 00000000..912f6a34 --- /dev/null +++ b/examples/grpo_trainer/run_deepseek7b_llm.sh @@ -0,0 +1,39 @@ +set -x + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.grad_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='deepseek_llm_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh b/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh new file mode 100644 index 00000000..adec112e --- /dev/null +++ b/examples/grpo_trainer/run_deepseek7b_llm_seq_balance.sh @@ -0,0 +1,38 @@ +set -x + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=512 \ + actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.grad_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='deepseek_llm_7b_function_rm_seq_packing' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2-7b.sh b/examples/grpo_trainer/run_qwen2-7b.sh new file mode 100644 index 00000000..a082c368 --- /dev/null +++ b/examples/grpo_trainer/run_qwen2-7b.sh @@ -0,0 +1,41 @@ +set -x + +export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=80 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.grad_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=160 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm' \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh new file mode 100644 index 00000000..09eac41d --- /dev/null +++ b/examples/grpo_trainer/run_qwen2-7b_seq_balance.sh @@ -0,0 +1,41 @@ +set -x + +export VLLM_ATTENTION_BACKEND=XFORMERS + +python3 -m verl.trainer.main_ppo \ + algorithm.adv_estimator=grpo \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.train_batch_size=1024 \ + data.val_batch_size=1312 \ + data.max_prompt_length=512 \ + data.max_response_length=1024 \ + actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.ppo_mini_batch_size=256 \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \ + actor_rollout_ref.actor.use_kl_loss=True \ + actor_rollout_ref.actor.kl_loss_coef=0.001 \ + actor_rollout_ref.actor.kl_loss_type=low_var_kl \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=False \ + actor_rollout_ref.actor.fsdp_config.grad_offload=False \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ + actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ + actor_rollout_ref.rollout.name=vllm \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.rollout.n=5 \ + actor_rollout_ref.ref.fsdp_config.param_offload=True \ + algorithm.kl_ctrl.kl_coef=0.001 \ + trainer.critic_warmup=0 \ + trainer.logger=['console','wandb'] \ + trainer.project_name='verl_grpo_example_gsm8k' \ + trainer.experiment_name='qwen2_7b_function_rm_kl1e-3' \ + +trainer.val_before_train=False \ + trainer.n_gpus_per_node=8 \ + trainer.nnodes=1 \ + trainer.save_freq=-1 \ + trainer.test_freq=5 \ + trainer.total_epochs=15 $@ \ No newline at end of file diff --git a/examples/ppo_trainer/run_deepseek7b_llm.sh b/examples/ppo_trainer/run_deepseek7b_llm.sh index d4ebdd8a..a34f67a5 100644 --- a/examples/ppo_trainer/run_deepseek7b_llm.sh +++ b/examples/ppo_trainer/run_deepseek7b_llm.sh @@ -11,21 +11,22 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=32 \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_micro_batch_size_per_gpu=32 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ @@ -37,4 +38,5 @@ python3 -m verl.trainer.main_ppo \ trainer.n_gpus_per_node=8 \ trainer.nnodes=1 \ trainer.save_freq=-1 \ + trainer.test_freq=1 \ trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh b/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh index 23996412..ef4db1a7 100644 --- a/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh +++ b/examples/ppo_trainer/run_deepseek7b_llm_sp2.sh @@ -11,23 +11,24 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \ actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=64 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.ulysses_sequence_parallel_size=2 \ critic.model.use_remove_padding=True \ critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=64 \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_micro_batch_size_per_gpu=64 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ @@ -37,7 +38,7 @@ python3 -m verl.trainer.main_ppo \ trainer.project_name='verl_example_gsm8k' \ trainer.experiment_name='deepseek_llm_7b_function_rm_sp2' \ trainer.n_gpus_per_node=8 \ - +trainer.val_before_train=False \ trainer.nnodes=1 \ trainer.save_freq=-1 \ + trainer.test_freq=5 \ trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh index bd2c0bc8..a7b16a7c 100644 --- a/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh +++ b/examples/ppo_trainer/run_deepseek_full_hh_rlhf.sh @@ -13,21 +13,21 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size=16 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.ref.param_offload=False \ critic.optim.lr=1e-5 \ critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=16 \ + critic.ppo_micro_batch_size_per_gpu=4 \ reward_model.enable=True \ reward_model.megatron.tensor_model_parallel_size=4 \ reward_model.model.path=deepseek-ai/deepseek-llm-7b-chat \ - reward_model.micro_batch_size=16 \ + reward_model.micro_batch_size_per_gpu=4 \ reward_model.param_offload=False \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ diff --git a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh index c342d526..17b170a1 100644 --- a/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh +++ b/examples/ppo_trainer/run_deepseek_math_gsm8k_megatron.sh @@ -18,16 +18,16 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-6.7b-instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=32 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=32 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ critic.optim.lr=1e-5 \ critic.model.path=deepseek-ai/deepseek-coder-6.7b-instruct \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=32 \ + critic.ppo_micro_batch_size_per_gpu=4 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ trainer.logger=['console','wandb'] \ diff --git a/examples/ppo_trainer/run_deepseek_megatron.sh b/examples/ppo_trainer/run_deepseek_megatron.sh index 1f0f51e7..c838a1bb 100644 --- a/examples/ppo_trainer/run_deepseek_megatron.sh +++ b/examples/ppo_trainer/run_deepseek_megatron.sh @@ -10,16 +10,16 @@ python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megat actor_rollout_ref.model.path=deepseek-ai/deepseek-coder-6.7b-instruct \ actor_rollout_ref.actor.optim.lr=2e-6 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=64 \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=64 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ critic.optim.lr=2e-5 \ critic.model.path=deepseek-ai/deepseek-coder-6.7b-instruct \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=64 \ + critic.ppo_micro_batch_size_per_gpu=8 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ trainer.logger=['console','wandb'] \ diff --git a/examples/ppo_trainer/run_gemma.sh b/examples/ppo_trainer/run_gemma.sh index 9fb455a6..5072e04e 100644 --- a/examples/ppo_trainer/run_gemma.sh +++ b/examples/ppo_trainer/run_gemma.sh @@ -11,21 +11,21 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=128 \ - actor_rollout_ref.actor.ppo_micro_batch_size=4 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=4 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=google/gemma-2-2b-it \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=4 \ + critic.ppo_micro_batch_size_per_gpu=4 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ diff --git a/examples/ppo_trainer/run_qwen2-7b.sh b/examples/ppo_trainer/run_qwen2-7b.sh index 8e6bb16c..80dcd922 100644 --- a/examples/ppo_trainer/run_qwen2-7b.sh +++ b/examples/ppo_trainer/run_qwen2-7b.sh @@ -19,21 +19,22 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=16 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=16 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=Qwen/Qwen2-7B-Instruct \ - critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=16 \ + critic.model.enable_gradient_checkpointing=True \ + critic.ppo_micro_batch_size_per_gpu=32 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm.sh b/examples/ppo_trainer/run_qwen2-7b_rm.sh index fc64d36b..35d030ad 100644 --- a/examples/ppo_trainer/run_qwen2-7b_rm.sh +++ b/examples/ppo_trainer/run_qwen2-7b_rm.sh @@ -23,22 +23,22 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=16 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=16 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=16 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.optim.lr_warmup_steps_ratio=0.05 \ critic.model.path=Qwen/Qwen2-7B-Instruct \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=16 \ + critic.ppo_micro_batch_size_per_gpu=16 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ @@ -46,7 +46,7 @@ python3 -m verl.trainer.main_ppo \ reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\ reward_model.model.use_remove_padding=True \ reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size=16 \ + reward_model.micro_batch_size_per_gpu=32 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ trainer.logger=['console','wandb'] \ diff --git a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh index eb470e43..c626e67c 100644 --- a/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh +++ b/examples/ppo_trainer/run_qwen2-7b_rm_seq_balance.sh @@ -45,7 +45,7 @@ python3 -m verl.trainer.main_ppo \ reward_model.model.path=sfairXC/FsfairX-LLaMA3-RM-v0.1\ reward_model.model.use_remove_padding=True \ reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size=16 \ + reward_model.micro_batch_size_per_gpu=32 \ reward_model.use_dynamic_bsz=True \ reward_model.forward_max_token_len_per_gpu=98304 \ algorithm.kl_ctrl.kl_coef=0.001 \ @@ -58,4 +58,4 @@ python3 -m verl.trainer.main_ppo \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=100 $@ + trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh index 0987c0e1..0c00b24c 100644 --- a/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh +++ b/examples/ppo_trainer/run_qwen2-7b_seq_balance.sh @@ -49,4 +49,4 @@ python3 -m verl.trainer.main_ppo \ trainer.nnodes=1 \ trainer.save_freq=-1 \ trainer.test_freq=5 \ - trainer.total_epochs=100 $@ + trainer.total_epochs=15 $@ diff --git a/examples/ppo_trainer/run_qwen2.5-32b.sh b/examples/ppo_trainer/run_qwen2.5-32b.sh index 1f3bdc3a..5841d5b1 100644 --- a/examples/ppo_trainer/run_qwen2.5-32b.sh +++ b/examples/ppo_trainer/run_qwen2.5-32b.sh @@ -20,21 +20,22 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ + actor_rollout_ref.model.enable_gradient_checkpointing=True \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ - actor_rollout_ref.rollout.gpu_memory_utilization=0.3 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=Qwen/Qwen2.5-32B-Instruct \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=32 \ + critic.ppo_micro_batch_size_per_gpu=8 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ diff --git a/examples/ppo_trainer/verl_getting_started.ipynb b/examples/ppo_trainer/verl_getting_started.ipynb index afe9ca35..dfa93789 100644 --- a/examples/ppo_trainer/verl_getting_started.ipynb +++ b/examples/ppo_trainer/verl_getting_started.ipynb @@ -646,7 +646,7 @@ "\u001b[36m(main_task pid=28294)\u001b[0m 'path': '/teamspace/studios/this_studio/models/Qwen2.5-0.5B-Instruct'},\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'ref': {'fsdp_config': {'param_offload': False,\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'wrap_policy': {'min_num_params': 0}},\n", - "\u001b[36m(main_task pid=28294)\u001b[0m 'log_prob_micro_batch_size': 4},\n", + "\u001b[36m(main_task pid=28294)\u001b[0m 'log_prob_micro_batch_size_per_gpu': 4},\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'rollout': {'do_sample': True,\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'dtype': 'bfloat16',\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'enforce_eager': True,\n", @@ -654,7 +654,7 @@ "\u001b[36m(main_task pid=28294)\u001b[0m 'gpu_memory_utilization': 0.4,\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'ignore_eos': False,\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'load_format': 'dummy_dtensor',\n", - "\u001b[36m(main_task pid=28294)\u001b[0m 'log_prob_micro_batch_size': 1,\n", + "\u001b[36m(main_task pid=28294)\u001b[0m 'log_prob_micro_batch_size_per_gpu': 1,\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'max_num_batched_tokens': 8192,\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'max_num_seqs': 1024,\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'n': 1,\n", @@ -671,7 +671,7 @@ "\u001b[36m(main_task pid=28294)\u001b[0m 'kl_penalty': 'kl',\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'lam': 1.0},\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'critic': {'cliprange_value': 0.5,\n", - "\u001b[36m(main_task pid=28294)\u001b[0m 'forward_micro_batch_size': 4,\n", + "\u001b[36m(main_task pid=28294)\u001b[0m 'forward_micro_batch_size_per_gpu': 4,\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'grad_clip': 1.0,\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'model': {'enable_gradient_checkpointing': False,\n", "\u001b[36m(main_task pid=28294)\u001b[0m 'external_lib': None,\n", @@ -1110,10 +1110,10 @@ " actor_rollout_ref.actor.optim.lr=1e-6 \\\n", " actor_rollout_ref.actor.ppo_mini_batch_size=64 \\\n", " actor_rollout_ref.actor.ppo_micro_batch_size=1 \\\n", - " actor_rollout_ref.rollout.log_prob_micro_batch_size=1 \\\n", + " actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \\\n", " actor_rollout_ref.rollout.tensor_model_parallel_size=1 \\\n", " actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \\\n", - " actor_rollout_ref.ref.log_prob_micro_batch_size=4 \\\n", + " actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \\\n", " critic.optim.lr=1e-5 \\\n", " critic.model.path=$HOME/models/Qwen2.5-0.5B-Instruct \\\n", " critic.ppo_micro_batch_size=1 \\\n", diff --git a/examples/sft/gsm8k/run_deepseek_6b7.sh b/examples/sft/gsm8k/run_deepseek_6b7.sh index 8e4d54c6..f11965a6 100644 --- a/examples/sft/gsm8k/run_deepseek_6b7.sh +++ b/examples/sft/gsm8k/run_deepseek_6b7.sh @@ -1,19 +1,29 @@ set -x -hdfs_path=hdfs://user/verl/experiments/gsm8k/deepseek-coder-6.7b-instruct/ # replace to your own hdfs/local path +if [ "$#" -lt 2 ]; then + echo "Usage: run_deepseek_6b7.sh [other_configs...]" + exit 1 +fi nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ -m verl.trainer.fsdp_sft_trainer \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ - data.prompt_key=prompt \ - data.response_key=answer \ - data.micro_batch_size=8 \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + +data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=4 \ model.partial_pretrain=deepseek-ai/deepseek-coder-6.7b-instruct \ - trainer.default_hdfs_dir=$hdfs_path \ + trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-deepseek-coder-6.7b-instruct \ trainer.total_epochs=4 \ - trainer.logger=['console','wandb'] \ No newline at end of file + trainer.logger=['console','wandb'] \ + trainer.default_hdfs_dir=null $@ \ No newline at end of file diff --git a/examples/sft/gsm8k/run_gemma_2b.sh b/examples/sft/gsm8k/run_gemma_2b.sh index 7ec85c09..6d7917d9 100644 --- a/examples/sft/gsm8k/run_gemma_2b.sh +++ b/examples/sft/gsm8k/run_gemma_2b.sh @@ -21,7 +21,7 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ data.response_key=extra_info \ +data.prompt_dict_keys=['question'] \ +data.response_dict_keys=['answer'] \ - data.micro_batch_size=8 \ + data.micro_batch_size_per_gpu=4 \ model.partial_pretrain=google/gemma-2b-it \ trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ diff --git a/examples/sft/gsm8k/run_gemma_7b.sh b/examples/sft/gsm8k/run_gemma_7b.sh index 9c357926..fdf4435b 100644 --- a/examples/sft/gsm8k/run_gemma_7b.sh +++ b/examples/sft/gsm8k/run_gemma_7b.sh @@ -1,8 +1,15 @@ set -x -hdfs_path=hdfs://user/verl/experiments/gsm8k/gemma-1.1-7b-it/ # replace to your own hdfs/local path +if [ "$#" -lt 2 ]; then + echo "Usage: run_gemma_7b.sh [other_configs...]" + exit 1 +fi nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ -m verl.trainer.fsdp_sft_trainer \ @@ -10,10 +17,11 @@ torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ data.val_files=$HOME/data/gsm8k/test.parquet \ data.prompt_key=prompt \ data.response_key=answer \ - data.micro_batch_size=8 \ + data.micro_batch_size_per_gpu=4 \ model.partial_pretrain=google/gemma-1.1-7b-it \ - trainer.default_hdfs_dir=$hdfs_path \ + trainer.default_local_dir=$save_path \ trainer.project_name=gsm8k-sft \ trainer.experiment_name=gsm8k-sft-gemma-1.1-7b-it \ trainer.total_epochs=4 \ - trainer.logger=['console','wandb'] \ No newline at end of file + trainer.logger=['console','wandb'] \ + trainer.default_hdfs_dir=null $@ \ No newline at end of file diff --git a/examples/sft/gsm8k/run_qwen_05_peft.sh b/examples/sft/gsm8k/run_qwen_05_peft.sh new file mode 100755 index 00000000..3ba61c3a --- /dev/null +++ b/examples/sft/gsm8k/run_qwen_05_peft.sh @@ -0,0 +1,38 @@ +# Tested with 2 & 4 GPUs + +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_peft.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + +data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \ + trainer.logger=['console'] \ + trainer.total_epochs=1 \ + trainer.default_hdfs_dir=null $@ \ + model.lora_rank=32\ + model.lora_alpha=16 \ + model.target_modules=all-linear + + # Or you can do this: + # model.target_modules=[q_proj,v_proj] \ diff --git a/examples/sft/gsm8k/run_qwen_05_sp2.sh b/examples/sft/gsm8k/run_qwen_05_sp2.sh new file mode 100755 index 00000000..a27cef1d --- /dev/null +++ b/examples/sft/gsm8k/run_qwen_05_sp2.sh @@ -0,0 +1,32 @@ +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_sp2.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + +data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct-sp2 \ + trainer.logger=['console'] \ + trainer.total_training_steps=1 \ + trainer.default_hdfs_dir=null $@ \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=true diff --git a/examples/split_placement/config/ppo_trainer_split.yaml b/examples/split_placement/config/ppo_trainer_split.yaml index 7984c45a..a475d7af 100644 --- a/examples/split_placement/config/ppo_trainer_split.yaml +++ b/examples/split_placement/config/ppo_trainer_split.yaml @@ -20,7 +20,8 @@ actor_rollout_ref: actor: strategy: fsdp # This is for backward-compatibility ppo_mini_batch_size: 256 - ppo_micro_batch_size: 64 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 64 grad_clip: 1.0 clip_ratio: 0.2 entropy_coeff: 0.001 @@ -45,7 +46,8 @@ actor_rollout_ref: wrap_policy: # transformer_layer_cls_to_wrap: None min_num_params: 0 - log_prob_micro_batch_size: 128 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 128 rollout: name: vllm temperature: 1.0 @@ -63,7 +65,8 @@ actor_rollout_ref: tensor_model_parallel_size: 2 max_num_batched_tokens: 8192 max_num_seqs: 1024 - log_prob_micro_batch_size: 128 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 128 # for hf rollout do_sample: True # number of responses (i.e. num sample times) @@ -91,7 +94,8 @@ critic: # transformer_layer_cls_to_wrap: None min_num_params: 0 ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - ppo_micro_batch_size: 64 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: 64 ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} shuffle: ${actor_rollout_ref.actor.shuffle} grad_clip: 1.0 @@ -107,7 +111,8 @@ reward_model: fsdp_config: min_num_params: 0 param_offload: False - micro_batch_size: 64 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: 64 max_length: null algorithm: diff --git a/examples/split_placement/run_deepseek7b_llm.sh b/examples/split_placement/run_deepseek7b_llm.sh index a2db960a..c701de85 100644 --- a/examples/split_placement/run_deepseek7b_llm.sh +++ b/examples/split_placement/run_deepseek7b_llm.sh @@ -10,20 +10,20 @@ python3 main_ppo_split.py \ actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=16 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=8 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=32 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ actor_rollout_ref.rollout.tensor_model_parallel_size=4 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=32 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=8 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.path=deepseek-ai/deepseek-llm-7b-chat \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=16 \ + critic.ppo_micro_batch_size_per_gpu=8 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ diff --git a/requirements.txt b/requirements.txt index 84f9df7a..bf1cadac 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ flash-attn hydra-core numpy pandas +peft pybind11 ray tensordict<0.6 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..7a7aadbc --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml b/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml index e6e36bc5..da1294e3 100644 --- a/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml +++ b/tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml @@ -22,12 +22,16 @@ actor_rollout_ref: actor: strategy: fsdp # This is for backward-compatibility ppo_mini_batch_size: 200 - ppo_micro_batch_size: 200 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 clip_ratio: 0.2 entropy_coeff: 0.0 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -40,13 +44,15 @@ actor_rollout_ref: param_offload: False grad_offload: False optimizer_offload: False + fsdp_size: -1 ref: fsdp_config: param_offload: False wrap_policy: # transformer_layer_cls_to_wrap: None min_num_params: 0 - micro_batch_size: 200 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size @@ -61,21 +67,25 @@ actor_rollout_ref: dtype: bfloat16 # should align with FSDP gpu_memory_utilization: 0.1 ignore_eos: False - micro_batch_size: 200 + micro_batch_size_per_gpu: 200 enforce_eager: True free_cache_engine: True load_format: dummy_dtensor tensor_model_parallel_size: 1 max_num_batched_tokens: 8192 max_num_seqs: 1024 - log_prob_micro_batch_size: 200 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: False # could get higher throughput # for hf rollout do_sample: True # number of responses (i.e. num sample times) n: 1 # > 1 for grpo + critic: strategy: fsdp optim: @@ -88,6 +98,7 @@ critic: enable_gradient_checkpointing: False use_remove_padding: False fsdp_config: + fsdp_size: -1 param_offload: False grad_offload: False optimizer_offload: False @@ -95,8 +106,10 @@ critic: # transformer_layer_cls_to_wrap: None min_num_params: 0 ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - ppo_micro_batch_size: 200 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} @@ -122,7 +135,9 @@ reward_model: use_remove_padding: False fsdp_config: min_num_params: 0 - micro_batch_size: 8 + fsdp_size: -1 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: null # set a number max_length: null ulysses_sequence_parallel_size: 1 # sp size diff --git a/tests/e2e/arithmetic_sequence/rl/main_trainer.py b/tests/e2e/arithmetic_sequence/rl/main_trainer.py index 18fdd457..90e9a9e2 100644 --- a/tests/e2e/arithmetic_sequence/rl/main_trainer.py +++ b/tests/e2e/arithmetic_sequence/rl/main_trainer.py @@ -105,14 +105,6 @@ def main(config): from omegaconf import OmegaConf pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values - dp_size = config.trainer.n_gpus_per_node * config.trainer.nnodes - # normalize batch_size - # TODO: move this inside each role - config.actor_rollout_ref.actor.ppo_mini_batch_size //= dp_size - config.actor_rollout_ref.actor.ppo_micro_batch_size //= dp_size - config.critic.ppo_micro_batch_size //= dp_size - config.actor_rollout_ref.rollout.micro_batch_size //= dp_size - # print the config # print initial config print('Config after normalizing batch_size') diff --git a/tests/e2e/check_results.py b/tests/e2e/check_results.py index bd3151f0..c12e2493 100644 --- a/tests/e2e/check_results.py +++ b/tests/e2e/check_results.py @@ -48,5 +48,5 @@ def extract_reward_from_line(line): best_reward = reward print(f'Best reward is {best_reward}') - assert best_reward > 0.2, f'Best reward must be greater than 0.3. best_reward: {best_reward}' + assert best_reward > 0.2, f'Best reward must be greater than 0.2. best_reward: {best_reward}' print('Check passes') diff --git a/tests/e2e/run_qwen_gsm8k_function_rm.sh b/tests/e2e/run_qwen_gsm8k_function_rm.sh index 459fbdb7..107674ea 100644 --- a/tests/e2e/run_qwen_gsm8k_function_rm.sh +++ b/tests/e2e/run_qwen_gsm8k_function_rm.sh @@ -13,21 +13,21 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.model.path=Qwen/Qwen2.5-0.5B \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=32 \ + critic.ppo_micro_batch_size_per_gpu=4 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ diff --git a/tests/e2e/run_qwen_gsm8k_function_rm_no_rmpad.sh b/tests/e2e/run_qwen_gsm8k_function_rm_no_rmpad.sh index 5250b813..9b628fbd 100644 --- a/tests/e2e/run_qwen_gsm8k_function_rm_no_rmpad.sh +++ b/tests/e2e/run_qwen_gsm8k_function_rm_no_rmpad.sh @@ -13,21 +13,21 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.model.use_remove_padding=False \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=False \ critic.model.path=Qwen/Qwen2.5-0.5B \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=32 \ + critic.ppo_micro_batch_size_per_gpu=4 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ diff --git a/tests/e2e/run_qwen_gsm8k_model_rm.sh b/tests/e2e/run_qwen_gsm8k_model_rm.sh index b7ef53ef..61da215f 100644 --- a/tests/e2e/run_qwen_gsm8k_model_rm.sh +++ b/tests/e2e/run_qwen_gsm8k_model_rm.sh @@ -15,22 +15,22 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=True \ critic.optim.lr_warmup_steps_ratio=0.05 \ critic.model.path=Qwen/Qwen2.5-0.5B \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=32 \ + critic.ppo_micro_batch_size_per_gpu=4 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ @@ -38,7 +38,7 @@ python3 -m verl.trainer.main_ppo \ reward_model.model.path=Qwen/Qwen2.5-0.5B\ reward_model.model.use_remove_padding=True \ reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size=16 \ + reward_model.micro_batch_size_per_gpu=16 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ trainer.logger=['console'] \ diff --git a/tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh b/tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh index cd06c2f8..e54d5568 100644 --- a/tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh +++ b/tests/e2e/run_qwen_gsm8k_model_rm_no_rmpad.sh @@ -15,22 +15,22 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.use_remove_padding=False \ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.model.use_remove_padding=False \ critic.optim.lr_warmup_steps_ratio=0.05 \ critic.model.path=Qwen/Qwen2.5-0.5B \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=32 \ + critic.ppo_micro_batch_size_per_gpu=4 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ @@ -38,7 +38,7 @@ python3 -m verl.trainer.main_ppo \ reward_model.model.path=Qwen/Qwen2.5-0.5B\ reward_model.model.use_remove_padding=False \ reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size=16 \ + reward_model.micro_batch_size_per_gpu=16 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ +trainer.val_before_train=False \ diff --git a/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh b/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh index e50a95fc..c4a686c6 100644 --- a/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh +++ b/tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh @@ -15,13 +15,11 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=32 \ actor_rollout_ref.actor.use_dynamic_bsz=True \ actor_rollout_ref.actor.ppo_max_token_len_per_gpu=12000 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ @@ -33,7 +31,6 @@ python3 -m verl.trainer.main_ppo \ critic.optim.lr_warmup_steps_ratio=0.05 \ critic.model.path=Qwen/Qwen2.5-0.5B \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=32 \ critic.use_dynamic_bsz=True \ critic.ppo_max_token_len_per_gpu=98304 \ critic.model.fsdp_config.param_offload=False \ @@ -43,7 +40,6 @@ python3 -m verl.trainer.main_ppo \ reward_model.model.path=Qwen/Qwen2.5-0.5B\ reward_model.model.use_remove_padding=True \ reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size=16 \ reward_model.use_dynamic_bsz=True \ reward_model.forward_max_token_len_per_gpu=98304 \ algorithm.kl_ctrl.kl_coef=0.001 \ diff --git a/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh b/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh index 18c3c2ad..7ab764f8 100644 --- a/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh +++ b/tests/e2e/run_qwen_gsm8k_model_rm_ulysses.sh @@ -15,16 +15,17 @@ python3 -m verl.trainer.main_ppo \ actor_rollout_ref.model.use_remove_padding=True \ actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \ actor_rollout_ref.actor.ppo_mini_batch_size=256 \ - actor_rollout_ref.actor.ppo_micro_batch_size=32 \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.actor.ulysses_sequence_parallel_size=2 \ actor_rollout_ref.actor.fsdp_config.param_offload=False \ actor_rollout_ref.actor.fsdp_config.grad_offload=False \ actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ - actor_rollout_ref.rollout.log_prob_micro_batch_size=128 \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=-1 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ - actor_rollout_ref.ref.log_prob_micro_batch_size=128 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=16 \ actor_rollout_ref.ref.fsdp_config.param_offload=True \ critic.optim.lr=1e-5 \ critic.ulysses_sequence_parallel_size=2 \ @@ -32,16 +33,17 @@ python3 -m verl.trainer.main_ppo \ critic.optim.lr_warmup_steps_ratio=0.05 \ critic.model.path=Qwen/Qwen2.5-0.5B \ critic.model.enable_gradient_checkpointing=False \ - critic.ppo_micro_batch_size=32 \ + critic.ppo_micro_batch_size_per_gpu=4 \ critic.model.fsdp_config.param_offload=False \ critic.model.fsdp_config.grad_offload=False \ critic.model.fsdp_config.optimizer_offload=False \ + critic.model.fsdp_config.fsdp_size=-1 \ reward_model.enable=True \ reward_model.ulysses_sequence_parallel_size=2 \ reward_model.model.path=Qwen/Qwen2.5-0.5B\ reward_model.model.use_remove_padding=True \ reward_model.model.fsdp_config.param_offload=True \ - reward_model.micro_batch_size=16 \ + reward_model.micro_batch_size_per_gpu=16 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.critic_warmup=0 \ +trainer.val_before_train=False \ diff --git a/tests/e2e/run_ray_trainer.sh b/tests/e2e/run_ray_trainer.sh index 51d18fcc..30457e64 100644 --- a/tests/e2e/run_ray_trainer.sh +++ b/tests/e2e/run_ray_trainer.sh @@ -11,6 +11,10 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \ data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \ data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \ actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=200 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=200 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=200 \ + critic.ppo_micro_batch_size_per_gpu=200 \ critic.model.path=tests/e2e/arithmetic_sequence/model | tee $OUTPUT_FILE; python3 tests/e2e/check_results.py --output_file=$OUTPUT_FILE diff --git a/tests/sft/run_sft.sh b/tests/sft/run_sft.sh new file mode 100644 index 00000000..89132832 --- /dev/null +++ b/tests/sft/run_sft.sh @@ -0,0 +1,22 @@ +# Tested with 2 & 4 GPUs + +set -x + +torchrun --standalone --nnodes=1 --nproc_per_node=8 \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + +data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=32 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$HOME/ckpts/ \ + trainer.project_name=qwen2.5-sft \ + trainer.experiment_name=gsm8k-sft-gemma-2b-it \ + trainer.total_training_steps=1 \ + trainer.logger=['console'] \ + trainer.default_hdfs_dir=null $@ + +rm -rf $HOME/ckpts/ \ No newline at end of file diff --git a/tests/sft/run_sft_qwen05_peft.sh b/tests/sft/run_sft_qwen05_peft.sh new file mode 100644 index 00000000..2d1744f6 --- /dev/null +++ b/tests/sft/run_sft_qwen05_peft.sh @@ -0,0 +1,38 @@ +# Tested with 2 & 4 GPUs + +set -x + +if [ "$#" -lt 2 ]; then + echo "Usage: run_qwen_05_peft.sh [other_configs...]" + exit 1 +fi + +nproc_per_node=$1 +save_path=$2 + +# Shift the arguments so $@ refers to the rest +shift 2 + +torchrun --standalone --nnodes=1 --nproc_per_node=$nproc_per_node \ + -m verl.trainer.fsdp_sft_trainer \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + optim.lr=1e-4 \ + +data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size_per_gpu=4 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + trainer.default_local_dir=$save_path \ + trainer.project_name=gsm8k-sft \ + trainer.experiment_name=gsm8k-sft-qwen-2.5-0.5b-instruct \ + trainer.logger=['console'] \ + trainer.total_training_steps=1 \ + trainer.default_hdfs_dir=null $@ \ + model.lora_rank=32\ + model.lora_alpha=16 \ + model.target_modules=all-linear + + # Or you can do this: + # model.target_modules=[q_proj,v_proj] \ diff --git a/tests/sft/run_sft_sp_loss_match.sh b/tests/sft/run_sft_sp_loss_match.sh new file mode 100644 index 00000000..a63328ec --- /dev/null +++ b/tests/sft/run_sft_sp_loss_match.sh @@ -0,0 +1,24 @@ +# Tested with 2 & 4 GPUs + +set -x + +torchrun --standalone --nnodes=1 --nproc_per_node=8 \ + tests/sft/test_sp_loss_match.py \ + data.train_files=$HOME/data/gsm8k/train.parquet \ + data.val_files=$HOME/data/gsm8k/test.parquet \ + data.prompt_key=extra_info \ + data.response_key=extra_info \ + +data.prompt_dict_keys=['question'] \ + +data.response_dict_keys=['answer'] \ + data.micro_batch_size=32 \ + model.partial_pretrain=Qwen/Qwen2.5-0.5B-Instruct \ + ulysses_sequence_parallel_size=2 \ + use_remove_padding=True \ + trainer.default_local_dir=$HOME/ckpts/ \ + trainer.project_name=qwen2.5-sft \ + trainer.experiment_name=gsm8k-sft-gemma-2b-it \ + trainer.total_training_steps=1 \ + trainer.logger=['console'] \ + trainer.default_hdfs_dir=null $@ + +rm -rf $HOME/ckpts/ diff --git a/tests/sft/test_sp_loss_match.py b/tests/sft/test_sp_loss_match.py new file mode 100644 index 00000000..69223d3d --- /dev/null +++ b/tests/sft/test_sp_loss_match.py @@ -0,0 +1,128 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed +from tensordict import TensorDict +from verl.trainer.fsdp_sft_trainer import FSDPSFTTrainer +from torch.distributed.device_mesh import init_device_mesh +from verl.utils.distributed import initialize_global_process_group + + +def test_trainer_forward_consistency(trainer: FSDPSFTTrainer, total_steps: int = 4): + """Test consistency between original forward pass and SP+rmpad forward passes. + + Args: + trainer: The FSDPSFTTrainer instance to test + total_steps: Number of steps to test (default: 4) + """ + if trainer.device_mesh.get_rank() == 0: + print("\nStarting debug comparison between original and SP+rmpad forward passes...") + print(f"Sequence parallel size: {trainer.config.ulysses_sequence_parallel_size}") + print(f"Remove padding: {trainer.use_remove_padding}\n") + + steps_remaining = total_steps + + for epoch in range(1): # Just one epoch for testing + trainer.train_sampler.set_epoch(epoch=epoch) + for data in trainer.train_dataloader: + data = TensorDict(data, batch_size=trainer.config.data.train_batch_size).cuda() + trainer.fsdp_model.train() + micro_batches = data.split(trainer.config.data.micro_batch_size) + + for idx, micro_batch in enumerate(micro_batches): + if trainer.device_mesh.get_rank() == 0: + print(f"\nProcessing micro batch {idx + 1}/{len(micro_batches)}") + + # Compute losses using both methods + # Disable SP and rmpad + trainer.use_remove_padding = False + old_sp = trainer.config.ulysses_sequence_parallel_size + trainer.config.ulysses_sequence_parallel_size = 1 + loss_ref = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) + + # Do SP and rmpad + trainer.config.ulysses_sequence_parallel_size = old_sp + trainer.use_remove_padding = True + loss_sp = trainer._compute_loss_and_backward(micro_batch.copy(), do_backward=False) + + # Collect losses across all ranks + loss_ref_all = loss_ref.clone() + loss_sp_all = loss_sp.clone() + torch.distributed.all_reduce(loss_ref_all, op=torch.distributed.ReduceOp.AVG) + torch.distributed.all_reduce(loss_sp_all, op=torch.distributed.ReduceOp.AVG) + + # Calculate relative difference of averaged losses + rel_diff = torch.abs(loss_ref_all - loss_sp_all) / (torch.abs(loss_ref_all) + 1e-8) + + if trainer.device_mesh.get_rank() == 0: + print("\nComparison Results (Averaged across ranks):") + print(f"Reference Loss: {loss_ref_all.item():.6f}") + print(f"SP+rmpad Loss: {loss_sp_all.item():.6f}") + print(f"Relative Difference: {rel_diff.item():.6f}") + + assert rel_diff.item() < 1e-2, "Significant difference detected between averaged losses!" + print("Loss difference is within the acceptable range.") + + steps_remaining -= 1 + if steps_remaining == 0: + break + if steps_remaining == 0: + break + break + + if trainer.device_mesh.get_rank() == 0: + print("\nDebug comparison completed successfully.") + + +def create_trainer(config): + """Create and initialize a trainer instance with the given config. + + Args: + config: Configuration object with training parameters + + Returns: + FSDPSFTTrainer: Initialized trainer instance + """ + local_rank, rank, world_size = initialize_global_process_group() + + device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) + + dp_size = world_size // config.ulysses_sequence_parallel_size + ulysses_device_mesh = init_device_mesh(device_type='cuda', + mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), + mesh_dim_names=('dp', 'sp')) + + return FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) + + +def main(config): + """Main function to run trainer tests. + + Args: + config: Configuration object with training parameters + """ + trainer = create_trainer(config) + test_trainer_forward_consistency(trainer) + + +if __name__ == '__main__': + import hydra + from omegaconf import DictConfig + + @hydra.main(config_path="../../verl/trainer/config", config_name="sft_trainer") + def hydra_entry(cfg: DictConfig) -> None: + main(cfg) + + hydra_entry() diff --git a/tests/utility/test_tensor_dict_utilities.py b/tests/utility/test_tensor_dict_utilities.py index c952d5a9..344cf3a8 100644 --- a/tests/utility/test_tensor_dict_utilities.py +++ b/tests/utility/test_tensor_dict_utilities.py @@ -41,8 +41,12 @@ def test_union_tensor_dict(): data = union_tensor_dict(data1, data_with_copied_obs) data = np.random.random(100) - a = {'a': data} - b = {'a': data} + data2 = [float('nan') for _ in range(99)] + data2.append('nan') + data2 = np.array(data2, dtype=object) + data3 = np.tile(data2, (2, 1)) + a = {'a': data, 'b': data2, 'c': data3} + b = {'a': data, 'b': data2, 'c': data3} b_ = {'a': np.random.random(100)} union_numpy_dict(a, b) with pytest.raises(AssertionError): diff --git a/verl/protocol.py b/verl/protocol.py index 803da366..80626242 100644 --- a/verl/protocol.py +++ b/verl/protocol.py @@ -18,6 +18,7 @@ import pickle import numpy as np +import pandas as pd import copy from dataclasses import dataclass, field from typing import Callable, Dict, List, Union @@ -82,7 +83,8 @@ def union_numpy_dict(tensor_dict1: dict[np.ndarray], tensor_dict2: dict[np.ndarr if key in tensor_dict1: assert isinstance(tensor_dict2[key], np.ndarray) assert isinstance(tensor_dict1[key], np.ndarray) - assert np.all(tensor_dict2[key] == tensor_dict1[key]), \ + # to properly deal with nan and object type + assert pd.DataFrame(tensor_dict2[key]).equals(pd.DataFrame(tensor_dict1[key])), \ f'{key} in tensor_dict1 and tensor_dict2 are not the same object' tensor_dict1[key] = val diff --git a/verl/trainer/config/generation.yaml b/verl/trainer/config/generation.yaml index ed805a8c..27d92116 100644 --- a/verl/trainer/config/generation.yaml +++ b/verl/trainer/config/generation.yaml @@ -23,13 +23,13 @@ rollout: dtype: bfloat16 # should align with FSDP gpu_memory_utilization: 0.5 ignore_eos: False - micro_batch_size: 256 enforce_eager: True free_cache_engine: True load_format: dummy_dtensor tensor_model_parallel_size: 1 max_num_batched_tokens: 8192 max_num_seqs: 1024 - log_prob_micro_batch_size: 8 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: 8 # for hf rollout do_sample: True \ No newline at end of file diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index 6ae26851..fce8a89a 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -20,7 +20,9 @@ actor_rollout_ref: actor: strategy: megatron # This is for backward-compatibility ppo_mini_batch_size: 256 - ppo_micro_batch_size: 64 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: False clip_ratio: 0.2 entropy_coeff: 0.001 ppo_epochs: 1 @@ -48,7 +50,8 @@ actor_rollout_ref: seed: 1 load_weight: True param_offload: False - log_prob_micro_batch_size: 32 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null rollout: name: vllm temperature: 1.0 @@ -66,7 +69,10 @@ actor_rollout_ref: tensor_model_parallel_size: 2 max_num_batched_tokens: 8192 max_num_seqs: 1024 - log_prob_micro_batch_size: 2 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null + disable_log_stats: True + enable_chunked_prefill: False # could get higher throughput # for hf rollout do_sample: True layer_name_map: @@ -98,7 +104,9 @@ critic: seed: 1 load_weight: True ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - ppo_micro_batch_size: 2 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null + use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} shuffle: ${actor_rollout_ref.actor.shuffle} cliprange_value: 0.5 @@ -121,7 +129,9 @@ reward_model: external_lib: ${actor_rollout_ref.model.external_lib} load_weight: True param_offload: False - micro_batch_size: 64 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: null + use_dynamic_bsz: ${critic.use_dynamic_bsz} max_length: null algorithm: diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 15cb223e..04561cbe 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -21,12 +21,16 @@ actor_rollout_ref: actor: strategy: fsdp # This is for backward-compatibility ppo_mini_batch_size: 256 - ppo_micro_batch_size: 64 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null use_dynamic_bsz: False ppo_max_token_len_per_gpu: 16384 # n * ${data.max_prompt_length} + ${data.max_response_length} grad_clip: 1.0 clip_ratio: 0.2 entropy_coeff: 0.001 + use_kl_loss: False # True for GRPO + kl_loss_coef: 0.001 # for grpo + kl_loss_type: low_var_kl # for grpo ppo_epochs: 1 shuffle: False ulysses_sequence_parallel_size: 1 # sp size @@ -50,8 +54,8 @@ actor_rollout_ref: wrap_policy: # transformer_layer_cls_to_wrap: None min_num_params: 0 - fsdp_size: -1 - log_prob_micro_batch_size: 128 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} ulysses_sequence_parallel_size: ${actor_rollout_ref.actor.ulysses_sequence_parallel_size} # sp size @@ -72,9 +76,12 @@ actor_rollout_ref: tensor_model_parallel_size: 2 max_num_batched_tokens: 8192 max_num_seqs: 1024 - log_prob_micro_batch_size: 128 + log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu + log_prob_micro_batch_size_per_gpu: null log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu} + disable_log_stats: True + enable_chunked_prefill: False # could get higher throughput # for hf rollout do_sample: True # number of responses (i.e. num sample times) @@ -104,8 +111,10 @@ critic: min_num_params: 0 fsdp_size: -1 ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} - ppo_micro_batch_size: 64 + ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu + ppo_micro_batch_size_per_gpu: null forward_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size_per_gpu: ${critic.ppo_micro_batch_size_per_gpu} use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz} ppo_max_token_len_per_gpu: 32768 # (${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}) * 2 forward_max_token_len_per_gpu: ${critic.ppo_max_token_len_per_gpu} @@ -126,7 +135,9 @@ reward_model: fsdp_config: min_num_params: 0 param_offload: False - micro_batch_size: 64 + fsdp_size: -1 + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: null # set a number max_length: null ulysses_sequence_parallel_size: 1 # sp size use_dynamic_bsz: ${critic.use_dynamic_bsz} diff --git a/verl/trainer/config/sft_trainer.yaml b/verl/trainer/config/sft_trainer.yaml index 1bf7b6ec..9ac707ad 100644 --- a/verl/trainer/config/sft_trainer.yaml +++ b/verl/trainer/config/sft_trainer.yaml @@ -1,6 +1,7 @@ data: train_batch_size: 256 - micro_batch_size: 16 # this is also val batch size + micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu + micro_batch_size_per_gpu: 4 # this is also val batch size train_files: ~/data/gsm8k/train.parquet val_files: ~/data/gsm8k/test.parquet prompt_key: question @@ -19,13 +20,17 @@ model: external_lib: null enable_gradient_checkpointing: False trust_remote_code: False + lora_rank: 0 # Set to positive value to enable LoRA (e.g., 32) + lora_alpha: 16 # LoRA scaling factor + target_modules: all-linear # Target modules for LoRA adaptation optim: lr: 1e-5 betas: [0.9, 0.95] weight_decay: 0.01 warmup_steps_ratio: 0.1 clip_grad: 1.0 - +ulysses_sequence_parallel_size: 1 +use_remove_padding: False trainer: default_local_dir: /tmp/sft_model default_hdfs_dir: hdfs://tmp/experiments/gsm8k/gemma-1.1-7b-it/ # change the hdfs path here @@ -33,6 +38,7 @@ trainer: project_name: gsm8k-sft experiment_name: test total_epochs: 4 + total_training_steps: null logger: ['console'] seed: 1 diff --git a/verl/trainer/fsdp_sft_trainer.py b/verl/trainer/fsdp_sft_trainer.py index bb876a73..51efb4b9 100644 --- a/verl/trainer/fsdp_sft_trainer.py +++ b/verl/trainer/fsdp_sft_trainer.py @@ -25,24 +25,32 @@ import logging import re +from contextlib import nullcontext import torch import torch.distributed from torch import nn, optim from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, CPUOffload +from tqdm import tqdm from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedModel, AutoConfig from verl.utils.torch_functional import get_cosine_schedule_with_warmup from tensordict import TensorDict from torch.utils.data import DataLoader, DistributedSampler +from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis from verl.utils.fsdp_utils import get_fsdp_wrap_policy, init_fn, get_init_weight_context_manager from verl.utils.dataset import SFTDataset from verl.utils.fs import copy_local_path_from_hdfs from verl.utils.tracking import Tracking - +from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group from torch.distributed.device_mesh import DeviceMesh import verl.utils.hdfs_io as hdfs_io from verl.utils.debug import log_gpu_memory_usage +from peft import LoraConfig, TaskType, get_peft_model + +from verl.workers.sharding_manager import FSDPUlyssesShardingManager +from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad +from verl import DataProto logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_SFT_LOGGING_LEVEL', 'WARN')) @@ -55,11 +63,25 @@ def extract_step(path): return None +def convert_to_regular_types(obj): + """Convert Hydra configs and other special types to regular Python types.""" + from omegaconf import ListConfig, DictConfig + if isinstance(obj, (ListConfig, DictConfig)): + return {k: convert_to_regular_types(v) for k, v in obj.items()} if isinstance(obj, DictConfig) else list(obj) + elif isinstance(obj, (list, tuple)): + return [convert_to_regular_types(x) for x in obj] + elif isinstance(obj, dict): + return {k: convert_to_regular_types(v) for k, v in obj.items()} + return obj + + class FSDPSFTTrainer(object): - def __init__(self, config, device_mesh: DeviceMesh): + def __init__(self, config, device_mesh: DeviceMesh, ulysses_device_mesh: DeviceMesh): self.config = config self.device_mesh = device_mesh + self.ulysses_device_mesh = ulysses_device_mesh + self.sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) # build tokenizer first local_model_path = copy_local_path_from_hdfs(src=self.config.model.partial_pretrain, verbose=True) from verl.utils import hf_tokenizer @@ -70,6 +92,13 @@ def __init__(self, config, device_mesh: DeviceMesh): # normalize dp size self._normalize_config_bsz() + # Set sequence parallel size + self.config.ulysses_sequence_parallel_size = getattr(self.config, 'ulysses_sequence_parallel_size', 1) + self.use_remove_padding = getattr(self.config, 'use_remove_padding', False) + if self.device_mesh.get_rank() == 0: + print(f'Using sequence parallel size: {self.config.ulysses_sequence_parallel_size}') + print(f'Using remove padding: {self.use_remove_padding}') + self._build_dataloader() # build model self._build_model_optimizer() @@ -79,15 +108,15 @@ def __init__(self, config, device_mesh: DeviceMesh): print(self.config) def _normalize_config_bsz(self): - dp_size = self.device_mesh.size() + dp_size = self.device_mesh.size(0) if not self.ulysses_device_mesh else self.ulysses_device_mesh.size(0) if self.device_mesh.get_rank() == 0: print(f'Normalize batch size by dp {dp_size}') - assert self.config.data.train_batch_size % dp_size == 0 - assert self.config.data.micro_batch_size % dp_size == 0 + assert self.config.data.train_batch_size % dp_size == 0, f"Global batch size {self.config.data.train_batch_size} is not divisible by dp size {dp_size}" self.config.data.train_batch_size //= dp_size - self.config.data.micro_batch_size //= dp_size + + assert self.config.data.train_batch_size % self.config.data.micro_batch_size_per_gpu == 0 def _build_dataloader(self): config = self.config @@ -110,8 +139,21 @@ def _build_dataloader(self): truncation=config.data.truncation) # build dataloader - rank = self.device_mesh.get_rank() - world_size = self.device_mesh.size() + # Use data parallel rank and size instead of global rank and world size + + # If doing SP, we need to use the local rank and size + if self.config.ulysses_sequence_parallel_size > 1: + rank = self.ulysses_device_mesh.get_local_rank('dp') + world_size = self.ulysses_device_mesh.size(0) + if self.ulysses_device_mesh.get_rank() == 0: + print(f'Using SP rank {rank} and size {world_size} for data distribution') + print(f'Each SP rank gets different data, but the same data WITHIN the same rank') + else: + rank = self.device_mesh.get_rank() + world_size = self.device_mesh.size() + if self.device_mesh.get_rank() == 0: + print(f'Using FSDP rank {rank} and size {world_size} for data distribution') + self.train_sampler = DistributedSampler(self.train_dataset, shuffle=True, num_replicas=world_size, @@ -130,7 +172,7 @@ def _build_dataloader(self): rank=rank, drop_last=True) self.val_dataloader = DataLoader(dataset=self.val_dataset, - batch_size=config.data.micro_batch_size, + batch_size=config.data.micro_batch_size_per_gpu, sampler=self.val_sampler, num_workers=8, pin_memory=True, @@ -152,6 +194,14 @@ def _build_model_optimizer(self): trust_remote_code = self.config.model.trust_remote_code # load config first config = AutoConfig.from_pretrained(local_model_path, trust_remote_code=trust_remote_code) + if self.config.ulysses_sequence_parallel_size > 1: + assert self.use_remove_padding, "Sequence parallel is only supported when remove_padding is enabled" + from verl.models.registry import check_model_support_rmpad + check_model_support_rmpad(config.model_type) + + if self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1: + from verl.models.transformers.monkey_patch import apply_monkey_patch + apply_monkey_patch(config, verbose=True) # This may be very large init_context = get_init_weight_context_manager(use_meta_tensor=not config.tie_word_embeddings) @@ -163,6 +213,18 @@ def _build_model_optimizer(self): attn_implementation='flash_attention_2', trust_remote_code=trust_remote_code) + if self.config.model.get('lora_rank', 0) > 0: + self.model.enable_input_require_grads() + # Convert config to regular Python types before creating PEFT model + lora_config = { + 'task_type': TaskType.CAUSAL_LM, + 'r': self.config.model.lora_rank, + 'lora_alpha': self.config.model.lora_alpha, + 'target_modules': convert_to_regular_types(self.config.model.target_modules), + 'bias': "none" + } + self.model = get_peft_model(self.model, LoraConfig(**lora_config)) + if self.config.model.enable_gradient_checkpointing: self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) @@ -172,7 +234,9 @@ def _build_model_optimizer(self): reduce_dtype=torch.float32, buffer_dtype=torch.float32) - auto_wrap_policy = get_fsdp_wrap_policy(self.model, config=self.config.model.fsdp_config.wrap_policy) + auto_wrap_policy = get_fsdp_wrap_policy(self.model, + config=self.config.model.fsdp_config.wrap_policy, + is_lora=self.config.model.get('lora_rank', 0) > 0) if self.device_mesh.get_rank() == 0: print(auto_wrap_policy) @@ -201,53 +265,116 @@ def _build_model_optimizer(self): log_gpu_memory_usage('After initialize optimizer', logger=logger) - steps_per_epoch = len(self.train_dataloader) - total_steps = steps_per_epoch * self.config.trainer.total_epochs + self.steps_per_epoch = len(self.train_dataloader) + self.total_steps = self.steps_per_epoch * self.config.trainer.total_epochs if self.device_mesh.get_rank() == 0: print( - f'Number of steps/epoch {steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {total_steps}' + f'Number of steps/epoch {self.steps_per_epoch}, number of epochs {self.config.trainer.total_epochs}, total number of steps {self.total_steps}' ) - num_warmup_steps = int(total_steps * self.config.optim.warmup_steps_ratio) + num_warmup_steps = int(self.total_steps * self.config.optim.warmup_steps_ratio) self.lr_scheduler = get_cosine_schedule_with_warmup(optimizer=self.optimizer, num_warmup_steps=num_warmup_steps, - num_training_steps=total_steps) - - def _compute_loss(self, batch): - loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda() - labels = batch['input_ids'][:, 1:].cuda() - - with torch.autocast(device_type='cuda', dtype=torch.bfloat16): - output = self.fsdp_model(input_ids=batch['input_ids'], - attention_mask=batch['attention_mask'], - position_ids=batch['position_ids'], - use_cache=False) # prevent model thinks it it generating + num_training_steps=self.total_steps) - logits = output.logits + def _compute_loss_and_backward(self, batch, do_backward=True): + """Compute loss with optional sequence parallelism and remove padding features""" + use_sp = self.use_remove_padding and self.config.ulysses_sequence_parallel_size > 1 - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels.contiguous() - # Flatten the tokens + # Move inputs to GPU and prepare loss mask + input_ids = batch['input_ids'].cuda() + attention_mask = batch['attention_mask'].cuda() + position_ids = batch['position_ids'].cuda() + loss_mask = batch.pop('loss_mask')[:, :-1].reshape(-1).cuda() loss_fct = nn.CrossEntropyLoss(reduction='none') - shift_logits = shift_logits.view(-1, self.model.config.vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - loss = loss * loss_mask - - valid_token_this_rank = torch.sum(loss_mask) - - if self.config.data.balance_dp_token: - torch.distributed.all_reduce(valid_token_this_rank) # becomes total valid tokens in all ranks - dp_size = torch.distributed.get_world_size() - else: - dp_size = 1 - loss = torch.sum(loss) / valid_token_this_rank * dp_size # possible bugs here for dp - return loss + # Context manager for sequence parallel if needed + context = self.sharding_manager if use_sp else nullcontext() + with context: + with torch.autocast(device_type='cuda', dtype=torch.bfloat16): + if not use_sp: + # Standard forward pass without sequence parallel + labels = input_ids[:, 1:].contiguous() + output = self.fsdp_model(input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + use_cache=False) + logits = output.logits + + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels.contiguous() + # Flatten the tokens + shift_logits = shift_logits.view(-1, self.model.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + loss = loss * loss_mask.to(loss.device) + else: + # IMPORTANT: We have a big assumption here, so we can shard the SAME sequence across SP ranks + # i.e., each GPU has <1 sequence, and each SP group has 1 sequence + # 1. All SP ranks will receive the *SAME* batch + # 2. Different SP groups will receive *DIFFERENT* batches + # This is implemented by the DistributedSampler + + batch_size, seqlen = input_ids.shape + # Remove padding + input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1), + attention_mask) # input_ids_rmpad (total_nnz, ...) + input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) + + # Unpad position_ids to align rotary + position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), + indices).transpose(0, 1) + + # Pad and slice inputs for sequence parallelism + input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs( + input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size()) + # For computing loss + input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz) + input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs( + input_ids_rmpad_rolled, None, get_ulysses_sequence_parallel_world_size()) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad) + + # Forward pass + output = self.fsdp_model( + input_ids=input_ids_rmpad_sliced, + attention_mask=None, # Not needed with flash attention varlen + position_ids=position_ids_rmpad_padded, + use_cache=False) + + # Compute loss locally then aggregate + logits_rmpad = output.logits.squeeze(0) + input_ids_rmpad_rolled = input_ids_rmpad_rolled.to(logits_rmpad.device) + loss = loss_fct(logits_rmpad, input_ids_rmpad_rolled) + # Gather and unpad for sequence parallelism + loss = gather_outpus_and_unpad(loss, gather_dim=0, unpad_dim=0, padding_size=pad_size) + + # This is the loss collected from all ulysses ranks + full_loss = pad_input(hidden_states=loss.unsqueeze(-1), + indices=indices, + batch=batch_size, + seqlen=seqlen) + full_loss = full_loss.squeeze(-1)[:, :-1] # Remove last token's loss + full_loss = full_loss.reshape(-1) + loss_mask = loss_mask.to(full_loss.device) + loss = full_loss * loss_mask + + valid_token_this_rank = torch.sum(loss_mask) + + if self.config.data.balance_dp_token: + torch.distributed.all_reduce(valid_token_this_rank) + dp_size = self.ulysses_device_mesh.size('dp') if use_sp else torch.distributed.get_world_size() + else: + dp_size = 1 + + loss = torch.sum(loss) / valid_token_this_rank * dp_size + + if do_backward: + loss.backward() + return loss def training_step(self, batch: TensorDict): self.fsdp_model.train() @@ -258,12 +385,11 @@ def training_step(self, batch: TensorDict): log_gpu_memory_usage('After optimizer zero_grad', logger=logger) - micro_batches = batch.split(self.config.data.micro_batch_size) + micro_batches = batch.split(self.config.data.micro_batch_size_per_gpu) n_micro_batches = len(micro_batches) step_loss = 0 for micro_batch in micro_batches: - loss = self._compute_loss(batch=micro_batch) / n_micro_batches - loss.backward() + loss = self._compute_loss_and_backward(batch=micro_batch) / n_micro_batches step_loss += loss.item() self.fsdp_model.clip_grad_norm_(max_norm=self.config.optim.clip_grad) @@ -288,7 +414,7 @@ def training_step(self, batch: TensorDict): def validation_step(self, batch: TensorDict): self.fsdp_model.eval() with torch.no_grad(): - loss = self._compute_loss(batch) + loss = self._compute_loss_and_backward(batch, do_backward=False) torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.AVG) return loss @@ -320,22 +446,51 @@ def fit(self): default_backend=self.config.trainer.logger) global_step = 0 + # compute the total training steps. + # the total training steps in SFT is mainly for early exit + total_training_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + + if self.config.trainer.total_training_steps is not None: + total_training_steps = self.config.trainer.total_training_steps + + self.total_training_steps = total_training_steps + print(f'Total training steps: {self.total_training_steps}') # TODO (zhangchi.usc1992) add back checkpoint manager. Currently, it blocks when uploading to hdfs. So very slow. for epoch in range(self.config.trainer.total_epochs): self.train_sampler.set_epoch(epoch=epoch) - for data in self.train_dataloader: + for data in tqdm(self.train_dataloader, + total=self.steps_per_epoch, + desc=f"Epoch {epoch+1}/{self.config.trainer.total_epochs}"): data = TensorDict(data, batch_size=self.config.data.train_batch_size).cuda() metric = self.training_step(data) if rank == 0: tracking.log(data=metric, step=global_step) global_step += 1 + # for early exit validation + if global_step >= self.total_training_steps: + # Perform final validation + val_losses = [] + for val_data in self.val_dataloader: + val_data = TensorDict(val_data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() + val_loss = self.validation_step(val_data) + val_losses.append(val_loss) + if rank == 0: + avg_val_loss = torch.mean(torch.stack(val_losses)) + metric = {'val/loss': avg_val_loss.detach().item()} + tracking.log(data=metric, step=global_step) + torch.distributed.barrier() + + # Save final checkpoint + self.save_checkpoint(step=global_step) + return + # validation val_losses = [] for data in self.val_dataloader: - data = TensorDict(data, batch_size=self.config.data.micro_batch_size).cuda() + data = TensorDict(data, batch_size=self.config.data.micro_batch_size_per_gpu).cuda() val_loss = self.validation_step(data) val_losses.append(val_loss) if rank == 0: @@ -360,8 +515,12 @@ def fit(self): def main(config): local_rank, rank, world_size = initialize_global_process_group() - device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('dp',)) - trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh) + device_mesh = init_device_mesh(device_type='cuda', mesh_shape=(world_size,), mesh_dim_names=('fsdp',)) + dp_size = world_size // config.ulysses_sequence_parallel_size + ulysses_device_mesh = init_device_mesh(device_type='cuda', + mesh_shape=(dp_size, config.ulysses_sequence_parallel_size), + mesh_dim_names=('dp', 'sp')) + trainer = FSDPSFTTrainer(config=config, device_mesh=device_mesh, ulysses_device_mesh=ulysses_device_mesh) trainer.fit() diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index 6e33b339..cf4ce7fe 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -20,6 +20,7 @@ import numpy as np import torch +from collections import defaultdict import verl.utils.torch_functional as verl_F @@ -106,6 +107,54 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc return advantages, returns +# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. +def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor, + eos_mask: torch.Tensor, + index: torch.Tensor, + epsilon: float = 1e-6): + """ + Compute advantage for GRPO, operating only on Outcome reward + (with only one scalar reward for each response). + Args: + token_level_rewards: `(torch.Tensor)` + shape: (bs, response_length) + eos_mask: `(torch.Tensor)` + shape: (bs, response_length) + + Returns: + advantages: `(torch.Tensor)` + shape: (bs, response_length) + Returns: `(torch.Tensor)` + shape: (bs, response_length) + """ + response_length = token_level_rewards.shape[-1] + non_zero_mask = (token_level_rewards != 0) + scores = (token_level_rewards * non_zero_mask).sum(dim=-1) + + id2score = defaultdict(list) + id2mean = {} + id2std = {} + + with torch.no_grad(): + bsz = scores.shape[0] + for i in range(bsz): + id2score[index[i]].append(scores[i]) + for idx in id2score: + if len(id2score[idx]) == 1: + id2mean[idx] = torch.tensor(0.0) + id2std[idx] = torch.tensor(1.0) + elif len(id2score[idx]) > 1: + id2mean[idx] = torch.mean(torch.tensor(id2score[idx])) + id2std[idx] = torch.std(torch.tensor([id2score[idx]])) + else: + raise ValueError(f"no score in prompt index: {idx}") + for i in range(bsz): + scores[i] = (scores[i] - id2mean[index[i]]) / (id2std[index[i]] + epsilon) + scores = scores.unsqueeze(-1).tile([1, response_length]) * eos_mask + + return scores, scores + + def compute_rewards(token_level_scores, old_log_prob, ref_log_prob, kl_ratio): kl = old_log_prob - ref_log_prob return token_level_scores - kl * kl_ratio @@ -210,6 +259,14 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe if kl_penalty == "mse": return 0.5 * (logprob - ref_logprob).square() + # J. Schulman. Approximating kl divergence, 2020. + # # URL http://joschu.net/blog/kl-approx.html. + if kl_penalty == 'low_var_kl': + kl = ref_logprob - logprob + ratio = torch.exp(kl) + kld = (ratio - kl - 1).contiguous() + return torch.clamp(kld, min=-10, max=10) + if kl_penalty == "full": # so, here logprob and ref_logprob should contain the logits for every token in vocabulary raise NotImplementedError diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 1d481686..071562dc 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -17,6 +17,7 @@ """ import os +import uuid from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum @@ -112,16 +113,16 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, return data, metrics -def compute_advantage(data: DataProto, gamma, lam, adv_estimator): - values = data.batch['values'] - responses = data.batch['responses'] - response_length = responses.size(1) - attention_mask = data.batch['attention_mask'] - response_mask = attention_mask[:, -response_length:] - token_level_rewards = data.batch['token_level_rewards'] - +def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): + # prepare response group # TODO: add other ways to estimate advantages if adv_estimator == 'gae': + values = data.batch['values'] + responses = data.batch['responses'] + response_length = responses.size(-1) + attention_mask = data.batch['attention_mask'] + response_mask = attention_mask[:, -response_length:] + token_level_rewards = data.batch['token_level_rewards'] advantages, returns = core_algos.compute_gae_advantage_return(token_level_rewards=token_level_rewards, values=values, eos_mask=response_mask, @@ -129,6 +130,18 @@ def compute_advantage(data: DataProto, gamma, lam, adv_estimator): lam=lam) data.batch['advantages'] = advantages data.batch['returns'] = returns + elif adv_estimator == 'grpo': + token_level_rewards = data.batch['token_level_rewards'] + index = data.non_tensor_batch['uid'] + responses = data.batch['responses'] + response_length = responses.size(-1) + attention_mask = data.batch['attention_mask'] + response_mask = attention_mask[:, -response_length:] + advantages, returns = core_algos.compute_grpo_outcome_advantage(token_level_rewards=token_level_rewards, + eos_mask=response_mask, + index=index) + data.batch['advantages'] = advantages + data.batch['returns'] = returns else: raise NotImplementedError return data @@ -156,14 +169,13 @@ def _compute_response_info(batch): ) -def compute_data_metrics(batch): +def compute_data_metrics(batch, use_critic=True): # TODO: add response length sequence_score = batch.batch['token_level_scores'].sum(-1) sequence_reward = batch.batch['token_level_rewards'].sum(-1) advantages = batch.batch['advantages'] returns = batch.batch['returns'] - values = batch.batch['values'] max_response_length = batch.batch['responses'].shape[-1] @@ -178,10 +190,12 @@ def compute_data_metrics(batch): valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) - valid_values = torch.masked_select(values, response_mask) - return_diff_var = torch.var(valid_returns - valid_values) - return_var = torch.var(valid_returns) + if use_critic: + values = batch.batch['values'] + valid_values = torch.masked_select(values, response_mask) + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) metrics = { # score @@ -212,13 +226,15 @@ def compute_data_metrics(batch): torch.max(valid_returns).detach().item(), 'critic/returns/min': torch.min(valid_returns).detach().item(), - # values - 'critic/values/mean': - torch.mean(valid_values).detach().item(), - 'critic/values/max': - torch.max(valid_values).detach().item(), - 'critic/values/min': - torch.min(valid_values).detach().item(), + **({ + # values + 'critic/values/mean': torch.mean(valid_values).detach().item(), + 'critic/values/max': torch.max(valid_values).detach().item(), + 'critic/values/min': torch.min(valid_values).detach().item(), + # vf explained var + 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } if use_critic else {}), + # response length 'response_length/mean': torch.mean(response_length).detach().item(), @@ -237,8 +253,6 @@ def compute_data_metrics(batch): torch.min(prompt_length).detach().item(), 'prompt_length/clip_ratio': torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), - # vf explained var - 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), } return metrics @@ -323,8 +337,13 @@ def __init__(self, else: self.kl_ctrl = core_algos.FixedKLController(kl_coef=0.) + self._validate_config() self._create_dataloader() + def _validate_config(self): + from verl.utils.config import validate_config + validate_config(self.config) + def _create_dataloader(self): from torch.utils.data import DataLoader # TODO: we have to make sure the batch size is divisible by the dp size @@ -449,8 +468,9 @@ def init_workers(self): critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=self.config.critic) self.resource_pool_to_cls[resource_pool]['critic'] = critic_cls self.use_critic = True + elif self.config.algorithm.adv_estimator == 'grpo': + self.use_critic = False else: - # support GRPO and ReMax raise NotImplementedError # create reference policy if needed @@ -572,6 +592,8 @@ def fit(self): with _timer('gen', timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + batch.non_tensor_batch['uid'] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], + dtype=object) # repeat to align with repeated responses in rollout batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) batch = batch.union(gen_batch_output) @@ -584,6 +606,11 @@ def fit(self): # compute global_valid tokens batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist() + # recompute old_log_probs + with _timer('old_log_prob', timing_raw): + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + batch = batch.union(old_log_prob) + if self.use_reference_policy: # compute reference log_prob with _timer('ref', timing_raw): @@ -591,9 +618,10 @@ def fit(self): batch = batch.union(ref_log_prob) # compute values - with _timer('values', timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) + if self.use_critic: + with _timer('values', timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) with _timer('adv', timing_raw): # compute scores. Support both model and function-based. @@ -609,16 +637,20 @@ def fit(self): batch.batch['token_level_scores'] = reward_tensor # compute rewards. apply_kl_penalty if available - batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty) - metrics.update(kl_metrics) + if not self.config.actor_rollout_ref.actor.use_kl_loss: + batch, kl_metrics = apply_kl_penalty(batch, + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + else: + batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] # compute advantages, executed on the driver process batch = compute_advantage(batch, - self.config.algorithm.gamma, - self.config.algorithm.lam, - adv_estimator=self.config.algorithm.adv_estimator) + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n) # update critic if self.use_critic: @@ -648,7 +680,7 @@ def fit(self): self._save_checkpoint() # collect metrics - metrics.update(compute_data_metrics(batch=batch)) + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: make a canonical logger that supports various backend diff --git a/verl/utils/config.py b/verl/utils/config.py index 5c9298c4..0dcd73a9 100644 --- a/verl/utils/config.py +++ b/verl/utils/config.py @@ -21,3 +21,69 @@ def update_dict_with_config(dictionary: Dict, config: DictConfig): for key in dictionary: if hasattr(config, key): dictionary[key] = getattr(config, key) + + +def validate_config(config): + # number of GPUs total + n_gpus = config.trainer.n_gpus_per_node * config.trainer.nnodes + + # 1. Check total batch size for data correctness + real_train_batch_size = config.data.train_batch_size * config.actor_rollout_ref.rollout.n + assert real_train_batch_size % n_gpus == 0, \ + f"real_train_batch_size ({real_train_batch_size}) must be divisible by total n_gpus ({n_gpus})." + + # A helper function to check "micro_batch_size" vs "micro_batch_size_per_gpu" + # We throw an error if the user sets both. The new convention is "..._micro_batch_size_per_gpu". + def check_mutually_exclusive(mbs, mbs_per_gpu, name: str): + if mbs is None and mbs_per_gpu is None: + raise ValueError(f"[{name}] Please set at least one of '{name}.micro_batch_size' or " + f"'{name}.micro_batch_size_per_gpu'.") + + if mbs is not None and mbs_per_gpu is not None: + raise ValueError(f"[{name}] You have set both '{name}.micro_batch_size' AND " + f"'{name}.micro_batch_size_per_gpu'. Please remove '{name}.micro_batch_size' " + f"because only '*_micro_batch_size_per_gpu' is supported (the former is deprecated).") + + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + # actor: ppo_micro_batch_size vs. ppo_micro_batch_size_per_gpu + check_mutually_exclusive(config.actor_rollout_ref.actor.ppo_micro_batch_size, + config.actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu, "actor_rollout_ref.actor") + + # reference: log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive(config.actor_rollout_ref.ref.log_prob_micro_batch_size, + config.actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.ref") + + # The rollout section also has log_prob_micro_batch_size vs. log_prob_micro_batch_size_per_gpu + check_mutually_exclusive(config.actor_rollout_ref.rollout.log_prob_micro_batch_size, + config.actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu, + "actor_rollout_ref.rollout") + + if not config.critic.use_dynamic_bsz: + # Check for critic micro-batch size conflicts + check_mutually_exclusive(config.critic.ppo_micro_batch_size, config.critic.ppo_micro_batch_size_per_gpu, + "critic") + + # Check for reward model micro-batch size conflicts + if config.reward_model.enable and not config.reward_model.use_dynamic_bsz: + check_mutually_exclusive(config.reward_model.micro_batch_size, config.reward_model.micro_batch_size_per_gpu, + "reward_model") + + # Actor + # if NOT dynamic_bsz, we must ensure: + # ppo_mini_batch_size is divisible by ppo_micro_batch_size + # ppo_micro_batch_size * sequence_parallel_size >= n_gpus + if not config.actor_rollout_ref.actor.use_dynamic_bsz: + sp_size = config.actor_rollout_ref.actor.ulysses_sequence_parallel_size + if config.actor_rollout_ref.actor.ppo_micro_batch_size is not None: + assert config.actor_rollout_ref.actor.ppo_mini_batch_size % config.actor_rollout_ref.actor.ppo_micro_batch_size == 0 + assert config.actor_rollout_ref.actor.ppo_micro_batch_size * sp_size >= n_gpus + + # critic + if not config.critic.use_dynamic_bsz: + sp_size = config.critic.ulysses_sequence_parallel_size + if config.critic.ppo_micro_batch_size is not None: + assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0 + assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus + + print("[validate_config] All configuration checks passed successfully!") diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index d4b18f05..48328d7a 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -142,4 +142,8 @@ def __getitem__(self, item): if self.return_raw_chat: row_dict['raw_prompt'] = chat.tolist() - return row_dict \ No newline at end of file + # add index for each prompt + index = row_dict.get("extra_info", {}).get("index", 0) + row_dict["index"] = index + + return row_dict diff --git a/verl/utils/fsdp_utils.py b/verl/utils/fsdp_utils.py index a4a8b8a2..90c534aa 100644 --- a/verl/utils/fsdp_utils.py +++ b/verl/utils/fsdp_utils.py @@ -12,10 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict import functools +import json +import math +import itertools +import os +from contextlib import contextmanager from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy from transformers.trainer_pt_utils import get_module_class_from_name import torch +import torch.nn as nn +import torch.distributed as dist def init_fn(x: torch.nn.Module): @@ -37,7 +45,14 @@ def get_init_weight_context_manager(use_meta_tensor=True): # Copyright 2020-present the HuggingFace Inc. team. # Adapted from https://github.com/huggingface/transformers/src/transformers/trainer.py -def get_fsdp_wrap_policy(module, config=None): +def get_fsdp_wrap_policy(module, config=None, is_lora=False): + """Get FSDP wrap policy for the module. + + Args: + module: The module to get wrap policy for + config: Configuration for wrap policy + is_lora: Whether to enable lambda policy for LoRA modules + """ if config is None: config = {} @@ -49,8 +64,26 @@ def get_fsdp_wrap_policy(module, config=None): default_transformer_cls_names_to_wrap) min_num_params = config.get('min_num_params', 0) auto_wrap_policy = None + + policies = [] + + from torch.distributed.fsdp.wrap import _or_policy, lambda_auto_wrap_policy, transformer_auto_wrap_policy + + # Add lambda policy for LoRA modules if is_lora is True + if is_lora: + + def lambda_policy_fn(module): + if (len(list(module.named_children())) == 0 and getattr(module, "weight", None) is not None and + module.weight.requires_grad): + return True + return False + + lambda_policy = functools.partial(lambda_auto_wrap_policy, lambda_fn=lambda_policy_fn) + policies.append(lambda_policy) + if min_num_params > 0: - auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) + size_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=min_num_params) + policies.append(size_policy) elif fsdp_transformer_layer_cls_to_wrap is not None: transformer_cls_to_wrap = set() for layer_class in fsdp_transformer_layer_cls_to_wrap: @@ -60,11 +93,15 @@ def get_fsdp_wrap_policy(module, config=None): else: transformer_cls_to_wrap.add(transformer_cls) - auto_wrap_policy = functools.partial( + transformer_policy = functools.partial( transformer_auto_wrap_policy, - # Transformer layer class to wrap transformer_layer_cls=transformer_cls_to_wrap, ) + policies.append(transformer_policy) + + if len(policies) > 0: + auto_wrap_policy = functools.partial(_or_policy, policies=policies) + return auto_wrap_policy @@ -120,3 +157,173 @@ def load_fsdp_optimizer(optimizer, device_id): if isinstance(value, torch.Tensor): state[key] = value.to(device_id, non_blocking=True) torch.cuda.empty_cache() + + +@contextmanager +def meta_device_init(): + """ + Create model parameters with meta device. + + Note buffers in model will still be initialized in default device (e.g., CPU), + since the buffers can be non-persistent and filled with expected values that can + NOT be captured in meta device. + """ + device = torch.device("meta") + old_register_parameter = nn.Module.register_parameter + registered = set() + + def register_empty_parameter(module, name, param): + old_register_parameter(module, name, param) + # we will skip register shared parameters as it + # is already registered previously + if param is not None and param not in registered: + param_cls = type(module._parameters[name]) + kwargs = module._parameters[name].__dict__ + kwargs["requires_grad"] = param.requires_grad + module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) + registered.add(module._parameters[name]) + + try: + nn.Module.register_parameter = register_empty_parameter + yield + finally: + registered.clear() + nn.Module.register_parameter = old_register_parameter + + +def parallel_load_safetensors(filepath): + """ + Parallel load safetensors from huggingface checkpoint + + Huggingface checkpoint contains: + + - config.json: a json file for model configuration + - model.safetensor.index.json: a json file for safetensors (parameters & buffers) index + - model-000x-of-ooxx.safetensors: a binary file for safetensors (parameters & buffers) chunks + + Or (when model is small), + + - model.safetensors: a binary file for all parameters and buffers + + Each rank will own a part of model chunks and load them directly into GPU memory. + """ + from safetensors.torch import load_file + + safetensors2param = {} + + index_file = os.path.join(filepath, "model.safetensors.index.json") + if os.path.exists(index_file): + index = json.load(open(index_file, "rb")) + for param_name, filename in index["weight_map"].items(): + safetensors2param.setdefault(filename, []).append(param_name) + else: + # in this case, the model is small and we can load it all at once + param_file = os.path.join(filepath, "model.safetensors") + assert os.path.exists(param_file), f"Cannot find {param_file}" + states = load_file(param_file) + for param_name in states: + safetensors2param.setdefault("model.safetensors", []).append(param_name) + del states + + total_files = len(safetensors2param) + ckpt_chunks = sorted(safetensors2param.keys()) + world_size = dist.get_world_size() + size = int(math.ceil(total_files / world_size)) + ckpt_chunks = [ckpt_chunks[rank * size:rank * size + size] for rank in range(world_size)] + + shard_states = {} + device = torch.cuda.current_device() + for rank, files in enumerate(ckpt_chunks): + if rank == dist.get_rank(): + for file in files: + file = os.path.join(filepath, file) + states = load_file(file, device=device) + # print(f"rank {rank} loading {file}...") + shard_states.update(states) + else: + for file in files: + for param_name in safetensors2param[file]: + shard_states[param_name] = rank + return shard_states + + +def parallel_init_module_fn(module: torch.nn.Module, shard_states: Dict[str, torch.nn.Parameter]): + """ + Generate a function to initialize sub-modules in the `module` with `shard_states` + from huggingface checkpoint. + + Args: + module (torch.nn.Module): the global module to be initialized + shard_states (Dict[str, torch.nn.Parameter]): the shard states from huggingface checkpoint + + Returns: + init_fn (Callable): a function to initialize sub-modules in the `module` with `shard_states` + """ + + state2fqn = {} + for name, state in itertools.chain(module.named_parameters(remove_duplicate=False), + module.named_buffers(remove_duplicate=False)): + state2fqn.setdefault(state, []).append(name) + # remove standalone parameters and buffers + shared = {s for s, names in state2fqn.items() if len(names) > 1} + materialized_states = {} + + @torch.no_grad() + def create_and_sync_state(param_name, state, is_param): + assert param_name in shard_states, f"{param_name} not loaded" + device = torch.cuda.current_device() + if is_param: + param = torch.nn.Parameter(torch.empty_like(state.data, device=device), requires_grad=state.requires_grad) + else: # buffer + param = torch.empty_like(state.data, device=device) + loaded = shard_states[param_name] + if isinstance(loaded, (torch.nn.Parameter, torch.Tensor)): + # NOTE: loaded.dtype can be different with param.dtype + param.data.copy_(loaded.data) + dist.broadcast(param.data, src=dist.get_rank()) + else: + assert isinstance(loaded, int) # the rank that holds the state + dist.broadcast(param.data, src=loaded) + shard_states.pop(param_name) + del loaded + return param + + def init_fn(sub_mod: torch.nn.Module, recurse: bool = True): + param_and_buffers = tuple(sub_mod.named_parameters(recurse=False)) + tuple(sub_mod.named_buffers(recurse=False)) + # param_and_buffers = sorted(sub_mod.named_parameters(recurse=False), key=lambda x: x[0]) + for name, state in param_and_buffers: + if not state.is_meta: + continue + is_param = name in sub_mod._parameters + fqn = state2fqn[state].pop(0) + # non-persistent buffers will not be saved in state dict, we can safely skip it + if (not is_param) and fqn not in shard_states: + if state.is_meta: + raise RuntimeError( + f"find a non-persistent buffer ({fqn}) initiated with device meta. " + "Such buffer is not saved in checkpoint and user should guarantee to init in CPU / GPU device.") + continue + # for shared parameter, we get it from the first time it is created + if state in shared: + if state not in materialized_states: + materialized_states[state] = create_and_sync_state(fqn, state, is_param) + else: + if fqn in shard_states: + shard_states.pop(fqn) + materialize_state = materialized_states[state] + # for not shared parameter, we create it directly + else: + materialize_state = create_and_sync_state(fqn, state, is_param) + if is_param: + sub_mod._parameters[name] = materialize_state + else: + sub_mod._buffers[name] = materialize_state + if recurse: + for module in sub_mod.children(): + init_fn(module, recurse=True) + + # for debug + # if len(shard_states) == 0: print("clear") + return sub_mod + + return init_fn diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 3a618589..cedfd53c 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -26,7 +26,7 @@ from verl.trainer.ppo import core_algos from verl.workers.actor import BasePPOActor from verl.utils.py_functional import append_to_dict -from verl.utils.torch_functional import logprobs_from_logits, log_probs_from_logits_all_rmpad +from verl.utils.torch_functional import logprobs_from_logits, masked_mean from verl.utils.ulysses import ulysses_pad_and_slice_inputs, gather_outpus_and_unpad from verl.utils.seqlen_balancing import rearrange_micro_batches, get_reverse_idx import verl.utils.torch_functional as verl_F @@ -204,11 +204,11 @@ def update_policy(self, data: DataProto): # make sure we are in training mode self.actor_module.train() - assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0 - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages'] + if self.config.use_kl_loss: + select_keys.append('ref_log_prob') batch = data.select(batch_keys=select_keys).batch # Split to make minibatch iterator for updating the actor @@ -223,8 +223,9 @@ def update_policy(self, data: DataProto): max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu # split batch into micro_batches - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size) + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) self.actor_optimizer.zero_grad() @@ -254,7 +255,23 @@ def update_policy(self, data: DataProto): # compute policy loss policy_loss = pg_loss - entropy_loss * entropy_coeff - loss = policy_loss / self.gradient_accumulation + if self.config.use_kl_loss: + ref_log_prob = data['ref_log_prob'] + # compute kl loss + kld = core_algos.kl_penalty(logprob=log_prob, + ref_logprob=ref_log_prob, + kl_penalty=self.config.kl_loss_type) + kl_loss = masked_mean(kld, response_mask) + + policy_loss = policy_loss - kl_loss * self.config.kl_loss_coef + metrics['actor/kl_loss'] = kl_loss.detach().item() + metrics['actor/kl_coef'] = self.config.kl_loss_coef + + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size) + else: + loss = policy_loss / self.gradient_accumulation loss.backward() data = { diff --git a/verl/workers/actor/megatron_actor.py b/verl/workers/actor/megatron_actor.py index e674a28f..694185a3 100644 --- a/verl/workers/actor/megatron_actor.py +++ b/verl/workers/actor/megatron_actor.py @@ -54,7 +54,7 @@ def __init__(self, config, model_config, megatron_config: ModelParallelConfig, a Args: config (OmegaConf): the basic config that contains the hyper-parameters of PPO Actor. It must contain - ``ppo_micro_batch_size``: minibatch size when updating ppo. + ``ppo_micro_batch_size_per_gpu``: micro batch size when updating ppo. ``ppo_mini_batch_size``: minibatch size when updating ppo using the batch data. @@ -232,7 +232,7 @@ def forward_backward_batch(self, data: DataProto, forward_only=False, post_proce if data.meta_info.get('micro_batch_size', None) is not None: batch_size = data.meta_info['micro_batch_size'] else: - batch_size = self.config.ppo_micro_batch_size + batch_size = self.config.ppo_micro_batch_size_per_gpu batches = split_dict_tensor_into_batches(data.batch, batch_size=batch_size) # compute input shapes for pp stages input_shapes = compute_transformers_input_shapes( diff --git a/verl/workers/critic/dp_critic.py b/verl/workers/critic/dp_critic.py index 0842ff4a..f2eb44c2 100644 --- a/verl/workers/critic/dp_critic.py +++ b/verl/workers/critic/dp_critic.py @@ -45,9 +45,6 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt self.use_remove_padding = self.config.model.get('use_remove_padding', False) print(f'Critic use_remove_padding={self.use_remove_padding}') - assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0 - self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size - self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) def _forward_micro_batch(self, micro_batch): @@ -161,7 +158,8 @@ def update_critic(self, data: DataProto): max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len) else: - micro_batches = mini_batch.split(self.config.ppo_micro_batch_size) + micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu) + self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu self.critic_optimizer.zero_grad() @@ -186,7 +184,12 @@ def update_critic(self, data: DataProto): returns=returns, eos_mask=eos_mask, cliprange_value=self.config.cliprange_value) - loss = vf_loss / self.gradient_accumulation + if self.config.use_dynamic_bsz: + # relative to the dynamic bsz + loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size) + else: + loss = vf_loss / self.gradient_accumulation + loss.backward() data = { diff --git a/verl/workers/critic/megatron_critic.py b/verl/workers/critic/megatron_critic.py index a39ad4b4..4545db68 100644 --- a/verl/workers/critic/megatron_critic.py +++ b/verl/workers/critic/megatron_critic.py @@ -118,7 +118,7 @@ def forward_backward_batch(self, data: DataProto, forward_only=False): group=mpu.get_pipeline_model_parallel_group()) # split into micro-batches data.batch['attention_mask'] = data.batch['attention_mask'].to(bool) - batches = split_dict_tensor_into_batches(data.batch, batch_size=self.config.ppo_micro_batch_size) + batches = split_dict_tensor_into_batches(data.batch, batch_size=self.config.ppo_micro_batch_size_per_gpu) n_micro_batch = len(batches) seq_len = batches[0]['input_ids'].shape[1] @@ -182,7 +182,7 @@ def forward_step(batch_iter, model): model=self.critic_module, num_microbatches=n_micro_batch, input_shapes=input_shapes, # must set for flash-attn sequence packing - seq_length=self.config.ppo_micro_batch_size * seq_len, # no use when input_shapes was set + seq_length=self.config.ppo_micro_batch_size_per_gpu * seq_len, # no use when input_shapes was set hidden_size=self.model_config.hidden_size, # no use when input_shapes was set micro_batch_size=1, # no use when input_shapes was set forward_only=forward_only, @@ -193,7 +193,7 @@ def forward_step(batch_iter, model): data_iterator=batch_generator, model=self.critic_module, num_microbatches=n_micro_batch, - seq_length=self.config.ppo_micro_batch_size * seq_len, # in use for pp = 1 + seq_length=self.config.ppo_micro_batch_size_per_gpu * seq_len, # in use for pp = 1 hidden_size=self.model_config.hidden_size, # in use for pp = 1 micro_batch_size=1, # in use for pp = 1 forward_only=forward_only, diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index 2df08f8f..50fc04e9 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -21,6 +21,7 @@ import torch import torch.distributed +from torch.distributed.device_mesh import init_device_mesh import verl.utils.hdfs_io as hdfs_io import verl.utils.torch_functional as verl_F from omegaconf import DictConfig, open_dict @@ -44,6 +45,30 @@ logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) +def create_device_mesh(world_size, fsdp_size): + if fsdp_size < 0 or fsdp_size >= world_size: + device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + else: + raise ValueError( + 'HSDP is not supported yet because it produces incorrect results for now. Please set fsdp_size=-1') + assert world_size % fsdp_size == 0 + device_mesh = init_device_mesh('cuda', + mesh_shape=(world_size // fsdp_size, fsdp_size), + mesh_dim_names=['ddp', 'fsdp']) + return device_mesh + + +def get_sharding_strategy(device_mesh): + from torch.distributed.fsdp import ShardingStrategy + if device_mesh.ndim == 1: + sharding_strategy = ShardingStrategy.FULL_SHARD + elif device_mesh.ndim == 2: + sharding_strategy = ShardingStrategy.HYBRID_SHARD + else: + raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2") + return sharding_strategy + + class ActorRolloutRefWorker(Worker): """ This worker can be instantiated as a standalone actor or a standalone rollout or a standalone reference policy @@ -59,9 +84,8 @@ def __init__(self, config: DictConfig, role: str): # build device mesh for FSDP world_size = torch.distributed.get_world_size() - from torch.distributed.device_mesh import init_device_mesh # TODO(sgm): support FSDP hybrid shard for larger model - self.device_mesh = init_device_mesh('cuda', mesh_shape=(world_size,), mesh_dim_names=['fsdp']) + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=self.config.actor.fsdp_config.fsdp_size) # build device mesh for Ulysses Sequence Parallel self.ulysses_device_mesh = None @@ -94,16 +118,24 @@ def __init__(self, config: DictConfig, role: str): # normalize config if self._is_actor: - self.config.actor.ppo_mini_batch_size //= self.device_mesh.shape[0] - self.config.actor.ppo_micro_batch_size //= self.device_mesh.shape[0] + self.config.actor.ppo_mini_batch_size //= (self.device_mesh.shape[0] // self.ulysses_sequence_parallel_size) self.config.actor.ppo_mini_batch_size *= self.config.rollout.n - self.config.actor.ppo_micro_batch_size *= self.config.rollout.n - if self._is_rollout: - self.config.rollout.log_prob_micro_batch_size //= self.device_mesh.shape[0] - self.config.rollout.log_prob_micro_batch_size *= self.config.rollout.n - if self._is_ref: - self.config.ref.log_prob_micro_batch_size //= self.device_mesh.shape[0] - self.config.ref.log_prob_micro_batch_size *= self.config.rollout.n + # micro bsz + if self.config.actor.ppo_micro_batch_size is not None: + self.config.actor.ppo_micro_batch_size //= (self.device_mesh.shape[0] // + self.ulysses_sequence_parallel_size) + self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size + assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0 + # normalize rollout config + if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None: + self.config.rollout.log_prob_micro_batch_size //= (self.device_mesh.shape[0] // + self.ulysses_sequence_parallel_size) + self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size + # normalize ref config + if self._is_ref and self.config.ref.log_prob_micro_batch_size is not None: + self.config.ref.log_prob_micro_batch_size //= (self.device_mesh.shape[0] // + self.ulysses_sequence_parallel_size) + self.config.ref.log_prob_micro_batch_size_per_gpu = self.config.ref.log_prob_micro_batch_size def _build_model_optimizer(self, model_path, @@ -112,13 +144,16 @@ def _build_model_optimizer(self, override_model_config, use_remove_padding=False, enable_gradient_checkpointing=False, - trust_remote_code=False): + trust_remote_code=False, + role='actor'): from verl.utils.model import print_model_size, update_model_config from verl.utils.torch_dtypes import PrecisionType from transformers import AutoModelForCausalLM, AutoConfig - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, CPUOffload from torch import optim + assert role in ['actor', 'ref'] + log_gpu_memory_usage('Before init from HF AutoModel', logger=logger) local_path = copy_local_path_from_hdfs(model_path) @@ -188,9 +223,6 @@ def _build_model_optimizer(self, mixed_precision = MixedPrecision(param_dtype=param_dtype, reduce_dtype=reduce_dtype, buffer_dtype=buffer_dtype) - if self._is_ref: - mixed_precision = None - auto_wrap_policy = get_fsdp_wrap_policy(module=actor_module, config=fsdp_config.get('wrap_policy', None)) if self._is_rollout and self.config.rollout.name == 'hf': @@ -199,15 +231,16 @@ def _build_model_optimizer(self, print(f'wrap_policy: {auto_wrap_policy}') - # TODO(sgm): support hybrid - if auto_wrap_policy is None: - sharding_strategy = ShardingStrategy.SHARD_GRAD_OP - else: - sharding_strategy = ShardingStrategy.FULL_SHARD + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) # TODO: add transformer policy + # We force reference policy to use CPUOffload to save memory. + # We force turn off CPUOffload for actor because it causes incorrect results when using grad accumulation + cpu_offload = None if role == 'actor' else CPUOffload(offload_params=True) actor_module_fsdp = FSDP( actor_module, + cpu_offload=cpu_offload, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, @@ -221,7 +254,7 @@ def _build_model_optimizer(self, log_gpu_memory_usage('After Actor FSDP init', logger=logger) # TODO: add more optimizer args into config - if self._is_actor: + if role == 'actor': from verl.utils.torch_functional import get_constant_schedule_with_warmup actor_optimizer = optim.AdamW(actor_module_fsdp.parameters(), lr=optim_config.lr, @@ -304,7 +337,8 @@ def init_model(self): override_model_config=override_model_config, use_remove_padding=use_remove_padding, enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False), - trust_remote_code=self.config.model.get('trust_remote_code', False)) + trust_remote_code=self.config.model.get('trust_remote_code', False), + role='actor') # get the original unwrapped module self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module @@ -335,10 +369,8 @@ def init_model(self): override_model_config=override_model_config, use_remove_padding=use_remove_padding, trust_remote_code=self.config.model.get( - 'trust_remote_code', False))[0] - if self._is_offload_param: - offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad) - + 'trust_remote_code', False), + role='ref')[0] OmegaConf.set_struct(self.config.ref, True) with open_dict(self.config.ref): self.config.ref.use_remove_padding = use_remove_padding @@ -397,8 +429,6 @@ def update_actor(self, data: DataProto): @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def generate_sequences(self, prompts: DataProto): prompts = prompts.to('cuda') - # set to False if it is validation - recompute_log_prob = prompts.meta_info.get('recompute_log_prob', True) assert self._is_rollout if self._is_offload_param: @@ -419,19 +449,6 @@ def generate_sequences(self, prompts: DataProto): output = self.rollout_sharding_manager.postprocess_data(output) - if self._is_actor and recompute_log_prob: - # we should always recompute old_log_probs when it is HybridEngine - output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size - output.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu - output.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz - output.meta_info['temperature'] = self.config.rollout.temperature - # perform recompute log_prob - with self.ulysses_sharding_manager: - output = self.ulysses_sharding_manager.preprocess_data(output) - old_log_probs = self.actor.compute_log_prob(data=output) - output.batch['old_log_probs'] = old_log_probs - output = self.ulysses_sharding_manager.postprocess_data(output) - output = output.to('cpu') if self._is_offload_param: @@ -442,18 +459,40 @@ def generate_sequences(self, prompts: DataProto): log_gpu_memory_usage('After recompute log prob', logger=logger) return output + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) + def compute_log_prob(self, data: DataProto): + assert self._is_actor + data = data.to('cuda') + # we should always recompute old_log_probs when it is HybridEngine + data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu + data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu + data.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz + data.meta_info['temperature'] = self.config.rollout.temperature + # perform recompute log_prob + with self.ulysses_sharding_manager: + data = self.ulysses_sharding_manager.preprocess_data(data) + output = self.actor.compute_log_prob(data=data) + output = DataProto.from_dict(tensors={'old_log_probs': output}, + meta_info={'temperature': self.config.rollout.temperature}) + output = self.ulysses_sharding_manager.postprocess_data(output) + + output = output.to('cpu') + + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1: + self.actor.actor_module._handle.reshard(True) + + torch.cuda.empty_cache() + return output + @register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO) def compute_ref_log_prob(self, data: DataProto): assert self._is_ref data = data.to('cuda') - if self._is_offload_param: - load_fsdp_param_and_grad(module=self.ref_module_fsdp, - device_id=torch.cuda.current_device(), - load_grad=self._is_offload_grad) - - micro_batch_size = self.config.ref.log_prob_micro_batch_size + micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu data.meta_info['micro_batch_size'] = micro_batch_size data.meta_info['temperature'] = self.config.rollout.temperature data.meta_info['max_token_len'] = self.config.ref.log_prob_max_token_len_per_gpu @@ -466,8 +505,11 @@ def compute_ref_log_prob(self, data: DataProto): output = output.to('cpu') - if self._is_offload_param: - offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad) + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + if self.world_size > 1: + self.ref_policy.actor_module._handle.reshard(True) + torch.cuda.empty_cache() return output @@ -513,6 +555,10 @@ def __init__(self, config): # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() from torch.distributed.device_mesh import init_device_mesh + + fsdp_size = self.config.model.fsdp_config.fsdp_size + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size @@ -529,9 +575,15 @@ def __init__(self, config): self._is_offload_optimizer = self.config.model.fsdp_config.optimizer_offload # normalize config - self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() - self.config.ppo_micro_batch_size //= torch.distributed.get_world_size() - self.config.forward_micro_batch_size //= torch.distributed.get_world_size() + self.config.ppo_mini_batch_size //= (torch.distributed.get_world_size() // self.ulysses_sequence_parallel_size) + if self.config.ppo_micro_batch_size is not None: + self.config.ppo_micro_batch_size //= (torch.distributed.get_world_size() // + self.ulysses_sequence_parallel_size) + self.config.forward_micro_batch_size //= (torch.distributed.get_world_size() // + self.ulysses_sequence_parallel_size) + self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size + self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size + assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0 def _build_critic_model_optimizer(self, config): # the following line is necessary @@ -615,15 +667,21 @@ def _build_critic_model_optimizer(self, config): log_gpu_memory_usage('Before critic FSDP', logger=None) + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + + # Note: We force turn off CPUOffload for critic because it causes incorrect results when using grad accumulation critic_module = FSDP(critic_module, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, + sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, sync_module_states=True, - forward_prefetch=False) + forward_prefetch=False, + device_mesh=self.device_mesh, + cpu_offload=None) log_gpu_memory_usage('After critic FSDP', logger=None) @@ -674,7 +732,7 @@ def compute_values(self, data: DataProto): load_fsdp_param_and_grad(module=self.critic_module, device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) - micro_batch_size = self.config.forward_micro_batch_size + micro_batch_size = self.config.forward_micro_batch_size_per_gpu data.meta_info['micro_batch_size'] = micro_batch_size data.meta_info['max_token_len'] = self.config.forward_max_token_len_per_gpu data.meta_info['use_dynamic_bsz'] = self.config.use_dynamic_bsz @@ -773,6 +831,10 @@ def __init__(self, config): # build device mesh for Ulysses Sequence Parallel world_size = torch.distributed.get_world_size() from torch.distributed.device_mesh import init_device_mesh + + fsdp_size = self.config.model.fsdp_config.fsdp_size + self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size) + self.ulysses_device_mesh = None self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1) dp = world_size // self.ulysses_sequence_parallel_size @@ -784,7 +846,11 @@ def __init__(self, config): self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh) self.use_remove_padding = self.config.model.get('use_remove_padding', False) - self.config.micro_batch_size //= torch.distributed.get_world_size() + + # normalize config + if self.config.micro_batch_size is not None: + self.config.micro_batch_size //= torch.distributed.get_world_size() + self.config.micro_batch_size_per_gpu = self.config.micro_batch_size def _build_model(self, config): # the following line is necessary @@ -830,16 +896,20 @@ def _build_model(self, config): reward_module.to(torch.bfloat16) auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config) + fsdp_mesh = self.device_mesh + sharding_strategy = get_sharding_strategy(fsdp_mesh) + reward_module = FSDP( reward_module, param_init_fn=init_fn, use_orig_params=False, auto_wrap_policy=auto_wrap_policy, device_id=torch.cuda.current_device(), - sharding_strategy=ShardingStrategy.FULL_SHARD, # zero3 + sharding_strategy=sharding_strategy, # zero3 sync_module_states=True, - cpu_offload=CPUOffload(offload_params=self.config.model.fsdp_config.param_offload), - forward_prefetch=False) + cpu_offload=CPUOffload(offload_params=True), + forward_prefetch=False, + device_mesh=self.device_mesh) return reward_module @@ -996,7 +1066,7 @@ def compute_rm_score(self, data: DataProto): max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len) else: - micro_batches = rm_data.batch.split(self.config.micro_batch_size) + micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu) output = [] for micro_batch in micro_batches: rm_score = self._forward_micro_batch(micro_batch) @@ -1014,6 +1084,10 @@ def compute_rm_score(self, data: DataProto): output = DataProto.from_dict(tensors={'rm_scores': token_level_scores}) output = self.ulysses_sharding_manager.postprocess_data(data=output) + # https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes + # unshard the root FSDP module + self.reward_module._handle.reshard(True) + output = output.to('cpu') torch.cuda.empty_cache() return output diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index 180a7761..d0ae638e 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -112,13 +112,19 @@ def __init__(self, config: DictConfig, role: str): # normalize config if self._is_actor and self._is_rollout: self.config.actor.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() - self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() - self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() + if self.config.actor.ppo_micro_batch_size is not None: + self.config.actor.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() + self.config.rollout.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() + self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size + self.config.rollout.log_prob_micro_batch_size_per_gpu = self.config.rollout.log_prob_micro_batch_size + self._is_offload_param = self.config.actor.get('param_offload', False) self._is_offload_grad = self.config.actor.get('grad_offload', False) self._is_offload_optimizer = self.config.actor.get('optimizer_offload', False) elif self._is_ref: - self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() + if self.config.ref.ppo_micro_batch_size is not None: + self.config.ref.log_prob_micro_batch_size //= mpu.get_data_parallel_world_size() + self.config.ref.ppo_micro_batch_size_per_gpu = self.config.ref.ppo_micro_batch_size self._is_offload_param = self.config.ref.get('param_offload', False) def _build_model_optimizer(self, @@ -361,7 +367,7 @@ def generate_sequences(self, prompts: DataProto): validate = prompts.meta_info.get('validate', False) if self._is_actor and not validate: # we should always recompute old_log_probs when it is HybridEngine - output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size + output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size_per_gpu output.meta_info['temperature'] = self.config.rollout.temperature old_log_probs = self.actor.compute_log_prob(data=output) output.batch['old_log_probs'] = old_log_probs @@ -380,7 +386,7 @@ def compute_ref_log_prob(self, data: DataProto): if self._is_offload_param: load_megatron_param_and_grad(self.ref_module, torch.cuda.current_device(), self._is_offload_grad) - micro_batch_size = self.config.rollout.log_prob_micro_batch_size + micro_batch_size = self.config.rollout.log_prob_micro_batch_size_per_gpu data.meta_info['micro_batch_size'] = micro_batch_size data.meta_info['temperature'] = self.config.rollout.temperature output = self.ref_policy.compute_log_prob(data=data) @@ -439,7 +445,9 @@ def __init__(self, config): # normalize config self.config.ppo_mini_batch_size //= mpu.get_data_parallel_world_size() - self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() + if self.config.ppo_micro_batch_size is not None: + self.config.ppo_micro_batch_size //= mpu.get_data_parallel_world_size() + self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size # TODO(sgm): support critic model offload @@ -609,7 +617,9 @@ def __init__(self, config): set_random_seed(seed=self.config.megatron.seed) # normalize config - self.config.micro_batch_size //= mpu.get_data_parallel_world_size() + if self.config.micro_batch_size is not None: + self.config.micro_batch_size //= mpu.get_data_parallel_world_size() + self.config.micro_batch_size_per_gpu = self.config.micro_batch_size def _build_rm_model(self, model_path, megatron_config: ModelParallelConfig, override_model_config): from megatron.core.models.gpt.gpt_model import ModelType diff --git a/verl/workers/reward_model/megatron/reward_model.py b/verl/workers/reward_model/megatron/reward_model.py index 2c4c1b60..1b58f42c 100644 --- a/verl/workers/reward_model/megatron/reward_model.py +++ b/verl/workers/reward_model/megatron/reward_model.py @@ -196,8 +196,8 @@ def forward_batch(self, data: DataProto): group=mpu.get_pipeline_model_parallel_group()) # split into micro-batches - if self.config is not None and 'ppo_micro_batch_size' in self.config: - infer_batch_size = self.config.ppo_micro_batch_size + if self.config is not None and 'ppo_micro_batch_size_per_gpu' in self.config: + infer_batch_size = self.config.ppo_micro_batch_size_per_gpu else: infer_batch_size = data.batch.batch_size[0] diff --git a/verl/workers/rollout/vllm_rollout/vllm_rollout.py b/verl/workers/rollout/vllm_rollout/vllm_rollout.py index 947d558f..7014ff1e 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_rollout.py +++ b/verl/workers/rollout/vllm_rollout/vllm_rollout.py @@ -74,6 +74,7 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model tensor_parallel_size = self.config.get('tensor_model_parallel_size', 1) assert tensor_parallel_size <= torch.distributed.get_world_size(), \ "tensor parallel size should be less than or equal to the world size" + max_num_batched_tokens = self.config.get('max_num_batched_tokens', 8192) if kwargs.get('train_tp', None) is not None: # deployed with megatron @@ -88,16 +89,21 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model assert model_hf_config.max_position_embeddings >= config.prompt_length + config.response_length, \ "model context length should be greater than total sequence length" - self.inference_engine = LLM(actor_module, - tokenizer=tokenizer, - model_hf_config=model_hf_config, - tensor_parallel_size=tensor_parallel_size, - dtype=config.dtype, - enforce_eager=config.enforce_eager, - gpu_memory_utilization=config.gpu_memory_utilization, - skip_tokenizer_init=False, - max_model_len=config.prompt_length + config.response_length, - load_format=config.load_format) + self.inference_engine = LLM( + actor_module, + tokenizer=tokenizer, + model_hf_config=model_hf_config, + tensor_parallel_size=tensor_parallel_size, + dtype=config.dtype, + enforce_eager=config.enforce_eager, + gpu_memory_utilization=config.gpu_memory_utilization, + skip_tokenizer_init=False, + max_model_len=config.prompt_length + config.response_length, + load_format=config.load_format, + disable_log_stats=config.disable_log_stats, + max_num_batched_tokens=max_num_batched_tokens, + enable_chunked_prefill=config.enable_chunked_prefill, + ) # Offload vllm model to reduce peak memory usage self.inference_engine.offload_model_weights()