Skip to content

Commit

Permalink
[misc] chore: refactor and add several metrics (#111)
Browse files Browse the repository at this point in the history
- Add format script
- Move save_checkpoint to a separate function
- Add timing/step, response_length/clip_ratio, prompt_length/clip_ratio
and critic/vf_explained_var metrics
- The training step starts from 1
  • Loading branch information
vermouth1992 authored Jan 17, 2025
1 parent ff0c7cc commit 018b0d7
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 105 deletions.
3 changes: 3 additions & 0 deletions scripts/format.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
pip3 install --upgrade yapf
yapf -ir -vv --style ./.style.yapf verl tests single_controller examples
2 changes: 1 addition & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ trainer:
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
test_freq: 2
test_freq: -1
critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
256 changes: 152 additions & 104 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,40 +164,80 @@ def compute_data_metrics(batch):
returns = batch.batch['returns']
values = batch.batch['values']

max_response_length = batch.batch['responses'].shape[-1]

prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool()
response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool()

max_prompt_length = prompt_mask.size(-1)

response_info = _compute_response_info(batch)
response_mask = response_info['response_mask']
prompt_length = response_info['prompt_length']
response_length = response_info['response_length']

valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
valid_values = torch.masked_select(values, response_mask)

return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)

metrics = {
# score
'critic/score/mean': torch.mean(sequence_score).detach().item(),
'critic/score/max': torch.max(sequence_score).detach().item(),
'critic/score/min': torch.min(sequence_score).detach().item(),
'critic/score/mean':
torch.mean(sequence_score).detach().item(),
'critic/score/max':
torch.max(sequence_score).detach().item(),
'critic/score/min':
torch.min(sequence_score).detach().item(),
# reward
'critic/rewards/mean': torch.mean(sequence_reward).detach().item(),
'critic/rewards/max': torch.max(sequence_reward).detach().item(),
'critic/rewards/min': torch.min(sequence_reward).detach().item(),
'critic/rewards/mean':
torch.mean(sequence_reward).detach().item(),
'critic/rewards/max':
torch.max(sequence_reward).detach().item(),
'critic/rewards/min':
torch.min(sequence_reward).detach().item(),
# adv
'critic/advantages/mean': masked_mean(advantages, response_mask).detach().item(),
'critic/advantages/max': torch.max(advantages[response_mask.bool()]).detach().item(),
'critic/advantages/min': torch.min(advantages[response_mask.bool()]).detach().item(),
'critic/advantages/mean':
torch.mean(valid_adv).detach().item(),
'critic/advantages/max':
torch.max(valid_adv).detach().item(),
'critic/advantages/min':
torch.min(valid_adv).detach().item(),
# returns
'critic/returns/mean': masked_mean(returns, response_mask).detach().item(),
'critic/returns/max': torch.max(returns[response_mask.bool()]).detach().item(),
'critic/returns/min': torch.min(returns[response_mask.bool()]).detach().item(),
'critic/returns/mean':
torch.mean(valid_returns).detach().item(),
'critic/returns/max':
torch.max(valid_returns).detach().item(),
'critic/returns/min':
torch.min(valid_returns).detach().item(),
# values
'critic/values/mean': masked_mean(values, response_mask).detach().item(),
'critic/values/max': torch.max(values[response_mask.bool()]).detach().item(),
'critic/values/min': torch.min(values[response_mask.bool()]).detach().item(),
'critic/values/mean':
torch.mean(valid_values).detach().item(),
'critic/values/max':
torch.max(valid_values).detach().item(),
'critic/values/min':
torch.min(valid_values).detach().item(),
# response length
'response_length/mean': torch.mean(response_length).detach().item(),
'response_length/max': torch.max(response_length).detach().item(),
'response_length/min': torch.min(response_length).detach().item(),
'response_length/mean':
torch.mean(response_length).detach().item(),
'response_length/max':
torch.max(response_length).detach().item(),
'response_length/min':
torch.min(response_length).detach().item(),
'response_length/clip_ratio':
torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
# prompt length
'prompt_length/mean': torch.mean(prompt_length).detach().item(),
'prompt_length/max': torch.max(prompt_length).detach().item(),
'prompt_length/min': torch.min(prompt_length).detach().item(),
'prompt_length/mean':
torch.mean(prompt_length).detach().item(),
'prompt_length/max':
torch.max(prompt_length).detach().item(),
'prompt_length/min':
torch.min(prompt_length).detach().item(),
'prompt_length/clip_ratio':
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
# vf explained var
'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
return metrics

Expand All @@ -217,10 +257,10 @@ def compute_timing_metrics(batch, timing_raw):

return {
**{
f'timing/{name}': value for name, value in timing_raw.items()
f'timing(s)/{name}': value for name, value in timing_raw.items()
},
**{
f'timing_per_token/{name}': timing_raw[name] / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
f'timing_per_token(ms)/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
)) & set(timing_raw.keys())
},
}
Expand Down Expand Up @@ -382,7 +422,7 @@ def _validate(self):

metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f'test_score/{data_source}'] = np.mean(rewards)
metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards)

return metric_dict

Expand Down Expand Up @@ -457,6 +497,20 @@ def init_workers(self):
self.actor_rollout_wg = all_wg['actor_rollout']
self.actor_rollout_wg.init_model()

def _save_checkpoint(self):
actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
f'global_step_{self.global_steps}')
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)

if self.use_critic:
critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
f'global_step_{self.global_steps}')
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)

def fit(self):
"""
The training loop of PPO.
Expand All @@ -482,6 +536,9 @@ def fit(self):
if self.config.trainer.get('val_only', False):
return

# we start from step 1
self.global_steps += 1

for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
Expand All @@ -493,71 +550,76 @@ def fit(self):
# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])

# generate a batch
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)

# compute values
with _timer('values', timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)

with _timer('adv', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor

# compute rewards. apply_kl_penalty if available
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)

# compute advantages, executed on the driver process
batch = compute_advantage(batch,
self.config.algorithm.gamma,
self.config.algorithm.lam,
adv_estimator=self.config.algorithm.adv_estimator)

# update critic
if self.use_critic:
with _timer('update_critic', timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)

# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)

# validate
if self.val_reward_fn is not None and (self.global_steps + 1) % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}
metrics.update(val_metrics)
with _timer('step', timing_raw):
# generate a batch
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)

# compute values
with _timer('values', timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)

with _timer('adv', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor

# compute rewards. apply_kl_penalty if available
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)

# compute advantages, executed on the driver process
batch = compute_advantage(batch,
self.config.algorithm.gamma,
self.config.algorithm.lam,
adv_estimator=self.config.algorithm.adv_estimator)

# update critic
if self.use_critic:
with _timer('update_critic', timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)

# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)

# validate
if self.val_reward_fn is not None and self.global_steps % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
metrics.update(val_metrics)

if self.config.trainer.save_freq > 0 and \
self.global_steps % self.config.trainer.save_freq == 0:
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()

# collect metrics
metrics.update(compute_data_metrics(batch=batch))
Expand All @@ -566,20 +628,6 @@ def fit(self):
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)

if self.config.trainer.save_freq > 0 and (self.global_steps + 1) % self.config.trainer.save_freq == 0:
actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
f'global_step_{self.global_steps}')
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)

if self.use_critic:
critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
f'global_step_{self.global_steps}')
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)

self.global_steps += 1

if self.global_steps >= self.total_training_steps:
Expand Down

0 comments on commit 018b0d7

Please sign in to comment.