diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py index 50392f2150..614ee153ea 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_preprocessor_test.py @@ -1,111 +1,111 @@ -# Copyright 2022 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tests for DistilBERT masked language model preprocessor layer.""" - -import tensorflow as tf - -from keras_nlp.backend import keras -from keras_nlp.models.distil_bert.distil_bert_masked_lm_preprocessor import ( - DistilBertMaskedLMPreprocessor, -) -from keras_nlp.models.distil_bert.distil_bert_tokenizer import ( - DistilBertTokenizer, -) -from keras_nlp.tests.test_case import TestCase - - -class DistilBertMaskedLMPreprocessorTest(TestCase): - def setUp(self): - self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] - self.vocab += ["THE", "QUICK", "BROWN", "FOX"] - self.vocab += ["the", "quick", "brown", "fox"] - - self.preprocessor = DistilBertMaskedLMPreprocessor( - tokenizer=DistilBertTokenizer( - vocabulary=self.vocab, - ), - # Simplify our testing by masking every available token. - mask_selection_rate=1.0, - mask_token_rate=1.0, - random_token_rate=0.0, - mask_selection_length=5, - sequence_length=8, - ) - - def test_preprocess_strings(self): - input_data = " THE QUICK BROWN FOX." - - x, y, sw = self.preprocessor(input_data) - self.assertAllEqual(x["token_ids"], [2, 4, 4, 4, 4, 4, 3, 0]) - self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0]) - self.assertAllEqual(x["mask_positions"], [1, 2, 3, 4, 5]) - self.assertAllEqual(y, [5, 6, 7, 8, 1]) - self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0, 1.0]) - - def test_preprocess_list_of_strings(self): - input_data = [" THE QUICK BROWN FOX."] * 4 - - x, y, sw = self.preprocessor(input_data) - self.assertAllEqual(x["token_ids"], [[2, 4, 4, 4, 4, 4, 3, 0]] * 4) - self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4) - self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4, 5]] * 4) - self.assertAllEqual(y, [[5, 6, 7, 8, 1]] * 4) - self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0, 1.0]] * 4) - - def test_preprocess_dataset(self): - sentences = tf.constant([" THE QUICK BROWN FOX."] * 4) - ds = tf.data.Dataset.from_tensor_slices(sentences) - ds = ds.map(self.preprocessor) - x, y, sw = ds.batch(4).take(1).get_single_element() - self.assertAllEqual(x["token_ids"], [[2, 4, 4, 4, 4, 4, 3, 0]] * 4) - self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4) - self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4, 5]] * 4) - self.assertAllEqual(y, [[5, 6, 7, 8, 1]] * 4) - self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0, 1.0]] * 4) - - def test_mask_multiple_sentences(self): - sentence_one = tf.constant(" THE QUICK") - sentence_two = tf.constant(" BROWN FOX.") - - x, y, sw = self.preprocessor((sentence_one, sentence_two)) - self.assertAllEqual(x["token_ids"], [2, 4, 4, 3, 4, 4, 4, 3]) - self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1]) - self.assertAllEqual(x["mask_positions"], [1, 2, 4, 5, 6]) - self.assertAllEqual(y, [5, 6, 7, 8, 1]) - self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0, 1.0]) - - def test_no_masking_zero_rate(self): - no_mask_preprocessor = DistilBertMaskedLMPreprocessor( - self.preprocessor.tokenizer, - mask_selection_rate=0.0, - mask_selection_length=5, - sequence_length=8, - ) - input_data = " THE QUICK BROWN FOX." - - x, y, sw = no_mask_preprocessor(input_data) - self.assertAllEqual(x["token_ids"], [2, 5, 6, 7, 8, 1, 3, 0]) - self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0]) - self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0, 0]) - self.assertAllEqual(y, [0, 0, 0, 0, 0]) - self.assertAllEqual(sw, [0.0, 0.0, 0.0, 0.0, 0.0]) - - def test_serialization(self): - config = keras.saving.serialize_keras_object(self.preprocessor) - new_preprocessor = keras.saving.deserialize_keras_object(config) - self.assertEqual( - new_preprocessor.get_config(), - self.preprocessor.get_config(), - ) +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DistilBERT masked language model preprocessor layer.""" + +import tensorflow as tf + +from keras_nlp.backend import keras +from keras_nlp.models.distil_bert.distil_bert_masked_lm_preprocessor import ( + DistilBertMaskedLMPreprocessor, +) +from keras_nlp.models.distil_bert.distil_bert_tokenizer import ( + DistilBertTokenizer, +) +from keras_nlp.tests.test_case import TestCase + + +class DistilBertMaskedLMPreprocessorTest(TestCase): + def setUp(self): + self.vocab = ["[PAD]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"] + self.vocab += ["THE", "QUICK", "BROWN", "FOX"] + self.vocab += ["the", "quick", "brown", "fox"] + + self.preprocessor = DistilBertMaskedLMPreprocessor( + tokenizer=DistilBertTokenizer( + vocabulary=self.vocab, + ), + # Simplify our testing by masking every available token. + mask_selection_rate=1.0, + mask_token_rate=1.0, + random_token_rate=0.0, + mask_selection_length=5, + sequence_length=8, + ) + + def test_preprocess_strings(self): + input_data = " THE QUICK BROWN FOX." + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [2, 4, 4, 4, 4, 4, 3, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0]) + self.assertAllEqual(x["mask_positions"], [1, 2, 3, 4, 5]) + self.assertAllEqual(y, [5, 6, 7, 8, 1]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0, 1.0]) + + def test_preprocess_list_of_strings(self): + input_data = [" THE QUICK BROWN FOX."] * 4 + + x, y, sw = self.preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [[2, 4, 4, 4, 4, 4, 3, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4, 5]] * 4) + self.assertAllEqual(y, [[5, 6, 7, 8, 1]] * 4) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0, 1.0]] * 4) + + def test_preprocess_dataset(self): + sentences = tf.constant([" THE QUICK BROWN FOX."] * 4) + ds = tf.data.Dataset.from_tensor_slices(sentences) + ds = ds.map(self.preprocessor) + x, y, sw = ds.batch(4).take(1).get_single_element() + self.assertAllEqual(x["token_ids"], [[2, 4, 4, 4, 4, 4, 3, 0]] * 4) + self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4) + self.assertAllEqual(x["mask_positions"], [[1, 2, 3, 4, 5]] * 4) + self.assertAllEqual(y, [[5, 6, 7, 8, 1]] * 4) + self.assertAllEqual(sw, [[1.0, 1.0, 1.0, 1.0, 1.0]] * 4) + + def test_mask_multiple_sentences(self): + sentence_one = tf.constant(" THE QUICK") + sentence_two = tf.constant(" BROWN FOX.") + + x, y, sw = self.preprocessor((sentence_one, sentence_two)) + self.assertAllEqual(x["token_ids"], [2, 4, 4, 3, 4, 4, 4, 3]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 1]) + self.assertAllEqual(x["mask_positions"], [1, 2, 4, 5, 6]) + self.assertAllEqual(y, [5, 6, 7, 8, 1]) + self.assertAllEqual(sw, [1.0, 1.0, 1.0, 1.0, 1.0]) + + def test_no_masking_zero_rate(self): + no_mask_preprocessor = DistilBertMaskedLMPreprocessor( + self.preprocessor.tokenizer, + mask_selection_rate=0.0, + mask_selection_length=5, + sequence_length=8, + ) + input_data = " THE QUICK BROWN FOX." + + x, y, sw = no_mask_preprocessor(input_data) + self.assertAllEqual(x["token_ids"], [2, 5, 6, 7, 8, 1, 3, 0]) + self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 1, 0]) + self.assertAllEqual(x["mask_positions"], [0, 0, 0, 0, 0]) + self.assertAllEqual(y, [0, 0, 0, 0, 0]) + self.assertAllEqual(sw, [0.0, 0.0, 0.0, 0.0, 0.0]) + + def test_serialization(self): + config = keras.saving.serialize_keras_object(self.preprocessor) + new_preprocessor = keras.saving.deserialize_keras_object(config) + self.assertEqual( + new_preprocessor.get_config(), + self.preprocessor.get_config(), + ) diff --git a/keras_nlp/models/f_net/f_net_classifier.py b/keras_nlp/models/f_net/f_net_classifier.py index 3bb812090a..1a84a83f48 100644 --- a/keras_nlp/models/f_net/f_net_classifier.py +++ b/keras_nlp/models/f_net/f_net_classifier.py @@ -1,168 +1,168 @@ -# Copyright 2022 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""FNet classification model.""" - -import copy - -from keras_nlp.api_export import keras_nlp_export -from keras_nlp.backend import keras -from keras_nlp.models.f_net.f_net_backbone import FNetBackbone -from keras_nlp.models.f_net.f_net_backbone import f_net_kernel_initializer -from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor -from keras_nlp.models.f_net.f_net_presets import backbone_presets -from keras_nlp.models.task import Task -from keras_nlp.utils.python_utils import classproperty - - -@keras_nlp_export("keras_nlp.models.FNetClassifier") -class FNetClassifier(Task): - """An end-to-end f_net model for classification tasks. - - This model attaches a classification head to a - `keras_nlp.model.FNetBackbone` instance, mapping from the backbone outputs - to logits suitable for a classification task. For usage of this model with - pre-trained weights, use the `from_preset()` constructor. - - This model can optionally be configured with a `preprocessor` layer, in - which case it will automatically apply preprocessing to raw inputs during - `fit()`, `predict()`, and `evaluate()`. This is done by default when - creating the model with `from_preset()`. - - Disclaimer: Pre-trained models are provided on an "as is" basis, without - warranties or conditions of any kind. - - Args: - backbone: A `keras_nlp.models.FNetBackbone` instance. - num_classes: int. Number of classes to predict. - preprocessor: A `keras_nlp.models.FNetPreprocessor` or `None`. If - `None`, this model will not apply preprocessing, and inputs should - be preprocessed before calling the model. - activation: Optional `str` or callable. The - activation function to use on the model outputs. Set - `activation="softmax"` to return output probabilities. - Defaults to `None`. - hidden_dim: int. The size of the pooler layer. - dropout: float. The dropout probability value, applied after the dense - layer. - - Examples: - - Raw string data. - ```python - features = ["The quick brown fox jumped.", "I forgot my homework."] - labels = [0, 3] - - # Pretrained classifier. - classifier = keras_nlp.models.FNetClassifier.from_preset( - "f_net_base_en", - num_classes=4, - ) - classifier.fit(x=features, y=labels, batch_size=2) - classifier.predict(x=features, batch_size=2) - - # Re-compile (e.g., with a new learning rate). - classifier.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), - optimizer=keras.optimizers.Adam(5e-5), - jit_compile=True, - ) - # Access backbone programatically (e.g., to change `trainable`). - classifier.backbone.trainable = False - # Fit again. - classifier.fit(x=features, y=labels, batch_size=2) - ``` - - Preprocessed integer data. - ```python - features = { - "token_ids": np.ones(shape=(2, 12), dtype="int32"), - "segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2), - } - labels = [0, 3] - - # Pretrained classifier without preprocessing. - classifier = keras_nlp.models.FNetClassifier.from_preset( - "f_net_base_en", - num_classes=4, - preprocessor=None, - ) - classifier.fit(x=features, y=labels, batch_size=2) - ``` - """ - - def __init__( - self, - backbone, - num_classes, - preprocessor=None, - activation=None, - dropout=0.1, - **kwargs, - ): - inputs = backbone.input - pooled = backbone(inputs)["pooled_output"] - pooled = keras.layers.Dropout(dropout)(pooled) - outputs = keras.layers.Dense( - num_classes, - kernel_initializer=f_net_kernel_initializer(), - activation=activation, - name="logits", - )(pooled) - # Instantiate using Functional API Model constructor - super().__init__( - inputs=inputs, - outputs=outputs, - include_preprocessing=preprocessor is not None, - **kwargs, - ) - # All references to `self` below this line - self.backbone = backbone - self.preprocessor = preprocessor - self.num_classes = num_classes - self.activation = keras.activations.get(activation) - self.dropout = dropout - - logit_output = self.activation == keras.activations.linear - self.compile( - loss=keras.losses.SparseCategoricalCrossentropy( - from_logits=logit_output - ), - optimizer=keras.optimizers.Adam(5e-5), - metrics=[keras.metrics.SparseCategoricalAccuracy()], - jit_compile=True, - ) - - def get_config(self): - config = super().get_config() - config.update( - { - "num_classes": self.num_classes, - "dropout": self.dropout, - "activation": keras.activations.serialize(self.activation), - } - ) - return config - - @classproperty - def backbone_cls(cls): - return FNetBackbone - - @classproperty - def preprocessor_cls(cls): - return FNetPreprocessor - - @classproperty - def presets(cls): - return copy.deepcopy(backbone_presets) +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FNet classification model.""" + +import copy + +from keras_nlp.api_export import keras_nlp_export +from keras_nlp.backend import keras +from keras_nlp.models.f_net.f_net_backbone import FNetBackbone +from keras_nlp.models.f_net.f_net_backbone import f_net_kernel_initializer +from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor +from keras_nlp.models.f_net.f_net_presets import backbone_presets +from keras_nlp.models.task import Task +from keras_nlp.utils.python_utils import classproperty + + +@keras_nlp_export("keras_nlp.models.FNetClassifier") +class FNetClassifier(Task): + """An end-to-end f_net model for classification tasks. + + This model attaches a classification head to a + `keras_nlp.model.FNetBackbone` instance, mapping from the backbone outputs + to logits suitable for a classification task. For usage of this model with + pre-trained weights, use the `from_preset()` constructor. + + This model can optionally be configured with a `preprocessor` layer, in + which case it will automatically apply preprocessing to raw inputs during + `fit()`, `predict()`, and `evaluate()`. This is done by default when + creating the model with `from_preset()`. + + Disclaimer: Pre-trained models are provided on an "as is" basis, without + warranties or conditions of any kind. + + Args: + backbone: A `keras_nlp.models.FNetBackbone` instance. + num_classes: int. Number of classes to predict. + preprocessor: A `keras_nlp.models.FNetPreprocessor` or `None`. If + `None`, this model will not apply preprocessing, and inputs should + be preprocessed before calling the model. + activation: Optional `str` or callable. The + activation function to use on the model outputs. Set + `activation="softmax"` to return output probabilities. + Defaults to `None`. + hidden_dim: int. The size of the pooler layer. + dropout: float. The dropout probability value, applied after the dense + layer. + + Examples: + + Raw string data. + ```python + features = ["The quick brown fox jumped.", "I forgot my homework."] + labels = [0, 3] + + # Pretrained classifier. + classifier = keras_nlp.models.FNetClassifier.from_preset( + "f_net_base_en", + num_classes=4, + ) + classifier.fit(x=features, y=labels, batch_size=2) + classifier.predict(x=features, batch_size=2) + + # Re-compile (e.g., with a new learning rate). + classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + jit_compile=True, + ) + # Access backbone programatically (e.g., to change `trainable`). + classifier.backbone.trainable = False + # Fit again. + classifier.fit(x=features, y=labels, batch_size=2) + ``` + + Preprocessed integer data. + ```python + features = { + "token_ids": np.ones(shape=(2, 12), dtype="int32"), + "segment_ids": np.array([[0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0]] * 2), + } + labels = [0, 3] + + # Pretrained classifier without preprocessing. + classifier = keras_nlp.models.FNetClassifier.from_preset( + "f_net_base_en", + num_classes=4, + preprocessor=None, + ) + classifier.fit(x=features, y=labels, batch_size=2) + ``` + """ + + def __init__( + self, + backbone, + num_classes, + preprocessor=None, + activation=None, + dropout=0.1, + **kwargs, + ): + inputs = backbone.input + pooled = backbone(inputs)["pooled_output"] + pooled = keras.layers.Dropout(dropout)(pooled) + outputs = keras.layers.Dense( + num_classes, + kernel_initializer=f_net_kernel_initializer(), + activation=activation, + name="logits", + )(pooled) + # Instantiate using Functional API Model constructor + super().__init__( + inputs=inputs, + outputs=outputs, + include_preprocessing=preprocessor is not None, + **kwargs, + ) + # All references to `self` below this line + self.backbone = backbone + self.preprocessor = preprocessor + self.num_classes = num_classes + self.activation = keras.activations.get(activation) + self.dropout = dropout + + logit_output = self.activation == keras.activations.linear + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy( + from_logits=logit_output + ), + optimizer=keras.optimizers.Adam(5e-5), + metrics=[keras.metrics.SparseCategoricalAccuracy()], + jit_compile=True, + ) + + def get_config(self): + config = super().get_config() + config.update( + { + "num_classes": self.num_classes, + "dropout": self.dropout, + "activation": keras.activations.serialize(self.activation), + } + ) + return config + + @classproperty + def backbone_cls(cls): + return FNetBackbone + + @classproperty + def preprocessor_cls(cls): + return FNetPreprocessor + + @classproperty + def presets(cls): + return copy.deepcopy(backbone_presets) diff --git a/keras_nlp/models/f_net/f_net_classifier_test.py b/keras_nlp/models/f_net/f_net_classifier_test.py index 917886b24a..f30e043da1 100644 --- a/keras_nlp/models/f_net/f_net_classifier_test.py +++ b/keras_nlp/models/f_net/f_net_classifier_test.py @@ -1,148 +1,148 @@ -# Copyright 2023 The KerasNLP Authors -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Tests for FNet classification model.""" - -import io -import os - -import numpy as np -import pytest -import sentencepiece -import tensorflow as tf - -from keras_nlp.backend import keras -from keras_nlp.backend import ops -from keras_nlp.models.f_net.f_net_backbone import FNetBackbone -from keras_nlp.models.f_net.f_net_classifier import FNetClassifier -from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor -from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer -from keras_nlp.tests.test_case import TestCase - - -class FNetClassifierTest(TestCase): - def setUp(self): - # Setup Model - bytes_io = io.BytesIO() - vocab_data = tf.data.Dataset.from_tensor_slices( - ["the quick brown fox", "the earth is round"] - ) - - sentencepiece.SentencePieceTrainer.train( - sentence_iterator=vocab_data.as_numpy_iterator(), - model_writer=bytes_io, - vocab_size=12, - model_type="WORD", - pad_id=3, - unk_id=0, - bos_id=4, - eos_id=5, - pad_piece="", - unk_piece="", - bos_piece="[CLS]", - eos_piece="[SEP]", - user_defined_symbols="[MASK]", - ) - - self.proto = bytes_io.getvalue() - - self.preprocessor = FNetPreprocessor( - tokenizer=FNetTokenizer(proto=self.proto), - sequence_length=8, - ) - self.backbone = FNetBackbone( - vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), - num_layers=2, - hidden_dim=2, - intermediate_dim=4, - max_sequence_length=self.preprocessor.packer.sequence_length, - ) - self.classifier = FNetClassifier( - self.backbone, - num_classes=4, - preprocessor=self.preprocessor, - # Check we handle serialization correctly. - activation=keras.activations.softmax, - ) - - # Setup data. - self.raw_batch = [ - "the quick brown fox.", - "the slow brown fox.", - ] - self.preprocessed_batch = self.preprocessor(self.raw_batch) - self.raw_dataset = tf.data.Dataset.from_tensor_slices( - (self.raw_batch, np.ones((2,))) - ).batch(2) - self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) - - def test_valid_call_classifier(self): - self.classifier(self.preprocessed_batch) - - def test_classifier_predict(self): - preds1 = self.classifier.predict(self.raw_batch) - self.classifier.preprocessor = None - preds2 = self.classifier.predict(self.preprocessed_batch) - # Assert predictions match. - self.assertAllClose(preds1, preds2) - # Assert valid softmax output. - self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0]) - - def test_fnet_classifier_fit(self): - self.classifier.fit(self.raw_dataset) - self.classifier.preprocessor = None - self.classifier.fit(self.preprocessed_dataset) - - def test_classifier_fit_no_xla(self): - self.classifier.preprocessor = None - self.classifier.compile( - loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), - jit_compile=False, - ) - self.classifier.fit(self.preprocessed_dataset) - - def test_serialization(self): - # Defaults. - original = FNetClassifier( - self.backbone, - num_classes=2, - ) - config = keras.saving.serialize_keras_object(original) - restored = keras.saving.deserialize_keras_object(config) - self.assertEqual(restored.get_config(), original.get_config()) - # With options. - original = FNetClassifier( - self.backbone, - num_classes=4, - preprocessor=self.preprocessor, - activation=keras.activations.softmax, - name="test", - trainable=False, - ) - config = keras.saving.serialize_keras_object(original) - restored = keras.saving.deserialize_keras_object(config) - self.assertEqual(restored.get_config(), original.get_config()) - - @pytest.mark.large - def test_saved_model(self): - model_output = self.classifier.predict(self.raw_batch) - path = os.path.join(self.get_temp_dir(), "model.keras") - self.classifier.save(path, save_format="keras_v3") - restored_model = keras.models.load_model(path) - - # Check we got the real object back. - self.assertIsInstance(restored_model, FNetClassifier) - - # Check that output matches. - restored_output = restored_model.predict(self.raw_batch) - self.assertAllClose(model_output, restored_output) +# Copyright 2023 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for FNet classification model.""" + +import io +import os + +import numpy as np +import pytest +import sentencepiece +import tensorflow as tf + +from keras_nlp.backend import keras +from keras_nlp.backend import ops +from keras_nlp.models.f_net.f_net_backbone import FNetBackbone +from keras_nlp.models.f_net.f_net_classifier import FNetClassifier +from keras_nlp.models.f_net.f_net_preprocessor import FNetPreprocessor +from keras_nlp.models.f_net.f_net_tokenizer import FNetTokenizer +from keras_nlp.tests.test_case import TestCase + + +class FNetClassifierTest(TestCase): + def setUp(self): + # Setup Model + bytes_io = io.BytesIO() + vocab_data = tf.data.Dataset.from_tensor_slices( + ["the quick brown fox", "the earth is round"] + ) + + sentencepiece.SentencePieceTrainer.train( + sentence_iterator=vocab_data.as_numpy_iterator(), + model_writer=bytes_io, + vocab_size=12, + model_type="WORD", + pad_id=3, + unk_id=0, + bos_id=4, + eos_id=5, + pad_piece="", + unk_piece="", + bos_piece="[CLS]", + eos_piece="[SEP]", + user_defined_symbols="[MASK]", + ) + + self.proto = bytes_io.getvalue() + + self.preprocessor = FNetPreprocessor( + tokenizer=FNetTokenizer(proto=self.proto), + sequence_length=8, + ) + self.backbone = FNetBackbone( + vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(), + num_layers=2, + hidden_dim=2, + intermediate_dim=4, + max_sequence_length=self.preprocessor.packer.sequence_length, + ) + self.classifier = FNetClassifier( + self.backbone, + num_classes=4, + preprocessor=self.preprocessor, + # Check we handle serialization correctly. + activation=keras.activations.softmax, + ) + + # Setup data. + self.raw_batch = [ + "the quick brown fox.", + "the slow brown fox.", + ] + self.preprocessed_batch = self.preprocessor(self.raw_batch) + self.raw_dataset = tf.data.Dataset.from_tensor_slices( + (self.raw_batch, np.ones((2,))) + ).batch(2) + self.preprocessed_dataset = self.raw_dataset.map(self.preprocessor) + + def test_valid_call_classifier(self): + self.classifier(self.preprocessed_batch) + + def test_classifier_predict(self): + preds1 = self.classifier.predict(self.raw_batch) + self.classifier.preprocessor = None + preds2 = self.classifier.predict(self.preprocessed_batch) + # Assert predictions match. + self.assertAllClose(preds1, preds2) + # Assert valid softmax output. + self.assertAllClose(ops.sum(preds2, axis=-1), [1.0, 1.0]) + + def test_fnet_classifier_fit(self): + self.classifier.fit(self.raw_dataset) + self.classifier.preprocessor = None + self.classifier.fit(self.preprocessed_dataset) + + def test_classifier_fit_no_xla(self): + self.classifier.preprocessor = None + self.classifier.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False), + jit_compile=False, + ) + self.classifier.fit(self.preprocessed_dataset) + + def test_serialization(self): + # Defaults. + original = FNetClassifier( + self.backbone, + num_classes=2, + ) + config = keras.saving.serialize_keras_object(original) + restored = keras.saving.deserialize_keras_object(config) + self.assertEqual(restored.get_config(), original.get_config()) + # With options. + original = FNetClassifier( + self.backbone, + num_classes=4, + preprocessor=self.preprocessor, + activation=keras.activations.softmax, + name="test", + trainable=False, + ) + config = keras.saving.serialize_keras_object(original) + restored = keras.saving.deserialize_keras_object(config) + self.assertEqual(restored.get_config(), original.get_config()) + + @pytest.mark.large + def test_saved_model(self): + model_output = self.classifier.predict(self.raw_batch) + path = os.path.join(self.get_temp_dir(), "model.keras") + self.classifier.save(path, save_format="keras_v3") + restored_model = keras.models.load_model(path) + + # Check we got the real object back. + self.assertIsInstance(restored_model, FNetClassifier) + + # Check that output matches. + restored_output = restored_model.predict(self.raw_batch) + self.assertAllClose(model_output, restored_output)