Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
susnato committed Aug 8, 2023
1 parent 7b01356 commit 1f12a98
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 19 deletions.
18 changes: 3 additions & 15 deletions keras_nlp/models/xlnet/xlnet_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ class XLNetBackbone(Backbone):
bias_initializer: string or `keras.initializers` initializer,
defaults to "zeros". The bias initializer for
the dense and multiheaded relative attention layers.
layer_norm_epsilon: float, defaults to 1e-12. The epsilon value in layer
normalization components.
**kwargs: other keyword arguments.
Call Args:
token_ids: Indices of input sequence tokens in the vocabulary of shape
Expand All @@ -67,12 +64,6 @@ class XLNetBackbone(Backbone):
padding_mask: Mask to avoid performing attention on padding token indices
of shape `[batch_size, sequence_length]`.
Returns:
last_hidden_state: last hidden state of query state of shape
`[batch_size, num_predict, hidden_dim]` if query state is not None
otherwise last hidden state of content of shape
`[batch_size, sequence_length, hidden_dim]`.
Examples:
```python
import numpy as np
Expand Down Expand Up @@ -113,7 +104,6 @@ def __init__(
activation="gelu",
kernel_initializer_range=0.02,
bias_initializer="zeros",
layer_norm_epsilon=1e-12,
**kwargs,
):
# Inputs
Expand Down Expand Up @@ -158,7 +148,7 @@ def __init__(
intermediate_dim=intermediate_dim,
dropout=dropout,
activation=activation,
layer_norm_epsilon=layer_norm_epsilon,
layer_norm_epsilon=1e-12,
kernel_initializer_range=kernel_initializer_range,
bias_initializer=bias_initializer,
name=f"xlnet_encoder_{i}",
Expand All @@ -178,7 +168,7 @@ def __init__(
"padding_mask": padding_mask,
"segment_ids": segment_ids,
},
outputs={"last_hidden_state": output},
outputs=output,
**kwargs,
)

Expand All @@ -192,7 +182,6 @@ def __init__(
self.activation = activation
self.kernel_initializer_range = kernel_initializer_range
self.bias_initializer = bias_initializer
self.layer_norm_epsilon = layer_norm_epsilon

def get_config(self):
config = super().get_config()
Expand All @@ -203,11 +192,10 @@ def get_config(self):
"num_heads": self.num_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"dropout": self.layer_norm_epsilon,
"dropout": self.dropout,
"activation": self.activation,
"kernel_initializer_range": self.kernel_initializer_range,
"bias_initializer": self.bias_initializer,
"layer_norm_epsilon": self.layer_norm_epsilon,
}
)
return config
Expand Down
7 changes: 3 additions & 4 deletions tools/checkpoint_conversion/convert_xlnet_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,21 +251,20 @@
print(hf_preds["last_hidden_state"])

knlp_preds = knlp_model(tokenized_knlp, training=False)
print(knlp_preds["last_hidden_state"], end="\n\n")
print(knlp_preds, end="\n\n")

print(
"Outputs matching or not for Last Hidden State : ",
np.allclose(
hf_preds["last_hidden_state"]
.numpy()
.reshape(-1, hf_model.config.d_model),
knlp_preds["last_hidden_state"]
.numpy()
.reshape(-1, hf_model.config.d_model),
knlp_preds.numpy().reshape(-1, hf_model.config.d_model),
atol=1e-3,
),
)

# won't work since the recent version of the model doesn't return any mems!
if check_mems:
for i in range(hf_model.config.n_layer):
print(
Expand Down

0 comments on commit 1f12a98

Please sign in to comment.