Skip to content

Commit

Permalink
Fix Whisper
Browse files Browse the repository at this point in the history
  • Loading branch information
abheesht17 committed Oct 27, 2023
1 parent d254b02 commit 04d931d
Show file tree
Hide file tree
Showing 4 changed files with 478 additions and 21 deletions.
137 changes: 137 additions & 0 deletions keras_nlp/models/whisper/whisper_cached_multi_head_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2023 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Whisper Cached Multi-Head Attention layer."""


import string

import keras_nlp
from keras_nlp.backend import keras


def _index_to_einsum_variable(i):
"""Converts an index to a einsum variable name.
We simply map indices to lowercase characters, e.g. 0 -> 'a', 1 -> 'b'.
"""
return string.ascii_lowercase[i]


def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _index_to_einsum_variable(i + letter_offset)
input_str += char
output_str += char

letter_offset += free_dims
for i in range(bound_dims):
char = _index_to_einsum_variable(i + letter_offset)
input_str += char
kernel_str += char

letter_offset += bound_dims
for i in range(output_dims):
char = _index_to_einsum_variable(i + letter_offset)
kernel_str += char
output_str += char
bias_axes += char
equation = f"{input_str},{kernel_str}->{output_str}"

return equation, bias_axes, len(output_str)


def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)


class WhisperCachedMultiHeadAttention(
keras_nlp.layers.CachedMultiHeadAttention
):
"""Whisper Cached Multi-Head Attention layer.
Inherits from `keras_nlp.layers.CachedMultiHeadAttention`, and overrides the
`build` method so that Q, V projection layers have bias
whereas K projection layer does not.
"""

def build(
self,
query_shape,
value_shape,
key_shape=None,
):
key_shape = value_shape if key_shape is None else key_shape
query_rank = len(query_shape)
value_rank = len(value_shape)
key_rank = len(key_shape)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
query_rank - 1, bound_dims=1, output_dims=2
)
self._query_dense = keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._num_heads, self._key_dim]
),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**self._get_common_kwargs_for_sublayer(),
)
self._query_dense.build(query_shape)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_rank - 1, bound_dims=1, output_dims=2
)
self._key_dense = keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._num_heads, self._key_dim]
),
name="key",
**self._get_common_kwargs_for_sublayer(),
)
self._key_dense.build(key_shape)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_rank - 1, bound_dims=1, output_dims=2
)
self._value_dense = keras.layers.EinsumDense(
einsum_equation,
output_shape=_get_output_shape(
output_rank - 1, [self._num_heads, self._value_dim]
),
bias_axes=bias_axes if self._use_bias else None,
name="value",
**self._get_common_kwargs_for_sublayer(),
)
self._value_dense.build(value_shape)

# Builds the attention computations for multi-head dot product
# attention. These computations could be wrapped into the keras
# attention layer once it supports multi-head einsum computations.
self._build_attention(output_rank)
self._output_dense = self._make_output_dense(
query_shape,
self._get_common_kwargs_for_sublayer(),
"attention_output",
)
output_dense_input_shape = list(
self._query_dense.compute_output_shape(query_shape)
)
output_dense_input_shape[-1] = self._value_dim
self._output_dense.build(tuple(output_dense_input_shape))
self.built = True
136 changes: 123 additions & 13 deletions keras_nlp/models/whisper/whisper_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,33 +11,143 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Whisper decoder block."""


from keras_nlp.backend import keras
from keras_nlp.layers.modeling.transformer_decoder import TransformerDecoder
from keras_nlp.models.whisper.whisper_cached_multi_head_attention import (
WhisperCachedMultiHeadAttention,
)
from keras_nlp.utils.keras_utils import clone_initializer


@keras.saving.register_keras_serializable(package="keras_nlp")
class WhisperDecoder(TransformerDecoder):
"""A Whisper decoder.
"""Whisper decoder.
Inherits from `keras_nlp.layers.TransformerDecoder`, and overrides the
`build` method so as to remove the bias term from the key projection layer.
`build` method to use the
`keras_nlp.models.whisper.whisper_multi_head_attention.WhisperMultiHeadAttention`
layer instead of `keras.layers.MultiHeadAttention` and
`keras_nlp.models.whisper.whisper_cached_multi_head_attention.WhisperCachedMultiHeadAttention`
instead of `keras_nlp.layers.cached_multi_head_attention.CachedMultiHeadAttention`.
"""

def build(
self,
decoder_sequence_shape,
encoder_sequence_shape=None,
):
super().build(
decoder_sequence_shape,
encoder_sequence_shape=encoder_sequence_shape,
self._decoder_sequence_shape = decoder_sequence_shape
self._encoder_sequence_shape = encoder_sequence_shape
# Infer the dimension of our hidden feature size from the build shape.
hidden_dim = decoder_sequence_shape[-1]
# Attention head size is `hidden_dim` over the number of heads.
head_dim = int(hidden_dim // self.num_heads)
if head_dim == 0:
raise ValueError(
"Attention `head_dim` computed cannot be zero. "
f"The `hidden_dim` value of {hidden_dim} has to be equal to "
f"or greater than `num_heads` value of {self.num_heads}."
)

# Self attention layers.
self._self_attention_layer = WhisperCachedMultiHeadAttention(
num_heads=self.num_heads,
key_dim=head_dim,
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="self_attention",
)
if hasattr(self._self_attention_layer, "_build_from_signature"):
self._self_attention_layer._build_from_signature(
query=decoder_sequence_shape,
value=decoder_sequence_shape,
)
else:
self._self_attention_layer.build(
query_shape=decoder_sequence_shape,
value_shape=decoder_sequence_shape,
)
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="self_attention_layer_norm",
)
self._self_attention_layer_norm.build(decoder_sequence_shape)
self._self_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="self_attention_dropout",
)

# Cross attention layers are optional.
self._cross_attention_layer = None
if encoder_sequence_shape:
self._cross_attention_layer = WhisperCachedMultiHeadAttention(
num_heads=self.num_heads,
key_dim=head_dim,
value_dim=head_dim,
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="cross_attention",
)
if hasattr(self._cross_attention_layer, "_build_from_signature"):
self._cross_attention_layer._build_from_signature(
query=decoder_sequence_shape,
value=encoder_sequence_shape,
)
else:
self._cross_attention_layer.build(
query_shape=decoder_sequence_shape,
value_shape=encoder_sequence_shape,
)
self._cross_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="cross_attention_layer_norm",
)
self._cross_attention_layer_norm.build(decoder_sequence_shape)
self._cross_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="cross_attention_dropout",
)

# Since there is no exposed option for this in MHA, we will reach into
# the internals of the layer for now.
self._self_attention_layer._key_dense.bias_axes = None
self._self_attention_layer._key_dense.bias = None
if self._cross_attention_layer:
self._cross_attention_layer._key_dense.bias_axes = None
self._cross_attention_layer._key_dense.bias = None
# Feedforward layers.
self._feedforward_intermediate_dense = keras.layers.Dense(
self.intermediate_dim,
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="feedforward_intermediate_dense",
)
self._feedforward_intermediate_dense.build(decoder_sequence_shape)
self._feedforward_output_dense = keras.layers.Dense(
hidden_dim,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="feedforward_output_dense",
)
intermediate_shape = list(decoder_sequence_shape)
intermediate_shape[-1] = self.intermediate_dim
self._feedforward_output_dense.build(tuple(intermediate_shape))
self._feedforward_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="feedforward_layer_norm",
)
self._feedforward_layer_norm.build(decoder_sequence_shape)
self._feedforward_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="feedforward_dropout",
)
# Create layers based on input shape.
self.built = True
Loading

0 comments on commit 04d931d

Please sign in to comment.