Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flash attention #20448

Merged
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
37e2302
Enable flash attention
divyashreepathihalli Nov 4, 2024
057aa66
code reformat
divyashreepathihalli Nov 4, 2024
1abc948
address review comments
divyashreepathihalli Nov 5, 2024
5784361
add docstring
divyashreepathihalli Nov 5, 2024
3a47c53
update docstring
divyashreepathihalli Nov 5, 2024
760e4b2
add numerical correctness test
divyashreepathihalli Nov 5, 2024
045f153
code reformat
divyashreepathihalli Nov 5, 2024
71bf0ce
use causal mask from call method
divyashreepathihalli Nov 5, 2024
3d2875f
address review comments
divyashreepathihalli Nov 5, 2024
6f99a57
update if
divyashreepathihalli Nov 5, 2024
4065bf3
fix tests
divyashreepathihalli Nov 5, 2024
a7390a6
update tests
divyashreepathihalli Nov 5, 2024
71cbbd5
enable flash attention on TPU JAX
divyashreepathihalli Nov 5, 2024
3af4c95
update code
divyashreepathihalli Nov 5, 2024
b9e075a
minor fix
divyashreepathihalli Nov 5, 2024
5b4d23d
address review comments
divyashreepathihalli Nov 5, 2024
59b0672
fix tests
divyashreepathihalli Nov 6, 2024
f269d06
run api_gen
divyashreepathihalli Nov 6, 2024
02f3451
code reformat
divyashreepathihalli Nov 6, 2024
f622918
fix mask issue
divyashreepathihalli Nov 6, 2024
3cf22e4
disable causal mask in dpa because it is comuted in comput_attention_…
divyashreepathihalli Nov 6, 2024
5b7f81a
fix masks tests
divyashreepathihalli Nov 6, 2024
53e082a
code reformat
divyashreepathihalli Nov 6, 2024
0ee6752
disable tests of env is not supported
divyashreepathihalli Nov 6, 2024
e310f22
fix code reformat error
divyashreepathihalli Nov 6, 2024
cb62bbc
fix torch GPU tests
divyashreepathihalli Nov 6, 2024
d0e5de5
Merge branch 'master' into enable_flash_attention
divyashreepathihalli Nov 6, 2024
c28e460
fix torch gpu tests
divyashreepathihalli Nov 6, 2024
a94a9e6
make everything contigious
divyashreepathihalli Nov 6, 2024
36c1131
check if mask is not before callng contigious
divyashreepathihalli Nov 6, 2024
58ae551
disable pytorch GPU test
divyashreepathihalli Nov 7, 2024
18d6ebd
merge master
divyashreepathihalli Nov 7, 2024
a36d119
code reformat
divyashreepathihalli Nov 7, 2024
ae34d7f
set bias to None
divyashreepathihalli Nov 7, 2024
98edd93
Merge branch 'keras-team:master' into enable_flash_attention
divyashreepathihalli Nov 7, 2024
9a5200e
disable GPU test
divyashreepathihalli Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras/api/_tf_keras/keras/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from keras.src.backend.config import set_image_data_format
from keras.src.dtype_policies.dtype_policy import dtype_policy
from keras.src.dtype_policies.dtype_policy import set_dtype_policy
from keras.src.layers.attention.attention import enable_flash_attention
from keras.src.layers.attention.attention import is_flash_attention_enabled
from keras.src.saving.serialization_lib import enable_unsafe_deserialization
from keras.src.utils.backend_utils import set_backend
from keras.src.utils.io_utils import disable_interactive_logging
Expand Down
2 changes: 2 additions & 0 deletions keras/api/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from keras.src.backend.config import set_image_data_format
from keras.src.dtype_policies.dtype_policy import dtype_policy
from keras.src.dtype_policies.dtype_policy import set_dtype_policy
from keras.src.layers.attention.attention import enable_flash_attention
from keras.src.layers.attention.attention import is_flash_attention_enabled
from keras.src.saving.serialization_lib import enable_unsafe_deserialization
from keras.src.utils.backend_utils import set_backend
from keras.src.utils.io_utils import disable_interactive_logging
Expand Down
33 changes: 33 additions & 0 deletions keras/src/layers/attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state
from keras.src.layers.layer import Layer


Expand Down Expand Up @@ -282,3 +283,35 @@ def get_config(self):
"dropout": self.dropout,
}
return {**base_config, **config}


@keras_export("keras.config.enable_flash_attention")
def enable_flash_attention(value):
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
"""Enable flash attention.

Flash attention offers performance optimization for attention layers,
making it especially useful for large language models (LLMs) that
benefit from faster and more memory-efficient attention computations.

Once enabled, supported layers like `MultiHeadAttention` will
use flash attention for faster computations.
"""
global_state.set_global_attribute("flash_attention", value)


