-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Actor model didn't update correctly when upgrade megatron to core-r0.6.0 #64
Comments
Hi @Wodswos, thanks for your detailed feedback! Your understanding of
Proposed Solutions
By implementing these solutions, the actor model's weights should update correctly in v0.6.0 and later. Let me know if you need further clarification or assistance! |
Thank you so much for your patient and detailed response! The padding logic alignment is a bit complex for me, so based on the hints you provided, I implemented a very rough temporary solution (it achieves the basic functionality, but with some performance trade-offs):
I think this issue can probably be marked as resolved now or closed after official support for version 0.6 is available. Once again, thank you very much! |
Cool! @Wodswos So you have already supported Megatron v0.6.0 and validated the convergence? If so, I wonder if you would like to submit a draft PR so that we can investigate how to support the padding logic. |
Since I am completely new to RLHF, I am actually not entirely sure what constitutes "convergence". I can only observe that the value of I am more than willing to discuss the upgrade to megatron-0.6 and have opened a PR here. However, at this stage, I genuinely lack an lightweight solution to align the arrangement order of parameter in |
@Wodswos , it looks really nice, Thanks! We'll take some time to check how to align the arrangement order of parameters in MemoryBuffer.data with ParamAndGradBuffer.param_data. For the convergence, I mean whether the training log (loss, reward, value, validation score) in your current fix in MCore 0.6.0 can align with that using MCore 0.4.0. It so, I think your implementation is correct. And, what do you mean about 0.5x? Do you mean the resharding overhead or the megatron training throughput? |
No, this is not about throughput; it's about the value of The metrics for MCore-0.4 and MCore-0.6 do not seem to be fully aligned, and I'm not sure if this is due to randomness or other reasons. |
@Wodswos Fully aligned may not be possible due to several reasons. I wonder what're the test_score/openai/gsm8k you run using MCore-0.4 and MCore-0.6, respectively? |
I tested veRL using llama-2-7b(with a smaller train_batch_size), and the results seem to meet expectations,
Then I want to migrate to Megatron 0.6, I made the following modifications.
get_model_config
like 0.4.0DistributedDataParallel
classself.data_parallel_group = data_parallel_group
in__init__
src=torch.distributed.get_process_group_ranks(self.data_parallel_group)[0]
in line272, according to this PR.from megatron.utls import print_rank_0, unwrap_model
tofrom megatron.training.utils import print_rank_0, unwrap_model
from megatron.optimizer/timer ...
tofrom megatron.core.optimizer/timer ...
megatron_actor/critic.py
zero_grad_buffer()
&optimizer.step()
hidden_size
parameter when callforward_backward_func()
on line 223.get_megatron_optimizer
function like followThe code to reproduce is on this branch.
When I rerun the scripts, the metrics is weired, the
critic/kl
is always zero:If I understand correctly, the
critic/kl
here is comparing the outputs of the actor policy and the ref policy. I printed the actor's parameter tensor, and upon comparison, the parameters indeed did not change before and afteroptimizer.step()
, but the gradients are not zero (What’s even stranger is the critic_module seems to be updating normally).Perhaps this is not an issue with veRL, but rather that I am not using veRL and Megatron correctly. However, this seems to be the most promising place to resolve my confusion. From this RFC, it seems that veRL is also trying to adapt to higher versions of Megatron. Do you have any suggestions?
The text was updated successfully, but these errors were encountered: