Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Jan 17, 2025
1 parent 444ffe6 commit 68a764a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 35 deletions.
11 changes: 8 additions & 3 deletions verl/utils/torch_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,8 +325,13 @@ def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad
return output


def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices=None, batch_size=None,
seqlen=None, response_length=None, pad=True):
def log_probs_from_logits_all_rmpad(input_ids_rmpad,
logits_rmpad,
indices=None,
batch_size=None,
seqlen=None,
response_length=None,
pad=True):
"""Compute the log_probs from logits with rmpad input_ids and logits. Note that
logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between
logits and input_ids.
Expand Down Expand Up @@ -355,7 +360,7 @@ def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices=None,
output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length]
return output
else:
return full_log_probs_rmpad # (total_nnz,)
return full_log_probs_rmpad # (total_nnz,)


from transformers.generation.logits_process import (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper)
Expand Down
56 changes: 28 additions & 28 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,53 +87,53 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor,
logits_rmpad /= temperature

# compute entropy
entropy_rmpad = verl_F.entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)

entropy_rmpad = verl_F.entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad)

# if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
logits_rmpad=logits_rmpad,
indices=indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length,
pad=not self.use_ulysses_sp,
)
log_probs = log_probs_from_logits_all_rmpad(
input_ids_rmpad=input_ids_rmpad,
logits_rmpad=logits_rmpad,
indices=indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length,
pad=not self.use_ulysses_sp,
)

# gather log_prob if sp > 1
if self.use_ulysses_sp:
# gather and unpad for the ulysses sp
full_log_probs_rmpad = gather_outpus_and_unpad(log_probs,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
full_entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad,
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
gather_dim=0,
unpad_dim=0,
padding_size=pad_size)
# pad back to (bsz, seqlen)
full_entropy = pad_input(hidden_states=full_entropy_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
indices=indices,
batch=batch_size,
seqlen=seqlen)
full_log_probs = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
indices=indices,
batch=batch_size,
seqlen=seqlen)
# only return response part:
entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length)

else: # not using rmpad and no ulysses sp
else: # not using rmpad and no ulysses sp
output = self.actor_module(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False) # prevent model thinks we are generating
logits = output.logits / temperature
logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
logits = logits[:, -response_length - 1:-1] # (bsz, response_length)
log_probs = logprobs_from_logits(logits, micro_batch['responses'])
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)
entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length)

return entropy, log_probs

def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
Expand Down
12 changes: 8 additions & 4 deletions verl/workers/sharding_manager/fsdp_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@

class FSDPVLLMShardingManager(BaseShardingManager):

def __init__(self, module: FSDP, inference_engine: LLM, model_config, full_params: bool = False,
device_mesh: DeviceMesh=None):
def __init__(self,
module: FSDP,
inference_engine: LLM,
model_config,
full_params: bool = False,
device_mesh: DeviceMesh = None):
self.module = module
self.inference_engine = inference_engine
self.model_config = model_config
Expand All @@ -50,7 +54,7 @@ def __init__(self, module: FSDP, inference_engine: LLM, model_config, full_param
FSDP.set_state_dict_type(self.module,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig())

# Note that torch_random_states may be different on each dp rank
self.torch_random_states = torch.cuda.get_rng_state()
# get a random rng states
Expand Down Expand Up @@ -80,7 +84,7 @@ def __enter__(self):
# torch.cuda.empty_cache()
# if torch.distributed.get_rank() == 0:
# print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB')

# important: need to manually set the random states of each tp to be identical.
if self.device_mesh is not None:
self.torch_random_states = torch.cuda.get_rng_state()
Expand Down

0 comments on commit 68a764a

Please sign in to comment.