Skip to content

Commit

Permalink
Port metrics to multi-backend
Browse files Browse the repository at this point in the history
  • Loading branch information
mattdangerw committed Jul 28, 2023
1 parent b117fbc commit 4eeba35
Show file tree
Hide file tree
Showing 11 changed files with 242 additions and 482 deletions.
30 changes: 13 additions & 17 deletions keras_nlp/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.utils.tensor_utils import assert_tf_backend
from keras_nlp.backend import ops
from keras_nlp.utils.tensor_utils import is_floating_dtype
from keras_nlp.utils.tensor_utils import tensor_to_list

Expand Down Expand Up @@ -112,8 +112,6 @@ def __init__(
name="bleu",
**kwargs,
):
assert_tf_backend(self.__class__.__name__)

super().__init__(name=name, dtype=dtype, **kwargs)

if not is_floating_dtype(dtype):
Expand Down Expand Up @@ -290,8 +288,10 @@ def _corpus_bleu(
)

def _calculate_bleu_score(self, references, translation):
references = tensor_to_list(references)
translation = tensor_to_list(translation)
if isinstance(references, (tf.Tensor, tf.RaggedTensor)):
references = tensor_to_list(references)
if isinstance(translation, (tf.Tensor, tf.RaggedTensor)):
translation = tensor_to_list(translation)

matches = self._matches.numpy()
possible_matches = self._possible_matches.numpy()
Expand All @@ -315,11 +315,11 @@ def _calculate_bleu_score(self, references, translation):
smooth=self.smooth,
)
return (
tf.constant(bleu_score, dtype=self.dtype),
tf.constant(matches, dtype=self.dtype),
tf.constant(possible_matches, dtype=self.dtype),
tf.constant(translation_length, dtype=self.dtype),
tf.constant(reference_length, dtype=self.dtype),
bleu_score,
matches,
possible_matches,
translation_length,
reference_length,
)

def update_state(self, y_true, y_pred, sample_weight=None):
Expand Down Expand Up @@ -357,11 +357,7 @@ def validate_and_fix_rank(inputs, tensor_name, base_rank=0):
possible_matches,
translation_length,
reference_length,
) = tf.py_function(
func=self._calculate_bleu_score,
inp=[y_true, y_pred],
Tout=[self.dtype, self.dtype, self.dtype, self.dtype, self.dtype],
)
) = self._calculate_bleu_score(y_true, y_pred)

self._matches.assign(matches)
self._possible_matches.assign(possible_matches)
Expand All @@ -374,10 +370,10 @@ def result(self):

def reset_state(self):
self._matches.assign(
tf.zeros(shape=(self.max_order,), dtype=self.dtype)
ops.zeros(shape=(self.max_order,), dtype=self.dtype)
)
self._possible_matches.assign(
tf.zeros(shape=(self.max_order,), dtype=self.dtype)
ops.zeros(shape=(self.max_order,), dtype=self.dtype)
)
self._translation_length.assign(0.0)
self._reference_length.assign(0.0)
Expand Down
120 changes: 43 additions & 77 deletions keras_nlp/metrics/bleu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from keras_nlp.tokenizers.byte_tokenizer import ByteTokenizer


@pytest.mark.tf_only
class BleuTest(TestCase):
def test_initialization(self):
bleu = Bleu()
Expand Down Expand Up @@ -69,64 +68,38 @@ def test_2d_list_input(self):
bleu_val = bleu(y_true, y_pred)
self.assertAlmostEqual(bleu_val, 0.243, delta=1e-3)

def test_1d_tensor_input(self):
bleu = Bleu()
y_true = tf.ragged.constant(
[
["He eats a sweet apple."],
["Silicon Valley is one of my favourite shows!"],
]
)
y_pred = tf.constant(
[
"He He He eats sweet apple which is a fruit.",
"I love Silicon Valley, it's one of my favourite shows.",
]
)

bleu_val = bleu(y_true, y_pred)
self.assertAlmostEqual(bleu_val, 0.243, delta=1e-3)

def test_2d_tensor_input(self):
bleu = Bleu()
y_true = tf.constant(
[
[["He eats a sweet apple."]],
[["Silicon Valley is one of my favourite shows!"]],
]
)
y_pred = tf.constant(
[
["He He He eats sweet apple which is a fruit."],
["I love Silicon Valley, it's one of my favourite shows."],
]
)

bleu_val = bleu(y_true, y_pred)
self.assertAlmostEqual(bleu_val, 0.243, delta=1e-3)

def test_custom_tokenizer(self):
byte_tokenizer = ByteTokenizer()
bleu = Bleu(tokenizer=byte_tokenizer)
y_true = tf.ragged.constant(
[
["He eats a sweet apple."],
["Silicon Valley is one of my favourite shows!"],
]
)
y_pred = tf.constant(
[
"He He He eats sweet apple which is a fruit.",
"I love Silicon Valley, it's one of my favourite shows.",
]
)
y_true = [
["He eats a sweet apple."],
["Silicon Valley is one of my favourite shows!"],
]
y_pred = [
"He He He eats sweet apple which is a fruit.",
"I love Silicon Valley, it's one of my favourite shows.",
]

