Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterSH6 committed Jan 23, 2025
1 parent f2e43e8 commit 557af54
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
7 changes: 4 additions & 3 deletions verl/trainer/ppo/core_algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def compute_gae_advantage_return(token_level_rewards: torch.Tensor, values: torc
advantages = verl_F.masked_whiten(advantages, eos_mask)
return advantages, returns


# NOTE(sgm): this implementation only consider outcome supervision, where the reward is a scalar.
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
def compute_grpo_outcome_advantage(token_level_rewards: torch.Tensor,
eos_mask: torch.Tensor,
index: torch.Tensor,
epsilon: float = 1e-6):
Expand Down Expand Up @@ -258,14 +259,14 @@ def kl_penalty(logprob: torch.FloatTensor, ref_logprob: torch.FloatTensor, kl_pe
if kl_penalty == "mse":
return 0.5 * (logprob - ref_logprob).square()

# J. Schulman. Approximating kl divergence, 2020.
# J. Schulman. Approximating kl divergence, 2020.
# # URL http://joschu.net/blog/kl-approx.html.
if kl_penalty == 'low_var_kl':
kl = ref_logprob - logprob
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
return torch.clamp(kld, min=-10, max=10)

if kl_penalty == "full":
# so, here logprob and ref_logprob should contain the logits for every token in vocabulary
raise NotImplementedError
Expand Down
33 changes: 16 additions & 17 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ def apply_kl_penalty(data: DataProto, kl_ctrl: core_algos.AdaptiveKLController,


def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_repeat=1):
from verl.protocol import fold_batch_dim, unfold_batch_dim
# prepare response group
# TODO: add other ways to estimate advantages
if adv_estimator == 'gae':
Expand Down Expand Up @@ -143,7 +142,6 @@ def compute_advantage(data: DataProto, adv_estimator, gamma=1.0, lam=1.0, num_re
index=index)
data.batch['advantages'] = advantages
data.batch['returns'] = returns
data = unfold_batch_dim(data, batch_dims=2)
else:
raise NotImplementedError
return data
Expand Down Expand Up @@ -192,7 +190,7 @@ def compute_data_metrics(batch, use_critic=True):

valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)

if use_critic:
values = batch.batch['values']
valid_values = torch.masked_select(values, response_mask)
Expand Down Expand Up @@ -222,19 +220,20 @@ def compute_data_metrics(batch, use_critic=True):
'critic/advantages/min':
torch.min(valid_adv).detach().item(),
# returns
'critic/returns/mean': torch.mean(valid_returns).detach().item(),
'critic/returns/max': torch.max(valid_returns).detach().item(),
'critic/returns/min': torch.min(valid_returns).detach().item(),

'critic/returns/mean':
torch.mean(valid_returns).detach().item(),
'critic/returns/max':
torch.max(valid_returns).detach().item(),
'critic/returns/min':
torch.min(valid_returns).detach().item(),
**({
# values
'critic/values/mean': torch.mean(valid_values).detach().item(),
'critic/values/max': torch.max(valid_values).detach().item(),
'critic/values/min': torch.min(valid_values).detach().item(),
# vf explained var
'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
if use_critic else {}),
# values
'critic/values/mean': torch.mean(valid_values).detach().item(),
'critic/values/max': torch.max(valid_values).detach().item(),
'critic/values/min': torch.min(valid_values).detach().item(),
# vf explained var
'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
} if use_critic else {}),

# response length
'response_length/mean':
Expand Down Expand Up @@ -628,8 +627,8 @@ def fit(self):
# compute rewards. apply_kl_penalty if available
if not self.config.actor_rollout_ref.actor.use_kl_loss:
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)
else:
batch.batch['token_level_rewards'] = batch.batch['token_level_scores']
Expand Down
4 changes: 2 additions & 2 deletions verl/utils/dataset/rl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,9 @@ def __getitem__(self, item):
# encode prompts without chat template
if self.return_raw_chat:
row_dict['raw_prompt'] = chat.tolist()

# add index for each prompt
index = row_dict.get("extra_info", {}).get("index", 0)
row_dict["index"] = index

return row_dict
return row_dict

0 comments on commit 557af54

Please sign in to comment.