Skip to content

Commit

Permalink
support different flash_attn version with variable num returns
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Jan 12, 2025
1 parent 3563f24 commit f855ea3
Show file tree
Hide file tree
Showing 7 changed files with 13 additions and 15 deletions.
10 changes: 4 additions & 6 deletions tests/model/test_transformer.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForTokenClassification, AutoTokenizer

import torch
from verl.utils.model import create_random_mask, compute_position_id_with_mask
from verl.utils.torch_functional import masked_mean, log_probs_from_logits_all_rmpad, logprobs_from_logits
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange

from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, AutoModelForSequenceClassification
# TODO(sgm): add more models for test
# we only need one scale for each model
test_configs = [
Expand All @@ -14,7 +13,6 @@
GemmaConfig(num_hidden_layers=1),
Qwen2Config(num_hidden_layers=1)
]
# test_cases = ['deepseek-ai/deepseek-llm-7b-chat', 'Qwen/Qwen2-7B-Instruct']


def test_hf_casual_models():
Expand All @@ -37,7 +35,7 @@ def test_hf_casual_models():
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here

input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

Expand All @@ -53,7 +51,7 @@ def test_hf_casual_models():
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False).logits
origin_logits_rmpad, origin_logits_indices, _, _ = unpad_input(origin_logits, attention_mask)
origin_logits_rmpad, origin_logits_indices, *_ = unpad_input(origin_logits, attention_mask)

logits_rmpad = logits_rmpad.squeeze(0)
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
Expand Down Expand Up @@ -98,7 +96,7 @@ def test_hf_value_models():
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here

input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

Expand Down
4 changes: 2 additions & 2 deletions verl/models/llama/megatron/modeling_llama_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def forward(
batch_size, sequence_length = input_ids.shape

# remove padding here
input_ids, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(dim=-1),
input_ids, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)

# pad input_ids to multiple of tp for all tp ranks
Expand Down Expand Up @@ -581,7 +581,7 @@ def forward(
# In the first pp, input_ids will be used, in other pp layers hidden_states will be used inside self.model
batch_size, sequence_length = input_ids.shape
# remove padding here
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(dim=-1),
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch, *_ = unpad_input(input_ids.unsqueeze(dim=-1),
attention_mask) # (total_nnz, 1)

# pad input_ids to multiple of tp for all tp ranks
Expand Down
2 changes: 1 addition & 1 deletion verl/utils/megatron/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def vocab_parallel_log_probs_from_logits_response_rmpad(input_ids, attention_mas
from flash_attn.bert_padding import pad_input, unpad_input

batch_size, seqlen = input_ids.shape
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(-1),
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask=attention_mask)
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
Expand Down
4 changes: 2 additions & 2 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,8 +313,8 @@ def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad
from flash_attn.bert_padding import pad_input, unpad_input

batch_size, seqlen = input_ids.shape
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(input_ids.unsqueeze(-1),
attention_mask=attention_mask)
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask=attention_mask)
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,)
Expand Down
4 changes: 2 additions & 2 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor,
position_ids = micro_batch['position_ids']

if self.use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

Expand Down Expand Up @@ -181,7 +181,7 @@ def update_policy(self, data: DataProto):
if self.use_remove_padding:
full_response_mask = attention_mask.clone()
full_response_mask[:, :-response_length] = 0 # set the prompt part to zero
full_response_mask_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
full_response_mask_rmpad, *_ = unpad_input(
full_response_mask.unsqueeze(-1), attention_mask=attention_mask)
full_response_mask_rmpad = full_response_mask_rmpad.squeeze(-1) # (total_nnz)
entropy_loss = core_algos.compute_entropy_loss(logits, full_response_mask_rmpad) # (total_nnz,)
Expand Down
2 changes: 1 addition & 1 deletion verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _forward_micro_batch(self, micro_batch):
position_ids = micro_batch['position_ids']

if self.use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

Expand Down
2 changes: 1 addition & 1 deletion verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def _forward_micro_batch(self, micro_batch):
position_ids = micro_batch['position_ids']

if self.use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)

Expand Down

0 comments on commit f855ea3

Please sign in to comment.