diff --git a/verl/utils/torch_functional.py b/verl/utils/torch_functional.py index f39407c4..afa3f155 100644 --- a/verl/utils/torch_functional.py +++ b/verl/utils/torch_functional.py @@ -325,8 +325,13 @@ def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad return output -def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices=None, batch_size=None, - seqlen=None, response_length=None, pad=True): +def log_probs_from_logits_all_rmpad(input_ids_rmpad, + logits_rmpad, + indices=None, + batch_size=None, + seqlen=None, + response_length=None, + pad=True): """Compute the log_probs from logits with rmpad input_ids and logits. Note that logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between logits and input_ids. @@ -355,7 +360,7 @@ def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices=None, output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length] return output else: - return full_log_probs_rmpad # (total_nnz,) + return full_log_probs_rmpad # (total_nnz,) from transformers.generation.logits_process import (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper) diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index 01d329f7..9e8a56ba 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -87,53 +87,53 @@ def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, logits_rmpad /= temperature # compute entropy - entropy_rmpad = verl_F.entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) - + entropy_rmpad = verl_F.entropy_from_logits(logits_rmpad) # ((total_nnz / sp) + pad) # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen) - log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad, - logits_rmpad=logits_rmpad, - indices=indices, - batch_size=batch_size, - seqlen=seqlen, - response_length=response_length, - pad=not self.use_ulysses_sp, - ) + log_probs = log_probs_from_logits_all_rmpad( + input_ids_rmpad=input_ids_rmpad, + logits_rmpad=logits_rmpad, + indices=indices, + batch_size=batch_size, + seqlen=seqlen, + response_length=response_length, + pad=not self.use_ulysses_sp, + ) # gather log_prob if sp > 1 if self.use_ulysses_sp: # gather and unpad for the ulysses sp full_log_probs_rmpad = gather_outpus_and_unpad(log_probs, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size) + gather_dim=0, + unpad_dim=0, + padding_size=pad_size) full_entropy_rmpad = gather_outpus_and_unpad(entropy_rmpad, - gather_dim=0, - unpad_dim=0, - padding_size=pad_size) + gather_dim=0, + unpad_dim=0, + padding_size=pad_size) # pad back to (bsz, seqlen) full_entropy = pad_input(hidden_states=full_entropy_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen) + indices=indices, + batch=batch_size, + seqlen=seqlen) full_log_probs = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1), - indices=indices, - batch=batch_size, - seqlen=seqlen) + indices=indices, + batch=batch_size, + seqlen=seqlen) # only return response part: - entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) - log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) + entropy = full_entropy.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) + log_probs = full_log_probs.squeeze(-1)[:, -response_length - 1:-1] # (bsz, response_length) - else: # not using rmpad and no ulysses sp + else: # not using rmpad and no ulysses sp output = self.actor_module(input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False) # prevent model thinks we are generating logits = output.logits / temperature - logits = logits[:, -response_length - 1:-1] # (bsz, response_length) + logits = logits[:, -response_length - 1:-1] # (bsz, response_length) log_probs = logprobs_from_logits(logits, micro_batch['responses']) - entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) - + entropy = verl_F.entropy_from_logits(logits) # (bsz, response_length) + return entropy, log_probs def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]: diff --git a/verl/workers/sharding_manager/fsdp_vllm.py b/verl/workers/sharding_manager/fsdp_vllm.py index c651a29a..19490f4e 100644 --- a/verl/workers/sharding_manager/fsdp_vllm.py +++ b/verl/workers/sharding_manager/fsdp_vllm.py @@ -33,8 +33,12 @@ class FSDPVLLMShardingManager(BaseShardingManager): - def __init__(self, module: FSDP, inference_engine: LLM, model_config, full_params: bool = False, - device_mesh: DeviceMesh=None): + def __init__(self, + module: FSDP, + inference_engine: LLM, + model_config, + full_params: bool = False, + device_mesh: DeviceMesh = None): self.module = module self.inference_engine = inference_engine self.model_config = model_config @@ -50,7 +54,7 @@ def __init__(self, module: FSDP, inference_engine: LLM, model_config, full_param FSDP.set_state_dict_type(self.module, state_dict_type=StateDictType.SHARDED_STATE_DICT, state_dict_config=ShardedStateDictConfig()) - + # Note that torch_random_states may be different on each dp rank self.torch_random_states = torch.cuda.get_rng_state() # get a random rng states @@ -80,7 +84,7 @@ def __enter__(self): # torch.cuda.empty_cache() # if torch.distributed.get_rank() == 0: # print(f'after model to cpu in sharding manager memory allocated: {torch.cuda.memory_allocated() / 1e9}GB, reserved: {torch.cuda.memory_reserved() / 1e9}GB') - + # important: need to manually set the random states of each tp to be identical. if self.device_mesh is not None: self.torch_random_states = torch.cuda.get_rng_state()