Skip to content

Commit

Permalink
tokenizer now working with ds.map(...)
Browse files Browse the repository at this point in the history
  • Loading branch information
susnato committed Aug 12, 2023
1 parent e03e9bb commit e431196
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 25 deletions.
67 changes: 42 additions & 25 deletions keras_nlp/models/xlnet/xlnet_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.tokenizers.sentence_piece_tokenizer import SentencePieceTokenizer
from keras_nlp.utils.tensor_utils import assert_tf_text_installed

try:
import unicodedata
import tensorflow_text as tf_text
except ImportError:
unicodedata = None
tf_text = None


@keras_nlp_export("keras_nlp.models.XLNetTokenizer")
Expand Down Expand Up @@ -90,6 +91,8 @@ class XLNetTokenizer(SentencePieceTokenizer):
"""

def __init__(self, proto, **kwargs):
assert_tf_text_installed(self.__class__.__name__)

super().__init__(proto=proto, **kwargs)

# Check for necessary special tokens.
Expand Down Expand Up @@ -129,30 +132,28 @@ def preprocess_text(self, inputs):
"""Preprocesses the text. This method removes spaces and accents from the text."""

# remove space
outputs = " ".join(inputs.strip().split())
outputs = outputs.replace("``", '"').replace("''", '"')
outputs = tf.strings.split(tf.strings.strip(inputs), sep=" ")
outputs = tf.strings.reduce_join(
outputs, separator=" ", keepdims=True, axis=-1
)
outputs = tf.strings.regex_replace(outputs, pattern="``", rewrite='"')
outputs = tf.strings.regex_replace(outputs, pattern="''", rewrite='"')

# remove accents
outputs = unicodedata.normalize("NFKD", outputs)
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
outputs = tf_text.normalize_utf8(outputs, "nfkd")

return outputs

def tokenize(self, text):
"""Tokenize a string."""

# check if there are multiple examples present or not
is_batched = isinstance(text, list)
if not is_batched:
text = [text]
def postprocess(self, batch_token_ids):
batch_token_ids = (
tf.squeeze(batch_token_ids, -2)
if tf.rank(batch_token_ids) > 2
else batch_token_ids
)

tokenized_text = []
for each_text in text:
each_text = self.preprocess_text(each_text)
pieces = [
self.id_to_token(token_id)
for token_id in super().tokenize(each_text)
]
for each_token_ids in batch_token_ids:
pieces = [self.id_to_token(token_id) for token_id in each_token_ids]

new_pieces = []
for piece in pieces:
Expand All @@ -164,7 +165,9 @@ def tokenize(self, text):
cur_pieces = [
self.id_to_token(cur_piece_id)
for cur_piece_id in super().tokenize(
piece[:-1].replace("▁", "")
tf.strings.regex_replace(
piece[:-1], pattern="▁", rewrite=""
)
)
]
if piece[0] != "▁" and cur_pieces[0][0] == "▁":
Expand All @@ -183,14 +186,28 @@ def tokenize(self, text):
]
# add sep_token and cls_token in the end.
new_pieces.extend([self.sep_token_id, self.cls_token_id])

tokenized_text.append(new_pieces)

# if there are multiple examples present, then return a `tf.RaggedTensor`.
if is_batched:
return tf.ragged.constant(tokenized_text)
tokenized_text = tf.ragged.constant(tokenized_text)

return tokenized_text

def tokenize(self, text):
"""Tokenize a string."""

text = self.preprocess_text(text)
token_ids = super().tokenize(text)
token_ids = tf.py_function(
func=self.postprocess,
inp=[token_ids],
Tout=tf.RaggedTensorSpec(shape=[None, None], dtype=tf.int32),
)

# if there is only one example in the batch then output tf.Tensor otherwise tf.RaggedTensor
if isinstance(text, str):
return token_ids.to_tensor()

return tokenized_text[0]
return token_ids

def detokenize(self, inputs):
"""Detokenize the input_ids into text."""
Expand Down
16 changes: 16 additions & 0 deletions keras_nlp/models/xlnet/xlnet_tokenizer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,22 @@ def test_tokenize_batch(self):
output, [[7, 12, 8, 10, 6, 5], [7, 9, 11, 13, 6, 5]]
)

def test_tokenize_ds(self):
input_ds = tf.data.Dataset.from_tensor_slices(
["the quick brown fox", "the earth is round"]
)
input_ds = input_ds.map(self.tokenizer)
outputs = []
for each_item in input_ds.take(2):
self.assertTrue(isinstance(each_item, tf.RaggedTensor))
outputs.append(each_item.to_tensor())

outputs = tf.squeeze(tf.convert_to_tensor(outputs), 1)
self.assertAllEqual(
outputs,
tf.convert_to_tensor([[7, 12, 8, 10, 6, 5], [7, 9, 11, 13, 6, 5]]),
)

def test_detokenize(self):
input_data = [[7, 12, 8, 10, 6, 5]]
output = self.tokenizer.detokenize(input_data)
Expand Down

0 comments on commit e431196

Please sign in to comment.