diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 3301367a..174e8f3d 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -37,7 +37,6 @@ from torch import nn from torch.backends import mps from transformers import BatchEncoding -from transformers.tokenization_utils import PaddingStrategy import numpy as np import torch @@ -1041,7 +1040,7 @@ def _get_tokenized(self, texts, **kwargs): the fast tokenizer with different truncation settings. """ - pad_to_max_length = kwargs.pop("pad_to_max_length", None) + padding_strategy = kwargs.pop("padding_strategy", True) # Keep copies of tokenizer per thread (in each wrapped model instance) thread_id = threading.get_ident() @@ -1051,32 +1050,18 @@ def _get_tokenized(self, texts, **kwargs): else self.tokenizers.setdefault(thread_id, deepcopy(self.tokenizer)) ) - if pad_to_max_length: - return tokenizer( - texts, - return_attention_mask=True, # Used for determining token count - return_token_type_ids=False, - return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches - return_offsets_mapping=True, # Used for truncation - return_length=False, - return_tensors="pt", - truncation=True, # DO NOT CHANGE else "Already borrowed" errors - padding=PaddingStrategy.MAX_LENGTH, # DO NOT CHANGE else "Already borrowed" errors - max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors - ) - else: - return tokenizer( - texts, - return_attention_mask=True, # Used for determining token count - return_token_type_ids=False, - return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches - return_offsets_mapping=True, # Used for truncation - return_length=False, - return_tensors="pt", - truncation=True, # DO NOT CHANGE else "Already borrowed" errors - padding=True, # DO NOT CHANGE else "Already borrowed" errors - max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors - ) + return tokenizer( + texts, + return_attention_mask=True, # Used for determining token count + return_token_type_ids=False, + return_overflowing_tokens=False, # DO NOT USE overflow tokens break sentence batches + return_offsets_mapping=True, # Used for truncation + return_length=False, + return_tensors="pt", + truncation=True, # DO NOT CHANGE else "Already borrowed" errors + padding=padding_strategy, # DO NOT CHANGE else "Already borrowed" errors + max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors + ) def encode( self, @@ -1095,7 +1080,7 @@ def encode( return_token_count: bool = False, implicit_truncation_errors: bool = True, autocast: bool = False, - tokenizer_kwargs: Dict[str, Any] = {}, + **kwargs, ) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]: """ Computes sentence embeddings @@ -1180,7 +1165,7 @@ def encode( truncate_input_tokens, sentences_batch, implicit_truncation_errors=implicit_truncation_errors, - **tokenizer_kwargs + **kwargs, ) if truncation_needed: # truncation was needed and was not done/not allowed diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index 209522a6..a5559648 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -2,12 +2,14 @@ # Standard from typing import List, Tuple +from unittest.mock import patch import os import tempfile # Third Party from pytest import approx from torch.backends import mps +from transformers import BatchEncoding import numpy as np import pytest import torch @@ -1144,21 +1146,70 @@ def test_same_same(loaded_model: EmbeddingModule, truncate_input_tokens): separate_vectors[1], separate_vectors[2], rtol=1e-05, atol=1e-08 ) -@pytest.mark.parametrize("pad_to_max_length", [None, False, True, 0, 1]) -def test_pad_to_max_length(pad_to_max_length, loaded_model): - """Tests for tokenization kwargs pad_to_max_length will modify tokenizer and give same result""" - model_max = loaded_model.model.max_seq_length - tokenizer_kwargs = {'pad_to_max_length': pad_to_max_length} - max_seq = "x " * (model_max - 2) # Subtract 2 for begin/end tokens - max_seq_minus_one = "x " * (model_max - 3) # 1 token length shorter than max_seq_length - short = "x " +def custom_sum_token_count( + tokenized: BatchEncoding, +) -> int: + """Returns total number of tokens regardless of attention_mask value""" + + token_count = 0 + for encoding in tokenized.encodings: + token_count += len(encoding.attention_mask) + + return token_count - normal_result = loaded_model._encode_with_retry( - [max_seq_minus_one, max_seq, short], return_token_count=True - ) - padded_result = loaded_model._encode_with_retry( - [max_seq_minus_one, max_seq, short], return_token_count=True, tokenizer_kwargs=tokenizer_kwargs - ) - assert np.all(normal_result.embedding == padded_result.embedding) \ No newline at end of file +@pytest.mark.parametrize("padding_strategy", [True, "max_length"]) +def test_pad_to_max_length(padding_strategy, loaded_model): + """Tests for tokenization kwargs max_length will modify tokenizer""" + model_max = loaded_model.model.max_seq_length + + tokenizer_kwargs = {"padding_strategy": padding_strategy} + max_seq = "x " * (model_max - 2) # Subtract 2 for begin/end tokens + max_seq_minus_one = "x " * ( + model_max - 3 + ) # 1 token length shorter than max_seq_length + single = "x " + + if padding_strategy is True: + normal_result = loaded_model._encode_with_retry( + [max_seq_minus_one], return_token_count=True + ) + padded_result = loaded_model._encode_with_retry( + [max_seq_minus_one], + return_token_count=True, + **tokenizer_kwargs, + ) + assert np.all(normal_result.embedding == padded_result.embedding) + elif padding_strategy == "max_length": + with patch( + "caikit_nlp.modules.text_embedding.embedding.sum_token_count" + ) as mock_sum_token_count: + mock_sum_token_count.side_effect = custom_sum_token_count + normal_result = loaded_model._encode_with_retry( + [max_seq_minus_one], return_token_count=True + ) + padded_result = loaded_model._encode_with_retry( + [max_seq_minus_one], + return_token_count=True, + **tokenizer_kwargs, + ) + assert normal_result.input_token_count != padded_result.input_token_count + assert padded_result.input_token_count == model_max + assert not np.all(normal_result.embedding == padded_result.embedding) + normal_result = loaded_model._encode_with_retry( + [max_seq], return_token_count=True + ) + padded_result = loaded_model._encode_with_retry( + [max_seq], return_token_count=True, **tokenizer_kwargs + ) + assert normal_result.input_token_count == padded_result.input_token_count + assert np.all(normal_result.embedding == padded_result.embedding) + normal_result = loaded_model._encode_with_retry( + [single], return_token_count=True + ) + padded_result = loaded_model._encode_with_retry( + [single], return_token_count=True, **tokenizer_kwargs + ) + assert normal_result.input_token_count != padded_result.input_token_count + assert not np.all(normal_result.embedding == padded_result.embedding)