Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ModernBERT FlexAttention #35423

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

staghado
Copy link
Contributor

What does this PR do?

This PR adds FlexAttention support for ModernBERT:

  • Combines sliding window and document masking to implement the alternating local/global attention pattern in ModernBERT
  • Mask creation is expensive so the two masks are cached at the model level and then re-used across layers.
  • Similar to the FA2 path, it works directly on the unpadded sequences
  • Re-uses the existing ModernBertRotaryEmbedding to avoid requiring FA2.

Note:
The current version requires one of the latest torch nightlies (e.g 2.6.0.dev20241112)
Currently transformers does not allow compiling the flex_attention function IIUC

@tomaarsen
Copy link
Member

I think this is very interesting, but only if the performance rivals that of e.g. SDPA. I see your issue here: pytorch-labs/attention-gym#95, which also shows that with compilation, Flex Attention outperforms SDPA in a lot of common cases, so we would have to introduce the compilation option for Flex Attention in transformers before this is viable I think.

I'd love to see that implemented, though!

  • Tom Aarsen

@staghado
Copy link
Contributor Author

Yeah when compiled FlexAttention is generally faster than SDPA and has much lower memory from my tests.
How would we go about adding the compilation option? would a flag suffice in this case?

@staghado
Copy link
Contributor Author

staghado commented Jan 7, 2025

Any ideas on how transformers plans to support FlexAttention compilation? I see that no models have it implemented for now(I might be wrong here)
cc @tomaarsen @ArthurZucker

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks alright appart from the nightly requirement!
Regarding compilation, we do support compile, and by default when you use generate with cache_implementation="static".

src/transformers/models/modernbert/modeling_modernbert.py Outdated Show resolved Hide resolved
src/transformers/models/modernbert/modeling_modernbert.py Outdated Show resolved Hide resolved
@staghado
Copy link
Contributor Author

staghado commented Jan 9, 2025

I think the nightly requirement might be resolved by specifying the BLOCK_SIZE argument in create_block_mask, will verify that.
for compilation, I meant supporting the flex_attention = torch.compile(flex_attention, dynamic=False) which is very important for performance as mentioned here.
It also looks like the flexattention in Gemma was removed?

@neavo
Copy link

neavo commented Jan 11, 2025

What does this PR do?

This PR adds FlexAttention support for ModernBERT:

  • Combines sliding window and document masking to implement the alternating local/global attention pattern in ModernBERT
  • Mask creation is expensive so the two masks are cached at the model level and then re-used across layers.
  • Similar to the FA2 path, it works directly on the unpadded sequences
  • Re-uses the existing ModernBertRotaryEmbedding to avoid requiring FA2.

Note: The current version requires one of the latest torch nightlies (e.g 2.6.0.dev20241112) Currently transformers does not allow compiling the flex_attention function IIUC

Torch 2.6 has released the RC version. The dependency issues should not be a big problem anymore. Looking forward to its merging.
Before that, could you please update the branch of this PR to make its code compatible with the current mainline version?
I would like to give it a try.

if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention

flex_attention = torch.compile(flex_attention)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these are needed to actually use the flexattention kernels but the utils/modular_model_converter.py does not allow it in the converted file.
let me know if there is a better way to do this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @ArthurZucker @tomaarsen (sorry for the ping)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any ideas on how to support compiling the flex_attention function in transformers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The failing CI is related to these two lines, there is no clear way how to support compile in a clean way

@staghado staghado requested a review from ArthurZucker January 15, 2025 15:50
@staghado
Copy link
Contributor Author

staghado commented Jan 30, 2025

With the release of PyTorch 2.6, it is now possible to use FlexAttention with ModernBERT without a nightly requirement.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants