Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Actor model didn't update correctly when upgrade megatron to core-r0.6.0 #64

Open
Wodswos opened this issue Dec 24, 2024 · 7 comments
Open

Comments

@Wodswos
Copy link

Wodswos commented Dec 24, 2024

I tested veRL using llama-2-7b(with a smaller train_batch_size), and the results seem to meet expectations,

(main_task pid=27483) step:3 - timing/gen:12.597 - timing/ref:0.522 - timing/values:0.457 - critic/kl:0.016 - critic/kl_coeff:0.001 - timing/adv:0.020 - timing/update_critic:1.744 - critic/vf_loss:0.317 - critic/vf_clipfrac:0.111 - critic/vpred_mean:-0.001 - timing/update_actor:1.537 - actor/entropy_loss:0.130 - actor/pg_loss:0.005 - actor/pg_clipfrac:0.002 - actor/ppo_kl:0.000 - timing/testing:11.108 - val/test_score/openai/gsm8k:0.188 - critic/score/mean:0.031 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.028 - critic/rewards/max:0.997 - critic/rewards/min:-0.015 - critic/advantages/mean:-0.000 - critic/advantages/max:3.696 - critic/advantages/min:-6.231 - critic/returns/mean:0.023 - critic/returns/max:1.000 - critic/returns/min:-0.015 - critic/values/mean:0.441 - critic/values/max:4.250 - critic/values/min:-1.859 - response_length/mean:236.156 - response_length/max:443.000 - response_length/min:65.000 - prompt_length/mean:92.688 - prompt_length/max:146.000 - prompt_length/min:48.000

Then I want to migrate to Megatron 0.6, I made the following modifications.

  1. Modification Megatron-LM
    1. patch get_model_config like 0.4.0
    2. Modify DistributedDataParallel class
      1. Add self.data_parallel_group = data_parallel_group in __init__
      2. Modify src=torch.distributed.get_process_group_ranks(self.data_parallel_group)[0] in line272, according to this PR.
  2. Modification in veRL
    1. import path
      1. change from megatron.utls import print_rank_0, unwrap_model to from megatron.training.utils import print_rank_0, unwrap_model
      2. change from megatron.optimizer/timer ... to from megatron.core.optimizer/timer ...
    2. modify megatron_actor/critic.py
      1. modify the paramter to call zero_grad_buffer() & optimizer.step()
      2. remove hidden_size parameter when call forward_backward_func() on line 223.
    3. modify get_megatron_optimizer function like follow
def get_megatron_optimizer(
        model,
        config: OptimizerConfig,
        no_weight_decay_cond=None,
        scale_lr_cond=None,
        lr_mult=1.0,
        check_for_nan_in_loss_and_grad=False,
        overlap_param_gather=False  # add for verl
):
        ...
        import torch
        from megatron.core import mpu

        def init_state_fn(opt):
            for group in opt.param_groups:
                for p in group['params']:
                    if len(opt.state[p]) == 0:
                        opt.state[p]['exp_avg'] = torch.zeros_like(p.data)
                        opt.state[p]['exp_avg_sq'] = torch.zeros_like(p.data)

        per_model_buffers = {}
        for model_idx, model_chunk in enumerate(model_chunks):
            if hasattr(model_chunk, 'buffers'):
                per_model_buffers[model_idx] = model_chunk.buffers

        if config.use_distributed_optimizer:
            # return DistributedOptimizer(optimizer, config.clip_grad, config.log_num_zeros_in_grad,
            #                             check_for_nan_in_loss_and_grad, params_have_main_grad, config.fp16, config.bf16,
            #                             config.params_dtype, grad_scaler, model, overlap_param_gather)
            return DistributedOptimizer(
                optimizer=optimizer,
                config=config,
                grad_scaler=grad_scaler,
                init_state_fn=init_state_fn,
                per_model_buffers=per_model_buffers,
                data_parallel_group=mpu.get_data_parallel_group(with_context_parallel=True),
                data_parallel_group_gloo=mpu.get_data_parallel_group_gloo(with_context_parallel=True),
                data_parallel_group_idx=torch.distributed.get_rank(mpu.get_model_parallel_group())
            )

The code to reproduce is on this branch.

When I rerun the scripts, the metrics is weired, the critic/kl is always zero:

(main_task pid=32059) step:3 - timing/gen:13.604 - timing/ref:0.478 - timing/values:0.456 - critic/kl:0.000 - critic/kl_coeff:0.001 - timing/adv:0.020 - timing/update_critic:1.597 - critic/vf_loss:0.363 - critic/vf_clipfrac:0.147 - critic/vpred_mean:0.099 - timing/update_actor:1.583 - actor/entropy_loss:0.134 - actor/pg_loss:0.015 - actor/pg_clipfrac:0.001 - actor/ppo_kl:-0.000 - timing/testing:11.299 - val/test_score/openai/gsm8k:0.031 - critic/score/mean:0.094 - critic/score/max:1.000 - critic/score/min:0.000 - critic/rewards/mean:0.094 - critic/rewards/max:1.000 - critic/rewards/min:0.000 - critic/advantages/mean:0.000 - critic/advantages/max:4.966 - critic/advantages/min:-4.105 - critic/returns/mean:0.096 - critic/returns/max:1.000 - critic/returns/min:0.000 - critic/values/mean:0.578 - critic/values/max:3.250 - critic/values/min:-1.906 - response_length/mean:251.750 - response_length/max:479.000 - response_length/min:87.000 - prompt_length/mean:85.469 - prompt_length/max:154.000 - prompt_length/min:53.000

If I understand correctly, the critic/kl here is comparing the outputs of the actor policy and the ref policy. I printed the actor's parameter tensor, and upon comparison, the parameters indeed did not change before and after optimizer.step(), but the gradients are not zero (What’s even stranger is the critic_module seems to be updating normally).

Perhaps this is not an issue with veRL, but rather that I am not using veRL and Megatron correctly. However, this seems to be the most promising place to resolve my confusion. From this RFC, it seems that veRL is also trying to adapt to higher versions of Megatron. Do you have any suggestions?

@Wodswos Wodswos changed the title Actor model didn't update correctly when upgrad megatron to core-r0.6.0 Actor model didn't update correctly when upgrade megatron to core-r0.6.0 Dec 24, 2024
@PeterSH6
Copy link
Collaborator

PeterSH6 commented Dec 24, 2024

Hi @Wodswos, thanks for your detailed feedback!

Your understanding of critic/kl is correct and I think the actor model weight is not updated when upgraded to v0.6.0.
I think this issue is only related to the actor model while the critic model can be updated correctly due to the mismatch between the newer Megatron-LM and veRL's HybridEngine implementation. Here are some analyses:

  1. Since Megatron v0.6.0, the GradBuffer used in DistributedDataParallel in Megatron-LM was replaced by ParamAndGradBuffer in param_and_grad_buffer.py. In this commit, the Megatron-LM authors map all the model parameter to a continuous param_buffer.

  2. In v0.4.0 (and also v0.5.0), we implemented an AllGatherPPModel in megatron_vllm.py to support our 3D-HybridEngine design. This approach establishes a contiguous buffer for each model parameter and directly maps it to the param.data in Megatron-LM. However, in v0.6.0 (and later), directly binding our model memory buffer to param.data ignores the new ParamAndGradBuffer implementation. As a result, the model parameters are not updated correctly.

Proposed Solutions
To address this issue in v0.6.0 and later versions, I recommend the following quick fixes:

  1. Bind veRL's memory_buffer to ParamAndGradBuffer:
    Instead of mapping veRL's buffer directly to param.data in the DistributedDataParallel module (as seen in memory_buffer.py), we should bind veRL's memory_buffer to the ParamAndGradBuffer for each model parameter. This ensures compatibility with Megatron-LM's updated parameter management logic.

  2. Align Padding Logic:
    Verify that the padding logic in veRL's MemoryBuffer (defined in memory_buffer.py) matches the padding logic in ParamAndGradBuffer (defined in param_and_grad_buffer.py). Consistent padding ensures that the buffers are aligned correctly for parameter updates.

By implementing these solutions, the actor model's weights should update correctly in v0.6.0 and later. Let me know if you need further clarification or assistance!

@Wodswos
Copy link
Author

Wodswos commented Jan 6, 2025

Thank you so much for your patient and detailed response!

The padding logic alignment is a bit complex for me, so based on the hints you provided, I implemented a very rough temporary solution (it achieves the basic functionality, but with some performance trade-offs):

I think this issue can probably be marked as resolved now or closed after official support for version 0.6 is available.

Once again, thank you very much!

@PeterSH6
Copy link
Collaborator

PeterSH6 commented Jan 6, 2025

Cool! @Wodswos So you have already supported Megatron v0.6.0 and validated the convergence?

If so, I wonder if you would like to submit a draft PR so that we can investigate how to support the padding logic.
We're looking for contributions to upgrade the Megatron and your upgrade could be the official one after reviews and discussion.

@Wodswos
Copy link
Author

Wodswos commented Jan 7, 2025

Since I am completely new to RLHF, I am actually not entirely sure what constitutes "convergence". I can only observe that the value of test_score/openai/gsm8k has significantly increased in the first few iteration steps and is approaching 0.5x (using Llama3.2-1B-Instruct).

I am more than willing to discuss the upgrade to megatron-0.6 and have opened a PR here. However, at this stage, I genuinely lack an lightweight solution to align the arrangement order of parameter in MemoryBuffer.data with ParamAndGradBuffer.param_data.

@PeterSH6
Copy link
Collaborator

PeterSH6 commented Jan 8, 2025

@Wodswos , it looks really nice, Thanks! We'll take some time to check how to align the arrangement order of parameters in MemoryBuffer.data with ParamAndGradBuffer.param_data.

For the convergence, I mean whether the training log (loss, reward, value, validation score) in your current fix in MCore 0.6.0 can align with that using MCore 0.4.0. It so, I think your implementation is correct.

And, what do you mean about 0.5x? Do you mean the resharding overhead or the megatron training throughput?

@Wodswos
Copy link
Author

Wodswos commented Jan 8, 2025

No, this is not about throughput; it's about the value of test_score/openai/gsm8k being roughly 0.5. 😂

The metrics for MCore-0.4 and MCore-0.6 do not seem to be fully aligned, and I'm not sure if this is due to randomness or other reasons.

@PeterSH6
Copy link
Collaborator

PeterSH6 commented Jan 8, 2025

@Wodswos Fully aligned may not be possible due to several reasons.

I wonder what're the test_score/openai/gsm8k you run using MCore-0.4 and MCore-0.6, respectively?
If they're similar, it would be fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants