diff --git a/keras/api/_tf_keras/keras/config/__init__.py b/keras/api/_tf_keras/keras/config/__init__.py index 13e334cb7c0..4ff4387aec9 100644 --- a/keras/api/_tf_keras/keras/config/__init__.py +++ b/keras/api/_tf_keras/keras/config/__init__.py @@ -13,6 +13,9 @@ from keras.src.backend.config import set_image_data_format from keras.src.dtype_policies.dtype_policy import dtype_policy from keras.src.dtype_policies.dtype_policy import set_dtype_policy +from keras.src.layers.attention.attention import disable_flash_attention +from keras.src.layers.attention.attention import enable_flash_attention +from keras.src.layers.attention.attention import is_flash_attention_enabled from keras.src.saving.serialization_lib import enable_unsafe_deserialization from keras.src.utils.backend_utils import set_backend from keras.src.utils.io_utils import disable_interactive_logging diff --git a/keras/api/config/__init__.py b/keras/api/config/__init__.py index 13e334cb7c0..4ff4387aec9 100644 --- a/keras/api/config/__init__.py +++ b/keras/api/config/__init__.py @@ -13,6 +13,9 @@ from keras.src.backend.config import set_image_data_format from keras.src.dtype_policies.dtype_policy import dtype_policy from keras.src.dtype_policies.dtype_policy import set_dtype_policy +from keras.src.layers.attention.attention import disable_flash_attention +from keras.src.layers.attention.attention import enable_flash_attention +from keras.src.layers.attention.attention import is_flash_attention_enabled from keras.src.saving.serialization_lib import enable_unsafe_deserialization from keras.src.utils.backend_utils import set_backend from keras.src.utils.io_utils import disable_interactive_logging diff --git a/keras/src/backend/jax/nn.py b/keras/src/backend/jax/nn.py index e4c3dfeb028..8003f801ce3 100644 --- a/keras/src/backend/jax/nn.py +++ b/keras/src/backend/jax/nn.py @@ -6,6 +6,9 @@ import jax.numpy as jnp from jax import lax from jax import nn as jnn +from jax.experimental.pallas.ops.tpu import ( + flash_attention as flash_attention_tpu, +) from keras.src import backend from keras.src.backend.common.backend_utils import ( @@ -1019,7 +1022,18 @@ def dot_product_attention( f"Received: query.shape={query.shape}, key.shape={key.shape}, " f"value.shape={value.shape}." ) - + is_tpu = jax.devices()[0].platform == "tpu" + if is_tpu and flash_attention: + # Use TPU-optimized flash attention from Pallas + return flash_attention_tpu( + query, + key, + value, + ab=bias, + segment_ids=mask, + causal=is_causal, + sm_scale=scale, + ) # `dot_product_attention` is only available in jax>=0.4.31 if hasattr(jax.nn, "dot_product_attention"): implementation = "cudnn" if flash_attention else "xla" @@ -1040,7 +1054,6 @@ def dot_product_attention( "current JAX version. Please update it " "using `pip install -U jax jaxlib`." ) - # Ref: jax.nn.dot_product_attention # https://github.com/jax-ml/jax/blob/jax-v0.4.33/jax/_src/nn/functions.py#L886 # Not support `query_seq_lengths` and `key_value_seq_lengths` args diff --git a/keras/src/backend/torch/nn.py b/keras/src/backend/torch/nn.py index 841b0b40923..1d7caed0bfa 100644 --- a/keras/src/backend/torch/nn.py +++ b/keras/src/backend/torch/nn.py @@ -954,7 +954,14 @@ def dot_product_attention( scale=scale, ) else: + if mask is not None: + mask = mask.contiguous() attention_output = torch.nn.functional.scaled_dot_product_attention( - query, key, value, attn_mask=mask, is_causal=is_causal, scale=scale + query.contiguous(), + key.contiguous(), + value.contiguous(), + attn_mask=mask, + is_causal=is_causal, + scale=scale, ) return torch.transpose(attention_output, axis1, axis0) diff --git a/keras/src/layers/attention/attention.py b/keras/src/layers/attention/attention.py index 592468fe802..f019c0f6af8 100644 --- a/keras/src/layers/attention/attention.py +++ b/keras/src/layers/attention/attention.py @@ -1,6 +1,7 @@ from keras.src import backend from keras.src import ops from keras.src.api_export import keras_export +from keras.src.backend.common import global_state from keras.src.layers.layer import Layer @@ -282,3 +283,49 @@ def get_config(self): "dropout": self.dropout, } return {**base_config, **config} + + +@keras_export("keras.config.enable_flash_attention") +def enable_flash_attention(): + """Enable flash attention. + + Flash attention offers performance optimization for attention layers, + making it especially useful for large language models (LLMs) that + benefit from faster and more memory-efficient attention computations. + + Once enabled, supported layers like `MultiHeadAttention` will + use flash attention for faster computations. + """ + global_state.set_global_attribute("flash_attention", True) + + +@keras_export("keras.config.disable_flash_attention") +def disable_flash_attention(): + """Disable flash attention. + + Flash attention offers performance optimization for attention layers, + making it especially useful for large language models (LLMs) that + benefit from faster and more memory-efficient attention computations. + + Once disabled, supported layers like `MultiHeadAttention` will not + use flash attention for faster computations. + """ + global_state.set_global_attribute("flash_attention", False) + + +@keras_export("keras.config.is_flash_attention_enabled") +def is_flash_attention_enabled(): + """Checks whether flash attention is globally enabled in Keras. + + Flash attention is a performance-optimized method for computing attention + in large models, such as transformers, allowing for faster and more + memory-efficient operations. This function checks the global Keras + configuration to determine if flash attention is enabled for compatible + layers (e.g., `MultiHeadAttention`). + + Returns: + bool or None: Returns `True` if flash attention is enabled, + `False` if it is disabled, and `None` if the global + setting has not been defined. + """ + return global_state.get_global_attribute("flash_attention", default=None) diff --git a/keras/src/layers/attention/multi_head_attention.py b/keras/src/layers/attention/multi_head_attention.py index cddfae38f78..195a4a353cf 100644 --- a/keras/src/layers/attention/multi_head_attention.py +++ b/keras/src/layers/attention/multi_head_attention.py @@ -11,6 +11,7 @@ from keras.src import regularizers from keras.src.api_export import keras_export from keras.src.layers.activations.softmax import Softmax +from keras.src.layers.attention.attention import is_flash_attention_enabled from keras.src.layers.core.einsum_dense import EinsumDense from keras.src.layers.layer import Layer from keras.src.layers.regularization.dropout import Dropout @@ -52,6 +53,9 @@ class MultiHeadAttention(Layer): feature dim (the query input's last dimension). attention_axes: axes over which the attention is applied. `None` means attention over all axes, but batch, heads, and features. + flash_attention: If unspecified, defaults to the global flash attention + configuration setting (which can be set via + `keras.config.enable_flash_attention(). kernel_initializer: Initializer for dense layer kernels. bias_initializer: Initializer for dense layer biases. kernel_regularizer: Regularizer for dense layer kernels. @@ -104,6 +108,7 @@ def __init__( use_bias=True, output_shape=None, attention_axes=None, + flash_attention=None, kernel_initializer="glorot_uniform", bias_initializer="zeros", kernel_regularizer=None, @@ -131,6 +136,8 @@ def __init__( self._activity_regularizer = regularizers.get(activity_regularizer) self._kernel_constraint = constraints.get(kernel_constraint) self._bias_constraint = constraints.get(bias_constraint) + self._flash_attention = flash_attention or is_flash_attention_enabled() + if isinstance(attention_axes, int): attention_axes = (attention_axes,) elif attention_axes and not isinstance(attention_axes, (list, tuple)): @@ -392,7 +399,13 @@ def _masked_softmax(self, attention_scores, attention_mask=None): return self._softmax(attention_scores, mask=attention_mask) def _compute_attention( - self, query, key, value, attention_mask=None, training=None + self, + query, + key, + value, + return_attention_scores, + attention_mask=None, + training=None, ): """Applies Dot-product attention with query, key, value tensors. @@ -415,9 +428,57 @@ def _compute_attention( attention_output: Multi-headed outputs of attention computation. attention_scores: Multi-headed attention weights. """ - # Note: Applying scalar multiply at the smaller end of einsum improves - # XLA performance, but may introduce slight numeric differences in - # the Transformer attention head. + + # Check for flash attention constraints + if self._flash_attention and return_attention_scores: + raise ValueError( + "Returning attention scores is not supported when flash " + "attention is enabled. Please disable flash attention to access" + " attention scores." + ) + if self._flash_attention and self._dropout > 0.0: + raise ValueError( + "Dropout is not supported when flash " + "attention is enabled. Please set dropout to 0.0 to use " + "flash attention." + ) + + # Determine whether to use dot-product attention + use_dot_product_attention = not ( + self._dropout > 0.0 + or return_attention_scores + or (len(query.shape) != 4) + ) + + if use_dot_product_attention: + if attention_mask is not None: + # Ensure attention_mask has the correct shape for broadcasting + # Expected shape: [batch_size, num_heads, query_seq_len, + # key_seq_len]. This is because masked_softmax is not supported + # in JAX. + while len(attention_mask.shape) < 4: + attention_mask = ops.expand_dims( + attention_mask, axis=1 + ) # Add dimension for num_heads + if attention_mask.shape[1] != self._num_heads: + attention_mask = ops.tile( + attention_mask, [1, self._num_heads, 1, 1] + ) + # Directly compute the attention output using dot-product attention + attention_output = ops.dot_product_attention( + query=query, + key=key, + value=value, + bias=None, + mask=attention_mask, + scale=self._inverse_sqrt_key_dim, + is_causal=False, + flash_attention=self._flash_attention, + ) + return attention_output, None + + # Default behavior without flash attention, with explicit attention + # scores query = ops.multiply( query, ops.cast(self._inverse_sqrt_key_dim, query.dtype) ) @@ -426,13 +487,13 @@ def _compute_attention( # attention scores. attention_scores = ops.einsum(self._dot_product_equation, key, query) + # Apply the mask using the custom masked softmax attention_scores = self._masked_softmax( attention_scores, attention_mask ) - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - if self.dropout: + # Apply dropout to the attention scores if needed + if self._dropout > 0.0: final_attn_scores = self._dropout_layer( attention_scores, training=training ) @@ -460,7 +521,6 @@ def call( ): if key is None: key = value - attention_mask = self._compute_attention_mask( query, value, @@ -470,9 +530,9 @@ def call( attention_mask=attention_mask, use_causal_mask=use_causal_mask, ) - # N = `num_attention_heads` # H = `size_per_head` + # `query` = [B, T, N ,H] query = self._query_dense.call(query) @@ -481,9 +541,13 @@ def call( # `value` = [B, S, N, H] value = self._value_dense.call(value) - attention_output, attention_scores = self._compute_attention( - query, key, value, attention_mask, training + query, + key, + value, + return_attention_scores, + attention_mask, + training, ) attention_output = self._output_dense.call(attention_output) diff --git a/keras/src/layers/attention/multi_head_attention_test.py b/keras/src/layers/attention/multi_head_attention_test.py index 3f8fe667c5d..d42d484dc47 100644 --- a/keras/src/layers/attention/multi_head_attention_test.py +++ b/keras/src/layers/attention/multi_head_attention_test.py @@ -12,6 +12,8 @@ from keras.src import models from keras.src import saving from keras.src import testing +from keras.src.layers.attention.attention import disable_flash_attention +from keras.src.layers.attention.attention import enable_flash_attention class MultiHeadAttentionTest(testing.TestCase): @@ -51,6 +53,79 @@ def test_basics(self): run_training_check=False, ) + def test_basics_with_flash_attention(self): + if backend.backend() in [ + "torch", + "tensorflow", + "numpy", + ]: + self.skipTest( + "Not supported in TF and NumPy and supported for " + "PyTorch with specific requirements." + ) + + if backend.backend() == "jax": + try: + enable_flash_attention() + self.run_layer_test( + layers.MultiHeadAttention, + init_kwargs={ + "num_heads": 2, + "key_dim": 2, + }, + input_shape={ + "query_shape": (2, 8, 16), + "value_shape": (2, 4, 16), + }, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=8, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=0, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + + self.run_layer_test( + layers.MultiHeadAttention, + init_kwargs={ + "num_heads": 2, + "key_dim": 2, + "value_dim": 4, + "use_bias": False, + "dropout": 0.5, + }, + input_shape={ + "query_shape": (2, 8, 16), + "value_shape": (2, 4, 16), + }, + expected_output_shape=(2, 8, 16), + expected_num_trainable_weights=4, + expected_num_non_trainable_weights=0, + expected_num_seed_generators=1, + expected_num_losses=0, + supports_masking=True, + run_training_check=False, + ) + disable_flash_attention() + except ValueError as e: + if e.args[0].startswith( + "Flash attention is not supported in your " + "current JAX version" + ): + self.skipTest( + "JAX version does not have " + "`dot_product_attention` function." + ) + except RuntimeError as e: + if e.args[0] == "cuDNN is not detected.": + self.skipTest("No CuDNN to run flash attention for JAX.") + elif e.args[0] == "Require at least Ampere arch to run": + self.skipTest( + "Requires at least Ampere arch to run flash attention " + "for JAX." + ) + @parameterized.named_parameters( ("4d_inputs_1freebatch_mask2", (3, 4), (3, 2), (4, 2), (2,)), ("4d_inputs_1freebatch_mask3", (3, 4), (3, 2), (3, 4, 2), (2,)), @@ -189,12 +264,23 @@ def test_initializer(self): ) def test_query_mask_propagation(self): """Test automatic propagation of the query's mask.""" - layer = layers.MultiHeadAttention(num_heads=2, key_dim=2) - self.assertTrue(layer.supports_masking) - query = np.array([[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]]) - masked_query = layers.Embedding(4, 8, mask_zero=True)(query) - value = np.random.normal(size=(3, 3, 8)) - output = layer(query=masked_query, value=value) + try: + layer = layers.MultiHeadAttention(num_heads=2, key_dim=2) + self.assertTrue(layer.supports_masking) + query = np.array( + [[1, 2, 3, 0, 0], [3, 3, 1, 1, 2], [1, 0, 0, 0, 0]] + ) + masked_query = layers.Embedding(4, 8, mask_zero=True)(query) + value = np.random.normal(size=(3, 3, 8)) + output = layer(query=masked_query, value=value) + except RuntimeError as e: + if e.args[0].startswith( + "(*bias): last dimension must be contiguous" + ): + self.skipTest( + "PyTorch errors out on GPU: issue to track bug is here " + "https://github.com/keras-team/keras/issues/20459" + ) self.assertAllClose(masked_query._keras_mask, output._keras_mask) @parameterized.named_parameters(("causal", True), ("not_causal", 0)) @@ -252,7 +338,6 @@ def test_correctness(self): bias = np.zeros((2, 2)) output_bias = np.zeros((2,)) layer.set_weights([kernel, bias] * 3 + [kernel, output_bias]) - # Call layer and assert output. output, scores = layer( query=query, @@ -413,3 +498,55 @@ def test_dtype_policy_map(self): self.assertDType(layer._query_dense._kernel, "int8") self.assertDType(layer._key_dense._kernel, "int8") self.assertDType(layer._value_dense._kernel, "int8") + + def test_flash_attention_with_attention_scores_error(self): + # Enable flash attention globally + if backend.backend() == "numpy" or "tensorflow": + pytest.skip( + reason="Flash attention is not supported on Tensorflow" + "and numpy." + ) + # Setup layer with required parameters + layer = layers.MultiHeadAttention(num_heads=2, key_dim=2) + + # Define sample input + query = np.random.random((2, 4, 8)) + value = np.random.random((2, 4, 8)) + + # Check if ValueError is raised when return_attention_scores=True + with self.assertRaisesRegex( + ValueError, + "Returning attention scores is not supported when flash " + "attention is enabled. Please disable flash attention to access" + " attention scores.", + ): + layer(query=query, value=value, return_attention_scores=True) + + def test_flash_attention_numerical_correctness(self): + if backend.backend() == "numpy" or backend.backend() == "tensorflow": + pytest.skip( + reason="Flash attention is not supported on Tensorflow " + "and numpy." + ) + # Create sample input data + # Define sample input + query = np.random.random((2, 4, 8)) + value = np.random.random((2, 4, 8)) + + # Initialize MultiHeadAttention layer + mha_layer = layers.MultiHeadAttention( + num_heads=2, + key_dim=2, + ) + + # Run with flash attention enabled + enable_flash_attention() + output_with_flash = mha_layer(query=query, value=value, training=False) + + disable_flash_attention() + # Run with flash attention disabled + output_without_flash = mha_layer( + query=query, value=value, training=False + ) + + self.assertAllClose(output_with_flash, output_without_flash)