@keras_export("keras.config.is_flash_attention_enabled")
def is_flash_attention_enabled():
"""Checks whether flash attention is globally enabled in Keras.

Flash attention is a performance-optimized method for computing attention
in large models, such as transformers, allowing for faster and more
memory-efficient operations. This function checks the global Keras
configuration to determine if flash attention is enabled for compatible
layers (e.g., `MultiHeadAttention`).

Returns:
bool or None: Returns `True` if flash attention is enabled,
`False` if it is disabled, and `None` if the global
setting has not been defined.
"""
return global_state.get_global_attribute("flash_attention", default=None)
43 changes: 38 additions & 5 deletions keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras.src import regularizers
from keras.src.api_export import keras_export
from keras.src.layers.activations.softmax import Softmax
from keras.src.layers.attention import attention
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
from keras.src.layers.core.einsum_dense import EinsumDense
from keras.src.layers.layer import Layer
from keras.src.layers.regularization.dropout import Dropout
Expand Down Expand Up @@ -104,6 +105,7 @@ def __init__(
use_bias=True,
output_shape=None,
attention_axes=None,
use_flash_attention=None,
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
Expand Down Expand Up @@ -131,6 +133,12 @@ def __init__(
self._activity_regularizer = regularizers.get(activity_regularizer)
self._kernel_constraint = constraints.get(kernel_constraint)
self._bias_constraint = constraints.get(bias_constraint)
if attention.is_flash_attention_enabled is None:
if use_flash_attention is not None:
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
attention.enable_flash_attention(use_flash_attention)
else:
attention.enable_flash_attention(True)

if isinstance(attention_axes, int):
attention_axes = (attention_axes,)
elif attention_axes and not isinstance(attention_axes, (list, tuple)):
Expand Down Expand Up @@ -392,7 +400,13 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
return self._softmax(attention_scores, mask=attention_mask)

def _compute_attention(
self, query, key, value, attention_mask=None, training=None
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
self,
query,
key,
value,
return_attention_scores,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyashreepathihalli shouldn't return_attention_scores have a default value here?
There might be many instances of _compute_attention call that don't pass return_attention_scores value and I think this change can break them. There is an example here. So I was wondering if it's possible to set a default value here so that other references of _compute_attention work as before.
I think this change is probably the reason that the test is failing in keras-team/keras-hub#1977

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tricky, right now the value passed to the call method is passed to _compute_attention. If we add a default value here and users don't pass the arg value from call it might override the call arg. and that could cause discrepancies.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there any other way to check and make sure the users pass the arg value (rather than lack of default value here)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a work around would be to add a self._return_attention_scores and then set it in the call method and use it in _compute_attention. wdyt?

attention_mask=None,
training=None,
):
"""Applies Dot-product attention with query, key, value tensors.

Expand All @@ -415,9 +429,28 @@ def _compute_attention(
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.
if attention.is_flash_attention_enabled() and return_attention_scores:
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Returning attention scores is not supported when flash "
"attention is enabled. Please disable flash attention to access"
" attention scores."
)
if attention.is_flash_attention_enabled():
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
# Directly compute the attention output using flash attention
attention_output = ops.dot_product_attention(
query=query,
key=key,
value=value,
mask=attention_mask,
scale=self._inverse_sqrt_key_dim,
flash_attention=True,
)
# Return only the attention output, as scores are not separately
# available
return attention_output, None

# Default behavior without flash attention, with explicit attention
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
# scores
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
query = ops.multiply(
query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
)
Expand Down Expand Up @@ -483,7 +516,7 @@ def call(
value = self._value_dense(value)

attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask, training
query, key, value, return_attention_scores, attention_mask, training
)
attention_output = self._output_dense(attention_output)

Expand Down
51 changes: 50 additions & 1 deletion keras/src/layers/attention/multi_head_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
from keras.src import models
from keras.src import saving
from keras.src import testing
from keras.src.layers.attention import attention


class MultiHeadAttentionTest(testing.TestCase):
def test_basics(self):
attention.enable_flash_attention(True)
self.run_layer_test(
layers.MultiHeadAttention,
init_kwargs={
Expand Down Expand Up @@ -50,6 +52,7 @@ def test_basics(self):
supports_masking=True,
run_training_check=False,
)
attention.enable_flash_attention(False)

@parameterized.named_parameters(
("4d_inputs_1freebatch_mask2", (3, 4), (3, 2), (4, 2), (2,)),
Expand Down Expand Up @@ -252,7 +255,7 @@ def test_correctness(self):
bias = np.zeros((2, 2))
output_bias = np.zeros((2,))
layer.set_weights([kernel, bias] * 3 + [kernel, output_bias])

attention.enable_flash_attention(False)
# Call layer and assert output.
output, scores = layer(
query=query,
Expand Down Expand Up @@ -378,6 +381,7 @@ def test_lora(self):

@parameterized.parameters([((1, 2, 3),), ((2, 3, 5),)])
def test_symbolic_return_attention_scores(self, shape):
attention.enable_flash_attention(False)
mha = layers.MultiHeadAttention(num_heads=4, key_dim=2)
x = layers.Input(batch_shape=shape)
y = layers.Input(batch_shape=shape)
Expand Down Expand Up @@ -413,3 +417,48 @@ def test_dtype_policy_map(self):
self.assertDType(layer._query_dense._kernel, "int8")
self.assertDType(layer._key_dense._kernel, "int8")
self.assertDType(layer._value_dense._kernel, "int8")

def test_flash_attention_with_attention_scores_error(self):
# Enable flash attention globally
attention.enable_flash_attention(True)
# Setup layer with required parameters
layer = layers.MultiHeadAttention(num_heads=2, key_dim=2)

# Define sample input
query = np.random.random((2, 4, 8))
value = np.random.random((2, 4, 8))

# Check if ValueError is raised when return_attention_scores=True
with self.assertRaisesRegex(
ValueError,
"Returning attention scores is not supported when flash "
"attention is enabled. Please disable flash attention to access"
" attention scores.",
):
layer(query=query, value=value, return_attention_scores=True)

attention.enable_flash_attention(False)

def test_flash_attention_numerical_correctness(self):
# Create sample input data
# Define sample input
query = np.random.random((2, 4, 8))
value = np.random.random((2, 4, 8))

# Initialize MultiHeadAttention layer
mha_layer = layers.MultiHeadAttention(
num_heads=2,
key_dim=2,
)

# Run with flash attention enabled
attention.enable_flash_attention(True)
output_with_flash = mha_layer(query=query, value=value, training=False)

# Run with flash attention disabled
attention.enable_flash_attention(False)
output_without_flash = mha_layer(
query=query, value=value, training=False
)

self.assertAllClose(output_with_flash, output_without_flash)
Loading