Skip to content

Commit

Permalink
[misc] feat: spport rmpad/data-packing in FSDP with transformers (#91)
Browse files Browse the repository at this point in the history
* init commit of rmpad

* add rmpad test

* support rmpad in actor model

* add test for value model

* support rmpad in critic and rm

* fix actor return and fix num_labels and clean not used rmpad

* fix critic and benchmark

* update script

* fix critic

* lint

* fix util issue

* fix unnecessary unpad

* address issues

* fix args

* update test and update rmpad support model list

* fix typo

* fix typo and fix name

* rename rmpad to rename padding

* fix arch to model_type

* add ci for e2e rmpad and fix typo

* lint

* fix ci

* fix typo

* update tests for customize tokenizer in actor

* fix rmpad test

* update requirement of transformers as hf_rollout may have issue
  • Loading branch information
PeterSH6 authored Jan 11, 2025
1 parent e88cf81 commit 569210e
Show file tree
Hide file tree
Showing 19 changed files with 413 additions and 45 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/e2e_gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options: --gpus all --shm-size=10g
Expand All @@ -32,7 +33,12 @@ jobs:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install hf_transfer
pip3 install -e .[test]
- name: Running digit completon e2e training tests on 8 L20 GPUs
run: |
bash tests/e2e/run_ray_trainer.sh
- name: Running digit completon e2e training tests on 8 L20 GPUs (with rmpad)
run: |
pip3 install --upgrade transformers
bash tests/e2e/run_ray_trainer_rmpad.sh
39 changes: 39 additions & 0 deletions .github/workflows/model.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
name: model_rmpad

on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/model.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/model.yml

jobs:
e2e_gpu:
runs-on: [self-hosted, l20-1]
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository and upgrade to latest transformers
run: |
pip3 install -e .[test]
pip3 install --upgrade transformers
- name: Running digit completon e2e training tests on 8 L20 GPUs
run: |
pytest -s tests/model/test_transformer.py
2 changes: 2 additions & 0 deletions examples/ppo_trainer/run_deepseek7b_llm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ python3 -m verl.trainer.main_ppo \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=32 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
Expand All @@ -21,6 +22,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \
Expand Down
2 changes: 2 additions & 0 deletions examples/ppo_trainer/run_gemma.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ python3 -m verl.trainer.main_ppo \
data.max_response_length=512 \
actor_rollout_ref.model.path=google/gemma-2-2b-it \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size=4 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
Expand All @@ -21,6 +22,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size=4 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=google/gemma-2-2b-it \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=4 \
Expand Down
2 changes: 2 additions & 0 deletions examples/ppo_trainer/run_qwen2-7b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ python3 -m verl.trainer.main_ppo \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=16 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
Expand All @@ -29,6 +30,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=Qwen/Qwen2-7B-Instruct \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=16 \
Expand Down
3 changes: 3 additions & 0 deletions examples/ppo_trainer/run_qwen2-7b_rm.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ python3 -m verl.trainer.main_ppo \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=16 \
Expand All @@ -31,6 +32,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.optim.lr_warmup_steps_ratio=0.05 \
critic.model.path=Qwen/Qwen2-7B-Instruct \
critic.model.enable_gradient_checkpointing=False \
Expand All @@ -40,6 +42,7 @@ python3 -m verl.trainer.main_ppo \
critic.model.fsdp_config.optimizer_offload=False \
reward_model.enable=True \
reward_model.model.path=sfairXC/FsfairX-Gemma2-RM-v0.1\
reward_model.model.use_remove_padding=True \
reward_model.model.fsdp_config.param_offload=True \
reward_model.micro_batch_size=16 \
algorithm.kl_ctrl.kl_coef=0.001 \
Expand Down
2 changes: 2 additions & 0 deletions examples/ppo_trainer/run_qwen2.5-32b.sh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \
actor_rollout_ref.model.enable_gradient_checkpointing=False \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=16 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
Expand All @@ -30,6 +31,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=Qwen/Qwen2.5-32B-Instruct \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ dependencies = [
"pybind11",
"ray",
"tensordict",
"transformers",
"transformers<4.48",
"vllm<=0.6.3",
]

Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,6 @@ pandas
pybind11
ray
tensordict<0.6
transformers
transformers<4.48
vllm<=0.6.3
wandb
wandb
4 changes: 4 additions & 0 deletions tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ actor_rollout_ref:
hybrid_engine: True
model:
path: ~/verl/tests/e2e/arithmetic_sequence/model
tokenizer_path: ${actor_rollout_ref.model.path}
external_lib: tests.e2e.envs.digit_completion
override_config: {}
enable_gradient_checkpointing: False
use_remove_padding: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 200
Expand Down Expand Up @@ -76,6 +78,7 @@ critic:
override_config: {}
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
use_remove_padding: False
fsdp_config:
param_offload: False
grad_offload: False
Expand Down Expand Up @@ -104,6 +107,7 @@ reward_model:
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
offload: False
use_remove_padding: False
fsdp_config:
min_num_params: 0
micro_batch_size: 8
Expand Down
2 changes: 1 addition & 1 deletion tests/e2e/arithmetic_sequence/rl/main_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def main(config):
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values

# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.tokenizer_path)
local_path = os.path.expanduser(local_path)
# instantiate tokenizern
tokenizer = AutoTokenizer.from_pretrained(local_path)
Expand Down
14 changes: 14 additions & 0 deletions tests/e2e/run_ray_trainer_rmpad.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#!/usr/bin/env bash

set -e -x

python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \
data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \
actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.model.tokenizer_path=tests/e2e/arithmetic_sequence/model \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.use_remove_padding=True \
trainer.total_epochs=1
129 changes: 129 additions & 0 deletions tests/model/test_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForTokenClassification, AutoTokenizer

import torch
from verl.utils.model import create_random_mask, compute_position_id_with_mask
from verl.utils.torch_functional import masked_mean, log_probs_from_logits_all_rmpad, logprobs_from_logits
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange

from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
# TODO(sgm): add more models for test
# we only need one scale for each model
test_configs = [
LlamaConfig(num_hidden_layers=1),
MistralConfig(num_hidden_layers=1),
GemmaConfig(num_hidden_layers=1),
Qwen2Config(num_hidden_layers=1)
]
# test_cases = ['deepseek-ai/deepseek-llm-7b-chat', 'Qwen/Qwen2-7B-Instruct']


def test_hf_casual_models():
batch_size = 4
seqlen = 128
response_length = 127

for config in test_configs:
# config = AutoConfig.from_pretrained(test_case)
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
model = model.to(device='cuda')
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.5)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here

input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)

# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad,
use_cache=False).logits # (1, total_nnz, vocab_size)

origin_logits = model(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False).logits
origin_logits_rmpad, origin_logits_indices, _, _ = unpad_input(origin_logits, attention_mask)

logits_rmpad = logits_rmpad.squeeze(0)
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
logits_rmpad=logits_rmpad,
indices=indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length) # (batch, seqlen)
origin_log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
logits_rmpad=origin_logits_rmpad,
indices=origin_logits_indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length) # (batch, seqlen)

torch.testing.assert_close(masked_mean(log_probs, attention_mask[:, -response_length - 1:-1]),
masked_mean(origin_log_probs, attention_mask[:, -response_length - 1:-1]),
atol=1e-2,
rtol=1e-5)
print(f'Check pass')


def test_hf_value_models():
batch_size = 4
seqlen = 128

for config in test_configs:
# config = AutoConfig.from_pretrained(test_case)
config.num_labels = 1
setattr(config, 'classifier_dropout', 0)
setattr(config, 'hidden_dropout', 0)
with torch.device('cuda'):
model = AutoModelForTokenClassification.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
model = model.to(device='cuda')
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.5)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here

input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)

origin_logits = model(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False).logits

# input with input_ids_rmpad and postition_ids to enable flash attention varlen
rmpad_logits = model(input_ids_rmpad, position_ids=position_ids_rmpad,
use_cache=False).logits # (1, total_nnz, 1)
rmpad_logits = rmpad_logits.squeeze(0)
pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen)

torch.testing.assert_close(masked_mean(pad_logits, attention_mask[:, :, None]),
masked_mean(origin_logits, attention_mask[:, :, None]),
atol=1e-2,
rtol=1e-5)
print('Value model check pass')


if __name__ == '__main__':
test_hf_casual_models()
test_hf_value_models()
15 changes: 15 additions & 0 deletions verl/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,21 @@

import torch.nn as nn

# Supported models using HF Rmpad
# TODO(sgm): HF may supported more than listed here, we should add more after testing
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config

_REOVEPAD_MODELS = {'llama': LlamaConfig, 'mistral': MistralConfig, 'gemma': GemmaConfig, 'qwen2': Qwen2Config}


def check_model_support_rmpad(model_type: str):
assert isinstance(model_type, str)
if not model_type in _REOVEPAD_MODELS.keys():
raise ValueError(f"Model architecture {model_type} is not supported for now. "
f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}")


# Supported models in Megatron-LM
# Architecture -> (module, class).
_MODELS = {
"LlamaForCausalLM":
Expand Down
3 changes: 3 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ actor_rollout_ref:
external_lib: null
override_config: { }
enable_gradient_checkpointing: False
use_remove_padding: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
Expand Down Expand Up @@ -83,6 +84,7 @@ critic:
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
use_remove_padding: False
fsdp_config:
param_offload: False
grad_offload: False
Expand All @@ -105,6 +107,7 @@ reward_model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
min_num_params: 0
param_offload: False
Expand Down
Loading

0 comments on commit 569210e

Please sign in to comment.