Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flash attention #20448

Merged

Conversation

divyashreepathihalli
Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli commented Nov 4, 2024

This PR

  • refactors the MHA layer so that its _compute_attention method would just call ops.dot_production_attention
  • Adds a global toggle keras.config.enable_flash_attention and keras.config.is_flash_attention_enabled
  • Modify attention mask to match _masked_softmax implementation

@codecov-commenter
Copy link

codecov-commenter commented Nov 4, 2024

Codecov Report

Attention: Patch coverage is 73.68421% with 10 lines in your changes missing coverage. Please review.

Project coverage is 82.02%. Comparing base (30a6b87) to head (9a5200e).

Files with missing lines Patch % Lines
keras/src/layers/attention/multi_head_attention.py 68.75% 2 Missing and 3 partials ⚠️
keras/api/_tf_keras/keras/config/__init__.py 0.00% 3 Missing ⚠️
keras/src/backend/jax/nn.py 50.00% 1 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20448      +/-   ##
==========================================
- Coverage   82.03%   82.02%   -0.02%     
==========================================
  Files         515      515              
  Lines       47346    47383      +37     
  Branches     7427     7435       +8     
==========================================
+ Hits        38842    38865      +23     
- Misses       6705     6714       +9     
- Partials     1799     1804       +5     
Flag Coverage Δ
keras 81.87% <73.68%> (-0.02%) ⬇️
keras-jax 64.94% <68.42%> (-0.01%) ⬇️
keras-numpy 59.89% <50.00%> (-0.02%) ⬇️
keras-tensorflow 65.96% <60.52%> (-0.02%) ⬇️
keras-torch 64.86% <71.05%> (-0.01%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

Copy link
Member

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

keras/src/backend/config.py Outdated Show resolved Hide resolved
keras/src/backend/config.py Outdated Show resolved Hide resolved
keras/src/layers/attention/multi_head_attention_test.py Outdated Show resolved Hide resolved
@divyashreepathihalli divyashreepathihalli marked this pull request as draft November 5, 2024 01:31
keras/src/layers/attention/multi_head_attention.py Outdated Show resolved Hide resolved
keras/src/layers/attention/multi_head_attention.py Outdated Show resolved Hide resolved
keras/src/layers/attention/multi_head_attention.py Outdated Show resolved Hide resolved
keras/src/layers/attention/multi_head_attention.py Outdated Show resolved Hide resolved
keras/src/layers/attention/multi_head_attention.py Outdated Show resolved Hide resolved
@fchollet fchollet marked this pull request as ready for review November 5, 2024 22:44
@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Nov 7, 2024
@fchollet fchollet merged commit 5bf4ac7 into keras-team:master Nov 7, 2024
8 of 9 checks passed
@google-ml-butler google-ml-butler bot removed awaiting review ready to pull Ready to be merged into the codebase kokoro:force-run labels Nov 7, 2024
query,
key,
value,
return_attention_scores,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyashreepathihalli shouldn't return_attention_scores have a default value here?
There might be many instances of _compute_attention call that don't pass return_attention_scores value and I think this change can break them. There is an example here. So I was wondering if it's possible to set a default value here so that other references of _compute_attention work as before.
I think this change is probably the reason that the test is failing in keras-team/keras-hub#1977

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tricky, right now the value passed to the call method is passed to _compute_attention. If we add a default value here and users don't pass the arg value from call it might override the call arg. and that could cause discrepancies.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there any other way to check and make sure the users pass the arg value (rather than lack of default value here)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a work around would be to add a self._return_attention_scores and then set it in the call method and use it in _compute_attention. wdyt?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants