-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[misc] feat: spport rmpad/data-packing in FSDP with transformers (#91)
* init commit of rmpad * add rmpad test * support rmpad in actor model * add test for value model * support rmpad in critic and rm * fix actor return and fix num_labels and clean not used rmpad * fix critic and benchmark * update script * fix critic * lint * fix util issue * fix unnecessary unpad * address issues * fix args * update test and update rmpad support model list * fix typo * fix typo and fix name * rename rmpad to rename padding * fix arch to model_type * add ci for e2e rmpad and fix typo * lint * fix ci * fix typo * update tests for customize tokenizer in actor * fix rmpad test * update requirement of transformers as hf_rollout may have issue
- Loading branch information
Showing
19 changed files
with
413 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
name: model_rmpad | ||
|
||
on: | ||
# Trigger the workflow on push or pull request, | ||
# but only for the main branch | ||
push: | ||
branches: | ||
- main | ||
paths: | ||
- "**/*.py" | ||
- .github/workflows/model.yml | ||
pull_request: | ||
branches: | ||
- main | ||
paths: | ||
- "**/*.py" | ||
- .github/workflows/model.yml | ||
|
||
jobs: | ||
e2e_gpu: | ||
runs-on: [self-hosted, l20-1] | ||
env: | ||
HTTP_PROXY: ${{ secrets.PROXY_HTTP }} | ||
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} | ||
NO_PROXY: "localhost,127.0.0.1" | ||
container: | ||
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3 | ||
options: --gpus all --shm-size=10g | ||
steps: | ||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 | ||
with: | ||
fetch-depth: 0 | ||
- name: Install the current repository and upgrade to latest transformers | ||
run: | | ||
pip3 install -e .[test] | ||
pip3 install --upgrade transformers | ||
- name: Running digit completon e2e training tests on 8 L20 GPUs | ||
run: | | ||
pytest -s tests/model/test_transformer.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,6 @@ pandas | |
pybind11 | ||
ray | ||
tensordict<0.6 | ||
transformers | ||
transformers<4.48 | ||
vllm<=0.6.3 | ||
wandb | ||
wandb |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#!/usr/bin/env bash | ||
|
||
set -e -x | ||
|
||
python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \ | ||
data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \ | ||
data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \ | ||
actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \ | ||
actor_rollout_ref.rollout.name=vllm \ | ||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ | ||
actor_rollout_ref.model.tokenizer_path=tests/e2e/arithmetic_sequence/model \ | ||
critic.model.path=Qwen/Qwen2.5-0.5B \ | ||
critic.model.use_remove_padding=True \ | ||
trainer.total_epochs=1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForTokenClassification, AutoTokenizer | ||
|
||
import torch | ||
from verl.utils.model import create_random_mask, compute_position_id_with_mask | ||
from verl.utils.torch_functional import masked_mean, log_probs_from_logits_all_rmpad, logprobs_from_logits | ||
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange | ||
|
||
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config | ||
# TODO(sgm): add more models for test | ||
# we only need one scale for each model | ||
test_configs = [ | ||
LlamaConfig(num_hidden_layers=1), | ||
MistralConfig(num_hidden_layers=1), | ||
GemmaConfig(num_hidden_layers=1), | ||
Qwen2Config(num_hidden_layers=1) | ||
] | ||
# test_cases = ['deepseek-ai/deepseek-llm-7b-chat', 'Qwen/Qwen2-7B-Instruct'] | ||
|
||
|
||
def test_hf_casual_models(): | ||
batch_size = 4 | ||
seqlen = 128 | ||
response_length = 127 | ||
|
||
for config in test_configs: | ||
# config = AutoConfig.from_pretrained(test_case) | ||
with torch.device('cuda'): | ||
model = AutoModelForCausalLM.from_config(config=config, | ||
torch_dtype=torch.bfloat16, | ||
attn_implementation='flash_attention_2') | ||
model = model.to(device='cuda') | ||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda') | ||
attention_mask = create_random_mask(input_ids=input_ids, | ||
max_ratio_of_left_padding=0.1, | ||
max_ratio_of_valid_token=0.8, | ||
min_ratio_of_valid_token=0.5) | ||
position_ids = compute_position_id_with_mask( | ||
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here | ||
|
||
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( | ||
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) | ||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) | ||
|
||
# unpad the position_ids to align the rotary | ||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), | ||
indices).transpose(0, 1) | ||
|
||
# input with input_ids_rmpad and postition_ids to enable flash attention varlen | ||
logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad, | ||
use_cache=False).logits # (1, total_nnz, vocab_size) | ||
|
||
origin_logits = model(input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
use_cache=False).logits | ||
origin_logits_rmpad, origin_logits_indices, _, _ = unpad_input(origin_logits, attention_mask) | ||
|
||
logits_rmpad = logits_rmpad.squeeze(0) | ||
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad, | ||
logits_rmpad=logits_rmpad, | ||
indices=indices, | ||
batch_size=batch_size, | ||
seqlen=seqlen, | ||
response_length=response_length) # (batch, seqlen) | ||
origin_log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad, | ||
logits_rmpad=origin_logits_rmpad, | ||
indices=origin_logits_indices, | ||
batch_size=batch_size, | ||
seqlen=seqlen, | ||
response_length=response_length) # (batch, seqlen) | ||
|
||
torch.testing.assert_close(masked_mean(log_probs, attention_mask[:, -response_length - 1:-1]), | ||
masked_mean(origin_log_probs, attention_mask[:, -response_length - 1:-1]), | ||
atol=1e-2, | ||
rtol=1e-5) | ||
print(f'Check pass') | ||
|
||
|
||
def test_hf_value_models(): | ||
batch_size = 4 | ||
seqlen = 128 | ||
|
||
for config in test_configs: | ||
# config = AutoConfig.from_pretrained(test_case) | ||
config.num_labels = 1 | ||
setattr(config, 'classifier_dropout', 0) | ||
setattr(config, 'hidden_dropout', 0) | ||
with torch.device('cuda'): | ||
model = AutoModelForTokenClassification.from_config(config=config, | ||
torch_dtype=torch.bfloat16, | ||
attn_implementation='flash_attention_2') | ||
model = model.to(device='cuda') | ||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda') | ||
attention_mask = create_random_mask(input_ids=input_ids, | ||
max_ratio_of_left_padding=0.1, | ||
max_ratio_of_valid_token=0.8, | ||
min_ratio_of_valid_token=0.5) | ||
position_ids = compute_position_id_with_mask( | ||
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here | ||
|
||
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input( | ||
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...) | ||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz) | ||
|
||
# unpad the position_ids to align the rotary | ||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), | ||
indices).transpose(0, 1) | ||
|
||
origin_logits = model(input_ids=input_ids, | ||
attention_mask=attention_mask, | ||
position_ids=position_ids, | ||
use_cache=False).logits | ||
|
||
# input with input_ids_rmpad and postition_ids to enable flash attention varlen | ||
rmpad_logits = model(input_ids_rmpad, position_ids=position_ids_rmpad, | ||
use_cache=False).logits # (1, total_nnz, 1) | ||
rmpad_logits = rmpad_logits.squeeze(0) | ||
pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen) | ||
|
||
torch.testing.assert_close(masked_mean(pad_logits, attention_mask[:, :, None]), | ||
masked_mean(origin_logits, attention_mask[:, :, None]), | ||
atol=1e-2, | ||
rtol=1e-5) | ||
print('Value model check pass') | ||
|
||
|
||
if __name__ == '__main__': | ||
test_hf_casual_models() | ||
test_hf_value_models() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.