-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable flash attention #20448
Enable flash attention #20448
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20448 +/- ##
==========================================
- Coverage 82.03% 82.02% -0.02%
==========================================
Files 515 515
Lines 47346 47383 +37
Branches 7427 7435 +8
==========================================
+ Hits 38842 38865 +23
- Misses 6705 6714 +9
- Partials 1799 1804 +5
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
3e34498
to
a36d119
Compare
query, | ||
key, | ||
value, | ||
return_attention_scores, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@divyashreepathihalli shouldn't return_attention_scores
have a default value here?
There might be many instances of _compute_attention
call that don't pass return_attention_scores
value and I think this change can break them. There is an example here. So I was wondering if it's possible to set a default value here so that other references of _compute_attention
work as before.
I think this change is probably the reason that the test is failing in keras-team/keras-hub#1977
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is tricky, right now the value passed to the call method is passed to _compute_attention
. If we add a default value here and users don't pass the arg value from call it might override the call arg. and that could cause discrepancies.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't there any other way to check and make sure the users pass the arg value (rather than lack of default value here)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think a work around would be to add a self._return_attention_scores
and then set it in the call method and use it in _compute_attention
. wdyt?
This PR
_compute_attention
method would just callops.dot_production_attention
keras.config.enable_flash_attention
andkeras.config.is_flash_attention_enabled
_masked_softmax
implementation