diff --git a/keras_nlp/models/t5/t5_backbone.py b/keras_nlp/models/t5/t5_backbone.py index 7514cc51ae..13db116f43 100644 --- a/keras_nlp/models/t5/t5_backbone.py +++ b/keras_nlp/models/t5/t5_backbone.py @@ -19,7 +19,6 @@ from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer from keras_nlp.utils.python_utils import classproperty -from keras_nlp.utils.tensor_utils import assert_tf_backend @keras_nlp_export("keras_nlp.models.T5Backbone") @@ -81,8 +80,6 @@ def __init__( tie_embedding_weights=False, **kwargs, ): - assert_tf_backend(self.__class__.__name__) - # Encoder inputs encoder_token_ids = keras.Input( shape=(None,), dtype="int32", name="encoder_token_ids" @@ -121,7 +118,7 @@ def __init__( position_bias = None for i in range(num_layers): - x, position_bias = T5TransformerLayer( + output = T5TransformerLayer( is_decoder=False, hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, @@ -138,6 +135,8 @@ def __init__( position_bias=position_bias, use_causal_mask=False, ) + if isinstance(output, tuple): + x, position_bias = output x = T5LayerNorm( epsilon=layer_norm_epsilon, @@ -162,7 +161,7 @@ def __init__( position_bias = None for i in range(num_layers): - x, position_bias = T5TransformerLayer( + output = T5TransformerLayer( is_decoder=True, hidden_dim=hidden_dim, intermediate_dim=intermediate_dim, @@ -181,6 +180,8 @@ def __init__( encoder_attention_mask=encoder_attention_mask, use_causal_mask=True, ) + if isinstance(output, tuple): + x, position_bias = output x = T5LayerNorm( epsilon=layer_norm_epsilon, diff --git a/keras_nlp/models/t5/t5_backbone_test.py b/keras_nlp/models/t5/t5_backbone_test.py index e5e147705e..b8041e876e 100644 --- a/keras_nlp/models/t5/t5_backbone_test.py +++ b/keras_nlp/models/t5/t5_backbone_test.py @@ -19,7 +19,6 @@ from keras_nlp.tests.test_case import TestCase -@pytest.mark.tf_only class T5BackboneTest(TestCase): def setUp(self): self.init_kwargs = { diff --git a/keras_nlp/models/t5/t5_layer_norm.py b/keras_nlp/models/t5/t5_layer_norm.py index 7cfdb2315e..b4f157c004 100644 --- a/keras_nlp/models/t5/t5_layer_norm.py +++ b/keras_nlp/models/t5/t5_layer_norm.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - from keras_nlp.backend import keras +from keras_nlp.backend import ops class T5LayerNorm(keras.layers.Layer): @@ -31,8 +30,6 @@ def build(self, input_shape): self.built = True def call(self, hidden_states): - variance = tf.math.reduce_mean( - tf.math.square(hidden_states), axis=-1, keepdims=True - ) - hidden_states = hidden_states * tf.math.rsqrt(variance + self.epsilon) + variance = ops.mean(ops.square(hidden_states), axis=-1, keepdims=True) + hidden_states = hidden_states * ops.rsqrt(variance + self.epsilon) return self.weight * hidden_states diff --git a/keras_nlp/models/t5/t5_multi_head_attention.py b/keras_nlp/models/t5/t5_multi_head_attention.py index 479de51e7d..5cb59769dc 100644 --- a/keras_nlp/models/t5/t5_multi_head_attention.py +++ b/keras_nlp/models/t5/t5_multi_head_attention.py @@ -12,18 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf -from tensorflow.compiler.tf2xla.python.xla import dynamic_slice +import numpy as np from keras_nlp.backend import keras - - -def shape_list(tensor): - dynamic = tf.shape(tensor) - if tensor.shape == tf.TensorShape(None): - return dynamic - static = tensor.shape.as_list() - return [dynamic[i] if s is None else s for i, s in enumerate(static)] +from keras_nlp.backend import ops class T5MultiHeadAttention(keras.layers.Layer): @@ -123,39 +115,39 @@ def _relative_position_bucket( if bidirectional: num_buckets //= 2 relative_buckets += ( - tf.cast( - tf.math.greater(relative_position, 0), + ops.cast( + ops.greater(relative_position, 0), dtype=relative_position.dtype, ) * num_buckets ) - relative_position = tf.math.abs(relative_position) + relative_position = ops.abs(relative_position) else: - relative_position = -tf.math.minimum(relative_position, 0) + relative_position = -ops.minimum(relative_position, 0) # now n is in the range [0, inf) max_exact = num_buckets // 2 - is_small = tf.math.less(relative_position, max_exact) - relative_position_if_large = max_exact + tf.cast( - tf.math.log( - tf.cast(relative_position, "float32") - / tf.cast(max_exact, "float32") + is_small = ops.less(relative_position, max_exact) + relative_position_if_large = max_exact + ops.cast( + ops.log( + ops.cast(relative_position, "float32") + / ops.cast(max_exact, "float32") ) - / tf.math.log(max_distance / max_exact) + / ops.cast(ops.log(max_distance / max_exact), "float32") * (num_buckets - max_exact), dtype=relative_position.dtype, ) - relative_position_if_large = tf.math.minimum( + relative_position_if_large = ops.minimum( relative_position_if_large, num_buckets - 1 ) - relative_buckets += tf.where( + relative_buckets += ops.where( is_small, relative_position, relative_position_if_large ) return relative_buckets def compute_bias(self, query_length, key_length): """Compute binned relative position bias""" - context_position = tf.range(query_length)[:, None] - memory_position = tf.range(key_length)[None, :] + context_position = ops.arange(query_length)[:, None] + memory_position = ops.arange(key_length)[None, :] relative_position = ( memory_position - context_position ) # shape (query_length, key_length) @@ -165,11 +157,11 @@ def compute_bias(self, query_length, key_length): num_buckets=self.relative_attention_buckets, max_distance=self.relative_attention_max_distance, ) - values = tf.gather( - self.relative_attention_bias, relative_position_bucket + values = ops.take( + self.relative_attention_bias, relative_position_bucket, axis=0 ) # shape (query_length, key_length, num_heads) - values = tf.expand_dims( - tf.transpose(values, [2, 0, 1]), axis=0 + values = ops.expand_dims( + ops.transpose(values, axes=(2, 0, 1)), axis=0 ) # shape (1, num_heads, query_length, key_length) return values @@ -186,7 +178,7 @@ def call( ): # Input is (batch_size, query_length, dim) # past_key_value[0] is (batch_size, num_heads, q_len - 1, dim_per_head) - batch_size, seq_length = shape_list(hidden_states)[:2] + batch_size, seq_length = ops.shape(hidden_states)[:2] real_seq_length = seq_length @@ -197,7 +189,7 @@ def call( f"keys and values. Got {len(past_key_value)} past states." ) real_seq_length += ( - shape_list(past_key_value[0])[2] + ops.shape(past_key_value[0])[2] if query_length is None else query_length ) @@ -205,21 +197,21 @@ def call( key_length = ( real_seq_length if key_value_states is None - else shape_list(key_value_states)[1] + else ops.shape(key_value_states)[1] ) def shape(hidden_states): - return tf.transpose( - tf.reshape( + return ops.transpose( + ops.reshape( hidden_states, (batch_size, -1, self.num_heads, self.key_value_dim), ), - perm=(0, 2, 1, 3), + axes=(0, 2, 1, 3), ) def unshape(hidden_states): - return tf.reshape( - tf.transpose(hidden_states, perm=(0, 2, 1, 3)), + return ops.reshape( + ops.transpose(hidden_states, axes=(0, 2, 1, 3)), (batch_size, -1, self.inner_dim), ) @@ -240,7 +232,7 @@ def project( if key_value_states is None: # self-attention # (batch_size, num_heads, key_length, dim_per_head) - hidden_states = tf.concat( + hidden_states = ops.concat( [past_key_value, hidden_states], axis=2 ) else: @@ -267,13 +259,13 @@ def project( past_key_value[1] if past_key_value is not None else None, ) - scores = tf.einsum( + scores = ops.einsum( "bnqd,bnkd->bnqk", query_states, key_states ) # (batch_size, num_heads, query_length, key_length) if position_bias is None: if not self.use_relative_attention_bias: - position_bias = tf.zeros( + position_bias = ops.zeros( (1, self.num_heads, real_seq_length, key_length), self.compute_dtype, ) @@ -289,10 +281,10 @@ def project( # we might have a padded past structure, # in which case we want to fetch the position bias slice # right after the most recently filled past index - most_recently_filled_past_index = tf.reduce_max( - tf.where(past_key_value[0][0, 0, :, 0] != 0.0) + most_recently_filled_past_index = ops.amax( + ops.where(past_key_value[0][0, 0, :, 0] != 0.0) ) - position_bias = dynamic_slice( + position_bias = ops.slice( position_bias, (0, 0, most_recently_filled_past_index + 1, 0), (1, self.num_heads, seq_length, real_seq_length), @@ -300,13 +292,13 @@ def project( if mask is not None: # Add a new mask axis for the head dim. - mask = mask[:, tf.newaxis, :, :] + mask = mask[:, np.newaxis, :, :] # Add a very large negative position bias for masked positions. - mask = (1.0 - tf.cast(mask, position_bias.dtype)) * -1e9 + mask = (1.0 - ops.cast(mask, position_bias.dtype)) * -1e9 position_bias = position_bias + mask scores += position_bias - weights = tf.nn.softmax( + weights = ops.nn.softmax( scores, axis=-1 ) # (batch_size, num_heads, query_length, key_length) weights = self.dropout_layer( @@ -315,9 +307,9 @@ def project( # Optionally mask heads if layer_head_mask is not None: - weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights + weights = ops.reshape(layer_head_mask, (1, -1, 1, 1)) * weights - attention_output = tf.matmul( + attention_output = ops.matmul( weights, value_states ) # (batch_size, num_heads, query_length, dim_per_head) diff --git a/keras_nlp/models/t5/t5_transformer_layer.py b/keras_nlp/models/t5/t5_transformer_layer.py index ce4a28d67f..22c2dc1c74 100644 --- a/keras_nlp/models/t5/t5_transformer_layer.py +++ b/keras_nlp/models/t5/t5_transformer_layer.py @@ -12,9 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorflow as tf - from keras_nlp.backend import keras +from keras_nlp.backend import ops from keras_nlp.layers.modeling.transformer_layer_utils import ( compute_causal_mask, ) @@ -103,10 +102,10 @@ def call( training=False, ): if use_causal_mask: - shape = tf.shape(hidden_states) + shape = ops.shape(hidden_states) batch_size, length = shape[0], shape[1] causal_mask = compute_causal_mask(batch_size, length, length) - attention_mask = tf.cast(attention_mask, "int32") + attention_mask = ops.cast(attention_mask, "int32") attention_mask = causal_mask & attention_mask x = hidden_states # Intermediate result. @@ -147,4 +146,7 @@ def call( x = self.dropout_layer(x, training=training) x = x + residual - return x, position_bias + if position_bias is not None: + return x, position_bias + else: + return x