bleu_val = bleu(y_true, y_pred)
self.assertAlmostEqual(bleu_val, 0.609, delta=1e-3)

def test_different_order(self):
bleu = Bleu(max_order=5)
y_true = tf.ragged.constant(
y_true = [
["He eats a sweet apple."],
["Silicon Valley is one of my favourite shows!"],
]
y_pred = [
"He He He eats sweet apple which is a fruit.",
"I love Silicon Valley, it's one of my favourite shows.",
]

bleu_val = bleu(y_true, y_pred)
self.assertAlmostEqual(bleu_val, 0.188, delta=1e-3)

def test_tensor_input(self):
bleu = Bleu()
y_true = tf.constant(
[
["He eats a sweet apple."],
["Silicon Valley is one of my favourite shows!"],
Expand All @@ -140,8 +113,9 @@ def test_different_order(self):
)

bleu_val = bleu(y_true, y_pred)
self.assertAlmostEqual(bleu_val, 0.188, delta=1e-3)
self.assertAlmostEqual(bleu_val, 0.243, delta=1e-3)

@pytest.mark.tf_only # string model output only applies to tf.
def test_model_compile(self):
inputs = keras.Input(shape=(), dtype="string")
outputs = keras.layers.Identity()(inputs)
Expand All @@ -166,18 +140,14 @@ def test_model_compile(self):

def test_reset_state(self):
bleu = Bleu()
y_true = tf.ragged.constant(
[
["He eats a sweet apple."],
["Silicon Valley is one of my favourite shows!"],
]
)
y_pred = tf.constant(
[
"He He He eats sweet apple which is a fruit.",
"I love Silicon Valley, it's one of my favourite shows.",
]
)
y_true = [
["He eats a sweet apple."],
["Silicon Valley is one of my favourite shows!"],
]
y_pred = [
"He He He eats sweet apple which is a fruit.",
"I love Silicon Valley, it's one of my favourite shows.",
]

bleu.update_state(y_true, y_pred)
bleu_val = bleu.result()
Expand All @@ -189,25 +159,21 @@ def test_reset_state(self):

def test_update_state(self):
bleu = Bleu()
y_true_1 = tf.ragged.constant(
[
["He eats a sweet apple."],
["Silicon Valley is one of my favourite shows!"],
]
)
y_pred_1 = tf.constant(
[
"He He He eats sweet apple which is a fruit.",
"I love Silicon Valley, it's one of my favourite shows.",
]
)
y_true_1 = [
["He eats a sweet apple."],
["Silicon Valley is one of my favourite shows!"],
]
y_pred_1 = [
"He He He eats sweet apple which is a fruit.",
"I love Silicon Valley, it's one of my favourite shows.",
]

bleu.update_state(y_true_1, y_pred_1)
bleu_val = bleu.result()
self.assertAlmostEqual(bleu_val, 0.243, delta=1e-3)

y_true_2 = tf.constant(["Virat Kohli is the GOAT."])
y_pred_2 = tf.constant("Virat Kohli is the greatest of all time!")
y_true_2 = ["Virat Kohli is the GOAT."]
y_pred_2 = "Virat Kohli is the greatest of all time!"

bleu.update_state(y_true_2, y_pred_2)
bleu_val = bleu.result()
Expand Down
23 changes: 0 additions & 23 deletions keras_nlp/metrics/edit_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.utils.tensor_utils import assert_tf_backend
from keras_nlp.utils.tensor_utils import is_floating_dtype


Expand Down Expand Up @@ -67,13 +66,6 @@ class EditDistance(keras.metrics.Metric):
>>> edit_distance(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.36363637>
Rank 1 tensor.
>>> edit_distance = keras_nlp.metrics.EditDistance()
>>> y_true = tf.strings.split("the tiny little cat was found under the big funny bed")
>>> y_pred = tf.strings.split("the cat was found under the bed")
>>> edit_distance(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.36363637>
Nested Python list.
>>> edit_distance = keras_nlp.metrics.EditDistance()
>>> y_true = [
Expand All @@ -86,19 +78,6 @@ class EditDistance(keras.metrics.Metric):
... ]
>>> edit_distance(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.73333335>
Rank 2 tensor.
>>> edit_distance = keras_nlp.metrics.EditDistance()
>>> y_true = tf.strings.split([
... "the tiny little cat was found under the big funny bed",
... "it is sunny today",
... ])
>>> y_pred = tf.strings.split([
... "the cat was found under the bed",
... "it is sunny but with a hint of cloud cover",
... ])
>>> edit_distance(y_true, y_pred)
<tf.Tensor: shape=(), dtype=float32, numpy=0.73333335>
"""

def __init__(
Expand All @@ -108,8 +87,6 @@ def __init__(
name="edit_distance",
**kwargs,
):
assert_tf_backend(self.__class__.__name__)

super().__init__(name=name, dtype=dtype, **kwargs)

if not is_floating_dtype(dtype):
Expand Down
Loading

0 comments on commit 4eeba35

Please sign in to comment.