Skip to content

Commit

Permalink
[misc] fix: gradient accumulation in seq balance and modify default v…
Browse files Browse the repository at this point in the history
…llm log level (#141)

- Previous gradient accumulation value is computed by micro_batch_size,
which is wrong when using dynamic_bsz
- Fix ci script to avoid overlooking this issue
- Change vLLM state log default value to True to disable log.
- We will check the `self.config.actor.ppo_mini_batch_size %
self.config.actor.ppo_micro_batch_size_per_gpu == 0` after normalization
in fsdp_workers instead of in dp_actor and dp_critic.
  • Loading branch information
PeterSH6 authored Jan 27, 2025
1 parent 78a4606 commit 695bdbb
Show file tree
Hide file tree
Showing 16 changed files with 41 additions and 16 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/dataset.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ on:
- "**/*.py"
- .github/workflows/dataset.yml



jobs:
ray:
runs-on: [self-hosted, gpu]
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/e2e_digit_completion.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ on:
- .github/workflows/e2e_digit_completion.yml
- "tests/e2e/*.sh"



jobs:
e2e_digit_completion:
runs-on: [self-hosted, l20-0]
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/e2e_gsm8k.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ on:
- .github/workflows/e2e_gsm8k.yml
- "tests/e2e/*.sh"



jobs:
e2e_gsm8k:
runs-on: [self-hosted, l20-1]
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/e2e_lora.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ on:
- .github/workflows/e2e_lora.yml
- "tests/e2e/*.sh"



jobs:
e2e_lora:
runs-on: [self-hosted, l20-1]
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/e2e_sft.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ on:
- .github/workflows/e2e_sft.yml
- "tests/e2e/*.sh"



jobs:
e2e_sft:
runs-on: [self-hosted, l20-1]
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ on:
- "**/*.py"
- .github/workflows/model.yml



jobs:
model_rmpad:
runs-on: [self-hosted, l20-1]
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/ray_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ on:
- "**/*.py"
- .github/workflows/ray_test.yml



jobs:
ray:
runs-on: [self-hosted, l20-0]
Expand Down
2 changes: 2 additions & 0 deletions tests/e2e/arithmetic_sequence/rl/config/ray_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ actor_rollout_ref:
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: False # could get higher throughput
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
Expand Down
4 changes: 0 additions & 4 deletions tests/e2e/run_qwen_gsm8k_model_rm_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=12000 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.grad_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=16 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \
Expand All @@ -33,7 +31,6 @@ python3 -m verl.trainer.main_ppo \
critic.optim.lr_warmup_steps_ratio=0.05 \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
critic.use_dynamic_bsz=True \
critic.ppo_max_token_len_per_gpu=98304 \
critic.model.fsdp_config.param_offload=False \
Expand All @@ -43,7 +40,6 @@ python3 -m verl.trainer.main_ppo \
reward_model.model.path=Qwen/Qwen2.5-0.5B\
reward_model.model.use_remove_padding=True \
reward_model.model.fsdp_config.param_offload=True \
reward_model.micro_batch_size_per_gpu=16 \
reward_model.use_dynamic_bsz=True \
reward_model.forward_max_token_len_per_gpu=98304 \
algorithm.kl_ctrl.kl_coef=0.001 \
Expand Down
8 changes: 4 additions & 4 deletions tests/e2e/run_ray_trainer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \
data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \
actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
critic.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=200 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=200 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=200 \
critic.ppo_micro_batch_size_per_gpu=200 \
critic.model.path=tests/e2e/arithmetic_sequence/model | tee $OUTPUT_FILE;

python3 tests/e2e/check_results.py --output_file=$OUTPUT_FILE
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/config/ppo_megatron_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ actor_rollout_ref:
max_num_seqs: 1024
log_prob_micro_batch_size: null # will be deprecated, use log_prob_micro_batch_size_per_gpu
log_prob_micro_batch_size_per_gpu: null
disable_log_stats: True
enable_chunked_prefill: False # could get higher throughput
# for hf rollout
do_sample: True
layer_name_map:
Expand Down
2 changes: 2 additions & 0 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ actor_rollout_ref:
log_prob_micro_batch_size_per_gpu: null
log_prob_use_dynamic_bsz: ${actor_rollout_ref.actor.use_dynamic_bsz}
log_prob_max_token_len_per_gpu: ${actor_rollout_ref.actor.ppo_max_token_len_per_gpu}
disable_log_stats: True
enable_chunked_prefill: False # could get higher throughput
# for hf rollout
do_sample: True
# number of responses (i.e. num sample times)
Expand Down
9 changes: 6 additions & 3 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,6 @@ def update_policy(self, data: DataProto):
# make sure we are in training mode
self.actor_module.train()

assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
temperature = data.meta_info['temperature'] # temperature must be in the data.meta_info to avoid slient error

select_keys = ['responses', 'input_ids', 'attention_mask', 'position_ids', 'old_log_probs', 'advantages']
Expand All @@ -225,6 +223,7 @@ def update_policy(self, data: DataProto):
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
# split batch into micro_batches
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

Expand Down Expand Up @@ -268,7 +267,11 @@ def update_policy(self, data: DataProto):
metrics['actor/kl_loss'] = kl_loss.detach().item()
metrics['actor/kl_coef'] = self.config.kl_loss_coef

loss = policy_loss / self.gradient_accumulation
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
else:
loss = policy_loss / self.gradient_accumulation
loss.backward()

data = {
Expand Down
11 changes: 7 additions & 4 deletions verl/workers/critic/dp_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ def __init__(self, config, critic_module: nn.Module, critic_optimizer: optim.Opt
self.use_remove_padding = self.config.model.get('use_remove_padding', False)
print(f'Critic use_remove_padding={self.use_remove_padding}')

assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu

self.ulysses_sequence_parallel_size = self.config.get('ulysses_sequence_parallel_size', 1)

def _forward_micro_batch(self, micro_batch):
Expand Down Expand Up @@ -162,6 +159,7 @@ def update_critic(self, data: DataProto):
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu

self.critic_optimizer.zero_grad()

Expand All @@ -186,7 +184,12 @@ def update_critic(self, data: DataProto):
returns=returns,
eos_mask=eos_mask,
cliprange_value=self.config.cliprange_value)
loss = vf_loss / self.gradient_accumulation
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = vf_loss * (len(data) / self.config.ppo_mini_batch_size)
else:
loss = vf_loss / self.gradient_accumulation

loss.backward()

data = {
Expand Down
2 changes: 2 additions & 0 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def __init__(self, config: DictConfig, role: str):
self.config.actor.ppo_micro_batch_size //= (self.device_mesh.shape[0] //
self.ulysses_sequence_parallel_size)
self.config.actor.ppo_micro_batch_size_per_gpu = self.config.actor.ppo_micro_batch_size
assert self.config.actor.ppo_mini_batch_size % self.config.actor.ppo_micro_batch_size_per_gpu == 0
# normalize rollout config
if self._is_rollout and self.config.rollout.log_prob_micro_batch_size is not None:
self.config.rollout.log_prob_micro_batch_size //= (self.device_mesh.shape[0] //
Expand Down Expand Up @@ -582,6 +583,7 @@ def __init__(self, config):
self.ulysses_sequence_parallel_size)
self.config.ppo_micro_batch_size_per_gpu = self.config.ppo_micro_batch_size
self.config.forward_micro_batch_size_per_gpu = self.config.forward_micro_batch_size
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size_per_gpu == 0

def _build_critic_model_optimizer(self, config):
# the following line is necessary
Expand Down
3 changes: 2 additions & 1 deletion verl/workers/rollout/vllm_rollout/vllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,9 @@ def __init__(self, actor_module: nn.Module, config: DictConfig, tokenizer, model
skip_tokenizer_init=False,
max_model_len=config.prompt_length + config.response_length,
load_format=config.load_format,
disable_log_stats=False,
disable_log_stats=config.disable_log_stats,
max_num_batched_tokens=max_num_batched_tokens,
enable_chunked_prefill=config.enable_chunked_prefill,
)

# Offload vllm model to reduce peak memory usage
Expand Down

0 comments on commit 695bdbb

Please sign in to comment.