-
Notifications
You must be signed in to change notification settings - Fork 27.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ModernBERT FlexAttention #35423
base: main
Are you sure you want to change the base?
ModernBERT FlexAttention #35423
Conversation
I think this is very interesting, but only if the performance rivals that of e.g. SDPA. I see your issue here: pytorch-labs/attention-gym#95, which also shows that with compilation, Flex Attention outperforms SDPA in a lot of common cases, so we would have to introduce the compilation option for Flex Attention in I'd love to see that implemented, though!
|
Yeah when compiled FlexAttention is generally faster than SDPA and has much lower memory from my tests. |
Any ideas on how transformers plans to support FlexAttention compilation? I see that no models have it implemented for now(I might be wrong here) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks alright appart from the nightly requirement!
Regarding compilation, we do support compile, and by default when you use generate with cache_implementation="static"
.
I think the nightly requirement might be resolved by specifying the |
Torch 2.6 has released the RC version. The dependency issues should not be a big problem anymore. Looking forward to its merging. |
if is_torch_flex_attn_available(): | ||
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention | ||
|
||
flex_attention = torch.compile(flex_attention) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these are needed to actually use the flexattention kernels but the utils/modular_model_converter.py
does not allow it in the converted file.
let me know if there is a better way to do this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cc @ArthurZucker @tomaarsen (sorry for the ping)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any ideas on how to support compiling the flex_attention function in transformers?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The failing CI is related to these two lines, there is no clear way how to support compile in a clean way
With the release of PyTorch 2.6, it is now possible to use FlexAttention with ModernBERT without a nightly requirement. |
What does this PR do?
This PR adds FlexAttention support for ModernBERT:
Note:
The current version requires one of the latest torch nightlies (e.g 2.6.0.dev20241112)
Currently transformers does not allow compiling the flex_attention function IIUC