-
Notifications
You must be signed in to change notification settings - Fork 1.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GRPO] add reward weight in multi-reward settings #2676
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @hesamsheikh! Do you have any references that show this method can be useful?
What about there are cases where a verifiable reward will be nulled if a primary reward is turned out to be unsatisfactory. Like if you rate a code snippet with execution. That code is not runnable at all, then all other rewards should be nulled right? That would need more than a weight but a primary setting. If that generalizes to more use cases. |
the paper actually provides a sneak peek on how their rewards are aggregated:
However, they only specify summing in the case of using accuracy reward and language consistency reward. In the current implementation, we could provide multiple rewards in the same scope (e.g. format of the output, or accuracy) so it makes sense that a weighted rewarding system can be beneficial. In the example provided in the test file: def reward_func1(completions, **kwargs):
"""Reward function that rewards longer completions."""
return [float(len(completion)) for completion in completions]
def reward_func2(completions, **kwargs):
"""Reward function that rewards completions with more unique letters."""
return [float(len(set(completion))) for completion in completions] both rewards are related to the output format, but a simple summation doesn't give us control over which is more important (e.g. to make the completions longer is much more important than having more unique letters). The implementation of the weighted sum is straightforward: # Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1) is replaced by rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) This allows custom weights rather than no weights (replacing 1 x r1 + 1 x r2 with w1 x r1 + w2 x r2) applied only to the final stage (summation of the rewards). It doesn't break the advantages or the loss function as it is applied to all the rewards. This weighted sum makes the reward aggregation much more flexible in cases of using multiple rewards. |
Thanks for this detailed explanation. I understand the underlying motivation. I'm wondering if it really helps to get better results? Or whether a naive sum, as is done in the paper actually is enough to get similar results. |
In cases where multiple rewards with different priorities need to be tuned, the weighted reward must be more handy. I'm down with doing some experiments if you suggest some. |
I mean, you can can just pass a single aggregate reward function and do arbitrary weighting there, no? |
Right, a custom aggregating function to rule them and pass that into trainer sounds like the best way to abstract this need of weighting or whatever many functions. Trainer doesn't really need to know of so many reward functions. |
It allows
|
Ok, fair points! |
What does this PR do?
As stated by the documentation, in multi-reward-function settings the final reward would be a sum of each reward. This PR is aimed to provide the ability to specify reward weights in multi-reward settings. This provides much more control and flexibility on which rewards require more emphasis. The reward weights support both floats (sum to 1 or not) and ints.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.