Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] chore: refactor and add several metrics #111

Merged
merged 5 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions scripts/format.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/bin/bash
pip3 install --upgrade yapf
yapf -ir -vv --style ./.style.yapf verl tests single_controller examples
2 changes: 1 addition & 1 deletion verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ trainer:
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
test_freq: 2
test_freq: -1
critic_warmup: 0
default_hdfs_dir: ~/experiments/gsm8k/ppo/${trainer.experiment_name}
default_local_dir: checkpoints/${trainer.project_name}/${trainer.experiment_name}
256 changes: 152 additions & 104 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,40 +164,80 @@ def compute_data_metrics(batch):
returns = batch.batch['returns']
values = batch.batch['values']

max_response_length = batch.batch['responses'].shape[-1]

prompt_mask = batch.batch['attention_mask'][:, :-max_response_length].bool()
response_mask = batch.batch['attention_mask'][:, -max_response_length:].bool()

max_prompt_length = prompt_mask.size(-1)

response_info = _compute_response_info(batch)
response_mask = response_info['response_mask']
prompt_length = response_info['prompt_length']
response_length = response_info['response_length']

valid_adv = torch.masked_select(advantages, response_mask)
valid_returns = torch.masked_select(returns, response_mask)
valid_values = torch.masked_select(values, response_mask)

return_diff_var = torch.var(valid_returns - valid_values)
return_var = torch.var(valid_returns)

metrics = {
# score
'critic/score/mean': torch.mean(sequence_score).detach().item(),
'critic/score/max': torch.max(sequence_score).detach().item(),
'critic/score/min': torch.min(sequence_score).detach().item(),
'critic/score/mean':
torch.mean(sequence_score).detach().item(),
'critic/score/max':
torch.max(sequence_score).detach().item(),
'critic/score/min':
torch.min(sequence_score).detach().item(),
# reward
'critic/rewards/mean': torch.mean(sequence_reward).detach().item(),
'critic/rewards/max': torch.max(sequence_reward).detach().item(),
'critic/rewards/min': torch.min(sequence_reward).detach().item(),
'critic/rewards/mean':
torch.mean(sequence_reward).detach().item(),
'critic/rewards/max':
torch.max(sequence_reward).detach().item(),
'critic/rewards/min':
torch.min(sequence_reward).detach().item(),
# adv
'critic/advantages/mean': masked_mean(advantages, response_mask).detach().item(),
'critic/advantages/max': torch.max(advantages[response_mask.bool()]).detach().item(),
'critic/advantages/min': torch.min(advantages[response_mask.bool()]).detach().item(),
'critic/advantages/mean':
torch.mean(valid_adv).detach().item(),
'critic/advantages/max':
torch.max(valid_adv).detach().item(),
'critic/advantages/min':
torch.min(valid_adv).detach().item(),
# returns
'critic/returns/mean': masked_mean(returns, response_mask).detach().item(),
'critic/returns/max': torch.max(returns[response_mask.bool()]).detach().item(),
'critic/returns/min': torch.min(returns[response_mask.bool()]).detach().item(),
'critic/returns/mean':
torch.mean(valid_returns).detach().item(),
'critic/returns/max':
torch.max(valid_returns).detach().item(),
'critic/returns/min':
torch.min(valid_returns).detach().item(),
# values
'critic/values/mean': masked_mean(values, response_mask).detach().item(),
'critic/values/max': torch.max(values[response_mask.bool()]).detach().item(),
'critic/values/min': torch.min(values[response_mask.bool()]).detach().item(),
'critic/values/mean':
torch.mean(valid_values).detach().item(),
'critic/values/max':
torch.max(valid_values).detach().item(),
'critic/values/min':
torch.min(valid_values).detach().item(),
# response length
'response_length/mean': torch.mean(response_length).detach().item(),
'response_length/max': torch.max(response_length).detach().item(),
'response_length/min': torch.min(response_length).detach().item(),
'response_length/mean':
torch.mean(response_length).detach().item(),
'response_length/max':
torch.max(response_length).detach().item(),
'response_length/min':
torch.min(response_length).detach().item(),
'response_length/clip_ratio':
torch.mean(torch.eq(response_length, max_response_length).float()).detach().item(),
# prompt length
'prompt_length/mean': torch.mean(prompt_length).detach().item(),
'prompt_length/max': torch.max(prompt_length).detach().item(),
'prompt_length/min': torch.min(prompt_length).detach().item(),
'prompt_length/mean':
torch.mean(prompt_length).detach().item(),
'prompt_length/max':
torch.max(prompt_length).detach().item(),
'prompt_length/min':
torch.min(prompt_length).detach().item(),
'prompt_length/clip_ratio':
torch.mean(torch.eq(prompt_length, max_prompt_length).float()).detach().item(),
# vf explained var
'critic/vf_explained_var': (1.0 - return_diff_var / (return_var + 1e-5)).detach().item(),
}
return metrics

Expand All @@ -217,10 +257,10 @@ def compute_timing_metrics(batch, timing_raw):

return {
**{
f'timing/{name}': value for name, value in timing_raw.items()
f'timing(s)/{name}': value for name, value in timing_raw.items()
},
**{
f'timing_per_token/{name}': timing_raw[name] / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
f'timing_per_token(ms)/{name}': timing_raw[name] * 1000 / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
)) & set(timing_raw.keys())
},
}
Expand Down Expand Up @@ -382,7 +422,7 @@ def _validate(self):

metric_dict = {}
for data_source, rewards in data_source_reward.items():
metric_dict[f'test_score/{data_source}'] = np.mean(rewards)
metric_dict[f'val/test_score/{data_source}'] = np.mean(rewards)

return metric_dict

Expand Down Expand Up @@ -457,6 +497,20 @@ def init_workers(self):
self.actor_rollout_wg = all_wg['actor_rollout']
self.actor_rollout_wg.init_model()

def _save_checkpoint(self):
actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
f'global_step_{self.global_steps}')
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)

if self.use_critic:
critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
f'global_step_{self.global_steps}')
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)

def fit(self):
"""
The training loop of PPO.
Expand All @@ -482,6 +536,9 @@ def fit(self):
if self.config.trainer.get('val_only', False):
return

# we start from step 1
self.global_steps += 1

for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
Expand All @@ -493,71 +550,76 @@ def fit(self):
# pop those keys for generation
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])

# generate a batch
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)

# compute values
with _timer('values', timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)

with _timer('adv', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor

# compute rewards. apply_kl_penalty if available
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)

# compute advantages, executed on the driver process
batch = compute_advantage(batch,
self.config.algorithm.gamma,
self.config.algorithm.lam,
adv_estimator=self.config.algorithm.adv_estimator)

# update critic
if self.use_critic:
with _timer('update_critic', timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)

# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)

# validate
if self.val_reward_fn is not None and (self.global_steps + 1) % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}
metrics.update(val_metrics)
with _timer('step', timing_raw):
# generate a batch
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)

# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)

if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)

# compute values
with _timer('values', timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)

with _timer('adv', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm:
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)

# we combine with rule-based rm
reward_tensor = self.reward_fn(batch)
batch.batch['token_level_scores'] = reward_tensor

# compute rewards. apply_kl_penalty if available
batch, kl_metrics = apply_kl_penalty(batch,
kl_ctrl=self.kl_ctrl,
kl_penalty=self.config.algorithm.kl_penalty)
metrics.update(kl_metrics)

# compute advantages, executed on the driver process
batch = compute_advantage(batch,
self.config.algorithm.gamma,
self.config.algorithm.lam,
adv_estimator=self.config.algorithm.adv_estimator)

# update critic
if self.use_critic:
with _timer('update_critic', timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)

# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with _timer('update_actor', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)

# validate
if self.val_reward_fn is not None and self.global_steps % self.config.trainer.test_freq == 0:
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
metrics.update(val_metrics)

if self.config.trainer.save_freq > 0 and \
self.global_steps % self.config.trainer.save_freq == 0:
with _timer('save_checkpoint', timing_raw):
self._save_checkpoint()

# collect metrics
metrics.update(compute_data_metrics(batch=batch))
Expand All @@ -566,20 +628,6 @@ def fit(self):
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)

if self.config.trainer.save_freq > 0 and (self.global_steps + 1) % self.config.trainer.save_freq == 0:
actor_local_path = os.path.join(self.config.trainer.default_local_dir, 'actor',
f'global_step_{self.global_steps}')
actor_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'actor')
self.actor_rollout_wg.save_checkpoint(actor_local_path, actor_remote_path)

if self.use_critic:
critic_local_path = os.path.join(self.config.trainer.default_local_dir, 'critic',
f'global_step_{self.global_steps}')
critic_remote_path = None if self.config.trainer.default_hdfs_dir is None else os.path.join(
self.config.trainer.default_hdfs_dir, 'critic')
self.critic_wg.save_checkpoint(critic_local_path, critic_remote_path)

self.global_steps += 1

if self.global_steps >= self.total_training_steps:
Expand Down
Loading