Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable flash attention #20448

Merged
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
37e2302
Enable flash attention
divyashreepathihalli Nov 4, 2024
057aa66
code reformat
divyashreepathihalli Nov 4, 2024
1abc948
address review comments
divyashreepathihalli Nov 5, 2024
5784361
add docstring
divyashreepathihalli Nov 5, 2024
3a47c53
update docstring
divyashreepathihalli Nov 5, 2024
760e4b2
add numerical correctness test
divyashreepathihalli Nov 5, 2024
045f153
code reformat
divyashreepathihalli Nov 5, 2024
71bf0ce
use causal mask from call method
divyashreepathihalli Nov 5, 2024
3d2875f
address review comments
divyashreepathihalli Nov 5, 2024
6f99a57
update if
divyashreepathihalli Nov 5, 2024
4065bf3
fix tests
divyashreepathihalli Nov 5, 2024
a7390a6
update tests
divyashreepathihalli Nov 5, 2024
71cbbd5
enable flash attention on TPU JAX
divyashreepathihalli Nov 5, 2024
3af4c95
update code
divyashreepathihalli Nov 5, 2024
b9e075a
minor fix
divyashreepathihalli Nov 5, 2024
5b4d23d
address review comments
divyashreepathihalli Nov 5, 2024
59b0672
fix tests
divyashreepathihalli Nov 6, 2024
f269d06
run api_gen
divyashreepathihalli Nov 6, 2024
02f3451
code reformat
divyashreepathihalli Nov 6, 2024
f622918
fix mask issue
divyashreepathihalli Nov 6, 2024
3cf22e4
disable causal mask in dpa because it is comuted in comput_attention_…
divyashreepathihalli Nov 6, 2024
5b7f81a
fix masks tests
divyashreepathihalli Nov 6, 2024
53e082a
code reformat
divyashreepathihalli Nov 6, 2024
0ee6752
disable tests of env is not supported
divyashreepathihalli Nov 6, 2024
e310f22
fix code reformat error
divyashreepathihalli Nov 6, 2024
cb62bbc
fix torch GPU tests
divyashreepathihalli Nov 6, 2024
d0e5de5
Merge branch 'master' into enable_flash_attention
divyashreepathihalli Nov 6, 2024
c28e460
fix torch gpu tests
divyashreepathihalli Nov 6, 2024
a94a9e6
make everything contigious
divyashreepathihalli Nov 6, 2024
36c1131
check if mask is not before callng contigious
divyashreepathihalli Nov 6, 2024
58ae551
disable pytorch GPU test
divyashreepathihalli Nov 7, 2024
18d6ebd
merge master
divyashreepathihalli Nov 7, 2024
a36d119
code reformat
divyashreepathihalli Nov 7, 2024
ae34d7f
set bias to None
divyashreepathihalli Nov 7, 2024
98edd93
Merge branch 'keras-team:master' into enable_flash_attention
divyashreepathihalli Nov 7, 2024
9a5200e
disable GPU test
divyashreepathihalli Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from keras.src.backend.config import set_image_data_format
from keras.src.dtype_policies.dtype_policy import dtype_policy
from keras.src.dtype_policies.dtype_policy import set_dtype_policy
from keras.src.layers.attention.attention import disable_flash_attention
from keras.src.layers.attention.attention import enable_flash_attention
from keras.src.layers.attention.attention import is_flash_attention_enabled
from keras.src.saving.serialization_lib import enable_unsafe_deserialization
from keras.src.utils.backend_utils import set_backend
from keras.src.utils.io_utils import disable_interactive_logging
Expand Down
3 changes: 3 additions & 0 deletions keras/api/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from keras.src.backend.config import set_image_data_format
from keras.src.dtype_policies.dtype_policy import dtype_policy
from keras.src.dtype_policies.dtype_policy import set_dtype_policy
from keras.src.layers.attention.attention import disable_flash_attention
from keras.src.layers.attention.attention import enable_flash_attention
from keras.src.layers.attention.attention import is_flash_attention_enabled
from keras.src.saving.serialization_lib import enable_unsafe_deserialization
from keras.src.utils.backend_utils import set_backend
from keras.src.utils.io_utils import disable_interactive_logging
Expand Down
17 changes: 15 additions & 2 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
import jax.numpy as jnp
from jax import lax
from jax import nn as jnn
from jax.experimental.pallas.ops.tpu import (
flash_attention as flash_attention_tpu,
)

