Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve layer naming consistency #1219

Merged
merged 1 commit into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions keras_nlp/layers/modeling/f_net_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,13 @@ def build(self, inputs_shape):

# Layer Norm layers.
self._mixing_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon
epsilon=self.layer_norm_epsilon,
name="mixing_layer_norm",
)
self._mixing_layer_norm.build(inputs_shape)
self._output_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon
epsilon=self.layer_norm_epsilon,
name="output_layer_norm",
)
self._output_layer_norm.build(inputs_shape)

Expand All @@ -113,12 +115,14 @@ def build(self, inputs_shape):
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="intermediate_dense",
)
self._intermediate_dense.build(inputs_shape)
self._output_dense = keras.layers.Dense(
feature_size,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="output_dense",
)
self._output_dense.build(
self._intermediate_dense.compute_output_shape(inputs_shape)
Expand Down
16 changes: 9 additions & 7 deletions keras_nlp/layers/modeling/masked_lm_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ class MaskedLMHead(keras.layers.Layer):
token_embedding: Optional. A `keras_nlp.layers.ReversibleEmbedding`
instance. If passed, the layer will be used to project from the
`hidden_dim` of the model to the output `vocabulary_size`.
intermediate_activation: The activation function of inner dense layer.
intermediate_activation: The activation function of intermediate dense layer.
activation: The activation function for the outputs of the layer.
Usually either `None` (return logits), or `"softmax"`
(return probabilities).
Expand Down Expand Up @@ -138,22 +138,24 @@ def build(self, inputs_shape, mask_positions_shape=None):
else:
feature_size = inputs_shape[-1]

self._dense = keras.layers.Dense(
self._intermediate_dense = keras.layers.Dense(
feature_size,
activation=self.intermediate_activation,
kernel_initializer=self.kernel_initializer,
bias_initializer=self.bias_initializer,
name="intermediate_dense",
)
self._layer_norm = keras.layers.LayerNormalization(
self._intermediate_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
name="intermediate_layer_norm",
)
# The gather length does not affect any of our built variables, so
# we can pass any value here.
gather_length = None
shape = (inputs_shape[0], gather_length, inputs_shape[-1])
self._dense.build(shape)
self._intermediate_dense.build(shape)
shape = (inputs_shape[0], gather_length, feature_size)
self._layer_norm.build(shape)
self._intermediate_layer_norm.build(shape)
if self.token_embedding is None:
self._kernel = self.add_weight(
name="output_kernel",
Expand All @@ -174,8 +176,8 @@ def call(self, inputs, mask_positions):
x = ops.take_along_axis(inputs, mask_positions, axis=1)

# Apply a trainable linear transformation and a layer norm.
x = self._dense(x)
x = self._layer_norm(x)
x = self._intermediate_dense(x)
x = self._intermediate_layer_norm(x)

# Transform encodings to vocabulary_size predictions.
if self.token_embedding:
Expand Down
6 changes: 2 additions & 4 deletions keras_nlp/layers/modeling/token_and_position_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,12 @@ def __init__(
self.embeddings_initializer
),
mask_zero=mask_zero,
name="token_embedding"
+ str(keras.backend.get_uid("token_embedding")),
name="token_embedding",
)
self.position_embedding = PositionEmbedding(
sequence_length=sequence_length,
initializer=clone_initializer(self.embeddings_initializer),
name="position_embedding"
+ str(keras.backend.get_uid("position_embedding")),
name="position_embedding",
)
self.supports_masking = self.token_embedding.supports_masking

Expand Down
24 changes: 12 additions & 12 deletions keras_nlp/layers/modeling/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def build(
query_shape=decoder_sequence_shape,
value_shape=decoder_sequence_shape,
)
self._self_attention_layernorm = keras.layers.LayerNormalization(
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)
self._self_attention_layernorm.build(decoder_sequence_shape)
self._self_attention_layer_norm.build(decoder_sequence_shape)
self._self_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
)
Expand All @@ -202,10 +202,10 @@ def build(
query_shape=encoder_sequence_shape,
value_shape=encoder_sequence_shape,
)
self._cross_attention_layernorm = keras.layers.LayerNormalization(
self._cross_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)
self._cross_attention_layernorm.build(encoder_sequence_shape)
self._cross_attention_layer_norm.build(encoder_sequence_shape)
self._cross_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
)
Expand All @@ -226,10 +226,10 @@ def build(
intermediate_shape = list(decoder_sequence_shape)
intermediate_shape[-1] = self.intermediate_dim
self._feedforward_output_dense.build(tuple(intermediate_shape))
self._feedforward_layernorm = keras.layers.LayerNormalization(
self._feedforward_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)
self._feedforward_layernorm.build(decoder_sequence_shape)
self._feedforward_layer_norm.build(decoder_sequence_shape)
self._feedforward_dropout = keras.layers.Dropout(
rate=self.dropout,
)
Expand Down Expand Up @@ -364,7 +364,7 @@ def call(
# Self attention block.
residual = x
if self.normalize_first:
x = self._self_attention_layernorm(x)
x = self._self_attention_layer_norm(x)
x, self_attention_cache = self._self_attention_layer(
query=x,
value=x,
Expand All @@ -375,7 +375,7 @@ def call(
x = self._self_attention_dropout(x)
x = x + residual
if not self.normalize_first:
x = self._self_attention_layernorm(x)
x = self._self_attention_layer_norm(x)

# Cross attention is optional.
if has_cross_attention:
Expand All @@ -387,7 +387,7 @@ def call(
# Cross attention block.
residual = x
if self.normalize_first:
x = self._cross_attention_layernorm(x)
x = self._cross_attention_layer_norm(x)
x, cross_attention_cache = self._cross_attention_layer(
query=x,
value=encoder_sequence,
Expand All @@ -398,18 +398,18 @@ def call(
x = self._cross_attention_dropout(x)
x = x + residual
if not self.normalize_first:
x = self._cross_attention_layernorm(x)
x = self._cross_attention_layer_norm(x)

# Feedforward block.
residual = x
if self.normalize_first:
x = self._feedforward_layernorm(x)
x = self._feedforward_layer_norm(x)
x = self._feedforward_intermediate_dense(x)
x = self._feedforward_output_dense(x)
x = self._feedforward_dropout(x)
x = x + residual
if not self.normalize_first:
x = self._feedforward_layernorm(x)
x = self._feedforward_layer_norm(x)

if self_attention_cache is not None:
if has_cross_attention:
Expand Down
21 changes: 13 additions & 8 deletions keras_nlp/layers/modeling/transformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def build(self, inputs_shape):
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="self_attention_layer",
)
if hasattr(self._self_attention_layer, "_build_from_signature"):
self._self_attention_layer._build_from_signature(
Expand All @@ -138,30 +139,34 @@ def build(self, inputs_shape):
query_shape=inputs_shape,
value_shape=inputs_shape,
)
self._self_attention_layernorm = keras.layers.LayerNormalization(
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
name="self_attention_layer_norm",
)
self._self_attention_layernorm.build(inputs_shape)
self._self_attention_layer_norm.build(inputs_shape)
self._self_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
name="self_attention_dropout",
)

# Feedforward layers.
self._feedforward_layernorm = keras.layers.LayerNormalization(
self._feedforward_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
)
self._feedforward_layernorm.build(inputs_shape)
self._feedforward_layer_norm.build(inputs_shape)
self._feedforward_intermediate_dense = keras.layers.Dense(
self.intermediate_dim,
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="feedforward_intermediate_dense",
)
self._feedforward_intermediate_dense.build(inputs_shape)
self._feedforward_output_dense = keras.layers.Dense(
hidden_dim,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="feedforward_output_dense",
)
intermediate_shape = list(inputs_shape)
intermediate_shape[-1] = self.intermediate_dim
Expand Down Expand Up @@ -197,7 +202,7 @@ def call(self, inputs, padding_mask=None, attention_mask=None):
# Self attention block.
residual = x
if self.normalize_first:
x = self._self_attention_layernorm(x)
x = self._self_attention_layer_norm(x)
x = self._self_attention_layer(
query=x,
value=x,
Expand All @@ -206,18 +211,18 @@ def call(self, inputs, padding_mask=None, attention_mask=None):
x = self._self_attention_dropout(x)
x = x + residual
if not self.normalize_first:
x = self._self_attention_layernorm(x)
x = self._self_attention_layer_norm(x)

# Feedforward block.
residual = x
if self.normalize_first:
x = self._feedforward_layernorm(x)
x = self._feedforward_layer_norm(x)
x = self._feedforward_intermediate_dense(x)
x = self._feedforward_output_dense(x)
x = self._feedforward_dropout(x)
x = x + residual
if not self.normalize_first:
x = self._feedforward_layernorm(x)
x = self._feedforward_layer_norm(x)

return x

Expand Down
19 changes: 13 additions & 6 deletions keras_nlp/models/deberta_v3/disentangled_attention_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,38 +101,45 @@ def build(self, inputs_shape):
dropout=self.dropout,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="self_attention_layer",
)
self._self_attention_layer.build(inputs_shape)
self._self_attention_layernorm = keras.layers.LayerNormalization(
self._self_attention_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
name="self_attention_layer_norm",
)
self._self_attention_layernorm.build(inputs_shape)
self._self_attention_layer_norm.build(inputs_shape)
self._self_attention_dropout = keras.layers.Dropout(
rate=self.dropout,
name="self_attention_dropout",
)

# Feedforward layers.
self._feedforward_layernorm = keras.layers.LayerNormalization(
self._feedforward_layer_norm = keras.layers.LayerNormalization(
epsilon=self.layer_norm_epsilon,
name="feedforward_layer_norm",
)
self._feedforward_layernorm.build(inputs_shape)
self._feedforward_layer_norm.build(inputs_shape)
self._feedforward_intermediate_dense = keras.layers.Dense(
self.intermediate_dim,
activation=self.activation,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="feedforward_intermediate_dense",
)
self._feedforward_intermediate_dense.build(inputs_shape)
self._feedforward_output_dense = keras.layers.Dense(
hidden_dim,
kernel_initializer=clone_initializer(self.kernel_initializer),
bias_initializer=clone_initializer(self.bias_initializer),
name="feedforward_output_dense",
)
intermediate_shape = list(inputs_shape)
intermediate_shape[-1] = self.intermediate_dim
self._feedforward_output_dense.build(tuple(intermediate_shape))
self._feedforward_dropout = keras.layers.Dropout(
rate=self.dropout,
name="feedforward_dropout",
)
self.built = True

Expand Down Expand Up @@ -177,15 +184,15 @@ def call(
)
x = self._self_attention_dropout(x)
x = x + residual
x = self._self_attention_layernorm(x)
x = self._self_attention_layer_norm(x)

# Feedforward block.
residual = x
x = self._feedforward_intermediate_dense(x)
x = self._feedforward_output_dense(x)
x = self._feedforward_dropout(x)
x = x + residual
x = self._feedforward_layernorm(x)
x = self._feedforward_layer_norm(x)

return x

Expand Down
Loading
Loading