diff --git a/keras_nlp/models/whisper/whisper_cached_multi_head_attention.py b/keras_nlp/models/whisper/whisper_cached_multi_head_attention.py new file mode 100644 index 0000000000..5ffb07e116 --- /dev/null +++ b/keras_nlp/models/whisper/whisper_cached_multi_head_attention.py @@ -0,0 +1,137 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Whisper Cached Multi-Head Attention layer.""" + + +import string + +import keras_nlp +from keras_nlp.backend import keras + + +def _index_to_einsum_variable(i): + """Converts an index to a einsum variable name. + + We simply map indices to lowercase characters, e.g. 0 -> 'a', 1 -> 'b'. + """ + return string.ascii_lowercase[i] + + +def _build_proj_equation(free_dims, bound_dims, output_dims): + """Builds an einsum equation for projections inside multi-head attention.""" + input_str = "" + kernel_str = "" + output_str = "" + bias_axes = "" + letter_offset = 0 + for i in range(free_dims): + char = _index_to_einsum_variable(i + letter_offset) + input_str += char + output_str += char + + letter_offset += free_dims + for i in range(bound_dims): + char = _index_to_einsum_variable(i + letter_offset) + input_str += char + kernel_str += char + + letter_offset += bound_dims + for i in range(output_dims): + char = _index_to_einsum_variable(i + letter_offset) + kernel_str += char + output_str += char + bias_axes += char + equation = f"{input_str},{kernel_str}->{output_str}" + + return equation, bias_axes, len(output_str) + + +def _get_output_shape(output_rank, known_last_dims): + return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) + + +class WhisperCachedMultiHeadAttention( + keras_nlp.layers.CachedMultiHeadAttention +): + """Whisper Cached Multi-Head Attention layer. + + Inherits from `keras_nlp.layers.CachedMultiHeadAttention`, and overrides the + `build` method so that Q, V projection layers have bias + whereas K projection layer does not. + """ + + def build( + self, + query_shape, + value_shape, + key_shape=None, + ): + key_shape = value_shape if key_shape is None else key_shape + query_rank = len(query_shape) + value_rank = len(value_shape) + key_rank = len(key_shape) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + query_rank - 1, bound_dims=1, output_dims=2 + ) + self._query_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + bias_axes=bias_axes if self._use_bias else None, + name="query", + **self._get_common_kwargs_for_sublayer(), + ) + self._query_dense.build(query_shape) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + key_rank - 1, bound_dims=1, output_dims=2 + ) + self._key_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + name="key", + **self._get_common_kwargs_for_sublayer(), + ) + self._key_dense.build(key_shape) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + value_rank - 1, bound_dims=1, output_dims=2 + ) + self._value_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._value_dim] + ), + bias_axes=bias_axes if self._use_bias else None, + name="value", + **self._get_common_kwargs_for_sublayer(), + ) + self._value_dense.build(value_shape) + + # Builds the attention computations for multi-head dot product + # attention. These computations could be wrapped into the keras + # attention layer once it supports multi-head einsum computations. + self._build_attention(output_rank) + self._output_dense = self._make_output_dense( + query_shape, + self._get_common_kwargs_for_sublayer(), + "attention_output", + ) + output_dense_input_shape = list( + self._query_dense.compute_output_shape(query_shape) + ) + output_dense_input_shape[-1] = self._value_dim + self._output_dense.build(tuple(output_dense_input_shape)) + self.built = True diff --git a/keras_nlp/models/whisper/whisper_decoder.py b/keras_nlp/models/whisper/whisper_decoder.py index 7f5d834741..3debacc45b 100644 --- a/keras_nlp/models/whisper/whisper_decoder.py +++ b/keras_nlp/models/whisper/whisper_decoder.py @@ -11,17 +11,26 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Whisper decoder block.""" + from keras_nlp.backend import keras from keras_nlp.layers.modeling.transformer_decoder import TransformerDecoder +from keras_nlp.models.whisper.whisper_cached_multi_head_attention import ( + WhisperCachedMultiHeadAttention, +) +from keras_nlp.utils.keras_utils import clone_initializer -@keras.saving.register_keras_serializable(package="keras_nlp") class WhisperDecoder(TransformerDecoder): - """A Whisper decoder. + """Whisper decoder. Inherits from `keras_nlp.layers.TransformerDecoder`, and overrides the - `build` method so as to remove the bias term from the key projection layer. + `build` method to use the + `keras_nlp.models.whisper.whisper_multi_head_attention.WhisperMultiHeadAttention` + layer instead of `keras.layers.MultiHeadAttention` and + `keras_nlp.models.whisper.whisper_cached_multi_head_attention.WhisperCachedMultiHeadAttention` + instead of `keras_nlp.layers.cached_multi_head_attention.CachedMultiHeadAttention`. """ def build( @@ -29,15 +38,116 @@ def build( decoder_sequence_shape, encoder_sequence_shape=None, ): - super().build( - decoder_sequence_shape, - encoder_sequence_shape=encoder_sequence_shape, + self._decoder_sequence_shape = decoder_sequence_shape + self._encoder_sequence_shape = encoder_sequence_shape + # Infer the dimension of our hidden feature size from the build shape. + hidden_dim = decoder_sequence_shape[-1] + # Attention head size is `hidden_dim` over the number of heads. + head_dim = int(hidden_dim // self.num_heads) + if head_dim == 0: + raise ValueError( + "Attention `head_dim` computed cannot be zero. " + f"The `hidden_dim` value of {hidden_dim} has to be equal to " + f"or greater than `num_heads` value of {self.num_heads}." + ) + + # Self attention layers. + self._self_attention_layer = WhisperCachedMultiHeadAttention( + num_heads=self.num_heads, + key_dim=head_dim, + dropout=self.dropout, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="self_attention", + ) + if hasattr(self._self_attention_layer, "_build_from_signature"): + self._self_attention_layer._build_from_signature( + query=decoder_sequence_shape, + value=decoder_sequence_shape, + ) + else: + self._self_attention_layer.build( + query_shape=decoder_sequence_shape, + value_shape=decoder_sequence_shape, + ) + self._self_attention_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="self_attention_layer_norm", ) + self._self_attention_layer_norm.build(decoder_sequence_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) + + # Cross attention layers are optional. + self._cross_attention_layer = None + if encoder_sequence_shape: + self._cross_attention_layer = WhisperCachedMultiHeadAttention( + num_heads=self.num_heads, + key_dim=head_dim, + value_dim=head_dim, + dropout=self.dropout, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="cross_attention", + ) + if hasattr(self._cross_attention_layer, "_build_from_signature"): + self._cross_attention_layer._build_from_signature( + query=decoder_sequence_shape, + value=encoder_sequence_shape, + ) + else: + self._cross_attention_layer.build( + query_shape=decoder_sequence_shape, + value_shape=encoder_sequence_shape, + ) + self._cross_attention_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="cross_attention_layer_norm", + ) + self._cross_attention_layer_norm.build(decoder_sequence_shape) + self._cross_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="cross_attention_dropout", + ) - # Since there is no exposed option for this in MHA, we will reach into - # the internals of the layer for now. - self._self_attention_layer._key_dense.bias_axes = None - self._self_attention_layer._key_dense.bias = None - if self._cross_attention_layer: - self._cross_attention_layer._key_dense.bias_axes = None - self._cross_attention_layer._key_dense.bias = None + # Feedforward layers. + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + activation=self.activation, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self._feedforward_intermediate_dense.build(decoder_sequence_shape) + self._feedforward_output_dense = keras.layers.Dense( + hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + intermediate_shape = list(decoder_sequence_shape) + intermediate_shape[-1] = self.intermediate_dim + self._feedforward_output_dense.build(tuple(intermediate_shape)) + self._feedforward_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="feedforward_layer_norm", + ) + self._feedforward_layer_norm.build(decoder_sequence_shape) + self._feedforward_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="feedforward_dropout", + ) + # Create layers based on input shape. + self.built = True diff --git a/keras_nlp/models/whisper/whisper_encoder.py b/keras_nlp/models/whisper/whisper_encoder.py index 31267cbf78..a3d6a5c5fc 100644 --- a/keras_nlp/models/whisper/whisper_encoder.py +++ b/keras_nlp/models/whisper/whisper_encoder.py @@ -11,23 +11,99 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +"""Whisper encoder block.""" + from keras_nlp.backend import keras from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder +from keras_nlp.models.whisper.whisper_multi_head_attention import ( + WhisperMultiHeadAttention, +) +from keras_nlp.utils.keras_utils import clone_initializer -@keras.saving.register_keras_serializable(package="keras_nlp") class WhisperEncoder(TransformerEncoder): - """A Whisper encoder. + """Whisper encoder. Inherits from `keras_nlp.layers.TransformerEncoder`, and overrides the - `build` method so as to remove the bias term from the key projection layer. + `_build` method to use the + `keras_nlp.models.whisper.whisper_multi_head_attention.WhisperMultiHeadAttention` + layer instead of `keras.layers.MultiHeadAttention`. """ def build(self, inputs_shape): - super().build(inputs_shape) + # Infer the dimension of our hidden feature size from the build shape. + hidden_dim = inputs_shape[-1] + # Attention head size is `hidden_dim` over the number of heads. + key_dim = int(hidden_dim // self.num_heads) + if key_dim == 0: + raise ValueError( + "Attention `key_dim` computed cannot be zero. " + f"The `hidden_dim` value of {hidden_dim} has to be equal to " + f"or greater than `num_heads` value of {self.num_heads}." + ) + + # Self attention layers. + self._self_attention_layer = WhisperMultiHeadAttention( + num_heads=self.num_heads, + key_dim=key_dim, + dropout=self.dropout, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="self_attention_layer", + ) + if hasattr(self._self_attention_layer, "_build_from_signature"): + self._self_attention_layer._build_from_signature( + query=inputs_shape, + value=inputs_shape, + ) + else: + self._self_attention_layer.build( + query_shape=inputs_shape, + value_shape=inputs_shape, + ) + self._self_attention_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="self_attention_layer_norm", + ) + self._self_attention_layer_norm.build(inputs_shape) + self._self_attention_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="self_attention_dropout", + ) - # Since there is no exposed option for this in MHA, we will reach into - # the internals of the layer for now. - self._self_attention_layer._key_dense.bias_axes = None - self._self_attention_layer._key_dense.bias = None + # Feedforward layers. + self._feedforward_layer_norm = keras.layers.LayerNormalization( + epsilon=self.layer_norm_epsilon, + dtype=self.dtype_policy, + name="feedforward_layer_norm", + ) + self._feedforward_layer_norm.build(inputs_shape) + self._feedforward_intermediate_dense = keras.layers.Dense( + self.intermediate_dim, + activation=self.activation, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="feedforward_intermediate_dense", + ) + self._feedforward_intermediate_dense.build(inputs_shape) + self._feedforward_output_dense = keras.layers.Dense( + hidden_dim, + kernel_initializer=clone_initializer(self.kernel_initializer), + bias_initializer=clone_initializer(self.bias_initializer), + dtype=self.dtype_policy, + name="feedforward_output_dense", + ) + intermediate_shape = list(inputs_shape) + intermediate_shape[-1] = self.intermediate_dim + self._feedforward_output_dense.build(tuple(intermediate_shape)) + self._feedforward_dropout = keras.layers.Dropout( + rate=self.dropout, + dtype=self.dtype_policy, + name="feedforward_dropout", + ) + self.built = True diff --git a/keras_nlp/models/whisper/whisper_multi_head_attention.py b/keras_nlp/models/whisper/whisper_multi_head_attention.py new file mode 100644 index 0000000000..350474ba7d --- /dev/null +++ b/keras_nlp/models/whisper/whisper_multi_head_attention.py @@ -0,0 +1,134 @@ +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Whisper Multi-Head Attention layer.""" + + +import string + +from keras_nlp.backend import keras + + +def _index_to_einsum_variable(i): + """Converts an index to a einsum variable name. + + We simply map indices to lowercase characters, e.g. 0 -> 'a', 1 -> 'b'. + """ + return string.ascii_lowercase[i] + + +def _build_proj_equation(free_dims, bound_dims, output_dims): + """Builds an einsum equation for projections inside multi-head attention.""" + input_str = "" + kernel_str = "" + output_str = "" + bias_axes = "" + letter_offset = 0 + for i in range(free_dims): + char = _index_to_einsum_variable(i + letter_offset) + input_str += char + output_str += char + + letter_offset += free_dims + for i in range(bound_dims): + char = _index_to_einsum_variable(i + letter_offset) + input_str += char + kernel_str += char + + letter_offset += bound_dims + for i in range(output_dims): + char = _index_to_einsum_variable(i + letter_offset) + kernel_str += char + output_str += char + bias_axes += char + equation = f"{input_str},{kernel_str}->{output_str}" + + return equation, bias_axes, len(output_str) + + +def _get_output_shape(output_rank, known_last_dims): + return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims) + + +class WhisperMultiHeadAttention(keras.layers.MultiHeadAttention): + """Whisper Multi-Head Attention layer. + + Inherits from `keras.layers.MultiHeadAttention`, and overrides the + `build` method so that Q, V projection layers have bias + whereas K projection layer does not. + """ + + def build( + self, + query_shape, + value_shape, + key_shape=None, + ): + key_shape = value_shape if key_shape is None else key_shape + query_rank = len(query_shape) + value_rank = len(value_shape) + key_rank = len(key_shape) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + query_rank - 1, bound_dims=1, output_dims=2 + ) + self._query_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + bias_axes=bias_axes if self._use_bias else None, + name="query", + **self._get_common_kwargs_for_sublayer(), + ) + self._query_dense.build(query_shape) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + key_rank - 1, bound_dims=1, output_dims=2 + ) + self._key_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._key_dim] + ), + name="key", + **self._get_common_kwargs_for_sublayer(), + ) + self._key_dense.build(key_shape) + einsum_equation, bias_axes, output_rank = _build_proj_equation( + value_rank - 1, bound_dims=1, output_dims=2 + ) + self._value_dense = keras.layers.EinsumDense( + einsum_equation, + output_shape=_get_output_shape( + output_rank - 1, [self._num_heads, self._value_dim] + ), + bias_axes=bias_axes if self._use_bias else None, + name="value", + **self._get_common_kwargs_for_sublayer(), + ) + self._value_dense.build(value_shape) + + # Builds the attention computations for multi-head dot product + # attention. These computations could be wrapped into the keras + # attention layer once it supports multi-head einsum computations. + self._build_attention(output_rank) + self._output_dense = self._make_output_dense( + query_shape, + self._get_common_kwargs_for_sublayer(), + "attention_output", + ) + output_dense_input_shape = list( + self._query_dense.compute_output_shape(query_shape) + ) + output_dense_input_shape[-1] = self._value_dim + self._output_dense.build(tuple(output_dense_input_shape)) + self.built = True