from keras.src import backend
from keras.src.backend.common.backend_utils import (
Expand Down Expand Up @@ -1019,7 +1022,18 @@ def dot_product_attention(
f"Received: query.shape={query.shape}, key.shape={key.shape}, "
f"value.shape={value.shape}."
)

is_tpu = jax.devices()[0].platform == "tpu"
if is_tpu and flash_attention:
# Use TPU-optimized flash attention from Pallas
return flash_attention_tpu(
query,
key,
value,
ab=bias,
segment_ids=mask,
causal=is_causal,
sm_scale=scale,
)
# `dot_product_attention` is only available in jax>=0.4.31
if hasattr(jax.nn, "dot_product_attention"):
implementation = "cudnn" if flash_attention else "xla"
Expand All @@ -1040,7 +1054,6 @@ def dot_product_attention(
"current JAX version. Please update it "
"using `pip install -U jax jaxlib`."
)

# Ref: jax.nn.dot_product_attention
# https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886
# Not support `query_seq_lengths` and `key_value_seq_lengths` args
Expand Down
9 changes: 8 additions & 1 deletion keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,7 +954,14 @@ def dot_product_attention(
scale=scale,
)
else:
if mask is not None:
mask = mask.contiguous()
attention_output = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale
query.contiguous(),
key.contiguous(),
value.contiguous(),
attn_mask=mask,
is_causal=is_causal,
scale=scale,
)
return torch.transpose(attention_output, axis1, axis0)
47 changes: 47 additions & 0 deletions keras/src/layers/attention/attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from keras.src import backend
from keras.src import ops
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state
from keras.src.layers.layer import Layer


Expand Down Expand Up @@ -282,3 +283,49 @@ def get_config(self):
"dropout": self.dropout,
}
return {**base_config, **config}


@keras_export("keras.config.enable_flash_attention")
def enable_flash_attention():
"""Enable flash attention.

Flash attention offers performance optimization for attention layers,
making it especially useful for large language models (LLMs) that
benefit from faster and more memory-efficient attention computations.

Once enabled, supported layers like `MultiHeadAttention` will
use flash attention for faster computations.
"""
global_state.set_global_attribute("flash_attention", True)


@keras_export("keras.config.disable_flash_attention")
def disable_flash_attention():
"""Disable flash attention.

Flash attention offers performance optimization for attention layers,
making it especially useful for large language models (LLMs) that
benefit from faster and more memory-efficient attention computations.

Once disabled, supported layers like `MultiHeadAttention` will not
use flash attention for faster computations.
"""
global_state.set_global_attribute("flash_attention", False)


