From c413414f0227c293c354590c7df717c453abcc2c Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Fri, 17 Jan 2025 18:41:14 +0800 Subject: [PATCH 1/4] add critical metrics --- verl/trainer/config/ppo_trainer.yaml | 2 +- verl/trainer/ppo/ray_trainer.py | 205 +++++++++++++++------------ 2 files changed, 114 insertions(+), 93 deletions(-) diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 5000592e..57ce6da9 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -132,7 +132,7 @@ trainer: nnodes: 1 n_gpus_per_node: 8 save_freq: -1 - test_freq: 2 + test_freq: -1 critic_warmup: 0 default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name} default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name} diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 05a54883..0fc1959c 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -164,11 +164,23 @@ def compute_data_metrics(batch): returns = batch.batch['returns'] values = batch.batch['values'] - response_info = _compute_response_info(batch) - response_mask = response_info['response_mask'] + prompt_mask = batch.batch['attention_mask'][:, :-response_length].bool() + response_mask = batch.batch['attention_mask'][:, -response_length:].bool() + + max_prompt_length = prompt_mask.size(-1) + max_response_length = response_mask.shape[-1] + + response_info = _compute_response_info(batch) prompt_length = response_info['prompt_length'] response_length = response_info['response_length'] + valid_adv = torch.masked_select(advantages, response_mask) + valid_returns = torch.masked_select(returns, response_mask) + valid_values = torch.masked_select(values, response_mask) + + return_diff_var = torch.var(valid_returns - valid_values) + return_var = torch.var(valid_returns) + metrics = { # score 'critic/score/mean': torch.mean(sequence_score).detach().item(), @@ -179,25 +191,29 @@ def compute_data_metrics(batch): 'critic/rewards/max': torch.max(sequence_reward).detach().item(), 'critic/rewards/min': torch.min(sequence_reward).detach().item(), # adv - 'critic/advantages/mean': masked_mean(advantages, response_mask).detach().item(), - 'critic/advantages/max': torch.max(advantages[response_mask.bool()]).detach().item(), - 'critic/advantages/min': torch.min(advantages[response_mask.bool()]).detach().item(), + 'critic/advantages/mean': torch.mean(valid_adv).detach().item(), + 'critic/advantages/max': torch.max(valid_adv).detach().item(), + 'critic/advantages/min': torch.min(valid_adv).detach().item(), # returns - 'critic/returns/mean': masked_mean(returns, response_mask).detach().item(), - 'critic/returns/max': torch.max(returns[response_mask.bool()]).detach().item(), - 'critic/returns/min': torch.min(returns[response_mask.bool()]).detach().item(), + 'critic/returns/mean': torch.mean(valid_returns).detach().item(), + 'critic/returns/max': torch.max(valid_returns).detach().item(), + 'critic/returns/min': torch.min(valid_returns).detach().item(), # values - 'critic/values/mean': masked_mean(values, response_mask).detach().item(), - 'critic/values/max': torch.max(values[response_mask.bool()]).detach().item(), - 'critic/values/min': torch.min(values[response_mask.bool()]).detach().item(), + 'critic/values/mean': torch.mean(valid_values).detach().item(), + 'critic/values/max': torch.max(valid_values).detach().item(), + 'critic/values/min': torch.min(valid_values).detach().item(), # response length 'response_length/mean': torch.mean(response_length).detach().item(), 'response_length/max': torch.max(response_length).detach().item(), 'response_length/min': torch.min(response_length).detach().item(), + 'response_length/clip_ratio': torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), # prompt length 'prompt_length/mean': torch.mean(prompt_length).detach().item(), 'prompt_length/max': torch.max(prompt_length).detach().item(), 'prompt_length/min': torch.min(prompt_length).detach().item(), + 'prompt_length/clip_ratio': torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + # vf explained var + 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), } return metrics @@ -217,10 +233,10 @@ def compute_timing_metrics(batch, timing_raw): return { **{ - f'timing/{name}': value for name, value in timing_raw.items() + f'timing(s)/{name}': value for name, value in timing_raw.items() }, **{ - f'timing_per_token/{name}': timing_raw[name] / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys( + f'timing_per_token(ms)/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys( )) & set(timing_raw.keys()) }, } @@ -457,6 +473,20 @@ def init_workers(self): self.actor_rollout_wg = all_wg['actor_rollout'] self.actor_rollout_wg.init_model() + def _save_checkpoint(self): + actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor', + f'global_step_{self.global_steps}') + actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( + self.config.trainer.default_hdfs_dir, 'actor') + self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path) + + if self.use_critic: + critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic', + f'global_step_{self.global_steps}') + critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( + self.config.trainer.default_hdfs_dir, 'critic') + self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path) + def fit(self): """ The training loop of PPO. @@ -493,71 +523,76 @@ def fit(self): # pop those keys for generation gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) - # generate a batch - with _timer('gen', timing_raw): - gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - - # repeat to align with repeated responses in rollout - batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) - batch = batch.union(gen_batch_output) - - if self.use_reference_policy: - # compute reference log_prob - with _timer('ref', timing_raw): - ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) - batch = batch.union(ref_log_prob) - - # compute values - with _timer('values', timing_raw): - values = self.critic_wg.compute_values(batch) - batch = batch.union(values) - - with _timer('adv', timing_raw): - # compute scores. Support both model and function-based. - # We first compute the scores using reward model. Then, we call reward_fn to combine - # the results from reward model and rule-based results. - if self.use_rm: - # we first compute reward model score - reward_tensor = self.rm_wg.compute_rm_score(batch) - batch = batch.union(reward_tensor) - - # we combine with rule-based rm - reward_tensor = self.reward_fn(batch) - batch.batch['token_level_scores'] = reward_tensor - - # compute rewards. apply_kl_penalty if available - batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty) - metrics.update(kl_metrics) - - # compute advantages, executed on the driver process - batch = compute_advantage(batch, - self.config.algorithm.gamma, - self.config.algorithm.lam, - adv_estimator=self.config.algorithm.adv_estimator) - - # update critic - if self.use_critic: - with _timer('update_critic', timing_raw): - critic_output = self.critic_wg.update_critic(batch) - critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) - metrics.update(critic_output_metrics) - - # implement critic warmup - if self.config.trainer.critic_warmup <= self.global_steps: - # update actor - with _timer('update_actor', timing_raw): - actor_output = self.actor_rollout_wg.update_actor(batch) - actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) - metrics.update(actor_output_metrics) - - # validate - if self.val_reward_fn is not None and (self.global_steps + 1) % self.config.trainer.test_freq == 0: - with _timer('testing', timing_raw): - val_metrics: dict = self._validate() - val_metrics = {f'val/{key}': val for key, val in val_metrics.items()} - metrics.update(val_metrics) + with _timer('step', timing_raw): + # generate a batch + with _timer('gen', timing_raw): + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if self.use_reference_policy: + # compute reference log_prob + with _timer('ref', timing_raw): + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + with _timer('values', timing_raw): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with _timer('adv', timing_raw): + # compute scores. Support both model and function-based. + # We first compute the scores using reward model. Then, we call reward_fn to combine + # the results from reward model and rule-based results. + if self.use_rm: + # we first compute reward model score + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + # we combine with rule-based rm + reward_tensor = self.reward_fn(batch) + batch.batch['token_level_scores'] = reward_tensor + + # compute rewards. apply_kl_penalty if available + batch, kl_metrics = apply_kl_penalty(batch, + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) + metrics.update(kl_metrics) + + # compute advantages, executed on the driver process + batch = compute_advantage(batch, + self.config.algorithm.gamma, + self.config.algorithm.lam, + adv_estimator=self.config.algorithm.adv_estimator) + + # update critic + if self.use_critic: + with _timer('update_critic', timing_raw): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with _timer('update_actor', timing_raw): + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) + metrics.update(actor_output_metrics) + + # validate + if self.val_reward_fn is not None and (self.global_steps + 1) % self.config.trainer.test_freq == 0: + with _timer('testing', timing_raw): + val_metrics: dict = self._validate() + val_metrics = {f'val/{key}': val for key, val in val_metrics.items()} + metrics.update(val_metrics) + + if self.config.trainer.save_freq > 0 and (self.global_steps + 1) % self.config.trainer.save_freq == 0: + with _timer('save_checkpoint', timing_raw): + self._save_checkpoint() # collect metrics metrics.update(compute_data_metrics(batch=batch)) @@ -566,20 +601,6 @@ def fit(self): # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=self.global_steps) - if self.config.trainer.save_freq > 0 and (self.global_steps + 1) % self.config.trainer.save_freq == 0: - actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor', - f'global_step_{self.global_steps}') - actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, 'actor') - self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path) - - if self.use_critic: - critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic', - f'global_step_{self.global_steps}') - critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( - self.config.trainer.default_hdfs_dir, 'critic') - self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path) - self.global_steps += 1 if self.global_steps >= self.total_training_steps: From 073f96f12d366dcbdd0c474254d8c95d8e0247f7 Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Fri, 17 Jan 2025 18:46:41 +0800 Subject: [PATCH 2/4] fix format --- scripts/format.sh | 3 ++ verl/trainer/ppo/ray_trainer.py | 88 +++++++++++++++++++++------------ 2 files changed, 59 insertions(+), 32 deletions(-) create mode 100644 scripts/format.sh diff --git a/scripts/format.sh b/scripts/format.sh new file mode 100644 index 00000000..311aec74 --- /dev/null +++ b/scripts/format.sh @@ -0,0 +1,3 @@ +#!/bin/bash +pip3 install --upgrade yapf +yapf -ir -vv --style ./.style.yapf verl tests single_controller examples \ No newline at end of file diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 0fc1959c..63e469d7 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -170,7 +170,7 @@ def compute_data_metrics(batch): max_prompt_length = prompt_mask.size(-1) max_response_length = response_mask.shape[-1] - response_info = _compute_response_info(batch) + response_info = _compute_response_info(batch) prompt_length = response_info['prompt_length'] response_length = response_info['response_length'] @@ -183,35 +183,58 @@ def compute_data_metrics(batch): metrics = { # score - 'critic/score/mean': torch.mean(sequence_score).detach().item(), - 'critic/score/max': torch.max(sequence_score).detach().item(), - 'critic/score/min': torch.min(sequence_score).detach().item(), + 'critic/score/mean': + torch.mean(sequence_score).detach().item(), + 'critic/score/max': + torch.max(sequence_score).detach().item(), + 'critic/score/min': + torch.min(sequence_score).detach().item(), # reward - 'critic/rewards/mean': torch.mean(sequence_reward).detach().item(), - 'critic/rewards/max': torch.max(sequence_reward).detach().item(), - 'critic/rewards/min': torch.min(sequence_reward).detach().item(), + 'critic/rewards/mean': + torch.mean(sequence_reward).detach().item(), + 'critic/rewards/max': + torch.max(sequence_reward).detach().item(), + 'critic/rewards/min': + torch.min(sequence_reward).detach().item(), # adv - 'critic/advantages/mean': torch.mean(valid_adv).detach().item(), - 'critic/advantages/max': torch.max(valid_adv).detach().item(), - 'critic/advantages/min': torch.min(valid_adv).detach().item(), + 'critic/advantages/mean': + torch.mean(valid_adv).detach().item(), + 'critic/advantages/max': + torch.max(valid_adv).detach().item(), + 'critic/advantages/min': + torch.min(valid_adv).detach().item(), # returns - 'critic/returns/mean': torch.mean(valid_returns).detach().item(), - 'critic/returns/max': torch.max(valid_returns).detach().item(), - 'critic/returns/min': torch.min(valid_returns).detach().item(), + 'critic/returns/mean': + torch.mean(valid_returns).detach().item(), + 'critic/returns/max': + torch.max(valid_returns).detach().item(), + 'critic/returns/min': + torch.min(valid_returns).detach().item(), # values - 'critic/values/mean': torch.mean(valid_values).detach().item(), - 'critic/values/max': torch.max(valid_values).detach().item(), - 'critic/values/min': torch.min(valid_values).detach().item(), + 'critic/values/mean': + torch.mean(valid_values).detach().item(), + 'critic/values/max': + torch.max(valid_values).detach().item(), + 'critic/values/min': + torch.min(valid_values).detach().item(), # response length - 'response_length/mean': torch.mean(response_length).detach().item(), - 'response_length/max': torch.max(response_length).detach().item(), - 'response_length/min': torch.min(response_length).detach().item(), - 'response_length/clip_ratio': torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), + 'response_length/mean': + torch.mean(response_length).detach().item(), + 'response_length/max': + torch.max(response_length).detach().item(), + 'response_length/min': + torch.min(response_length).detach().item(), + 'response_length/clip_ratio': + torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(), # prompt length - 'prompt_length/mean': torch.mean(prompt_length).detach().item(), - 'prompt_length/max': torch.max(prompt_length).detach().item(), - 'prompt_length/min': torch.min(prompt_length).detach().item(), - 'prompt_length/clip_ratio': torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), + 'prompt_length/mean': + torch.mean(prompt_length).detach().item(), + 'prompt_length/max': + torch.max(prompt_length).detach().item(), + 'prompt_length/min': + torch.min(prompt_length).detach().item(), + 'prompt_length/clip_ratio': + torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(), # vf explained var 'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(), } @@ -475,14 +498,14 @@ def init_workers(self): def _save_checkpoint(self): actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor', - f'global_step_{self.global_steps}') + f'global_step_{self.global_steps}') actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( self.config.trainer.default_hdfs_dir, 'actor') self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path) if self.use_critic: critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic', - f'global_step_{self.global_steps}') + f'global_step_{self.global_steps}') critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join( self.config.trainer.default_hdfs_dir, 'critic') self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path) @@ -558,15 +581,15 @@ def fit(self): # compute rewards. apply_kl_penalty if available batch, kl_metrics = apply_kl_penalty(batch, - kl_ctrl=self.kl_ctrl, - kl_penalty=self.config.algorithm.kl_penalty) + kl_ctrl=self.kl_ctrl, + kl_penalty=self.config.algorithm.kl_penalty) metrics.update(kl_metrics) # compute advantages, executed on the driver process batch = compute_advantage(batch, - self.config.algorithm.gamma, - self.config.algorithm.lam, - adv_estimator=self.config.algorithm.adv_estimator) + self.config.algorithm.gamma, + self.config.algorithm.lam, + adv_estimator=self.config.algorithm.adv_estimator) # update critic if self.use_critic: @@ -590,7 +613,8 @@ def fit(self): val_metrics = {f'val/{key}': val for key, val in val_metrics.items()} metrics.update(val_metrics) - if self.config.trainer.save_freq > 0 and (self.global_steps + 1) % self.config.trainer.save_freq == 0: + if self.config.trainer.save_freq > 0 and \ + (self.global_steps + 1) % self.config.trainer.save_freq == 0: with _timer('save_checkpoint', timing_raw): self._save_checkpoint() From 4016569635d8ec4ef82834e9ea88c8c443242d40 Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Fri, 17 Jan 2025 19:06:24 +0800 Subject: [PATCH 3/4] fix --- verl/trainer/ppo/ray_trainer.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 63e469d7..b778a9ba 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -164,11 +164,12 @@ def compute_data_metrics(batch): returns = batch.batch['returns'] values = batch.batch['values'] - prompt_mask = batch.batch['attention_mask'][:, :-response_length].bool() - response_mask = batch.batch['attention_mask'][:, -response_length:].bool() + max_response_length = batch.batch['responses'].shape[-1] + + prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool() + response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool() max_prompt_length = prompt_mask.size(-1) - max_response_length = response_mask.shape[-1] response_info = _compute_response_info(batch) prompt_length = response_info['prompt_length'] @@ -535,6 +536,9 @@ def fit(self): if self.config.trainer.get('val_only', False): return + # we start from step 1 + self.global_steps += 1 + for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: metrics = {} @@ -607,14 +611,14 @@ def fit(self): metrics.update(actor_output_metrics) # validate - if self.val_reward_fn is not None and (self.global_steps + 1) % self.config.trainer.test_freq == 0: + if self.val_reward_fn is not None and self.global_steps % self.config.trainer.test_freq == 0: with _timer('testing', timing_raw): val_metrics: dict = self._validate() val_metrics = {f'val/{key}': val for key, val in val_metrics.items()} metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and \ - (self.global_steps + 1) % self.config.trainer.save_freq == 0: + self.global_steps % self.config.trainer.save_freq == 0: with _timer('save_checkpoint', timing_raw): self._save_checkpoint() From 2fc9b3a3905563f83fff9eaea4304235d2bdee6d Mon Sep 17 00:00:00 2001 From: "zhangchi.usc1992" Date: Fri, 17 Jan 2025 21:55:09 +0800 Subject: [PATCH 4/4] unify metrics name --- verl/trainer/ppo/ray_trainer.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index b778a9ba..a12c7250 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -422,7 +422,7 @@ def _validate(self): metric_dict = {} for data_source, rewards in data_source_reward.items(): - metric_dict[f'test_score/{data_source}'] = np.mean(rewards) + metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards) return metric_dict @@ -614,7 +614,6 @@ def fit(self): if self.val_reward_fn is not None and self.global_steps % self.config.trainer.test_freq == 0: with _timer('testing', timing_raw): val_metrics: dict = self._validate() - val_metrics = {f'val/{key}': val for key, val in val_metrics.items()} metrics.update(val_metrics) if self.config.trainer.save_freq > 0 and \