Skip to content

Commit

Permalink
Convert T5 to Keras 3 (#1274)
Browse files Browse the repository at this point in the history
* Change TF ops to Keras Core ops

* Fix formatting

* Remove build override

* Fix formatting and remove unneeded function
  • Loading branch information
nkovela1 authored Oct 18, 2023
1 parent ab376b1 commit 871f664
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 64 deletions.
11 changes: 6 additions & 5 deletions keras_nlp/models/t5/t5_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from keras_nlp.models.t5.t5_layer_norm import T5LayerNorm
from keras_nlp.models.t5.t5_transformer_layer import T5TransformerLayer
from keras_nlp.utils.python_utils import classproperty
from keras_nlp.utils.tensor_utils import assert_tf_backend


@keras_nlp_export("keras_nlp.models.T5Backbone")
Expand Down Expand Up @@ -81,8 +80,6 @@ def __init__(
tie_embedding_weights=False,
**kwargs,
):
assert_tf_backend(self.__class__.__name__)

# Encoder inputs
encoder_token_ids = keras.Input(
shape=(None,), dtype="int32", name="encoder_token_ids"
Expand Down Expand Up @@ -121,7 +118,7 @@ def __init__(

position_bias = None
for i in range(num_layers):
x, position_bias = T5TransformerLayer(
output = T5TransformerLayer(
is_decoder=False,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
Expand All @@ -138,6 +135,8 @@ def __init__(
position_bias=position_bias,
use_causal_mask=False,
)
if isinstance(output, tuple):
x, position_bias = output

x = T5LayerNorm(
epsilon=layer_norm_epsilon,
Expand All @@ -162,7 +161,7 @@ def __init__(

position_bias = None
for i in range(num_layers):
x, position_bias = T5TransformerLayer(
output = T5TransformerLayer(
is_decoder=True,
hidden_dim=hidden_dim,
intermediate_dim=intermediate_dim,
Expand All @@ -181,6 +180,8 @@ def __init__(
encoder_attention_mask=encoder_attention_mask,
use_causal_mask=True,
)
if isinstance(output, tuple):
x, position_bias = output

x = T5LayerNorm(
epsilon=layer_norm_epsilon,
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/models/t5/t5_backbone_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from keras_nlp.tests.test_case import TestCase


@pytest.mark.tf_only
class T5BackboneTest(TestCase):
def setUp(self):
self.init_kwargs = {
Expand Down
9 changes: 3 additions & 6 deletions keras_nlp/models/t5/t5_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from keras_nlp.backend import keras
from keras_nlp.backend import ops


class T5LayerNorm(keras.layers.Layer):
Expand All @@ -31,8 +30,6 @@ def build(self, input_shape):
self.built = True

def call(self, hidden_states):
variance = tf.math.reduce_mean(
tf.math.square(hidden_states), axis=-1, keepdims=True
)
hidden_states = hidden_states * tf.math.rsqrt(variance + self.epsilon)
variance = ops.mean(ops.square(hidden_states), axis=-1, keepdims=True)
hidden_states = hidden_states * ops.rsqrt(variance + self.epsilon)
return self.weight * hidden_states
86 changes: 39 additions & 47 deletions keras_nlp/models/t5/t5_multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf
from tensorflow.compiler.tf2xla.python.xla import dynamic_slice
import numpy as np

from keras_nlp.backend import keras


def shape_list(tensor):
dynamic = tf.shape(tensor)
if tensor.shape == tf.TensorShape(None):
return dynamic
static = tensor.shape.as_list()
return [dynamic[i] if s is None else s for i, s in enumerate(static)]
from keras_nlp.backend import ops


class T5MultiHeadAttention(keras.layers.Layer):
Expand Down Expand Up @@ -123,39 +115,39 @@ def _relative_position_bucket(
if bidirectional:
num_buckets //= 2
relative_buckets += (
tf.cast(
tf.math.greater(relative_position, 0),
ops.cast(
ops.greater(relative_position, 0),
dtype=relative_position.dtype,
)
* num_buckets
)
relative_position = tf.math.abs(relative_position)
relative_position = ops.abs(relative_position)
else:
relative_position = -tf.math.minimum(relative_position, 0)
relative_position = -ops.minimum(relative_position, 0)
# now n is in the range [0, inf)
max_exact = num_buckets // 2
is_small = tf.math.less(relative_position, max_exact)
relative_position_if_large = max_exact + tf.cast(
tf.math.log(
tf.cast(relative_position, "float32")
/ tf.cast(max_exact, "float32")
is_small = ops.less(relative_position, max_exact)
relative_position_if_large = max_exact + ops.cast(
ops.log(
ops.cast(relative_position, "float32")
/ ops.cast(max_exact, "float32")
)
/ tf.math.log(max_distance / max_exact)
/ ops.cast(ops.log(max_distance / max_exact), "float32")
* (num_buckets - max_exact),
dtype=relative_position.dtype,
)
relative_position_if_large = tf.math.minimum(
relative_position_if_large = ops.minimum(
relative_position_if_large, num_buckets - 1
)
relative_buckets += tf.where(
relative_buckets += ops.where(
is_small, relative_position, relative_position_if_large
)
return relative_buckets

def compute_bias(self, query_length, key_length):
"""Compute binned relative position bias"""
context_position = tf.range(query_length)[:, None]
memory_position = tf.range(key_length)[None, :]
context_position = ops.arange(query_length)[:, None]
memory_position = ops.arange(key_length)[None, :]
relative_position = (
memory_position - context_position
) # shape (query_length, key_length)
Expand All @@ -165,11 +157,11 @@ def compute_bias(self, query_length, key_length):
num_buckets=self.relative_attention_buckets,
max_distance=self.relative_attention_max_distance,
)
values = tf.gather(
self.relative_attention_bias, relative_position_bucket
values = ops.take(
self.relative_attention_bias, relative_position_bucket, axis=0
) # shape (query_length, key_length, num_heads)
values = tf.expand_dims(
tf.transpose(values, [2, 0, 1]), axis=0
values = ops.expand_dims(
ops.transpose(values, axes=(2, 0, 1)), axis=0
) # shape (1, num_heads, query_length, key_length)
return values

Expand All @@ -186,7 +178,7 @@ def call(
):
# Input is (batch_size, query_length, dim)
# past_key_value[0] is (batch_size, num_heads, q_len - 1, dim_per_head)
batch_size, seq_length = shape_list(hidden_states)[:2]
batch_size, seq_length = ops.shape(hidden_states)[:2]

real_seq_length = seq_length

Expand All @@ -197,29 +189,29 @@ def call(
f"keys and values. Got {len(past_key_value)} past states."
)
real_seq_length += (
shape_list(past_key_value[0])[2]
ops.shape(past_key_value[0])[2]
if query_length is None
else query_length
)

key_length = (
real_seq_length
if key_value_states is None
else shape_list(key_value_states)[1]
else ops.shape(key_value_states)[1]
)

def shape(hidden_states):
return tf.transpose(
tf.reshape(
return ops.transpose(
ops.reshape(
hidden_states,
(batch_size, -1, self.num_heads, self.key_value_dim),
),
perm=(0, 2, 1, 3),
axes=(0, 2, 1, 3),
)

def unshape(hidden_states):
return tf.reshape(
tf.transpose(hidden_states, perm=(0, 2, 1, 3)),
return ops.reshape(
ops.transpose(hidden_states, axes=(0, 2, 1, 3)),
(batch_size, -1, self.inner_dim),
)

Expand All @@ -240,7 +232,7 @@ def project(
if key_value_states is None:
# self-attention
# (batch_size, num_heads, key_length, dim_per_head)
hidden_states = tf.concat(
hidden_states = ops.concat(
[past_key_value, hidden_states], axis=2
)
else:
Expand All @@ -267,13 +259,13 @@ def project(
past_key_value[1] if past_key_value is not None else None,
)

scores = tf.einsum(
scores = ops.einsum(
"bnqd,bnkd->bnqk", query_states, key_states
) # (batch_size, num_heads, query_length, key_length)

if position_bias is None:
if not self.use_relative_attention_bias:
position_bias = tf.zeros(
position_bias = ops.zeros(
(1, self.num_heads, real_seq_length, key_length),
self.compute_dtype,
)
Expand All @@ -289,24 +281,24 @@ def project(
# we might have a padded past structure,
# in which case we want to fetch the position bias slice
# right after the most recently filled past index
most_recently_filled_past_index = tf.reduce_max(
tf.where(past_key_value[0][0, 0, :, 0] != 0.0)
most_recently_filled_past_index = ops.amax(
ops.where(past_key_value[0][0, 0, :, 0] != 0.0)
)
position_bias = dynamic_slice(
position_bias = ops.slice(
position_bias,
(0, 0, most_recently_filled_past_index + 1, 0),
(1, self.num_heads, seq_length, real_seq_length),
)

if mask is not None:
# Add a new mask axis for the head dim.
mask = mask[:, tf.newaxis, :, :]
mask = mask[:, np.newaxis, :, :]
# Add a very large negative position bias for masked positions.
mask = (1.0 - tf.cast(mask, position_bias.dtype)) * -1e9
mask = (1.0 - ops.cast(mask, position_bias.dtype)) * -1e9
position_bias = position_bias + mask

scores += position_bias
weights = tf.nn.softmax(
weights = ops.nn.softmax(
scores, axis=-1
) # (batch_size, num_heads, query_length, key_length)
weights = self.dropout_layer(
Expand All @@ -315,9 +307,9 @@ def project(

# Optionally mask heads
if layer_head_mask is not None:
weights = tf.reshape(layer_head_mask, (1, -1, 1, 1)) * weights
weights = ops.reshape(layer_head_mask, (1, -1, 1, 1)) * weights

attention_output = tf.matmul(
attention_output = ops.matmul(
weights, value_states
) # (batch_size, num_heads, query_length, dim_per_head)

Expand Down
12 changes: 7 additions & 5 deletions keras_nlp/models/t5/t5_transformer_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorflow as tf

from keras_nlp.backend import keras
from keras_nlp.backend import ops
from keras_nlp.layers.modeling.transformer_layer_utils import (
compute_causal_mask,
)
Expand Down Expand Up @@ -103,10 +102,10 @@ def call(
training=False,
):
if use_causal_mask:
shape = tf.shape(hidden_states)
shape = ops.shape(hidden_states)
batch_size, length = shape[0], shape[1]
causal_mask = compute_causal_mask(batch_size, length, length)
attention_mask = tf.cast(attention_mask, "int32")
attention_mask = ops.cast(attention_mask, "int32")
attention_mask = causal_mask & attention_mask

x = hidden_states # Intermediate result.
Expand Down Expand Up @@ -147,4 +146,7 @@ def call(
x = self.dropout_layer(x, training=training)
x = x + residual

return x, position_bias
if position_bias is not None:
return x, position_bias
else:
return x

0 comments on commit 871f664

Please sign in to comment.