diff --git a/caikit_nlp/config/config.yml b/caikit_nlp/config/config.yml index 92a72d9b..f6e7f026 100644 --- a/caikit_nlp/config/config.yml +++ b/caikit_nlp/config/config.yml @@ -36,5 +36,23 @@ training_data_limit: 6655831b-960a-4dc5-8df4-867026e2cd41: add_model_name_here: 10000 +# Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32 +embedding: + # Number of times to retry on error. Most deployments should use 0 retries. + retries: 0 + # Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used + batch_size: 0 + # Should implicit truncation (with truncate_input_tokens=0) throw error for truncation (default) or disable this + implicit_truncation_errors: true + # Attempt to optimize with PyTorch compile() + pt2_compile: false + # Use IPEX optimize. Works best when used with autocast (bfloat16) below. + ipex: false + # Use autocast in encode with its default dtype (bfloat16) + autocast: false + # For testing, set device to "mps" on MacOS or "xpu" for IPEX GPU. + # Otherwise, the default does automatic checks for cuda GPU (else cpu). + device: "" + runtime: library: caikit_nlp diff --git a/caikit_nlp/modules/__init__.py b/caikit_nlp/modules/__init__.py index d77f740b..e0bff734 100644 --- a/caikit_nlp/modules/__init__.py +++ b/caikit_nlp/modules/__init__.py @@ -12,4 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # Local -from . import text_classification, text_embedding, text_generation, token_classification +from . import ( + text_classification, + text_embedding, + text_generation, + token_classification, + tokenization, +) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 476e9782..d4447ebe 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -13,17 +13,21 @@ # limitations under the License. # Standard -from copy import deepcopy -from typing import List, Optional +from collections.abc import Sized +from enum import Enum, auto +from typing import Callable, Dict, List, NamedTuple, Optional, TypeVar, Union import importlib import os import time # Third Party from torch.backends import mps +from transformers import BatchEncoding +import numpy as np import torch # First Party +from caikit import get_config from caikit.core import ModuleBase, ModuleConfig, ModuleSaver, module from caikit.core.data_model.json_dict import JsonDict from caikit.core.exceptions import error_handler @@ -49,21 +53,24 @@ ) import alog +# Local +from caikit_nlp.modules.text_embedding.utils import env_val_to_bool, env_val_to_int + logger = alog.use_channel("TXT_EMB") error = error_handler.get(logger) + # To avoid dependency problems, make sentence-transformers an optional import and # defer any ModuleNotFoundError until someone actually tries to init a model with this module. try: sentence_transformers = importlib.import_module("sentence_transformers") # Third Party from sentence_transformers import SentenceTransformer + from sentence_transformers.util import batch_to_device, cos_sim, dot_score from sentence_transformers.util import ( - cos_sim, - dot_score, - normalize_embeddings, - semantic_search, + normalize_embeddings as normalize, # avoid parameter shadowing ) + from sentence_transformers.util import semantic_search except ModuleNotFoundError: # When it is not available, create a dummy that raises an error on attempted init() class SentenceTransformerNotAvailable: @@ -73,22 +80,33 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument SentenceTransformer = SentenceTransformerNotAvailable +embedding_cfg = get_config().get("embedding", {}) + +AUTOCAST = env_val_to_bool(val=embedding_cfg.get("autocast")) +IPEX = env_val_to_bool(val=embedding_cfg.get("ipex")) +PT2_COMPILE = env_val_to_bool(val=embedding_cfg.get("pt2_compile")) +RETRIES = env_val_to_int(val=embedding_cfg.get("retries"), default=0) +BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0) +NO_IMPLICIT_TRUNCATION = env_val_to_bool( + val=embedding_cfg.get("implicit_truncation_errors", True) +) +DEVICE = embedding_cfg.get("device", "") + +RT = TypeVar("RT") # return type + -# For testing env vars for values that mean false -FALSY = ("no", "n", "false", "0", "f", "off") +class EmbeddingResultTuple(NamedTuple): + """Output of SentenceTransformerWithTruncate.encode()""" + embedding: np.ndarray + input_token_count: int -def env_var_to_int(name, default): - """Returns the integer value of name env var or default value if None or invalid integer""" - s = os.getenv(name, default) - try: - return int(s) - except (TypeError, ValueError): - return default +class TruncatedTokensTuple(NamedTuple): + """Output of SentenceTransformerWithTruncate._truncate_input_tokens()""" -# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used -BATCH_SIZE = env_var_to_int("BATCH_SIZE", default=0) + tokenized: BatchEncoding + input_token_count: int @module( @@ -105,9 +123,8 @@ def env_var_to_int(name, default): ], ) class EmbeddingModule(ModuleBase): - # Retry count if enabled to try again (was for thread contention errors) - RETRY_COUNT = max(env_var_to_int("RETRY_COUNT", default=0), 0) + RETRY_COUNT = max(RETRIES, 0) # Ensure non-negative, before using in loop! _ARTIFACTS_PATH_KEY = "artifacts_path" _ARTIFACTS_PATH_DEFAULT = "artifacts" @@ -119,12 +136,6 @@ def __init__( super().__init__() self.model = model - # Separate copy of tokenizer just for _truncate_input_tokens() - # This avoids RuntimeError('Already borrowed') which is way too frequent - # otherwise when using Python threads and tokenize for truncation followed - # by sentence-transformers tokenize/encode. - self._tokenizer = deepcopy(self.model.tokenizer) - @classmethod def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": """Load model @@ -150,13 +161,15 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": artifacts_path = os.path.abspath(os.path.join(model_path, artifacts_path)) error.dir_check("", artifacts_path) - ipex = cls._get_ipex() - device = cls._select_device(ipex) - model = SentenceTransformer(model_name_or_path=artifacts_path, device=device) + ipex = cls._get_ipex(IPEX) + device = cls._select_device(ipex, DEVICE) + model = SentenceTransformerWithTruncate( + model_name_or_path=artifacts_path, device=device + ) model.eval() # required for IPEX at least if device is not None: model.to(torch.device(device)) - model = EmbeddingModule._optimize(model, ipex, device) + model = EmbeddingModule._optimize(model, ipex, device, AUTOCAST, PT2_COMPILE) # Validate model with any encode test (simple and hardcoded for now). # This gets some of the first-time inference cost out of the way. @@ -165,8 +178,8 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": return cls(model) - @staticmethod - def _get_ipex(): + @classmethod + def _get_ipex(cls, ipex_flag): """Get IPEX optimization library if enabled and available, else return False Returns ipex library or False @@ -174,32 +187,28 @@ def _get_ipex(): ret = False # Enabled by environment variable - # When IPEX_OPTIMIZE is not false, attempt to import the library and use it. - if os.getenv("IPEX_OPTIMIZE", "false").lower() not in FALSY: + # When IPEX is not false, attempt to import the library and use it. + if ipex_flag: try: ret = importlib.import_module("intel_extension_for_pytorch") except Exception as ie: # pylint: disable=broad-exception-caught # We don't require the module so catch, log, proceed to return False msg = ( - f"IPEX_OPTIMIZE enabled in env, but skipping ipex.optimize() because " + f"IPEX enabled in env, but skipping ipex.optimize() because " f"import intel_extension_for_pytorch failed with exception: {ie}" ) - logger.warning(msg, exc_info=1) + logger.warning(msg, exc_info=True) return ret @staticmethod - def _select_device(use_ipex): + def _select_device(use_ipex, device): """Use environment variables and availability to determine the device to use""" if use_ipex: # If enabled, use "xpu" (IPEX on GPU instead of IPEX on CPU) - if os.getenv("USE_XPU", "false").lower() not in FALSY: + if device == "xpu": return "xpu" - elif ( - os.getenv("USE_MPS", "false").lower() not in FALSY - and mps.is_built() - and mps.is_available() - ): + elif device == "mps" and mps.is_built() and mps.is_available(): # Never use on ipex, but otherwise use mps if enabled and available return "mps" @@ -220,26 +229,30 @@ def _get_backend(use_ipex, use_device): return "inductor" # default backend @staticmethod - def _optimize(model, ipex, device): - + def _optimize(model, ipex, device, autocast, pt2_compile): if ipex: - model = ipex.optimize(model) + if autocast: # IPEX performs best with autocast using bfloat16 + model = ipex.optimize( + model, dtype=torch.bfloat16, weights_prepack=False + ) + else: + model = ipex.optimize(model, weights_prepack=False) # torch.compile won't work everywhere, but when set we'll try it - if os.getenv("PT2_COMPILE", "false").lower() not in FALSY: + if pt2_compile: backend = EmbeddingModule._get_backend(ipex, device) try: model = torch.compile(model, backend=backend, mode="max-autotune") except Exception as e: # pylint: disable=broad-exception-caught # Not always supported (e.g. in a python version) so catch, log, proceed. warn_msg = ( - f"PT2_COMPILE enabled in env, but continuing without torch.compile() " + f"PT2_COMPILE enabled, but continuing without torch.compile() " f"because it failed with exception: {e}" ) logger.warning(warn_msg, exc_info=True) return model - def _with_retry(self, fn, *args, **kwargs): + def _with_retry(self, fn: Callable[..., RT], *args, **kwargs) -> RT: first_exception = None for count in range(1 + self.RETRY_COUNT): # try once plus retries (if needed) try: @@ -259,7 +272,9 @@ def _with_retry(self, fn, *args, **kwargs): exception=first_exception, ) - def _encode_with_retry(self, *args, **kwargs): + def _encode_with_retry( + self, *args, **kwargs + ) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]: """All encode calls should use this for consistent param adding and retry loop""" # Add the batch_size kwarg if not passed in and given a usable BATCH_SIZE @@ -269,117 +284,25 @@ def _encode_with_retry(self, *args, **kwargs): if "batch_size" not in kwargs: kwargs["batch_size"] = BATCH_SIZE + if isinstance(self.model, SentenceTransformerWithTruncate): + kwargs[ + "implicit_truncation_errors" + ] = NO_IMPLICIT_TRUNCATION # config/env overrides default + return self._with_retry(self.model.encode, *args, **kwargs) + + # Else... + # It's possible to init with a model that doesn't have the added kwargs. + # E.g. a SentenceTransformer or other transformer model. Remove those kwargs! + # This is not the normal use case but at least don't pass invalid kwargs, to encode() + # and don't return the unexpected tuple (adding token count). + if "truncate_input_tokens" in kwargs: + del kwargs["truncate_input_tokens"] + if "return_token_count" in kwargs: + del kwargs["return_token_count"] + if "implicit_truncation_errors" in kwargs: + del kwargs["implicit_truncation_errors"] return self._with_retry(self.model.encode, *args, **kwargs) - def _truncate_input_tokens( - self, truncate_input_tokens, texts: List[str] - ) -> List[str]: - """Truncate input tokens - Args: - truncate_input_tokens: int - Truncation length for input tokens. - If less than zero, this is disabled (returns texts without processing). - If zero or greater than the model's maximum, then this is a test - to see if truncation is needed. If needed, an exception is thrown. - Otherwise, we take this usable truncation limit to truncate the tokens and then - decode them to return truncated strings that can be used with this model. - texts: List[str] - Input texts to be checked and optionally truncated. - Returns: - List[str]: the texts after checking and/or truncating - """ - - # NOTE: When inference is called immediately after load (typical case with lazy loading), - # using the tokenizer right away here results in a warning like: - # huggingface/tokenizers: The current process just got forked, after parallelism - # has already been used. Disabling parallelism to avoid deadlocks... - # To disable this warning, you can either: - # - Avoid using `tokenizers` before the fork if possible - # - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false) - # A warmup encode() call in load() will take care of this (the first option above). - # This here comment is in case we need to set TOKENIZERS_PARALLELISM in the future. - - if truncate_input_tokens < 0: - return texts - - max_tokens = self.model.max_seq_length - - # Do truncation if given a usable truncation value, else test for need to truncation - if 0 < truncate_input_tokens <= max_tokens: - okay_to_truncate = True - max_length = truncate_input_tokens - ret = [] # will return truncated texts - else: - okay_to_truncate = False - max_length = max_tokens - ret = texts # will not alter texts when truncation is not allowed - - tokenized = self._with_retry( - self._tokenizer, - texts, - return_attention_mask=False, - return_token_type_ids=False, - return_overflowing_tokens=True, - return_offsets_mapping=True, - return_length=True, - truncation=True, - max_length=max_length, - ) - - texts_map = tokenized["overflow_to_sample_mapping"] - - for text_number, text in enumerate(texts): - # positions: the positions (in lengths and offsets arrays) that belong to this text - positions = [ - position - for position, sample_number in enumerate(texts_map) - if sample_number == text_number - ] - lengths = [tokenized["length"][pos] for pos in positions] - - was_truncated = len(lengths) > 1 # multiple lengths when truncated - - if not okay_to_truncate and was_truncated: - # Raise error. We don't allow silent truncation in this case. - tokens = sum(lengths) # add up total tokens for error message - error.log_raise( - "", - ValueError( - f"Token sequence length is longer than the specified " - f"maximum sequence length for this model ({tokens} > {max_tokens})." - ), - ) - - elif okay_to_truncate and not was_truncated: - ret.append(text) # collect original text to return - - elif okay_to_truncate and was_truncated: - # Truncate the text that maps to the truncated tokens. - # The offset_mapping describes the text position for each token. - # Added tokens were not in the text, so they show up as (0, 0). - - # Get the text offsets for the tokens that are to be kept after truncation. - # Take the first set of offsets that mapped to this text's positions. - # This first set represents what will be kept after truncation. - # Each offset tells us which chars in the original text map to this token. - offsets = next(tokenized["offset_mapping"][pos] for pos in positions) - - # Find the first offset that is not empty (0, 0) to avoid added tokens - start = next(offset for offset in offsets if offset != (0, 0)) - - # Find the last offset that is not empty (0, 0) to avoid added tokens - end = next( - offset for offset in reversed(list(offsets)) if offset != (0, 0) - ) - - # Use the start-beginning end-ending to slice the text based on token truncation - # i.e. if start=(0,5) and end=(72,78) then we want slice [0:78] - truncated_text = text[start[0] : end[1]] - - ret.append(truncated_text) # return the truncated text for this one - - return ret - @EmbeddingTask.taskmethod() def run_embedding( self, @@ -402,10 +325,15 @@ def run_embedding( """ error.type_check("", str, text=text) - text = self._truncate_input_tokens(truncate_input_tokens, [text])[0] + embeddings, input_token_count = self._encode_with_retry( + text, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, + ) return EmbeddingResult( - result=Vector1D.from_vector(self._encode_with_retry(text)), + result=Vector1D.from_vector(embeddings), producer_id=self.PRODUCER_ID, + input_token_count=input_token_count, ) @EmbeddingTasks.taskmethod() @@ -434,12 +362,17 @@ def run_embeddings( ): # encode allows str, but the result would lack a dimension texts = [texts] - texts = self._truncate_input_tokens(truncate_input_tokens, texts) - - embeddings = self._encode_with_retry(texts) + embeddings, input_token_count = self._encode_with_retry( + texts, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, + ) vectors = [Vector1D.from_vector(e) for e in embeddings] + return EmbeddingResults( - results=ListOfVector1D(vectors=vectors), producer_id=self.PRODUCER_ID + results=ListOfVector1D(vectors=vectors), + producer_id=self.PRODUCER_ID, + input_token_count=input_token_count, ) @SentenceSimilarityTask.taskmethod() @@ -465,18 +398,24 @@ def run_sentence_similarity( SentenceSimilarityResult: Similarity scores for each sentence. """ - source_sentence = self._truncate_input_tokens( - truncate_input_tokens, [source_sentence] - )[0] - sentences = self._truncate_input_tokens(truncate_input_tokens, sentences) - - source_embedding = self._encode_with_retry(source_sentence) - embeddings = self._encode_with_retry(sentences) + source_embedding, source_token_count = self._encode_with_retry( + source_sentence, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, + ) + embeddings, sentences_token_count = self._encode_with_retry( + sentences, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, + ) + input_token_count = source_token_count + sentences_token_count res = cos_sim(source_embedding, embeddings) + return SentenceSimilarityResult( result=SentenceSimilarityScores(scores=res.tolist()[0]), producer_id=self.PRODUCER_ID, + input_token_count=input_token_count, ) @SentenceSimilarityTasks.taskmethod() @@ -503,19 +442,25 @@ def run_sentence_similarities( Each one contains the source-sentence's score for each sentence in order. """ - source_sentences = self._truncate_input_tokens( - truncate_input_tokens, source_sentences + source_embedding, source_token_count = self._encode_with_retry( + source_sentences, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, + ) + embeddings, sentences_token_count = self._encode_with_retry( + sentences, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, ) - sentences = self._truncate_input_tokens(truncate_input_tokens, sentences) - - source_embedding = self._encode_with_retry(source_sentences) - embeddings = self._encode_with_retry(sentences) + input_token_count = source_token_count + sentences_token_count res = cos_sim(source_embedding, embeddings) float_list_list = res.tolist() + return SentenceSimilarityResults( results=[SentenceSimilarityScores(fl) for fl in float_list_list], producer_id=self.PRODUCER_ID, + input_token_count=input_token_count, ) @RerankTask.taskmethod() @@ -579,17 +524,22 @@ def run_rerank_query( return_documents=return_documents, return_queries=return_query, return_text=return_text, - ).results + ) - if results: - return RerankResult(result=results[0], producer_id=self.PRODUCER_ID) + if results.results: + return RerankResult( + result=results.results[0], + producer_id=self.PRODUCER_ID, + input_token_count=results.input_token_count, + ) RerankResult( - producer_id=self.PRODUCER_ID, result=RerankScore( scores=[], query=query if return_query else None, ), + producer_id=self.PRODUCER_ID, + input_token_count=results.input_token_count, ) @RerankTasks.taskmethod() @@ -661,20 +611,21 @@ def get_text(doc): doc_texts = [get_text(doc) for doc in documents] - doc_texts = self._truncate_input_tokens(truncate_input_tokens, doc_texts) - queries = self._truncate_input_tokens(truncate_input_tokens, queries) - - doc_embeddings = normalize_embeddings( - self._encode_with_retry(doc_texts, convert_to_tensor=True).to( - self.model.device - ) + doc_embeddings, doc_token_count = self._encode_with_retry( + doc_texts, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, + convert_to_tensor=True, ) + doc_embeddings = normalize(doc_embeddings.to(self.model.device)) - query_embeddings = normalize_embeddings( - self._encode_with_retry(queries, convert_to_tensor=True).to( - self.model.device - ) + query_embeddings, query_token_count = self._encode_with_retry( + queries, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, + convert_to_tensor=True, ) + query_embeddings = normalize(query_embeddings.to(self.model.device)) res = semantic_search( query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score @@ -702,8 +653,13 @@ def add_query(q): ) for q, r in enumerate(res) ] + input_token_count = doc_token_count + query_token_count - return RerankResults(results=results, producer_id=self.PRODUCER_ID) + return RerankResults( + results=results, + producer_id=self.PRODUCER_ID, + input_token_count=input_token_count, + ) @classmethod def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule": @@ -750,3 +706,313 @@ def save(self, model_path: str, *args, **kwargs): # Save the config ModuleConfig(saver.config).save(model_config_path) + + +def get_sample_start_indexes(tokenized: BatchEncoding) -> List[int]: + """Returns a list containing the index for the first encoding of each sample + contained in tokenized.""" + + # When truncating occurs a sample is split across multiple encodings + # ie. len(tokenized.encodings) > the number of text samples input for tokenization + + # Knowing the encoding index of where each sample's first encoding is located allows us to + # access the encodings for individual samples + + # note: tokenized["overflow_to_sample_mapping"] is a torch.Tensor + + samples_start_indexes: Dict[int, int] = {} + for i, tensor_sample in enumerate(tokenized["overflow_to_sample_mapping"]): + int_sample = int(tensor_sample) + if int_sample not in samples_start_indexes: + samples_start_indexes[int_sample] = i + + return list(samples_start_indexes.values()) + + +class TruncateCountBehavior(Enum): + ONLY = auto() + ALL = auto() + IGNORE = auto() + + +def sum_token_count( + tokenized: BatchEncoding, + truncate_only: bool, +) -> int: + """Returns the number of non-special tokens. + Args: + tokenized: BatchEncoding + truncate_only: bool + Returns: + Int total of all tokens contained in tokenized. + """ + # Encoding objects have various attributes of note: + # - tokens: list of tokens (sub-parts of the input strings after word/subword + # splitting and before conversion to integer indices) + # - attention_mask: List of indices specifying which tokens should be attended to + # by the model. Note that [PAD] = 0, while [CLS] / [SEP] = 1 + # - special_tokens_mask: List of 0s and 1s, with 1 specifying added special tokens + # and 0 specifying regular sequence tokens + + error.type_check( + "", + BatchEncoding, + tokenized=tokenized, + ) + error.value_check( + "", + tokenized.encodings, + "Number of tokenized encodings is only known when a non-python tokenizer is used", + ) + + token_count = 0 + + if truncate_only: + # Only sum the length for the 1st encoding of each sample + samples_start_idx = get_sample_start_indexes(tokenized) + + token_count = sum( + ( + x + for idx in samples_start_idx + for x in tokenized.encodings[idx].attention_mask + ) + ) + else: + # Sum the length of all encodings for all samples + for encoding in tokenized.encodings: + token_count += sum(encoding.attention_mask) + + return token_count + + +class SentenceTransformerWithTruncate(SentenceTransformer): + def _truncate_input_tokens( + self, + truncate_input_tokens: int, + texts: List[str], + implicit_truncation_errors: bool = True, + ) -> TruncatedTokensTuple: + """Truncate input tokens + Args: + truncate_input_tokens: int + Truncation length for input tokens. + If less than zero, this truncation is left up to the tokenizer default (model max). + If zero or greater than the model's maximum, then this is used as a test + to see if truncation is needed. If needed is needed, an exception is thrown. + Otherwise, we take this usable truncation limit to truncate the input tokens. + texts: List[str] + Input texts to be checked and optionally truncated. + implicit_truncation_errors: bool + Configuration indicates whether implicit truncation should be rejected. + Returns: + Tuple containing a dictionary of lists/arrays/tensors returned by the tokenizer, with + proper truncation ('input_ids', 'attention_mask', etc.), and the input_token_count int. + """ + + max_tokens = self.max_seq_length + + # Do truncation if given a usable truncation value, else test for need to truncation + if truncate_input_tokens < 0: + okay_to_truncate = True + max_length = max_tokens + elif 0 < truncate_input_tokens <= max_tokens: + okay_to_truncate = True + # Add 2 for begin/end tokens, but don't go higher than model's max_tokens + max_length = min(truncate_input_tokens + 2, max_tokens) + + else: + okay_to_truncate = not implicit_truncation_errors + max_length = max_tokens + + assert len(texts) > 0, "Cannot truncate nothing" + assert isinstance(texts[0], str), "Only str can be truncated" + + to_tokenize = [texts] + to_tokenize = [[str(s).strip() for s in col] for col in to_tokenize] + tokenized = self.tokenizer( + *to_tokenize, + return_attention_mask=True, + return_token_type_ids=False, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_length=True, + return_tensors="pt", + truncation=True, + padding=True, + max_length=max_length, + ) + + # When truncation occurs multiple encodings are created for a single sample text + was_truncated = len(tokenized.encodings) > len(to_tokenize[0]) + + if not okay_to_truncate and was_truncated: + # re-tokenize without truncation to eliminate the duplication of certain + # special tokens (eg. [CLS] and [SEP]) with each overflow encoding. + tokenized = self.tokenizer( + *to_tokenize, + return_attention_mask=True, + return_token_type_ids=False, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_length=True, + return_tensors="pt", + truncation=False, + padding=True, + ) + + tokens = 0 + for encoding in tokenized.encodings: + tokens = max(sum(encoding.attention_mask), tokens) + error.log_raise( + "", + ValueError( + f"Token sequence length is longer than the specified " + f"maximum sequence length for this model ({tokens} > {max_tokens})." + ), + ) + + input_token_count = sum_token_count(tokenized, truncate_only=True) + + # Tokenize without overflow for batching and truncation to work together. + tokenized = self.tokenizer( + *to_tokenize, + return_attention_mask=True, + return_token_type_ids=False, + return_overflowing_tokens=False, + return_offsets_mapping=False, + return_length=False, + return_tensors="pt", + truncation=True, + padding=True, + max_length=max_length, + ) + + return TruncatedTokensTuple(tokenized, input_token_count) + + def encode( + self, + sentences: Union[str, List[str]], + batch_size: int = 32, + show_progress_bar: bool = None, + output_value: str = "sentence_embedding", + convert_to_numpy: bool = True, + convert_to_tensor: bool = False, + device: str = None, + normalize_embeddings: bool = False, + truncate_input_tokens: int = 0, + return_token_count: bool = False, + implicit_truncation_errors: bool = True, + ) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]: + """ + Computes sentence embeddings + + :param sentences: the sentences to embed + :param batch_size: the batch size used for the computation + :param show_progress_bar: Ignored here. Added for compatibility with super API. + :param output_value: Ignored here. Added for compatibility with super API. + :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list + of pytorch tensors. + :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any + setting from convert_to_numpy + :param device: Which torch.device to use for the computation + :param normalize_embeddings: Ignored here. Added for compatibility with super API. + :param truncate_input_tokens: Truncation length for input tokens. + Truncation length for input tokens. + If less than zero, this truncation is left up to the tokenizer default (model max). + If zero or greater than the model's maximum, then this is used as a test + to see if truncation is needed. If truncation is needed, an exception is thrown, + unless implicit_truncation_errors=False (see below). + Otherwise, we take this usable truncation limit to truncate the input tokens. + :param return_token_count: If true, a tuple is returned to add the input token count. + :param implicit_truncation_errors: If true (default) implicit truncation throws an error. + If false, the model default behavior or used. + + :return: + If return_token_count is False, the embedding is returned as a numpy matrix. + If return_token_count is True, a tuple is returned with both the embedding and + the input token count. + """ + + # These args are for API compatability, but are currently ignored in our version of encode() + _ = ( + show_progress_bar, + output_value, + normalize_embeddings, + ) + + self.eval() + + if convert_to_tensor: + convert_to_numpy = False + + input_was_string = False + list_of_sentences = sentences + if isinstance(list_of_sentences, str) or not isinstance( + sentences, Sized + ): # Cast an individual sentence to a list with length 1 + list_of_sentences = [sentences] + input_was_string = True + + error.type_check_all("", str, sentences=list_of_sentences) + + if device is None: + device = self.device + + self.to(device) + + all_embeddings = [] + + # Sort sentences according to length, from longest to shortest + # OOM errors then occurs at start of encoding + length_sorted_idx = np.argsort( + [-self._text_length(sen) for sen in list_of_sentences] + ) + sentences_sorted: list[str] = [ + list_of_sentences[idx] for idx in length_sorted_idx + ] + + input_token_count = 0 + + for start_index in range(0, len(list_of_sentences), batch_size): + sentences_batch = sentences_sorted[start_index : start_index + batch_size] + features, token_count = self._truncate_input_tokens( + truncate_input_tokens, + sentences_batch, + implicit_truncation_errors=implicit_truncation_errors, + ) + input_token_count += token_count + + features = batch_to_device(features, device) + + if AUTOCAST: + with torch.no_grad(), torch.cpu.amp.autocast(): + out_features = self.forward(features) + embeddings = out_features["sentence_embedding"] + if convert_to_numpy: + embeddings = embeddings.detach().cpu() + all_embeddings.extend(embeddings) + else: + with torch.no_grad(): + out_features = self.forward(features) + embeddings = out_features["sentence_embedding"] + if convert_to_numpy: + embeddings = embeddings.detach().cpu() + all_embeddings.extend(embeddings) + + # Restore original order + all_embeddings = [all_embeddings[idx] for idx in np.argsort(length_sorted_idx)] + + if convert_to_tensor: + all_embeddings = torch.stack(all_embeddings) + elif convert_to_numpy: + all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) + + if input_was_string: + all_embeddings = all_embeddings[0] + + return ( + EmbeddingResultTuple(all_embeddings, input_token_count) + if return_token_count + else all_embeddings + ) diff --git a/caikit_nlp/modules/text_embedding/utils.py b/caikit_nlp/modules/text_embedding/utils.py new file mode 100644 index 00000000..39adfb82 --- /dev/null +++ b/caikit_nlp/modules/text_embedding/utils.py @@ -0,0 +1,32 @@ +# Copyright The Caikit Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def env_val_to_bool(val): + """Returns the bool value of env var""" + if val is None: + return False + if isinstance(val, bool): + return val + + # For testing env vars for values that mean false (else True!) + return str(val).lower().strip() not in ("no", "n", "false", "0", "f", "off", "") + + +def env_val_to_int(val, default): + """Returns the integer value of env var or default value if None or invalid integer""" + try: + return int(val) + except (TypeError, ValueError): + return default diff --git a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py index 5642c824..91f747c9 100644 --- a/caikit_nlp/modules/text_generation/peft_prompt_tuning.py +++ b/caikit_nlp/modules/text_generation/peft_prompt_tuning.py @@ -48,8 +48,9 @@ ClassificationTrainRecord, GeneratedTextResult, GeneratedTextStreamResult, + TokenizationResults, ) -from caikit.interfaces.nlp.tasks import TextGenerationTask +from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask import alog # Local @@ -87,7 +88,7 @@ id="6655831b-960a-4dc5-8df4-867026e2cd41", name="Peft generation", version="0.1.0", - task=TextGenerationTask, + tasks=[TextGenerationTask, TokenizationTask], ) class PeftPromptTuning(ModuleBase): @@ -274,6 +275,22 @@ def run_stream_out( stop_sequences=stop_sequences, ) + @TokenizationTask.taskmethod() + def run_tokenizer( + self, + text: str, + ) -> TokenizationResults: + """Run tokenization task against the model + + Args: + text: str + Text to tokenize + Returns: + TokenizationResults + The token count + """ + raise NotImplementedError("Tokenization not implemented for local") + @classmethod def train( cls, diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 38e05704..6a87ab45 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -29,8 +29,9 @@ from caikit.interfaces.nlp.data_model import ( GeneratedTextResult, GeneratedTextStreamResult, + TokenizationResults, ) -from caikit.interfaces.nlp.tasks import TextGenerationTask +from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask from caikit_tgis_backend import TGISBackend import alog @@ -47,7 +48,11 @@ error = error_handler.get(log) -@modules.module(backend_type=TGISBackend.backend_type, base_module=PeftPromptTuning) +@modules.module( + backend_type=TGISBackend.backend_type, + base_module=PeftPromptTuning, + tasks=[TextGenerationTask, TokenizationTask], +) class PeftPromptTuningTGIS(ModuleBase): # pylint: disable=too-many-instance-attributes SUPPORTED_LOAD_BACKENDS = [TGISBackend.backend_type, backend_types.LOCAL] ## Module Interface ## @@ -197,6 +202,10 @@ def run( stop_sequences: Optional[List[str]] = None, seed: Optional[np.uint64] = None, preserve_input_text: bool = False, + input_tokens: bool = False, + generated_tokens: bool = True, + token_logprobs: bool = True, + token_ranks: bool = True, ) -> GeneratedTextResult: f"""Run inference against the model running in TGIS. @@ -216,6 +225,10 @@ def run( return self.tgis_generation_client.unary_generate( text=verbalized_text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -251,6 +264,10 @@ def run_stream_out( stop_sequences: Optional[List[str]] = None, seed: Optional[np.uint64] = None, preserve_input_text: bool = False, + input_tokens: bool = False, + generated_tokens: bool = True, + token_logprobs: bool = True, + token_ranks: bool = True, ) -> Iterable[GeneratedTextStreamResult]: f"""Run output stream inferencing against the model running in TGIS @@ -270,6 +287,10 @@ def run_stream_out( return self.tgis_generation_client.stream_generate( text=verbalized_text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -284,3 +305,21 @@ def run_stream_out( exponential_decay_length_penalty=exponential_decay_length_penalty, stop_sequences=stop_sequences, ) + + @TokenizationTask.taskmethod() + def run_tokenizer( + self, + text: str, + ) -> TokenizationResults: + """Run tokenization task against the model running in TGIS. + + Args: + text: str + Text to tokenize + Returns: + TokenizationResults + The token count + """ + return self.tgis_generation_client.unary_tokenize( + text=text, + ) diff --git a/caikit_nlp/modules/text_generation/text_generation_local.py b/caikit_nlp/modules/text_generation/text_generation_local.py index e0558d4a..290551ba 100644 --- a/caikit_nlp/modules/text_generation/text_generation_local.py +++ b/caikit_nlp/modules/text_generation/text_generation_local.py @@ -31,8 +31,8 @@ from caikit.core.data_model import DataStream from caikit.core.exceptions import error_handler from caikit.core.modules import ModuleBase, ModuleConfig, ModuleSaver, module -from caikit.interfaces.nlp.data_model import GeneratedTextResult -from caikit.interfaces.nlp.tasks import TextGenerationTask +from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults +from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask import alog # Local @@ -60,7 +60,7 @@ id="f9181353-4ccf-4572-bd1e-f12bcda26792", name="Text Generation", version="0.1.0", - task=TextGenerationTask, + tasks=[TextGenerationTask, TokenizationTask], ) class TextGeneration(ModuleBase): """Module to provide text generation capabilities""" @@ -521,6 +521,7 @@ def save(self, model_path): json.dump(loss_log, f) f.write("\n") + @TextGenerationTask.taskmethod() def run( self, text: str, @@ -575,6 +576,22 @@ def run( **kwargs, ) + @TokenizationTask.taskmethod() + def run_tokenizer( + self, + text: str, + ) -> TokenizationResults: + """Run tokenization task against the model + + Args: + text: str + Text to tokenize + Returns: + TokenizationResults + The token count + """ + raise NotImplementedError("Tokenization not implemented for local") + ################################## Private Functions ###################################### @staticmethod diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index c57998f3..7c55cc2c 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -27,8 +27,9 @@ from caikit.interfaces.nlp.data_model import ( GeneratedTextResult, GeneratedTextStreamResult, + TokenizationResults, ) -from caikit.interfaces.nlp.tasks import TextGenerationTask +from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask from caikit_tgis_backend import TGISBackend import alog @@ -51,7 +52,11 @@ # pylint: disable=too-many-instance-attributes -@module(backend_type=TGISBackend.backend_type, base_module=TextGeneration) +@module( + backend_type=TGISBackend.backend_type, + base_module=TextGeneration, + tasks=[TextGenerationTask, TokenizationTask], +) class TextGenerationTGIS(ModuleBase): """Module to provide text generation capabilities""" @@ -222,6 +227,10 @@ def run( stop_sequences: Optional[List[str]] = None, seed: Optional[np.uint64] = None, preserve_input_text: bool = False, + input_tokens: bool = False, + generated_tokens: bool = True, + token_logprobs: bool = True, + token_ranks: bool = True, ) -> GeneratedTextResult: f"""Run inference against the model running in TGIS. @@ -231,11 +240,14 @@ def run( GeneratedTextResult Generated text result produced by TGIS. """ - if self._model_loaded: return self.tgis_generation_client.unary_generate( text=text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -271,6 +283,10 @@ def run_stream_out( stop_sequences: Optional[List[str]] = None, seed: Optional[np.uint64] = None, preserve_input_text: bool = False, + input_tokens: bool = False, + generated_tokens: bool = True, + token_logprobs: bool = True, + token_ranks: bool = True, ) -> Iterable[GeneratedTextStreamResult]: f"""Run output stream inferencing for text generation module. @@ -284,6 +300,10 @@ def run_stream_out( return self.tgis_generation_client.stream_generate( text=text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -298,3 +318,22 @@ def run_stream_out( exponential_decay_length_penalty=exponential_decay_length_penalty, stop_sequences=stop_sequences, ) + + @TokenizationTask.taskmethod() + def run_tokenizer( + self, + text: str, + ) -> TokenizationResults: + """Run tokenization task against the model running in TGIS. + + Args: + text: str + Text to tokenize + Returns: + TokenizationResults + The token count + """ + if self._model_loaded: + return self.tgis_generation_client.unary_tokenize( + text=text, + ) diff --git a/caikit_nlp/modules/token_classification/filtered_span_classification.py b/caikit_nlp/modules/token_classification/filtered_span_classification.py index 55733df1..55963cb8 100644 --- a/caikit_nlp/modules/token_classification/filtered_span_classification.py +++ b/caikit_nlp/modules/token_classification/filtered_span_classification.py @@ -17,7 +17,8 @@ At this time this module is only designed for inference""" # Standard -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Union +import itertools import os # First Party @@ -122,7 +123,7 @@ def __init__( @TokenClassificationTask.taskmethod() def run( - self, text: str, threshold: Optional[float] = None + self, text: str, threshold: Optional[Union[float, int]] = None ) -> TokenClassificationResults: """Run classification on text split into spans. Returns results based on score threshold for labels that are to be outputted @@ -130,14 +131,16 @@ def run( Args: text: str Document to run classification on - threshold: float + threshold: float | int (Optional) Threshold based on which to return score results Returns: TokenClassificationResults """ error.type_check("", str, text=text) - error.type_check("", float, allow_none=True, threshold=threshold) + error.type_check( + "", float, int, allow_none=True, threshold=threshold + ) if threshold is None: threshold = self.default_threshold @@ -189,7 +192,7 @@ def run( @TokenClassificationTask.taskmethod(input_streaming=True, output_streaming=True) def run_bidi_stream( - self, text_stream: Iterable[str], threshold: Optional[float] = None + self, text_stream: Iterable[str], threshold: Optional[Union[float, int]] = None ) -> Iterable[TokenClassificationStreamResult]: """Run bi-directional streaming inferencing for this module. Run classification on text split into spans. Returns results @@ -198,24 +201,31 @@ def run_bidi_stream( Args: text_stream: Iterable[str] Text stream to run classification on - threshold: float + threshold: float | int (Optional) Threshold based on which to return score results Returns: Iterable[TokenClassificationStreamResult] """ - error.type_check("", float, allow_none=True, threshold=threshold) + error.type_check( + "", float, int, allow_none=True, threshold=threshold + ) # TODO: For optimization implement window based approach. if threshold is None: threshold = self.default_threshold - # Types on the stream are checked later on iteration - if len(text_stream) == 0: + # Avoid length check here since it can be time consuming to iterate through stream + # Tee stream to 2 - one to check emptiness, one for full iteration + analysis + text_streams = itertools.tee(text_stream, 2) + try: + next(text_streams[0]) + except StopIteration: + # Types on the stream are checked later on iteration # Allow empty text case to fall through - some tokenizers or # classifiers may error on this yield TokenClassificationStreamResult(results=[], processed_index=0) - for span_output in self._stream_span_output(text_stream): + for span_output in self._stream_span_output(text_streams[1]): classification_result = self.classifier.run(span_output.text) results_to_end_of_span = False for classification in classification_result.results: diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 50728f0d..8b50dd15 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -22,6 +22,7 @@ GeneratedTextResult, GeneratedTextStreamResult, GeneratedToken, + TokenizationResults, TokenStreamDetails, ) from caikit_tgis_backend.protobufs import generation_pb2 @@ -36,9 +37,19 @@ GENERATE_FUNCTION_TGIS_ARGS = """ {} - preserve_input_text: str + preserve_input_text: bool Whether or not the source string should be contained in the generated output, e.g., as a prefix. + input_tokens: bool + Whether or not to include list of input tokens. + generated_tokens: bool + Whether or not to include list of individual generated tokens. + token_logprobs: bool + Whether or not to include logprob for each returned token. + Applicable only if generated_tokens == true and/or input_tokens == true + token_ranks: bool + Whether or not to include rank of each returned token. + Applicable only if generated_tokens == true and/or input_tokens == true """.format( GENERATE_FUNCTION_ARGS ) @@ -47,6 +58,10 @@ def validate_inf_params( text, preserve_input_text, + input_tokens, + generated_tokens, + token_logprobs, + token_ranks, eos_token, max_new_tokens, min_new_tokens, @@ -73,6 +88,10 @@ def validate_inf_params( ) error.type_check("", str, text=text) error.type_check("", bool, preserve_input_text=preserve_input_text) + error.type_check("", bool, input_tokens=input_tokens) + error.type_check("", bool, generated_tokens=generated_tokens) + error.type_check("", bool, token_logprobs=token_logprobs) + error.type_check("", bool, token_ranks=token_ranks) error.type_check("", str, allow_none=True, eos_token=eos_token) error.type_check( "", @@ -173,6 +192,10 @@ def validate_inf_params( def get_params( preserve_input_text, + input_tokens, + generated_tokens, + token_logprobs, + token_ranks, max_new_tokens, min_new_tokens, truncate_input_tokens, @@ -210,10 +233,10 @@ def get_params( res_options = generation_pb2.ResponseOptions( input_text=preserve_input_text, - generated_tokens=True, - input_tokens=False, - token_logprobs=True, - token_ranks=True, + generated_tokens=generated_tokens, + input_tokens=input_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, ) stopping = generation_pb2.StoppingCriteria( stop_sequences=stop_sequences, @@ -267,6 +290,10 @@ def unary_generate( self, text, preserve_input_text, + input_tokens, + generated_tokens, + token_logprobs, + token_ranks, max_new_tokens, min_new_tokens, truncate_input_tokens, @@ -304,6 +331,10 @@ def unary_generate( validate_inf_params( text=text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, @@ -324,6 +355,10 @@ def unary_generate( params = get_params( preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -365,6 +400,24 @@ def unary_generate( ) response = batch_response.responses[0] + token_list = [] + if response.tokens is not None: + for token in response.tokens: + token_list.append( + GeneratedToken( + text=token.text, logprob=token.logprob, rank=token.rank + ) + ) + + input_token_list = [] + if response.input_tokens is not None: + for token in response.input_tokens: + input_token_list.append( + GeneratedToken( + text=token.text, logprob=token.logprob, rank=token.rank + ) + ) + return GeneratedTextResult( generated_text=response.text, generated_tokens=response.generated_token_count, @@ -372,12 +425,18 @@ def unary_generate( producer_id=self.producer_id, input_token_count=response.input_token_count, seed=seed, + tokens=token_list, + input_tokens=input_token_list, ) def stream_generate( self, text, preserve_input_text, + input_tokens, + generated_tokens, + token_logprobs, + token_ranks, max_new_tokens, min_new_tokens, truncate_input_tokens, @@ -415,6 +474,10 @@ def stream_generate( validate_inf_params( text=text, preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, eos_token=self.eos_token, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, @@ -433,6 +496,10 @@ def stream_generate( params = get_params( preserve_input_text=preserve_input_text, + input_tokens=input_tokens, + generated_tokens=generated_tokens, + token_logprobs=token_logprobs, + token_ranks=token_ranks, max_new_tokens=max_new_tokens, min_new_tokens=min_new_tokens, truncate_input_tokens=truncate_input_tokens, @@ -475,12 +542,71 @@ def stream_generate( input_token_count=stream_part.input_token_count, ) token_list = [] - for token in stream_part.tokens: - token_list.append( - GeneratedToken(text=token.text, logprob=token.logprob) - ) + if stream_part.tokens is not None: + for token in stream_part.tokens: + token_list.append( + GeneratedToken( + text=token.text, logprob=token.logprob, rank=token.rank + ) + ) + input_token_list = [] + if stream_part.input_tokens is not None: + for token in stream_part.input_tokens: + input_token_list.append( + GeneratedToken( + text=token.text, logprob=token.logprob, rank=token.rank + ) + ) yield GeneratedTextStreamResult( generated_text=stream_part.text, tokens=token_list, + input_tokens=input_token_list, details=details, ) + + def unary_tokenize( + self, + text: str, + ) -> TokenizationResults: + """Tokenize unary input using TGIS + + Args: + text: str + Text to tokenize + Returns: + TokenizationResults + The token count + """ + + # In case internal client is not configured - tokenization + # cannot be done (individual modules may already check + # for this) + error.value_check( + "", + self.tgis_client is not None, + "Backend must be configured and loaded for tokenization", + ) + + log.debug("Building protobuf request to send to TGIS") + + gen_reqs = [generation_pb2.TokenizeRequest(text=text)] + + request = generation_pb2.BatchedTokenizeRequest( + requests=gen_reqs, + model_id=self.base_model_name, + ) + + # Currently, we send a batch request of len(x)==1, so we expect one response back + with alog.ContextTimer(log.trace, "TGIS request duration: "): + batch_response = self.tgis_client.Tokenize(request) + + error.value_check( + "", + len(batch_response.responses) == 1, + f"Got {len(batch_response.responses)} responses for a single request", + ) + response = batch_response.responses[0] + + return TokenizationResults( + token_count=response.token_count, + ) diff --git a/caikit_nlp/toolkit/torch_run.py b/caikit_nlp/toolkit/torch_run.py index 3a8879c8..43184f2d 100644 --- a/caikit_nlp/toolkit/torch_run.py +++ b/caikit_nlp/toolkit/torch_run.py @@ -24,7 +24,8 @@ # Third Party from torch import cuda -from torch.distributed.launcher.api import LaunchConfig, Std +from torch.distributed.elastic.multiprocessing.api import Std +from torch.distributed.launcher.api import LaunchConfig import torch.distributed as dist # First Party diff --git a/pyproject.toml b/pyproject.toml index 24ce9459..bcb82db0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.25.0,<0.27.0", + "caikit[runtime-grpc,runtime-http]>=0.26.17,<0.27.0", "caikit-tgis-backend>=0.1.27,<0.2.0", # TODO: loosen dependencies "accelerate>=0.22.0", @@ -24,8 +24,9 @@ dependencies = [ "pandas>=1.5.0", "scikit-learn>=1.1", "scipy>=1.8.1", + "sentence-transformers>=2.3.1,<2.4.0", "tokenizers>=0.13.3", - "torch>=2.0.1", + "torch>=2.0.1,<2.3.0", "tqdm>=4.65.0", "transformers>=4.32.0", "peft==0.6.0", diff --git a/runtime_config.yaml b/runtime_config.yaml index 3c3812b3..cbd27421 100644 --- a/runtime_config.yaml +++ b/runtime_config.yaml @@ -41,3 +41,21 @@ model_management: ca_cert_file: null client_cert_file: null client_key_file: null + +# Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32 +embedding: + # Number of times to retry on error. Most deployments should use 0 retries. + retries: 0 + # Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used + batch_size: 0 + # Should implicit truncation (with truncate_input_tokens=0) throw error for truncation (default) or disable this + implicit_truncation_errors: true + # Attempt to optimize with PyTorch compile() + pt2_compile: false + # Use IPEX optimize. Works best when used with autocast (bfloat16) below. + ipex: false + # Use autocast in encode with its default dtype (bfloat16) + autocast: false + # For testing, set device to "mps" on MacOS or "xpu" for IPEX GPU. + # Otherwise, the default does automatic checks for cuda GPU (else cpu). + device: "" diff --git a/sentence-transformers.nodeps.txt b/sentence-transformers.nodeps.txt deleted file mode 100644 index dad81257..00000000 --- a/sentence-transformers.nodeps.txt +++ /dev/null @@ -1,7 +0,0 @@ -# These can be installed with --no-deps. - -# Minimum needed to use sentence-transformers: -sentence-transformers>=2.2.2,<2.3.0 -nltk>=3.8.1,<3.9.0 -Pillow>=10.0.0,<10.1.0 - diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 31f22ad2..e7440e83 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -18,7 +18,7 @@ # First Party from caikit.config.config import merge_configs -from caikit.interfaces.nlp.data_model import GeneratedTextResult +from caikit.interfaces.nlp.data_model import GeneratedTextResult, TokenizationResults from caikit_tgis_backend import TGISBackend from caikit_tgis_backend.tgis_connection import TGISConnection import aconfig @@ -204,6 +204,9 @@ def Generate(self, request): def GenerateStream(self, request): return StubTGISClient.stream_generate(request) + def Tokenize(self, request): + return StubTGISClient.tokenize(request) + @staticmethod def unary_generate(request): fake_response = mock.Mock() @@ -212,6 +215,16 @@ def unary_generate(request): fake_result.generated_token_count = 1 fake_result.text = "moose" fake_result.input_token_count = 1 + token = mock.Mock() + token.text = "moose" + token.logprob = 0.2 + token.rank = 1 + fake_result.tokens = [token] + input_tokens = mock.Mock() + input_tokens.text = "moose" + input_tokens.logprob = 0.2 + input_tokens.rank = 1 + fake_result.input_tokens = [input_tokens] fake_response.responses = [fake_result] return fake_response @@ -225,11 +238,25 @@ def stream_generate(request): token = mock.Mock() token.text = "moose" token.logprob = 0.2 + token.rank = 1 fake_stream.tokens = [token] + input_tokens = mock.Mock() + input_tokens.text = "moose" + input_tokens.logprob = 0.2 + input_tokens.rank = 1 + fake_stream.input_tokens = [input_tokens] fake_stream.text = "moose" for _ in range(3): yield fake_stream + @staticmethod + def tokenize(request): + fake_response = mock.Mock() + fake_result = mock.Mock() + fake_result.token_count = 1 + fake_response.responses = [fake_result] + return fake_response + @staticmethod def validate_unary_generate_response(result): assert isinstance(result, GeneratedTextResult) @@ -237,6 +264,12 @@ def validate_unary_generate_response(result): assert result.generated_tokens == 1 assert result.finish_reason == 5 assert result.input_token_count == 1 + assert result.tokens[0].text == "moose" + assert result.tokens[0].logprob == 0.2 + assert result.tokens[0].rank == 1 + assert result.input_tokens[0].text == "moose" + assert result.input_tokens[0].logprob == 0.2 + assert result.input_tokens[0].rank == 1 @staticmethod def validate_stream_generate_response(stream_result): @@ -248,11 +281,20 @@ def validate_stream_generate_response(stream_result): assert first_result.generated_text == "moose" assert first_result.tokens[0].text == "moose" assert first_result.tokens[0].logprob == 0.2 + assert first_result.tokens[0].rank == 1 + assert first_result.input_tokens[0].text == "moose" + assert first_result.input_tokens[0].logprob == 0.2 + assert first_result.input_tokens[0].rank == 1 assert first_result.details.finish_reason == 5 assert first_result.details.generated_tokens == 1 assert first_result.details.seed == 10 assert first_result.details.input_token_count == 1 + @staticmethod + def validate_tokenize_response(result): + assert isinstance(result, TokenizationResults) + assert result.token_count == 1 + class StubTGISBackend(TGISBackend): def __init__( diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index b0a610c7..e625588b 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -1,7 +1,7 @@ -"""Tests for text embedding module -""" +"""Tests for text embedding module""" + # Standard -from typing import List +from typing import List, Tuple import os import tempfile @@ -10,6 +10,7 @@ from torch.backends import mps import numpy as np import pytest +import torch # First Party from caikit.core import ModuleConfig @@ -23,7 +24,11 @@ ) # Local -from caikit_nlp.modules.text_embedding import EmbeddingModule +from caikit_nlp.modules.text_embedding import EmbeddingModule, utils +from caikit_nlp.modules.text_embedding.embedding import ( + get_sample_start_indexes, + sum_token_count, +) from tests.fixtures import SEQ_CLASS_MODEL ## Setup ######################################################################## @@ -32,14 +37,31 @@ # .bootstrap is tested separately in the first test BOOTSTRAPPED_MODEL = EmbeddingModule.bootstrap(SEQ_CLASS_MODEL) +# Token counts: +# All expected token counts were calculated with reference to the +# `BertForSequenceClassification` model. Each model's tokenizer behaves differently +# which can lead to the expected token counts being invalid. + INPUT = "The quick brown fox jumps over the lazy dog." +INPUT_TOKEN_COUNT = 36 + 2 # [CLS] Thequickbrownfoxjumpsoverthelazydog. [SEP] + +MANY_INPUTS = [ + "The quick brown fox jumps over the lazy dog.", + "But I must explain to you how all this mistaken idea.", + "No one rejects or dislikes.", +] QUERY = "What is foo bar?" +QUERY_TOKEN_COUNT = 13 + 2 # [CLS] Whatisfoobar? [SEP] QUERIES: List[str] = [ "Who is foo?", "Where is the bar?", ] +QUERIES_TOKEN_COUNT = (9 + 2) + ( + 14 + 2 +) # [CLS] Whoisfoo? [SEP], [CLS] Whereisthebar? [SEP] + # These are used to test that documents can handle different types in and out TYPE_KEYS = "str_test", "int_test", "float_test", "nested_dict_test" @@ -67,14 +89,21 @@ }, ] +# The `text` and `_text` keys are extracted from DOCS as input to the tokenizer +# [CLS] foo [SEP], [CLS] bar [SEP], [CLS] fooandbar [SEP], [CLS] Whereisthebar [SEP] +DOCS_TOKEN_COUNT = (3 + 2) + (3 + 2) + (9 + 2) + (13 + 2) + # Use text or _text from DOCS for our test sentences SENTENCES = [d.get("text", d.get("_text")) for d in DOCS] +# [CLS] foo [SEP], [CLS] bar [SEP], [CLS] fooandbar [SEP], [CLS] Whereisthebar [SEP] +SENTENCES_TOKEN_COUNT = (3 + 2) + (3 + 2) + (9 + 2) + (13 + 2) + ## Tests ######################################################################## -@pytest.fixture(scope="module") -def loaded_model(tmp_path_factory): +@pytest.fixture(scope="module", name="loaded_model") +def fixture_loaded_model(tmp_path_factory): models_dir = tmp_path_factory.mktemp("models") model_path = str(models_dir / "model_id") BOOTSTRAPPED_MODEL.save(model_path) @@ -141,8 +170,15 @@ def _assert_valid_scores(scores, type_tests={}): return type_tests -def test_bootstrap_reuse(): - assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap reuse error" +def test_bootstrap_model(loaded_model): + assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap model type" + assert ( + BOOTSTRAPPED_MODEL.model.__class__.__name__ == "SentenceTransformer" + ), "bootstrap model class name" + # worth noting that bootstrap does not wrap, but load does + assert ( + loaded_model.model.__class__.__name__ == "SentenceTransformerWithTruncate" + ), "loaded model class name" def test_save_load_and_run(): @@ -211,6 +247,7 @@ def test_run_embedding_type_check(loaded_model): def test_run_embedding(loaded_model): res = loaded_model.run_embedding(text=INPUT) _assert_is_expected_embedding_result(res) + assert res.input_token_count == INPUT_TOKEN_COUNT def test_run_embeddings_str_type(loaded_model): @@ -224,6 +261,7 @@ def test_run_embeddings(loaded_model): res = loaded_model.run_embeddings(texts=[INPUT]) assert isinstance(res.results.vectors, list) _assert_is_expected_embeddings_results(res.results) + assert res.input_token_count == INPUT_TOKEN_COUNT @pytest.mark.parametrize( @@ -245,7 +283,8 @@ def test_run_rerank_query_type_error(query, docs, top_n, loaded_model): def test_run_rerank_query_no_type_error(loaded_model): """no type error with list of string queries and list of dict documents""" - loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=1) + res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=1) + assert res.input_token_count == QUERY_TOKEN_COUNT + DOCS_TOKEN_COUNT @pytest.mark.parametrize( @@ -263,6 +302,7 @@ def test_run_rerank_query_top_n(top_n, expected, loaded_model): res = loaded_model.run_rerank_query(query=QUERY, documents=DOCS, top_n=top_n) assert isinstance(res, RerankResult) assert len(res.result.scores) == expected + assert res.input_token_count == QUERY_TOKEN_COUNT + DOCS_TOKEN_COUNT def test_run_rerank_query_no_query(loaded_model): @@ -286,6 +326,7 @@ def test_run_rerank_query(loaded_model): types_found = _assert_valid_scores(scores) _assert_types_found(types_found) + assert res.input_token_count == QUERY_TOKEN_COUNT + DOCS_TOKEN_COUNT @pytest.mark.parametrize( @@ -300,7 +341,8 @@ def test_run_rerank_queries_type_error(queries, docs, loaded_model): def test_run_rerank_queries_no_type_error(loaded_model): """no type error with list of string queries and list of dict documents""" - loaded_model.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99) + res = loaded_model.run_rerank_queries(queries=QUERIES, documents=DOCS, top_n=99) + assert res.input_token_count == QUERIES_TOKEN_COUNT + DOCS_TOKEN_COUNT @pytest.mark.parametrize( @@ -321,6 +363,7 @@ def test_run_rerank_queries_top_n(top_n, expected, loaded_model): assert len(res.results) == len(QUERIES) for result in res.results: assert len(result.scores) == expected + assert res.input_token_count == QUERIES_TOKEN_COUNT + DOCS_TOKEN_COUNT @pytest.mark.parametrize( @@ -361,6 +404,7 @@ def test_run_rerank_queries(loaded_model): # Make sure our document fields of different types made it in/out ok _assert_types_found(types_found) + assert rerank_result.input_token_count == QUERIES_TOKEN_COUNT + DOCS_TOKEN_COUNT def test_run_sentence_similarity(loaded_model): @@ -371,6 +415,7 @@ def test_run_sentence_similarity(loaded_model): assert len(scores) == len(SENTENCES) for score in scores: assert isinstance(score, float) + assert res.input_token_count == QUERY_TOKEN_COUNT + SENTENCES_TOKEN_COUNT def test_run_sentence_similarities(loaded_model): @@ -384,35 +429,28 @@ def test_run_sentence_similarities(loaded_model): assert len(scores) == len(SENTENCES) for score in scores: assert isinstance(score, float) + assert res.input_token_count == QUERIES_TOKEN_COUNT + SENTENCES_TOKEN_COUNT @pytest.mark.parametrize( - "use_ipex, use_xpu, use_mps, expected", + "use_ipex, device, expected", [ - (True, "true", "true", "xpu"), - (True, "true", "false", "xpu"), - (True, "false", "true", None), - (True, "false", "false", None), - (False, "false", "false", None), - (False, "true", "false", None), - ( - False, - "false", - "true", - "mps" if mps.is_built() and mps.is_available() else None, - ), + (True, "", None), + (False, "", None), + (True, None, None), + (False, None, None), + (False, "xpu", None), + (True, "xpu", "xpu"), + (True, "mps", None), ( False, - "true", - "true", + "mps", "mps" if mps.is_built() and mps.is_available() else None, ), ], ) -def test__select_device(use_ipex, use_xpu, use_mps, expected, monkeypatch): - monkeypatch.setenv("USE_XPU", use_xpu) - monkeypatch.setenv("USE_MPS", use_mps) - assert EmbeddingModule._select_device(use_ipex) == expected +def test__select_device(use_ipex, device, expected): + assert EmbeddingModule._select_device(use_ipex, device) == expected @pytest.mark.parametrize( @@ -433,52 +471,73 @@ def test__get_backend(use_ipex, use_device, expected): "use_ipex", [None, "true", "True", "False", "false"], ) -def test__get_ipex(use_ipex, monkeypatch): +def test__get_ipex(use_ipex): """Test that _get_ipex returns False instead of raising an exception. Assumes that when running tests, we won't have IPEX installed. """ - monkeypatch.setenv("IPEX_OPTIMIZE", use_ipex) - assert not EmbeddingModule._get_ipex() + assert not EmbeddingModule._get_ipex(use_ipex) -def test__optimize(monkeypatch): +def test__optimize(): """Test that _optimize does nothing when disabled""" fake = "fake model" # Will be returned as-is - monkeypatch.setenv("PT2_COMPILE", "False") - assert fake == EmbeddingModule._optimize(fake, False, "bogus") + assert fake == EmbeddingModule._optimize(fake, False, "bogus", False, False) -@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 10, 333]) -def test__truncate_input_tokens(truncate_input_tokens, loaded_model): - - if truncate_input_tokens < 0: - num_xs = 500 # fill-er up - else: - num_xs = truncate_input_tokens - 4 # subtract room for (y y), but not z +@pytest.mark.parametrize("truncate_input_tokens", [0, 513]) +def test__truncate_input_tokens_raises(truncate_input_tokens, loaded_model): + model_max = loaded_model.model.max_seq_length - too_long = "x " * num_xs + "y y z " # z will go over - actual = loaded_model._truncate_input_tokens( - truncate_input_tokens=truncate_input_tokens, texts=[too_long, too_long] + too_long = "x " * (model_max - 1) # This will go over + over = model_max + 1 + with pytest.raises(ValueError, match=f"({over} > {model_max})"): + loaded_model.model.encode( + sentences=[too_long], truncate_input_tokens=truncate_input_tokens + ) + # Same behavior when implicit_truncation_errors is True (the default) + with pytest.raises(ValueError, match=f"({over} > {model_max})"): + loaded_model.model.encode( + sentences=[too_long], + truncate_input_tokens=truncate_input_tokens, + implicit_truncation_errors=True, + ) + # Different behavior when implicit_truncation_errors is False -- no error raised! + loaded_model.model.encode( + sentences=[too_long], + truncate_input_tokens=truncate_input_tokens, + implicit_truncation_errors=False, ) - assert actual[0] == actual[1] # they are still the same - if truncate_input_tokens < 0: - assert actual[0] == too_long, "expected no truncation" - else: - assert actual[0] + " z " == too_long, "expected truncation" +def test__implicit_truncation(loaded_model): + """Test that implicit truncation happens (when allowed)""" + model_max = loaded_model.model.max_seq_length + too_long = "x " * (model_max - 1) # This will go over a little + extra_long = ( + too_long + + "more clever words that surely change the meaning of this text" + * (model_max - 1) + ) -@pytest.mark.parametrize("truncate_input_tokens", [0, 513]) -def test__truncate_input_tokens_raises(truncate_input_tokens, loaded_model): - model_max = loaded_model.model.max_seq_length + # Allowed truncation using default tokens (0) and config to disable the error. + res = loaded_model.model.encode( + sentences=[too_long], truncate_input_tokens=0, implicit_truncation_errors=False + ) + # Allowed truncation using model max + res_extra_max = loaded_model.model.encode( + sentences=[extra_long], truncate_input_tokens=loaded_model.model.max_seq_length + ) + # Allowed truncation using -1 to just let the model do its thing + res_extra_neg = loaded_model.model.encode( + sentences=[extra_long], truncate_input_tokens=-1 + ) - too_long = "x " * (model_max - 1) # This will go over - with pytest.raises(ValueError): - loaded_model._truncate_input_tokens( - truncate_input_tokens=truncate_input_tokens, texts=[too_long] - ) + # Demonstrating that when implicit truncation is allowed, sentence-transformers is quietly truncating at model max + # The simple too_long string of x's, is equivalent to the string with significantly different extra text (truncated) + assert np.allclose(res, res_extra_max) + assert np.allclose(res, res_extra_neg) def test_not_too_many_tokens(loaded_model): @@ -505,40 +564,41 @@ def test_too_many_tokens_default(loaded_model): """These endpoints raise an error when truncation would happen.""" model_max = loaded_model.model.max_seq_length + over = model_max + 1 ok = "x " * (model_max - 2) # Subtract 2 for begin/end tokens too_long = "x " * (model_max - 1) # This will go over # embedding(s) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_embedding(text=too_long) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_embeddings(texts=[too_long]) # sentence similarity(ies) test both source_sentence and sentences - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_sentence_similarity(source_sentence=too_long, sentences=[ok]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_sentence_similarity(source_sentence=ok, sentences=[too_long]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_sentence_similarities( source_sentences=[too_long], sentences=[ok] ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_sentence_similarities( source_sentences=[ok], sentences=[too_long] ) # reranker test both query and document text - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_rerank_query(query=too_long, documents=[{"text": ok}]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_rerank_query(query=ok, documents=[{"text": too_long}]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_rerank_queries(queries=[too_long], documents=[{"text": ok}]) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_rerank_queries(queries=[ok], documents=[{"text": too_long}]) @@ -551,41 +611,42 @@ def test_too_many_tokens_error_params(truncate_input_tokens, loaded_model): """ model_max = loaded_model.model.max_seq_length + over = model_max + 1 ok = "x " * (model_max - 2) # Subtract 2 for begin/end tokens too_long = "x " * (model_max - 1) # This will go over # embedding(s) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_embedding( text=too_long, truncate_input_tokens=truncate_input_tokens ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_embeddings( texts=[too_long], truncate_input_tokens=truncate_input_tokens ) # sentence similarity(ies) test both source_sentence and sentences - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_sentence_similarity( source_sentence=too_long, sentences=[ok], truncate_input_tokens=truncate_input_tokens, ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_sentence_similarity( source_sentence=ok, sentences=[too_long], truncate_input_tokens=truncate_input_tokens, ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_sentence_similarities( source_sentences=[too_long], sentences=[ok], truncate_input_tokens=truncate_input_tokens, ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_sentence_similarities( source_sentences=[ok], sentences=[too_long], @@ -593,26 +654,26 @@ def test_too_many_tokens_error_params(truncate_input_tokens, loaded_model): ) # reranker test both query and document text - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_rerank_query( query=too_long, documents=[{"text": ok}], truncate_input_tokens=truncate_input_tokens, ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_rerank_query( query=ok, documents=[{"text": too_long}], truncate_input_tokens=truncate_input_tokens, ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_rerank_queries( queries=[too_long], documents=[{"text": ok}], truncate_input_tokens=truncate_input_tokens, ) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match=f"({over} > {model_max})"): loaded_model.run_rerank_queries( queries=[ok], documents=[{"text": too_long}], @@ -620,12 +681,14 @@ def test_too_many_tokens_error_params(truncate_input_tokens, loaded_model): ) -@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 512]) +@pytest.mark.parametrize("truncate_input_tokens", [-1, 99, 510, 511, 512]) def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_model): """truncate_input_tokens prevents these endpoints from raising an error when too many tokens. Test with -1 which lets the model do truncation instead of raising an error. - Test with 99 (< 512) which causes our code to do the truncation instead of raising an error. + Test with 99 (< 512 -2) which causes our code to do the truncation instead of raising an error. + Test with 510 (512 -2) which causes our code to do the truncation instead of raising an error. + 511 and 512 also behave like 510. The value is allowed, but begin/end tokens will take space. """ model_max = loaded_model.model.max_seq_length @@ -688,21 +751,27 @@ def test_too_many_tokens_with_truncation_working(truncate_input_tokens, loaded_m ) -@pytest.mark.parametrize("truncate_input_tokens", [99, 512, -1]) +@pytest.mark.parametrize( + "truncate_input_tokens", [1, 2, 3, 4, 99, 100, 101, 510, 511, 512, -1] +) def test_embeddings_with_truncation(truncate_input_tokens, loaded_model): """verify that results are as expected with truncation""" + max_len = loaded_model.model.max_seq_length - 2 if truncate_input_tokens is None or truncate_input_tokens < 0: - # For -1 we don't truncate, but sentence-transformers will truncate at max_seq_length - repeat = loaded_model.model.max_seq_length + # For -1 we don't truncate, but sentence-transformers will truncate at max_seq_length - 2 + repeat = max_len else: - repeat = truncate_input_tokens + repeat = min( + truncate_input_tokens, max_len + ) # max_len is used when we need -2 for begin/end # Build a text like "x x x.. x " with room for one more token - repeat = repeat - 2 # space for start/end tokens repeat = repeat - 1 # space for the final x or y token to show difference - base = "x " * repeat # A bunch of "x" tokens + base = "" + if repeat > 0: + base = "x " * repeat # A bunch of "x" tokens x = base + "x" # One last "x" that will not get truncated y = base + "y" # A different last character "y" not truncated z = y + "z" # Add token "z" after "y". This should get truncated. @@ -710,11 +779,29 @@ def test_embeddings_with_truncation(truncate_input_tokens, loaded_model): res = loaded_model.run_embeddings( texts=[base, x, y, z], truncate_input_tokens=truncate_input_tokens ) + vectors = res.results.vectors # vectors from batch embeddings - vectors = res.results.vectors + # Compare with results from individual embedding calls in a loop + loop_res = [] + for t in [base, x, y, z]: + r = loaded_model.run_embedding( + text=t, truncate_input_tokens=truncate_input_tokens + ) + loop_res.append(r) + loop_vectors = [ + r.result for r in loop_res + ] # vectors from loop of single embedding calls + + assert len(vectors) == len(loop_vectors), "expected the same length vectors" + # compare the vectors from batch with the single calls + for i, e in enumerate(vectors): + assert np.allclose(e.data.values, loop_vectors[i].data.values) # x...xyz is the same as x...xy because that is exactly where truncation worked + assert len(vectors[2].data.values) == len(vectors[3].data.values) assert np.allclose(vectors[2].data.values, vectors[3].data.values) + for i in range(len(vectors[2].data.values)): + assert approx(vectors[2].data.values[i]) == approx(vectors[3].data.values[i]) # Make sure the base, x, y are not a match (we kept the significant last char) assert not np.allclose(vectors[0].data.values, vectors[1].data.values) @@ -727,7 +814,7 @@ def test__with_retry_happy_path(loaded_model): loaded_model._with_retry(print, "hello", "world", sep="<:)>", end="!!!\n") -def test__with_retry_fail(loaded_model, monkeypatch): +def test__with_retry_fail(loaded_model): """fn never works, loops then raises the exception""" def fn(): @@ -780,3 +867,230 @@ def fail_fail_win(): # Third try did not raise an exception. Returns 3. assert 3 == loaded_model._with_retry(fail_fail_win) + + +def test_env_val_to_bool(): + assert not utils.env_val_to_bool(None) + assert not utils.env_val_to_bool("") + assert not utils.env_val_to_bool(" ") + assert not utils.env_val_to_bool(0) + assert not utils.env_val_to_bool("0") + assert not utils.env_val_to_bool(" False ") + assert not utils.env_val_to_bool(" false ") + assert not utils.env_val_to_bool(" fAlSE ") + + assert utils.env_val_to_bool(1) + assert utils.env_val_to_bool("1") + assert utils.env_val_to_bool(" True ") + assert utils.env_val_to_bool(" true ") + assert utils.env_val_to_bool(" tRuE ") + + +def test_env_val_to_int(): + expected_default = 12345 + assert expected_default == utils.env_val_to_int(None, expected_default) + assert expected_default == utils.env_val_to_int("", expected_default) + assert expected_default == utils.env_val_to_int(" ", expected_default) + assert expected_default == utils.env_val_to_int(" ss ", expected_default) + assert expected_default == utils.env_val_to_int(" sss ", expected_default) + assert expected_default == utils.env_val_to_int(" ssss ", expected_default) + + assert 0 == utils.env_val_to_int(0, expected_default) + assert 0 == utils.env_val_to_int("0", expected_default) + assert 0 == utils.env_val_to_int(False, expected_default) + assert 456 == utils.env_val_to_int("456", expected_default) + assert 456 == utils.env_val_to_int(" 456 ", expected_default) + assert 1 == utils.env_val_to_int(True, expected_default) + + +@pytest.mark.parametrize( + # `expected_count` are valid for the `BertForSequenceClassification` model. + ["texts", "expected_count"], + [ + # Only tokens requiring model attention is counted. + # [PAD] doesn't attract model attention, but [CLS] and [SEP] does + # [CLS] 5 normal tokens [SEP] + (["12345"], 5 + 2), + # [CLS] 5 normal [SEP], [CLS] 4 normal [SEP] [PAD] + (["12 345", "6 789"], 9 + 4), + ], +) +def test_sum_token_count_no_truncation(texts, expected_count, loaded_model): + + tokenized = loaded_model.model.tokenizer( + texts, + return_attention_mask=True, + return_token_type_ids=False, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_length=True, + return_tensors="pt", + truncation=True, + padding=True, + max_length=loaded_model.model.max_seq_length, + ) + token_count = sum_token_count( + tokenized, + truncate_only=False, + ) + + assert token_count == expected_count + + +@pytest.mark.parametrize( + # `expected_count` are valid for the `BertForSequenceClassification` model. + ["texts", "truncate", "expected_count"], + [ + # Only tokens requiring model attention is counted. + # [PAD] doesn't attract model attention, but [CLS] and [SEP] does + # + # All encodings: [CLS] 12345 [SEP] + # No truncation + (["12345"], 10, 7), + # All encodings: [CLS] 123 [SEP] + [CLS] 45 [SEP] [PAD] + # Only truncated: [CLS] 123 [SEP] + (["12345"], 5, 3 + 2), + # + # All encodings: [CLS] 123 [SEP] + [CLS] 45 [SEP] [PAD], [CLS] 678 [SEP] + [CLS] 9 [SEP] [PAD] [PAD] + # Only truncated: [CLS] 123 [SEP] , [CLS] 678 [SEP] + (["12 345", "6 789"], 5, (3 + 2) + (3 + 2)), + ], +) +def test_sum_token_count_with_truncation(texts, truncate, expected_count, loaded_model): + tokenized = loaded_model.model.tokenizer( + texts, + return_attention_mask=True, + return_token_type_ids=False, + return_overflowing_tokens=True, + return_offsets_mapping=True, + return_length=True, + return_tensors="pt", + truncation=True, + padding=True, + max_length=truncate, + ) + token_count = sum_token_count( + tokenized, + truncate_only=True, + ) + + assert token_count == expected_count + + +@pytest.mark.parametrize( + "truncate_input_tokens", [0, 1, 2, 3, 4, 99, 100, 101, 510, 511, 512, 513, -1] +) +def test_encoding_order(loaded_model: EmbeddingModule, truncate_input_tokens): + """Confirm that encoding doesn't modify the original sort order""" + separate_embeddings = [ + loaded_model.run_embedding(text=i, truncate_input_tokens=truncate_input_tokens) + for i in MANY_INPUTS + ] + combined_embeddings = loaded_model.run_embeddings( + texts=MANY_INPUTS, truncate_input_tokens=truncate_input_tokens + ) + + separate_vectors = [ + e.to_dict()["result"]["data"]["values"] for e in separate_embeddings + ] + combined_vectors = [ + e["data"]["values"] for e in combined_embeddings.to_dict()["results"]["vectors"] + ] + + assert len(separate_vectors) == len( + combined_vectors + ), "expected the same number separate and combined embeddings" + + # test order by comparing value of individual embeddings in sequence + for i, e in enumerate(separate_vectors): + assert np.allclose(e, combined_vectors[i]) + + # test expected failure case by reordering + shifted_separate_vectors = separate_vectors[1:] + [separate_vectors[0]] + + for i, e in enumerate(shifted_separate_vectors): + assert e != separate_vectors[i], "expected order to be have been altered" + assert ( + not approx(e) == combined_vectors[i] + ), "expected altered order to not match combined vectors" + assert not np.allclose( + e, combined_vectors[i] + ), "expected altered order to not match combined" + + +@pytest.mark.parametrize( + ("mapping", "expected"), + [ + ([0, 0, 0, 0, 0], [0]), + ([0, 1, 2, 3, 4], [0, 1, 2, 3, 4]), + ([0, 0, 1, 1, 1, 2], [0, 2, 5]), + ([], []), + ], +) +def test_get_sample_start_indexes(mapping, expected): + mock_tokenized = { + "overflow_to_sample_mapping": torch.Tensor(mapping).type(torch.int8) + } + assert get_sample_start_indexes(mock_tokenized) == expected + + +def test_encode_extensions(loaded_model): + # loaded model can return_token_count + ret = loaded_model._encode_with_retry("text here", return_token_count=True) + assert isinstance(ret, Tuple) + assert isinstance(ret[0], np.ndarray) + assert isinstance(ret[1], int) + ret = loaded_model._encode_with_retry("text here", return_token_count=False) + assert isinstance(ret, np.ndarray) + + # Make sure use with un-wrapped SentenceTransformer model is unaffected by extended params or return tokens + ret = BOOTSTRAPPED_MODEL._encode_with_retry( + "text here", + return_token_count=True, + truncate_input_tokens=123, + implicit_truncation_errors=False, + ) + assert isinstance(ret, np.ndarray) + BOOTSTRAPPED_MODEL._encode_with_retry( + "text here" + ) # and no KeyError trying to remove non-existing keys + + +@pytest.mark.parametrize( + "truncate_input_tokens", + [0, 1, 2, 3, 4, 5, 99, 100, 101, 300, 510, 511, 512, 513, 1000, -1], +) +def test_same_same(loaded_model: EmbeddingModule, truncate_input_tokens): + """Confirm that same text gives same results""" + + inputs = ["What is generative ai?", "What is generative ai?", "different"] + + # First ensuring that batch input vs loop over inputs is the same + separate_embeddings = [ + loaded_model.run_embedding(text=i, truncate_input_tokens=truncate_input_tokens) + for i in inputs + ] + combined_embeddings = loaded_model.run_embeddings( + texts=inputs, truncate_input_tokens=truncate_input_tokens + ) + + separate_vectors = [ + e.to_dict()["result"]["data"]["values"] for e in separate_embeddings + ] + combined_vectors = [ + e["data"]["values"] for e in combined_embeddings.to_dict()["results"]["vectors"] + ] + + assert len(separate_vectors) == len( + combined_vectors + ), "expected the same number separate and combined embeddings" + + # test order by comparing value of individual embeddings in sequence + for i, e in enumerate(separate_vectors): + assert np.allclose(e, combined_vectors[i]) + + # Next ensuring that the two identical sentences yield identical results (and 3rd does not) + assert np.array_equal(combined_vectors[0], combined_vectors[1]) + assert not np.array_equal(combined_vectors[1], combined_vectors[2]) + assert np.array_equal(separate_vectors[0], separate_vectors[1]) + assert not np.array_equal(separate_vectors[1], separate_vectors[2]) diff --git a/tests/modules/text_generation/test_peft_prompt_tuning.py b/tests/modules/text_generation/test_peft_prompt_tuning.py index 5cf82439..74360d36 100644 --- a/tests/modules/text_generation/test_peft_prompt_tuning.py +++ b/tests/modules/text_generation/test_peft_prompt_tuning.py @@ -429,6 +429,14 @@ def test_run_exponential_decay_len_penatly_object(causal_lm_dummy_model): assert isinstance(pred, GeneratedTextResult) +def test_run_tokenizer_not_implemented(causal_lm_dummy_model): + with pytest.raises(NotImplementedError): + causal_lm_dummy_model.run_tokenizer("This text doesn't matter") + + +######################## Test train ############################################### + + def test_train_with_data_validation_raises(causal_lm_train_kwargs, set_cpu_device): """Check if we are able to throw error for when number of examples are more than configured limit""" patch_kwargs = { diff --git a/tests/modules/text_generation/test_peft_tgis_remote.py b/tests/modules/text_generation/test_peft_tgis_remote.py index 2895088a..bdc11beb 100644 --- a/tests/modules/text_generation/test_peft_tgis_remote.py +++ b/tests/modules/text_generation/test_peft_tgis_remote.py @@ -57,6 +57,26 @@ def test_load_and_run(causal_lm_dummy_model, stub_tgis_backend): assert model_prompt_dir == stub_generation_request.prefix_id +def test_load_and_tokenize(causal_lm_dummy_model, stub_tgis_backend): + """Ensure we can export an in memory model, load it, and tokenize it""" + # Patch our stub backend into caikit so that we don't actually try to start TGIS + causal_lm_dummy_model.verbalizer = "hello distributed {{input}}" + + with mock.patch.object(StubTGISClient, "Tokenize") as mock_gen: + mock_gen.side_effect = StubTGISClient.tokenize + + # Save the local model & reload it a TGIS backend distributed module + with tempfile.TemporaryDirectory() as model_dir: + causal_lm_dummy_model.save(model_dir) + mock_tgis_model = PeftPromptTuningTGIS.load(model_dir, stub_tgis_backend) + + result = mock_tgis_model.run_tokenizer(SAMPLE_TEXT) + StubTGISClient.validate_tokenize_response(result) + + # Validate that our verbalizer carried over correctly & was applied at inference time + assert mock_tgis_model.verbalizer == causal_lm_dummy_model.verbalizer + + def test_load_and_run_stream_out(causal_lm_dummy_model, stub_tgis_backend): """Ensure we can export an in memory model, load it, and (mock) run output streaming with the right text & prefix ID.""" diff --git a/tests/modules/text_generation/test_text_generation_local.py b/tests/modules/text_generation/test_text_generation_local.py index 5afce777..f76400f1 100644 --- a/tests/modules/text_generation/test_text_generation_local.py +++ b/tests/modules/text_generation/test_text_generation_local.py @@ -209,3 +209,9 @@ def test_zero_epoch_case(disable_wip): } model = TextGeneration.train(**train_kwargs) assert isinstance(model.model, HFAutoSeq2SeqLM) + + +def test_run_tokenizer_not_implemented(): + with pytest.raises(NotImplementedError): + model = TextGeneration.bootstrap(SEQ2SEQ_LM_MODEL) + model.run_tokenizer("This text doesn't matter") diff --git a/tests/modules/text_generation/test_text_generation_tgis.py b/tests/modules/text_generation/test_text_generation_tgis.py index 208dc462..741f1e27 100644 --- a/tests/modules/text_generation/test_text_generation_tgis.py +++ b/tests/modules/text_generation/test_text_generation_tgis.py @@ -51,6 +51,26 @@ def test_bootstrap_and_run_seq2seq(): StubTGISClient.validate_unary_generate_response(result) +def test_bootstrap_and_tokenize_casualllm(): + """Check if we can bootstrap and tokenize text""" + model = TextGenerationTGIS.bootstrap( + CAUSAL_LM_MODEL, load_backend=StubTGISBackend() + ) + + result = model.run_tokenizer(SAMPLE_TEXT) + StubTGISClient.validate_tokenize_response(result) + + +def test_bootstrap_and_tokenize_seq2seq(): + """Check if we can bootstrap and tokenize text""" + model = TextGenerationTGIS.bootstrap( + SEQ2SEQ_LM_MODEL, load_backend=StubTGISBackend() + ) + + result = model.run_tokenizer(SAMPLE_TEXT) + StubTGISClient.validate_tokenize_response(result) + + def test_run_multi_response_errors(): """Check if multiple responses errors""" with mock.patch.object(StubTGISClient, "Generate") as mock_gen_stream: diff --git a/tests/modules/token_classification/test_filtered_span_classification.py b/tests/modules/token_classification/test_filtered_span_classification.py index c8f14036..1120170c 100644 --- a/tests/modules/token_classification/test_filtered_span_classification.py +++ b/tests/modules/token_classification/test_filtered_span_classification.py @@ -135,6 +135,15 @@ def test_bootstrap_run_with_threshold(): ) # 4 (all) results over 0.0 expected +def test_bootstrap_run_with_int_threshold(): + """Check if we can bootstrap span classification models with overriden int threshold""" + token_classification_result = BOOTSTRAPPED_MODEL.run(DOCUMENT, threshold=0) + assert isinstance(token_classification_result, TokenClassificationResults) + assert ( + len(token_classification_result.results) == 4 + ) # 4 (all) results over 0 expected + + def test_bootstrap_run_with_optional_labels_to_output(): """Check if we can run span classification models with labels_to_output""" model = FilteredSpanClassification.bootstrap( diff --git a/tox.ini b/tox.ini index bae63fe3..c220f022 100644 --- a/tox.ini +++ b/tox.ini @@ -15,9 +15,7 @@ passenv = LOG_FORMATTER LOG_THREAD_ID LOG_CHANNEL_WIDTH -commands = - python -I -m pip install --force-reinstall --no-deps -r sentence-transformers.nodeps.txt - pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests} +commands = pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests} ; Unclear: We probably want to test wheel packaging ; But! tox will fail when this is set and _any_ interpreter is missing