diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 06779d0d..174e8f3d 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -976,6 +976,7 @@ def _tokenize_plus( truncate_input_tokens: int, texts: List[str], implicit_truncation_errors: bool = True, + **kwargs, ) -> TruncatedTokensTuple: """Tokenize with support for truncation handling and returning the token count Args: @@ -1015,7 +1016,7 @@ def _tokenize_plus( texts = [str(s).strip() for s in texts] # Call tokenizer with the same truncation parameters every time - tokenized = self._get_tokenized(texts) + tokenized = self._get_tokenized(texts, **kwargs) # Custom truncation and/or error raise if needed truncation_needed = self._truncation_needed(tokenized, max_length, texts) @@ -1023,13 +1024,13 @@ def _tokenize_plus( # Truncate texts in place _truncate_texts(texts, tokenized, max_length, truncation_needed) # Re-tokenize the truncated texts - tokenized = self._get_tokenized(texts) + tokenized = self._get_tokenized(texts, **kwargs) truncation_needed = [] # truncation accomplished input_token_count = sum_token_count(tokenized) return TruncatedTokensTuple(tokenized, input_token_count, truncation_needed) - def _get_tokenized(self, texts): + def _get_tokenized(self, texts, **kwargs): """Intentionally always call tokenizer the same way to avoid thread issues. Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID). @@ -1039,6 +1040,8 @@ def _get_tokenized(self, texts): the fast tokenizer with different truncation settings. """ + padding_strategy = kwargs.pop("padding_strategy", True) + # Keep copies of tokenizer per thread (in each wrapped model instance) thread_id = threading.get_ident() tokenizer = ( @@ -1056,7 +1059,7 @@ def _get_tokenized(self, texts): return_length=False, return_tensors="pt", truncation=True, # DO NOT CHANGE else "Already borrowed" errors - padding=True, # DO NOT CHANGE else "Already borrowed" errors + padding=padding_strategy, # DO NOT CHANGE else "Already borrowed" errors max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors ) @@ -1077,6 +1080,7 @@ def encode( return_token_count: bool = False, implicit_truncation_errors: bool = True, autocast: bool = False, + **kwargs, ) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]: """ Computes sentence embeddings @@ -1161,6 +1165,7 @@ def encode( truncate_input_tokens, sentences_batch, implicit_truncation_errors=implicit_truncation_errors, + **kwargs, ) if truncation_needed: # truncation was needed and was not done/not allowed diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index cbd22ee8..a5559648 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -2,12 +2,14 @@ # Standard from typing import List, Tuple +from unittest.mock import patch import os import tempfile # Third Party from pytest import approx from torch.backends import mps +from transformers import BatchEncoding import numpy as np import pytest import torch @@ -1143,3 +1145,71 @@ def test_same_same(loaded_model: EmbeddingModule, truncate_input_tokens): assert not np.allclose( separate_vectors[1], separate_vectors[2], rtol=1e-05, atol=1e-08 ) + + +def custom_sum_token_count( + tokenized: BatchEncoding, +) -> int: + """Returns total number of tokens regardless of attention_mask value""" + + token_count = 0 + for encoding in tokenized.encodings: + token_count += len(encoding.attention_mask) + + return token_count + + +@pytest.mark.parametrize("padding_strategy", [True, "max_length"]) +def test_pad_to_max_length(padding_strategy, loaded_model): + """Tests for tokenization kwargs max_length will modify tokenizer""" + model_max = loaded_model.model.max_seq_length + + tokenizer_kwargs = {"padding_strategy": padding_strategy} + max_seq = "x " * (model_max - 2) # Subtract 2 for begin/end tokens + max_seq_minus_one = "x " * ( + model_max - 3 + ) # 1 token length shorter than max_seq_length + single = "x " + + if padding_strategy is True: + normal_result = loaded_model._encode_with_retry( + [max_seq_minus_one], return_token_count=True + ) + padded_result = loaded_model._encode_with_retry( + [max_seq_minus_one], + return_token_count=True, + **tokenizer_kwargs, + ) + assert np.all(normal_result.embedding == padded_result.embedding) + elif padding_strategy == "max_length": + with patch( + "caikit_nlp.modules.text_embedding.embedding.sum_token_count" + ) as mock_sum_token_count: + mock_sum_token_count.side_effect = custom_sum_token_count + normal_result = loaded_model._encode_with_retry( + [max_seq_minus_one], return_token_count=True + ) + padded_result = loaded_model._encode_with_retry( + [max_seq_minus_one], + return_token_count=True, + **tokenizer_kwargs, + ) + assert normal_result.input_token_count != padded_result.input_token_count + assert padded_result.input_token_count == model_max + assert not np.all(normal_result.embedding == padded_result.embedding) + normal_result = loaded_model._encode_with_retry( + [max_seq], return_token_count=True + ) + padded_result = loaded_model._encode_with_retry( + [max_seq], return_token_count=True, **tokenizer_kwargs + ) + assert normal_result.input_token_count == padded_result.input_token_count + assert np.all(normal_result.embedding == padded_result.embedding) + normal_result = loaded_model._encode_with_retry( + [single], return_token_count=True + ) + padded_result = loaded_model._encode_with_retry( + [single], return_token_count=True, **tokenizer_kwargs + ) + assert normal_result.input_token_count != padded_result.input_token_count + assert not np.all(normal_result.embedding == padded_result.embedding)