Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GRPO] add reward weight in multi-reward settings #2676

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

hesamsheikh
Copy link

What does this PR do?

As stated by the documentation, in multi-reward-function settings the final reward would be a sum of each reward. This PR is aimed to provide the ability to specify reward weights in multi-reward settings. This provides much more control and flexibility on which rewards require more emphasis. The reward weights support both floats (sum to 1 or not) and ints.

from trl import GRPOTrainer

trainer = GRPOTrainer(
    reward_funcs=[reward_func1, reward_func2],
    reward_weights=[1, 2]
    ...,
)

Before submitting

  • [ ✅] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • [✅ ] Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • [ ✅] Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • [✅] Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @hesamsheikh! Do you have any references that show this method can be useful?

@Superskyyy
Copy link
Contributor

What about there are cases where a verifiable reward will be nulled if a primary reward is turned out to be unsatisfactory. Like if you rate a code snippet with execution. That code is not runnable at all, then all other rewards should be nulled right? That would need more than a weight but a primary setting. If that generalizes to more use cases.

@hesamsheikh
Copy link
Author

Thanks @hesamsheikh! Do you have any references that show this method can be useful?

the paper actually provides a sneak peek on how their rewards are aggregated:

Finally, we combine the accuracy of
reasoning tasks and the reward for language consistency by directly summing them to form the
final reward. We then apply RL training on the fine-tuned model until it achieves convergence
on reasoning tasks.

However, they only specify summing in the case of using accuracy reward and language consistency reward. In the current implementation, we could provide multiple rewards in the same scope (e.g. format of the output, or accuracy) so it makes sense that a weighted rewarding system can be beneficial. In the example provided in the test file:

        def reward_func1(completions, **kwargs):
            """Reward function that rewards longer completions."""
            return [float(len(completion)) for completion in completions]

        def reward_func2(completions, **kwargs):
            """Reward function that rewards completions with more unique letters."""
            return [float(len(set(completion))) for completion in completions]

both rewards are related to the output format, but a simple summation doesn't give us control over which is more important (e.g. to make the completions longer is much more important than having more unique letters).

The implementation of the weighted sum is straightforward:

# Sum the rewards from all reward functions
rewards = rewards_per_func.sum(dim=1)

is replaced by

rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)

This allows custom weights rather than no weights (replacing 1 x r1 + 1 x r2 with w1 x r1 + w2 x r2) applied only to the final stage (summation of the rewards). It doesn't break the advantages or the loss function as it is applied to all the rewards. This weighted sum makes the reward aggregation much more flexible in cases of using multiple rewards.

@qgallouedec
Copy link
Member

Thanks for this detailed explanation. I understand the underlying motivation. I'm wondering if it really helps to get better results? Or whether a naive sum, as is done in the paper actually is enough to get similar results.

@hesamsheikh
Copy link
Author

Thanks for this detailed explanation. I understand the underlying motivation. I'm wondering if it really helps to get better results? Or whether a naive sum, as is done in the paper actually is enough to get similar results.

In cases where multiple rewards with different priorities need to be tuned, the weighted reward must be more handy. I'm down with doing some experiments if you suggest some.

@Benjoyo
Copy link

Benjoyo commented Jan 31, 2025

I mean, you can can just pass a single aggregate reward function and do arbitrary weighting there, no?
I don’t quite understand the need for a list of functions anyway, except it is slightly more self-documenting. But you can do the same and more with a single function and I don’t think we should add additional parameters to fix the problems with separation reward functions. What do you think?

@Superskyyy
Copy link
Contributor

Superskyyy commented Jan 31, 2025

I mean, you can can just pass a single aggregate reward function and do arbitrary weighting there, no?

I don’t quite understand the need for a list of functions anyway, except it is slightly more self-documenting. But you can do the same and more with a single function and I don’t think we should add additional parameters to fix the problems with separation reward functions. What do you think?

Right, a custom aggregating function to rule them and pass that into trainer sounds like the best way to abstract this need of weighting or whatever many functions. Trainer doesn't really need to know of so many reward functions.

@qgallouedec
Copy link
Member

I don’t quite understand the need for a list of functions anyway

It allows

  1. to be compatible with reward models (you can mix functions and models)
  2. to log each reward separately

@Benjoyo
Copy link

Benjoyo commented Feb 1, 2025

I don’t quite understand the need for a list of functions anyway

It allows

  1. to be compatible with reward models (you can mix functions and models)
  2. to log each reward separately

Ok, fair points!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants