diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c26b9621..3987f0dc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,6 +8,7 @@ repos: hooks: - id: black exclude: imports + additional_dependencies: ["platformdirs"] - repo: https://github.com/PyCQA/isort rev: 5.11.5 hooks: diff --git a/Dockerfile b/Dockerfile index 9525d5cd..af44cb75 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,15 +1,15 @@ -FROM registry.access.redhat.com/ubi8/ubi-minimal:latest as base +FROM registry.access.redhat.com/ubi9/ubi-minimal:latest as base RUN microdnf update -y && \ microdnf install -y \ - git python39 && \ - pip3 install --upgrade --no-cache-dir pip wheel && \ + git python-pip && \ + pip install --upgrade --no-cache-dir pip wheel && \ microdnf clean all FROM base as builder WORKDIR /build -RUN pip3 install --no-cache tox +RUN pip install --no-cache tox COPY README.md . COPY pyproject.toml . COPY tox.ini . @@ -20,7 +20,7 @@ RUN --mount=source=.git,target=.git,type=bind tox -e build FROM base as deploy -RUN python3 -m venv /opt/caikit/ +RUN python -m venv /opt/caikit/ ENV VIRTUAL_ENV=/opt/caikit ENV PATH="$VIRTUAL_ENV/bin:$PATH" diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index c2ef5045..476e9782 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -90,9 +90,6 @@ def env_var_to_int(name, default): # Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used BATCH_SIZE = env_var_to_int("BATCH_SIZE", default=0) -# Retry count for catching sporadic encode() or tokenize() errors (in case if they come back) -RETRY_COUNT = env_var_to_int("RETRY_COUNT", default=5) - @module( "eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f", @@ -109,6 +106,9 @@ def env_var_to_int(name, default): ) class EmbeddingModule(ModuleBase): + # Retry count if enabled to try again (was for thread contention errors) + RETRY_COUNT = max(env_var_to_int("RETRY_COUNT", default=0), 0) + _ARTIFACTS_PATH_KEY = "artifacts_path" _ARTIFACTS_PATH_DEFAULT = "artifacts" @@ -153,6 +153,7 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": ipex = cls._get_ipex() device = cls._select_device(ipex) model = SentenceTransformer(model_name_or_path=artifacts_path, device=device) + model.eval() # required for IPEX at least if device is not None: model.to(torch.device(device)) model = EmbeddingModule._optimize(model, ipex, device) @@ -238,17 +239,25 @@ def _optimize(model, ipex, device): logger.warning(warn_msg, exc_info=True) return model - @staticmethod - def _with_retry(fn, *args, **kwargs): - retries = max(RETRY_COUNT, 0) - for count in range(1 + retries): # try once plus retries (if needed) + def _with_retry(self, fn, *args, **kwargs): + first_exception = None + for count in range(1 + self.RETRY_COUNT): # try once plus retries (if needed) try: return fn(*args, **kwargs) except Exception as e: # pylint: disable=broad-exception-caught - warn_msg = f"Retry {fn} due to: {e}" - logger.warning(warn_msg, exc_info=True) - time.sleep(0.1 * (count * 2)) - error.log_raise("", RuntimeError(f"Too many retries of fn={fn}")) + if first_exception is None: + first_exception = e + if self.RETRY_COUNT > 0: + warn_msg = f"Try {count + 1}: {fn} failed due to: {e}" + logger.warning("", warn_msg, exc_info=True) + if count + 1 < self.RETRY_COUNT: + time.sleep(0.1 * (count * 2)) + + # If above return did not happen, raise the first exception + error.log_raise( + log_code="", + exception=first_exception, + ) def _encode_with_retry(self, *args, **kwargs): """All encode calls should use this for consistent param adding and retry loop""" diff --git a/caikit_nlp/modules/token_classification/filtered_span_classification.py b/caikit_nlp/modules/token_classification/filtered_span_classification.py index 15ea432a..55733df1 100644 --- a/caikit_nlp/modules/token_classification/filtered_span_classification.py +++ b/caikit_nlp/modules/token_classification/filtered_span_classification.py @@ -136,8 +136,16 @@ def run( Returns: TokenClassificationResults """ + error.type_check("", str, text=text) + error.type_check("", float, allow_none=True, threshold=threshold) + if threshold is None: threshold = self.default_threshold + if not text: + # Allow empty text case to fall through - some tokenizers or + # classifiers may error on this + return TokenClassificationResults(results=[]) + token_classification_results = [] if self.classification_task == TextClassificationTask: # Split document into spans @@ -196,10 +204,17 @@ def run_bidi_stream( Returns: Iterable[TokenClassificationStreamResult] """ + error.type_check("", float, allow_none=True, threshold=threshold) # TODO: For optimization implement window based approach. if threshold is None: threshold = self.default_threshold + # Types on the stream are checked later on iteration + if len(text_stream) == 0: + # Allow empty text case to fall through - some tokenizers or + # classifiers may error on this + yield TokenClassificationStreamResult(results=[], processed_index=0) + for span_output in self._stream_span_output(text_stream): classification_result = self.classifier.run(span_output.text) results_to_end_of_span = False @@ -344,6 +359,7 @@ def __update_spans(token): return token for text in text_stream: + error.type_check("", str, text=text) stream_accumulator += text # In order to avoid processing all of the spans again, we only # send out the spans that are not yet finalized in detected_spans diff --git a/caikit_nlp/resources/pretrained_model/base.py b/caikit_nlp/resources/pretrained_model/base.py index eba74744..b68c8d67 100644 --- a/caikit_nlp/resources/pretrained_model/base.py +++ b/caikit_nlp/resources/pretrained_model/base.py @@ -201,25 +201,27 @@ def bootstrap( else "right" ) - # Load the tokenizer and set up the pad token if needed - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - local_files_only=not get_config().allow_downloads, - padding_side=padding_side, - # We can't disable use_fast otherwise unit test fails - # use_fast=False, - ) + with alog.ContextTimer(log.info, "Tokenizer loaded in "): + # Load the tokenizer and set up the pad token if needed + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + local_files_only=not get_config().allow_downloads, + padding_side=padding_side, + # We can't disable use_fast otherwise unit test fails + # use_fast=False, + ) if tokenizer.pad_token_id is None: tokenizer.pad_token_id = tokenizer.eos_token_id - # Load the model - model = cls.MODEL_TYPE.from_pretrained( - model_name, - local_files_only=not get_config().allow_downloads, - torch_dtype=torch_dtype, - **kwargs, - ) + with alog.ContextTimer(log.info, f"Model {model_name} loaded in "): + # Load the model + model = cls.MODEL_TYPE.from_pretrained( + model_name, + local_files_only=not get_config().allow_downloads, + torch_dtype=torch_dtype, + **kwargs, + ) log.debug4("Model Details: %s", model) # Create the class instance diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index fa386b3e..b0a610c7 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -727,19 +727,45 @@ def test__with_retry_happy_path(loaded_model): loaded_model._with_retry(print, "hello", "world", sep="<:)>", end="!!!\n") -def test__with_retry_fail(loaded_model): - """fn never works, loops then raises RuntimeError""" +def test__with_retry_fail(loaded_model, monkeypatch): + """fn never works, loops then raises the exception""" def fn(): - assert 0 + raise (ValueError("always fails with ValueError")) - with pytest.raises(RuntimeError): + with pytest.raises(ValueError): loaded_model._with_retry(fn) -def test__with_retry_fail_fail_win(loaded_model): +def test__with_retry_fail_fail(loaded_model, monkeypatch): + """fn needs a few tries, tries twice and fails.""" + + monkeypatch.setattr(loaded_model, "RETRY_COUNT", 1) # less than 3 tries + + def generate_ints(): + yield from range(9) # More than enough for retry loop + + ints = generate_ints() + + def fail_fail_win(): + for i in ints: + if i < 2: # fail, fail + raise (ValueError(f"fail {i}")) + else: # win and return 3 + return i + 1 + + # Without a third try raises first exception + with pytest.raises(ValueError) as e: + loaded_model._with_retry(fail_fail_win) + + assert e.value.args[0] == "fail 0", "expected first exception 'fail 0'" + + +def test__with_retry_fail_fail_win(loaded_model, monkeypatch): """fn needs a few tries, logs, loops and succeeds""" + monkeypatch.setattr(loaded_model, "RETRY_COUNT", 6) # test needs at least 3 tries + def generate_ints(): yield from range(9) # More than enough for retry loop @@ -748,8 +774,8 @@ def generate_ints(): def fail_fail_win(): for i in ints: if i < 2: # fail, fail - assert 0 - else: # win + raise (ValueError("fail, fail")) + else: # win and return 3 return i + 1 # Third try did not raise an exception. Returns 3. diff --git a/tests/modules/token_classification/test_filtered_span_classification.py b/tests/modules/token_classification/test_filtered_span_classification.py index ce20c5cb..c8f14036 100644 --- a/tests/modules/token_classification/test_filtered_span_classification.py +++ b/tests/modules/token_classification/test_filtered_span_classification.py @@ -51,6 +51,14 @@ ) TOK_CLASSIFICATION_RESULT = TokenClassificationResults(results=[FOX_CLASS, DOG_CLASS]) +# NOTE: First test will test this separately +BOOTSTRAPPED_MODEL = FilteredSpanClassification.bootstrap( + lang="en", + tokenizer=SENTENCE_TOKENIZER, + classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, + default_threshold=0.5, +) + # Modules that already returns token classification for tests @module( "44d61711-c64b-4774-a39f-a9f40f1fcff0", @@ -120,13 +128,7 @@ def test_bootstrap_run(): def test_bootstrap_run_with_threshold(): """Check if we can bootstrap span classification models with overriden threshold""" - model = FilteredSpanClassification.bootstrap( - lang="en", - tokenizer=SENTENCE_TOKENIZER, - classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, - default_threshold=0.5, - ) - token_classification_result = model.run(DOCUMENT, threshold=0.0) + token_classification_result = BOOTSTRAPPED_MODEL.run(DOCUMENT, threshold=0.0) assert isinstance(token_classification_result, TokenClassificationResults) assert ( len(token_classification_result.results) == 4 @@ -187,16 +189,17 @@ def test_bootstrap_run_with_token_classification_no_results(): assert len(token_classification_result.results) == 0 +def test_bootstrap_run_empty(): + """Check if span classification model can run with empty string""" + token_classification_result = BOOTSTRAPPED_MODEL.run("") + assert isinstance(token_classification_result, TokenClassificationResults) + assert len(token_classification_result.results) == 0 + + def test_save_load_and_run_model(): """Check if we can run a saved model successfully""" - model = FilteredSpanClassification.bootstrap( - lang="en", - tokenizer=SENTENCE_TOKENIZER, - classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, - default_threshold=0.5, - ) with tempfile.TemporaryDirectory() as model_dir: - model.save(model_dir) + BOOTSTRAPPED_MODEL.save(model_dir) assert os.path.exists(os.path.join(model_dir, "config.yml")) assert os.path.exists(os.path.join(model_dir, "tokenizer")) assert os.path.exists(os.path.join(model_dir, "classification")) @@ -216,14 +219,9 @@ def test_run_bidi_stream_model(): """Check if model prediction works as expected for bi-directional stream""" stream_input = data_model.DataStream.from_iterable(DOCUMENT) - model = FilteredSpanClassification.bootstrap( - lang="en", - tokenizer=SENTENCE_TOKENIZER, - classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, - default_threshold=0.5, + streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream( + stream_input ) - - streaming_token_classification_result = model.run_bidi_stream(stream_input) assert isinstance(streaming_token_classification_result, Iterable) # Convert to list to more easily check outputs result_list = list(streaming_token_classification_result) @@ -351,14 +349,10 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk(): works as expected for bi-directional stream""" doc_stream = (DOCUMENT, " I am another sentence.") stream_input = data_model.DataStream.from_iterable(doc_stream) - model = FilteredSpanClassification.bootstrap( - lang="en", - tokenizer=SENTENCE_TOKENIZER, - classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, - default_threshold=0.5, - ) - streaming_token_classification_result = model.run_bidi_stream(stream_input) + streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream( + stream_input + ) assert isinstance(streaming_token_classification_result, Iterable) # Convert to list to more easily check outputs result_list = list(streaming_token_classification_result) @@ -385,6 +379,20 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk(): assert count == expected_number_of_sentences +def test_run_bidi_stream_empty(): + """Check if span classification model can run with empty string for streaming""" + stream_input = data_model.DataStream.from_iterable("") + streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream( + stream_input + ) + assert isinstance(streaming_token_classification_result, Iterable) + # Convert to list to more easily check outputs + result_list = list(streaming_token_classification_result) + assert len(result_list) == 1 + assert result_list[0].results == [] + assert result_list[0].processed_index == 0 + + def test_run_stream_vs_no_stream(): """Check if model prediction on stream with multiple sentences/spans works as expected for bi-directional stream and gives expected span results @@ -392,15 +400,9 @@ def test_run_stream_vs_no_stream(): multiple_sentences = ( "The dragon hoarded gold. The cow ate grass. What is happening? What a day!" ) - model = FilteredSpanClassification.bootstrap( - lang="en", - tokenizer=SENTENCE_TOKENIZER, - classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL, - default_threshold=0.5, - ) # Non-stream run - nonstream_classification_result = model.run(multiple_sentences) + nonstream_classification_result = BOOTSTRAPPED_MODEL.run(multiple_sentences) assert len(nonstream_classification_result.results) == 4 assert nonstream_classification_result.results[0].word == "The dragon hoarded gold." assert nonstream_classification_result.results[0].start == 0 @@ -411,7 +413,7 @@ def test_run_stream_vs_no_stream(): # Char-based stream stream_input = data_model.DataStream.from_iterable(multiple_sentences) - stream_classification_result = model.run_bidi_stream(stream_input) + stream_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(stream_input) # Convert to list to more easily check outputs result_list = list(stream_classification_result) assert len(result_list) == 4 # one per sentence @@ -422,7 +424,9 @@ def test_run_stream_vs_no_stream(): # Chunk-based stream chunk_stream_input = data_model.DataStream.from_iterable((multiple_sentences,)) - chunk_stream_classification_result = model.run_bidi_stream(chunk_stream_input) + chunk_stream_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream( + chunk_stream_input + ) result_list = list(chunk_stream_classification_result) assert len(result_list) == 4 # one per sentence assert result_list[0].processed_index == 24