@keras_export("keras.config.is_flash_attention_enabled")
def is_flash_attention_enabled():
"""Checks whether flash attention is globally enabled in Keras.

Flash attention is a performance-optimized method for computing attention
in large models, such as transformers, allowing for faster and more
memory-efficient operations. This function checks the global Keras
configuration to determine if flash attention is enabled for compatible
layers (e.g., `MultiHeadAttention`).

Returns:
bool or None: Returns `True` if flash attention is enabled,
`False` if it is disabled, and `None` if the global
setting has not been defined.
"""
return global_state.get_global_attribute("flash_attention", default=None)
85 changes: 74 additions & 11 deletions keras/src/layers/attention/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from keras.src import regularizers
from keras.src.api_export import keras_export
from keras.src.layers.activations.softmax import Softmax
from keras.src.layers.attention.attention import is_flash_attention_enabled
from keras.src.layers.core.einsum_dense import EinsumDense
from keras.src.layers.layer import Layer
from keras.src.layers.regularization.dropout import Dropout
Expand Down Expand Up @@ -52,6 +53,9 @@ class MultiHeadAttention(Layer):
feature dim (the query input's last dimension).
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
flash_attention: If unspecified, defaults to the global flash attention
configuration setting (which can be set via
`keras.config.enable_flash_attention().
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
Expand Down Expand Up @@ -104,6 +108,7 @@ def __init__(
use_bias=True,
output_shape=None,
attention_axes=None,
flash_attention=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
Expand Down Expand Up @@ -131,6 +136,8 @@ def __init__(
self._activity_regularizer = regularizers.get(activity_regularizer)
self._kernel_constraint = constraints.get(kernel_constraint)
self._bias_constraint = constraints.get(bias_constraint)
self._flash_attention = flash_attention or is_flash_attention_enabled()

if isinstance(attention_axes, int):
attention_axes = (attention_axes,)
elif attention_axes and not isinstance(attention_axes, (list, tuple)):
Expand Down Expand Up @@ -392,7 +399,13 @@ def _masked_softmax(self, attention_scores, attention_mask=None):
return self._softmax(attention_scores, mask=attention_mask)

def _compute_attention(
self, query, key, value, attention_mask=None, training=None
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
self,
query,
key,
value,
return_attention_scores,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@divyashreepathihalli shouldn't return_attention_scores have a default value here?
There might be many instances of _compute_attention call that don't pass return_attention_scores value and I think this change can break them. There is an example here. So I was wondering if it's possible to set a default value here so that other references of _compute_attention work as before.
I think this change is probably the reason that the test is failing in keras-team/keras-hub#1977

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is tricky, right now the value passed to the call method is passed to _compute_attention. If we add a default value here and users don't pass the arg value from call it might override the call arg. and that could cause discrepancies.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't there any other way to check and make sure the users pass the arg value (rather than lack of default value here)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a work around would be to add a self._return_attention_scores and then set it in the call method and use it in _compute_attention. wdyt?

attention_mask=None,
training=None,
):
"""Applies Dot-product attention with query, key, value tensors.

Expand All @@ -415,9 +428,56 @@ def _compute_attention(
attention_output: Multi-headed outputs of attention computation.
attention_scores: Multi-headed attention weights.
"""
# Note: Applying scalar multiply at the smaller end of einsum improves
# XLA performance, but may introduce slight numeric differences in
# the Transformer attention head.

# Check for flash attention constraints
if self._flash_attention and return_attention_scores:
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"Returning attention scores is not supported when flash "
"attention is enabled. Please disable flash attention to access"
" attention scores."
)
if self._flash_attention and self._dropout > 0.0:
raise ValueError(
"Dropout is not supported when flash "
"attention is enabled. Please set dropout to 0.0 to use "
"flash attention."
)

# Determine whether to use dot-product attention
use_dot_product_attention = not (
self._dropout > 0.0
or return_attention_scores
or (len(query.shape) != 4)
)

if use_dot_product_attention:
if attention_mask is not None:
# Ensure attention_mask has the correct shape for broadcasting
# Expected shape: [batch_size, num_heads, query_seq_len,
# key_seq_len]. This is because masked_softmax is not supported
# in JAX.
while len(attention_mask.shape) < 4:
attention_mask = ops.expand_dims(
attention_mask, axis=1
) # Add dimension for num_heads
if attention_mask.shape[1] != self._num_heads:
attention_mask = ops.tile(
attention_mask, [1, self._num_heads, 1, 1]
)
# Directly compute the attention output using dot-product attention
attention_output = ops.dot_product_attention(
query=query,
key=key,
value=value,
mask=attention_mask,
scale=self._inverse_sqrt_key_dim,
is_causal=False,
flash_attention=self._flash_attention,
)
return attention_output, None

# Default behavior without flash attention, with explicit attention
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
# scores
divyashreepathihalli marked this conversation as resolved.
Show resolved Hide resolved
query = ops.multiply(
query, ops.cast(self._inverse_sqrt_key_dim, query.dtype)
)
Expand All @@ -426,13 +486,13 @@ def _compute_attention(
# attention scores.
attention_scores = ops.einsum(self._dot_product_equation, key, query)

# Apply the mask using the custom masked softmax
attention_scores = self._masked_softmax(
attention_scores, attention_mask
)

# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if self.dropout:
# Apply dropout to the attention scores if needed
if self._dropout > 0.0:
final_attn_scores = self._dropout_layer(
attention_scores, training=training
)
Expand Down Expand Up @@ -460,7 +520,6 @@ def call(
):
if key is None:
key = value

attention_mask = self._compute_attention_mask(
query,
value,
Expand All @@ -470,9 +529,9 @@ def call(
attention_mask=attention_mask,
use_causal_mask=use_causal_mask,
)

# N = `num_attention_heads`
# H = `size_per_head`

# `query` = [B, T, N ,H]
query = self._query_dense.call(query)

Expand All @@ -481,9 +540,13 @@ def call(

# `value` = [B, S, N, H]
value = self._value_dense.call(value)

attention_output, attention_scores = self._compute_attention(
query, key, value, attention_mask, training
query,
key,
value,
return_attention_scores,
attention_mask,
training,
)
attention_output = self._output_dense.call(attention_output)

Expand Down
Loading