Skip to content

Commit

Permalink
Merge pull request caikit#393 from kcirred/main
Browse files Browse the repository at this point in the history
Enable using kwargs for selecting pad-to-max-length strategy for tokenizer in embeddings
  • Loading branch information
gkumbhat authored Oct 16, 2024
2 parents 1695c3b + f79d65b commit 0219d50
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 4 deletions.
13 changes: 9 additions & 4 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,6 +976,7 @@ def _tokenize_plus(
truncate_input_tokens: int,
texts: List[str],
implicit_truncation_errors: bool = True,
**kwargs,
) -> TruncatedTokensTuple:
"""Tokenize with support for truncation handling and returning the token count
Args:
Expand Down Expand Up @@ -1015,21 +1016,21 @@ def _tokenize_plus(
texts = [str(s).strip() for s in texts]

# Call tokenizer with the same truncation parameters every time
tokenized = self._get_tokenized(texts)
tokenized = self._get_tokenized(texts, **kwargs)

# Custom truncation and/or error raise if needed
truncation_needed = self._truncation_needed(tokenized, max_length, texts)
if truncation_needed and okay_to_truncate:
# Truncate texts in place
_truncate_texts(texts, tokenized, max_length, truncation_needed)
# Re-tokenize the truncated texts
tokenized = self._get_tokenized(texts)
tokenized = self._get_tokenized(texts, **kwargs)
truncation_needed = [] # truncation accomplished

input_token_count = sum_token_count(tokenized)
return TruncatedTokensTuple(tokenized, input_token_count, truncation_needed)

def _get_tokenized(self, texts):
def _get_tokenized(self, texts, **kwargs):
"""Intentionally always call tokenizer the same way to avoid thread issues.
Use a copy of the tokenizer per-model (self) and per-thread (map by thread ID).
Expand All @@ -1039,6 +1040,8 @@ def _get_tokenized(self, texts):
the fast tokenizer with different truncation settings.
"""

padding_strategy = kwargs.pop("padding_strategy", True)

# Keep copies of tokenizer per thread (in each wrapped model instance)
thread_id = threading.get_ident()
tokenizer = (
Expand All @@ -1056,7 +1059,7 @@ def _get_tokenized(self, texts):
return_length=False,
return_tensors="pt",
truncation=True, # DO NOT CHANGE else "Already borrowed" errors
padding=True, # DO NOT CHANGE else "Already borrowed" errors
padding=padding_strategy, # DO NOT CHANGE else "Already borrowed" errors
max_length=self.max_seq_length, # DO NOT CHANGE else "Already borrowed" errors
)

Expand All @@ -1077,6 +1080,7 @@ def encode(
return_token_count: bool = False,
implicit_truncation_errors: bool = True,
autocast: bool = False,
**kwargs,
) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]:
"""
Computes sentence embeddings
Expand Down Expand Up @@ -1161,6 +1165,7 @@ def encode(
truncate_input_tokens,
sentences_batch,
implicit_truncation_errors=implicit_truncation_errors,
**kwargs,
)

if truncation_needed: # truncation was needed and was not done/not allowed
Expand Down
70 changes: 70 additions & 0 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

# Standard
from typing import List, Tuple
from unittest.mock import patch
import os
import tempfile

# Third Party
from pytest import approx
from torch.backends import mps
from transformers import BatchEncoding
import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -1143,3 +1145,71 @@ def test_same_same(loaded_model: EmbeddingModule, truncate_input_tokens):
assert not np.allclose(
separate_vectors[1], separate_vectors[2], rtol=1e-05, atol=1e-08
)


def custom_sum_token_count(
tokenized: BatchEncoding,
) -> int:
"""Returns total number of tokens regardless of attention_mask value"""

token_count = 0
for encoding in tokenized.encodings:
token_count += len(encoding.attention_mask)

return token_count


@pytest.mark.parametrize("padding_strategy", [True, "max_length"])
def test_pad_to_max_length(padding_strategy, loaded_model):
"""Tests for tokenization kwargs max_length will modify tokenizer"""
model_max = loaded_model.model.max_seq_length

tokenizer_kwargs = {"padding_strategy": padding_strategy}
max_seq = "x " * (model_max - 2) # Subtract 2 for begin/end tokens
max_seq_minus_one = "x " * (
model_max - 3
) # 1 token length shorter than max_seq_length
single = "x "

if padding_strategy is True:
normal_result = loaded_model._encode_with_retry(
[max_seq_minus_one], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[max_seq_minus_one],
return_token_count=True,
**tokenizer_kwargs,
)
assert np.all(normal_result.embedding == padded_result.embedding)
elif padding_strategy == "max_length":
with patch(
"caikit_nlp.modules.text_embedding.embedding.sum_token_count"
) as mock_sum_token_count:
mock_sum_token_count.side_effect = custom_sum_token_count
normal_result = loaded_model._encode_with_retry(
[max_seq_minus_one], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[max_seq_minus_one],
return_token_count=True,
**tokenizer_kwargs,
)
assert normal_result.input_token_count != padded_result.input_token_count
assert padded_result.input_token_count == model_max
assert not np.all(normal_result.embedding == padded_result.embedding)
normal_result = loaded_model._encode_with_retry(
[max_seq], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[max_seq], return_token_count=True, **tokenizer_kwargs
)
assert normal_result.input_token_count == padded_result.input_token_count
assert np.all(normal_result.embedding == padded_result.embedding)
normal_result = loaded_model._encode_with_retry(
[single], return_token_count=True
)
padded_result = loaded_model._encode_with_retry(
[single], return_token_count=True, **tokenizer_kwargs
)
assert normal_result.input_token_count != padded_result.input_token_count
assert not np.all(normal_result.embedding == padded_result.embedding)

0 comments on commit 0219d50

Please sign in to comment.