Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RMSNormGated in Zamba2 #35943

Open
wants to merge 94 commits into
base: main
Choose a base branch
from
Open

Fix RMSNormGated in Zamba2 #35943

wants to merge 94 commits into from

Conversation

pglorio
Copy link
Contributor

@pglorio pglorio commented Jan 28, 2025

What does this PR do?

This PR extends Zamba2RMSNormGated to allow config.mamba_ngroups>1. The Zamba2 7B checkpoints have config.mamba_ngroups=2 so this change is necessary to have the correct forward pass.

I defined Zamba2RMSNormGated inside modular_zamba2.py instead of importing it, as this differs from the definition in modeling_mamba2.py. The implementation in this PR is the torch version of the mamba-ssm implementation of the original mamba2 (used here and torch implementation given here).

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

@ArthurZucker @Cyrilvallez

@pglorio pglorio changed the title Zamba2 Fix RMSNormGated in Zamba2 Jan 28, 2025
class Zamba2RMSNormGated(MambaRMSNormGated):
pass
class Zamba2RMSNormGated(torch.nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this will also affect the mamba2 code then (as codestral mamba also uses ngroups > 1) - so I'd be for implementing this in the mamba2 code and use modular then.

cc @molbap

Copy link
Contributor Author

@pglorio pglorio Jan 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vasqu @molbap sounds good. Should I go ahead and update mamba2?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, but as I'm no maintainer I leave the decision to the others 👀

@Rocketknight1
Copy link
Member

cc @molbap I think!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants