Skip to content

Commit

Permalink
Simplify
Browse files Browse the repository at this point in the history
  • Loading branch information
abheesht17 committed Oct 28, 2023
1 parent a3391b9 commit a8473d6
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 186 deletions.
2 changes: 0 additions & 2 deletions keras_nlp/models/whisper/whisper_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ def test_saved_model(self):
input_data=self.input_data,
)

@pytest.mark.skip # TODO: fix weight mismatch error.
@pytest.mark.large
def test_smallest_preset(self):
self.run_preset_test(
Expand Down Expand Up @@ -105,7 +104,6 @@ def test_smallest_preset(self):
},
)

@pytest.mark.skip # TODO: fix weight mismatch error.
@pytest.mark.extra_large
def test_all_presets(self):
for preset in WhisperBackbone.presets:
Expand Down
76 changes: 31 additions & 45 deletions keras_nlp/models/whisper/whisper_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class WhisperDecoder(TransformerDecoder):
def build(
self,
decoder_sequence_shape,
encoder_sequence_shape=None,
encoder_sequence_shape,
):
self._decoder_sequence_shape = decoder_sequence_shape
self._encoder_sequence_shape = encoder_sequence_shape
Expand All @@ -62,16 +62,11 @@ def build(
dtype=self.dtype_policy,
name="self_attention",
)
if hasattr(self._self_attention_layer, "_build_from_signature"):
self._self_attention_layer._build_from_signature(
query=decoder_sequence_shape,
value=decoder_sequence_shape,
)
else:
self._self_attention_layer.build(
query_shape=decoder_sequence_shape,
value_shape=decoder_sequence_shape,
)

self._self_attention_layer.build(
query_shape=decoder_sequence_shape,
value_shape=decoder_sequence_shape,
)
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
Expand All @@ -84,40 +79,31 @@ def build(
name="self_attention_dropout",
)

# Cross attention layers are optional.
self._cross_attention_layer = None
if encoder_sequence_shape:
self._cross_attention_layer = WhisperCachedMultiHeadAttention(
num_heads=self.num_heads,
key_dim=head_dim,
value_dim=head_dim,
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="cross_attention",
)
if hasattr(self._cross_attention_layer, "_build_from_signature"):
self._cross_attention_layer._build_from_signature(
query=decoder_sequence_shape,
value=encoder_sequence_shape,
)
else:
self._cross_attention_layer.build(
query_shape=decoder_sequence_shape,
value_shape=encoder_sequence_shape,
)
self._cross_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="cross_attention_layer_norm",
)
self._cross_attention_layer_norm.build(decoder_sequence_shape)
self._cross_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="cross_attention_dropout",
)
self._cross_attention_layer = WhisperCachedMultiHeadAttention(
num_heads=self.num_heads,
key_dim=head_dim,
value_dim=head_dim,
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
dtype=self.dtype_policy,
name="cross_attention",
)
self._cross_attention_layer.build(
query_shape=decoder_sequence_shape,
value_shape=encoder_sequence_shape,
)
self._cross_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
dtype=self.dtype_policy,
name="cross_attention_layer_norm",
)
self._cross_attention_layer_norm.build(decoder_sequence_shape)
self._cross_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
dtype=self.dtype_policy,
name="cross_attention_dropout",
)

# Feedforward layers.
self._feedforward_intermediate_dense = keras.layers.Dense(
Expand Down
8 changes: 4 additions & 4 deletions keras_nlp/models/whisper/whisper_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@

from keras_nlp.backend import keras
from keras_nlp.layers.modeling.transformer_encoder import TransformerEncoder
from keras_nlp.models.whisper.whisper_multi_head_attention import (
WhisperMultiHeadAttention,
from keras_nlp.models.whisper.whisper_cached_multi_head_attention import (
WhisperCachedMultiHeadAttention,
)
from keras_nlp.utils.keras_utils import clone_initializer

Expand All @@ -28,7 +28,7 @@ class WhisperEncoder(TransformerEncoder):
Inherits from `keras_nlp.layers.TransformerEncoder`, and overrides the
`_build` method to use the
`keras_nlp.models.whisper.whisper_multi_head_attention.WhisperMultiHeadAttention`
`keras_nlp.models.whisper.whisper_multi_head_attention.WhisperCachedMultiHeadAttention`
layer instead of `keras.layers.MultiHeadAttention`.
"""

Expand All @@ -45,7 +45,7 @@ def build(self, inputs_shape):
)

# Self attention layers.
self._self_attention_layer = WhisperMultiHeadAttention(
self._self_attention_layer = WhisperCachedMultiHeadAttention(
num_heads=self.num_heads,
key_dim=key_dim,
dropout=self.dropout,
Expand Down
135 changes: 0 additions & 135 deletions keras_nlp/models/whisper/whisper_multi_head_attention.py

This file was deleted.

0 comments on commit a8473d6

Please sign in to comment.