Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate use_remove_padding when applying sequence parallelism #153

Open
wants to merge 3 commits into
base: main
Choose a base branch
from

Conversation

chujiezheng
Copy link

This is because the ulysses_sp is activated only when use_remove_padding is enabled:

if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# for compute the log_prob
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=1) # (1, total_nnz)
# pad and slice the inputs if sp > 1
if self.use_ulysses_sp:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(input_ids_rmpad, \
position_ids_rmpad, \
sp_size=self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled, _, _ = ulysses_pad_and_slice_inputs(input_ids_rmpad_rolled, None,
self.ulysses_sequence_parallel_size)
input_ids_rmpad_rolled = input_ids_rmpad_rolled.squeeze(0) # ((total_nnz / sp) + pad)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.actor_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad.div_(temperature)
# compute entropy
entropy_rmpad = self.compute_entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)
# if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
log_probs = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled)
# gather log_prob if sp > 1
if self.use_ulysses_sp:
# gather and unpad for the ulysses sp
log_probs = gather_outpus_and_unpad(log_probs, gather_dim=0, unpad_dim=0, padding_size=pad_size)
entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
# pad back to (bsz, seqlen)
full_entropy = pad_input(hidden_states=entropy_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
full_log_probs = pad_input(hidden_states=log_probs.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
# only return response part:
entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)

Without this check, users may encounter OOM issues when the set sp_size > 1 but use_remove_padding is mistakenly disabled.

@vermouth1992
Copy link
Collaborator

Could you format the code by running bash script/format.sh

@chujiezheng
Copy link
Author

Done!

@@ -86,4 +86,13 @@ def check_mutually_exclusive(mbs, mbs_per_gpu, name: str):
assert config.critic.ppo_mini_batch_size % config.critic.ppo_micro_batch_size == 0
assert config.critic.ppo_micro_batch_size * sp_size >= n_gpus

# Check if use_remove_padding is enabled when using sequence parallelism
if config.actor_rollout_ref.actor.ulysses_sequence_parallel_size > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that the correct key is config. actor_rollout_ref.model.use_remove_padding, critic.model.use_remove_padding and reward_model.model.use_remove_padding

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed now!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants