diff --git a/verl/trainer/ppo/core_algos.py b/verl/trainer/ppo/core_algos.py index b9790be6..cf4ce7fe 100644 --- a/verl/trainer/ppo/core_algos.py +++ b/verl/trainer/ppo/core_algos.py @@ -106,8 +106,9 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc advantages = verl_F.masked_whiten(advantages, eos_mask) return advantages, returns + # NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar. -def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor, +def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor, eos_mask: torch.Tensor, index: torch.Tensor, epsilon: float = 1e-6): @@ -258,14 +259,14 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe if kl_penalty == "mse": return 0.5 * (logprob - ref_logprob).square() - # J. Schulman. Approximating kl divergence, 2020. + # J. Schulman. Approximating kl divergence, 2020. # # URL http://joschu.net/blog/kl-approx.html. if kl_penalty == 'low_var_kl': kl = ref_logprob - logprob ratio = torch.exp(kl) kld = (ratio - kl - 1).contiguous() return torch.clamp(kld, min=-10, max=10) - + if kl_penalty == "full": # so, here logprob and ref_logprob should contain the logits for every token in vocabulary raise NotImplementedError diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 68ded805..331757a8 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -113,7 +113,6 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController, def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1): - from verl.protocol import fold_batch_dim, unfold_batch_dim # prepare response group # TODO: add other ways to estimate advantages if adv_estimator == 'gae': @@ -143,7 +142,6 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re index=index) data.batch['advantages'] = advantages data.batch['returns'] = returns - data = unfold_batch_dim(data, batch_dims=2) else: raise NotImplementedError return data @@ -192,7 +190,7 @@ def compute_data_metrics(batch, use_critic=True): valid_adv = torch.masked_select(advantages, response_mask) valid_returns = torch.masked_select(returns, response_mask) - + if use_critic: values = batch.batch['values'] valid_values = torch.masked_select(values, response_mask) @@ -222,19 +220,20 @@ def compute_data_metrics(batch, use_critic=True): 'critic/advantages/min': torch.min(valid_adv).detach().item(), # returns - 'critic/returns/mean': torch.mean(valid_returns).detach().item(), - 'critic/returns/max': torch.max(valid_returns).detach().item(), - 'critic/returns/min': torch.min(valid_returns).detach().item(), - + 'critic/returns/mean': + torch.mean(valid_returns).detach().item(), + 'critic/returns/max': + torch.max(valid_returns).detach().item(), + 'critic/returns/min': + torch.min(valid_returns).detach().item(), **({ - # values - 'critic/values/mean': torch.mean(valid_values).detach().item(), - 'critic/values/max': torch.max(valid_values).detach().item(), - 'critic/values/min': torch.min(valid_values).detach().item(), - # vf explained var - 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), - } - if use_critic else {}), + # values + 'critic/values/mean': torch.mean(valid_values).detach().item(), + 'critic/values/max': torch.max(valid_values).detach().item(), + 'critic/values/min': torch.min(valid_values).detach().item(), + # vf explained var + 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), + } if use_critic else {}), # response length 'response_length/mean': @@ -628,8 +627,8 @@ def fit(self): # compute rewards. apply_kl_penalty if available if not self.config.actor_rollout_ref.actor.use_kl_loss: batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty) + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) metrics.update(kl_metrics) else: batch.batch['token_level_rewards'] = batch.batch['token_level_scores'] diff --git a/verl/utils/dataset/rl_dataset.py b/verl/utils/dataset/rl_dataset.py index 543087ef..48328d7a 100644 --- a/verl/utils/dataset/rl_dataset.py +++ b/verl/utils/dataset/rl_dataset.py @@ -141,9 +141,9 @@ def __getitem__(self, item): # encode prompts without chat template if self.return_raw_chat: row_dict['raw_prompt'] = chat.tolist() - + # add index for each prompt index = row_dict.get("extra_info", {}).get("index", 0) row_dict["index"] = index - return row_dict \ No newline at end of file + return row_dict