Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main@4396e54' into sync-upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
dtrifiro committed Feb 21, 2024
2 parents 4888e72 + 4396e54 commit 5865ed4
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 75 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ repos:
hooks:
- id: black
exclude: imports
additional_dependencies: ["platformdirs"]
- repo: https://github.com/PyCQA/isort
rev: 5.11.5
hooks:
Expand Down
10 changes: 5 additions & 5 deletions Dockerfile
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
FROM registry.access.redhat.com/ubi8/ubi-minimal:latest as base
FROM registry.access.redhat.com/ubi9/ubi-minimal:latest as base

RUN microdnf update -y && \
microdnf install -y \
git python39 && \
pip3 install --upgrade --no-cache-dir pip wheel && \
git python-pip && \
pip install --upgrade --no-cache-dir pip wheel && \
microdnf clean all

FROM base as builder
WORKDIR /build

RUN pip3 install --no-cache tox
RUN pip install --no-cache tox
COPY README.md .
COPY pyproject.toml .
COPY tox.ini .
Expand All @@ -20,7 +20,7 @@ RUN --mount=source=.git,target=.git,type=bind tox -e build

FROM base as deploy

RUN python3 -m venv /opt/caikit/
RUN python -m venv /opt/caikit/

ENV VIRTUAL_ENV=/opt/caikit
ENV PATH="$VIRTUAL_ENV/bin:$PATH"
Expand Down
31 changes: 20 additions & 11 deletions caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,6 @@ def env_var_to_int(name, default):
# Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used
BATCH_SIZE = env_var_to_int("BATCH_SIZE", default=0)

# Retry count for catching sporadic encode() or tokenize() errors (in case if they come back)
RETRY_COUNT = env_var_to_int("RETRY_COUNT", default=5)


@module(
"eeb12558-b4fa-4f34-a9fd-3f5890e9cd3f",
Expand All @@ -109,6 +106,9 @@ def env_var_to_int(name, default):
)
class EmbeddingModule(ModuleBase):

# Retry count if enabled to try again (was for thread contention errors)
RETRY_COUNT = max(env_var_to_int("RETRY_COUNT", default=0), 0)

_ARTIFACTS_PATH_KEY = "artifacts_path"
_ARTIFACTS_PATH_DEFAULT = "artifacts"

Expand Down Expand Up @@ -153,6 +153,7 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":
ipex = cls._get_ipex()
device = cls._select_device(ipex)
model = SentenceTransformer(model_name_or_path=artifacts_path, device=device)
model.eval() # required for IPEX at least
if device is not None:
model.to(torch.device(device))
model = EmbeddingModule._optimize(model, ipex, device)
Expand Down Expand Up @@ -238,17 +239,25 @@ def _optimize(model, ipex, device):
logger.warning(warn_msg, exc_info=True)
return model

@staticmethod
def _with_retry(fn, *args, **kwargs):
retries = max(RETRY_COUNT, 0)
for count in range(1 + retries): # try once plus retries (if needed)
def _with_retry(self, fn, *args, **kwargs):
first_exception = None
for count in range(1 + self.RETRY_COUNT): # try once plus retries (if needed)
try:
return fn(*args, **kwargs)
except Exception as e: # pylint: disable=broad-exception-caught
warn_msg = f"Retry {fn} due to: {e}"
logger.warning(warn_msg, exc_info=True)
time.sleep(0.1 * (count * 2))
error.log_raise("<NLP31069292E>", RuntimeError(f"Too many retries of fn={fn}"))
if first_exception is None:
first_exception = e
if self.RETRY_COUNT > 0:
warn_msg = f"Try {count + 1}: {fn} failed due to: {e}"
logger.warning("<NLP54902271W>", warn_msg, exc_info=True)
if count + 1 < self.RETRY_COUNT:
time.sleep(0.1 * (count * 2))

# If above return did not happen, raise the first exception
error.log_raise(
log_code="<NLP13096081E>",
exception=first_exception,
)

def _encode_with_retry(self, *args, **kwargs):
"""All encode calls should use this for consistent param adding and retry loop"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,16 @@ def run(
Returns:
TokenClassificationResults
"""
error.type_check("<NLP82129006E>", str, text=text)
error.type_check("<NLP01414077E>", float, allow_none=True, threshold=threshold)

if threshold is None:
threshold = self.default_threshold
if not text:
# Allow empty text case to fall through - some tokenizers or
# classifiers may error on this
return TokenClassificationResults(results=[])

token_classification_results = []
if self.classification_task == TextClassificationTask:
# Split document into spans
Expand Down Expand Up @@ -196,10 +204,17 @@ def run_bidi_stream(
Returns:
Iterable[TokenClassificationStreamResult]
"""
error.type_check("<NLP96166348E>", float, allow_none=True, threshold=threshold)
# TODO: For optimization implement window based approach.
if threshold is None:
threshold = self.default_threshold

# Types on the stream are checked later on iteration
if len(text_stream) == 0:
# Allow empty text case to fall through - some tokenizers or
# classifiers may error on this
yield TokenClassificationStreamResult(results=[], processed_index=0)

for span_output in self._stream_span_output(text_stream):
classification_result = self.classifier.run(span_output.text)
results_to_end_of_span = False
Expand Down Expand Up @@ -344,6 +359,7 @@ def __update_spans(token):
return token

for text in text_stream:
error.type_check("<NLP38357927E>", str, text=text)
stream_accumulator += text
# In order to avoid processing all of the spans again, we only
# send out the spans that are not yet finalized in detected_spans
Expand Down
32 changes: 17 additions & 15 deletions caikit_nlp/resources/pretrained_model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,25 +201,27 @@ def bootstrap(
else "right"
)

# Load the tokenizer and set up the pad token if needed
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
local_files_only=not get_config().allow_downloads,
padding_side=padding_side,
# We can't disable use_fast otherwise unit test fails
# use_fast=False,
)
with alog.ContextTimer(log.info, "Tokenizer loaded in "):
# Load the tokenizer and set up the pad token if needed
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
local_files_only=not get_config().allow_downloads,
padding_side=padding_side,
# We can't disable use_fast otherwise unit test fails
# use_fast=False,
)

if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id

# Load the model
model = cls.MODEL_TYPE.from_pretrained(
model_name,
local_files_only=not get_config().allow_downloads,
torch_dtype=torch_dtype,
**kwargs,
)
with alog.ContextTimer(log.info, f"Model {model_name} loaded in "):
# Load the model
model = cls.MODEL_TYPE.from_pretrained(
model_name,
local_files_only=not get_config().allow_downloads,
torch_dtype=torch_dtype,
**kwargs,
)
log.debug4("Model Details: %s", model)

# Create the class instance
Expand Down
40 changes: 33 additions & 7 deletions tests/modules/text_embedding/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,19 +727,45 @@ def test__with_retry_happy_path(loaded_model):
loaded_model._with_retry(print, "hello", "world", sep="<:)>", end="!!!\n")


def test__with_retry_fail(loaded_model):
"""fn never works, loops then raises RuntimeError"""
def test__with_retry_fail(loaded_model, monkeypatch):
"""fn never works, loops then raises the exception"""

def fn():
assert 0
raise (ValueError("always fails with ValueError"))

with pytest.raises(RuntimeError):
with pytest.raises(ValueError):
loaded_model._with_retry(fn)


def test__with_retry_fail_fail_win(loaded_model):
def test__with_retry_fail_fail(loaded_model, monkeypatch):
"""fn needs a few tries, tries twice and fails."""

monkeypatch.setattr(loaded_model, "RETRY_COUNT", 1) # less than 3 tries

def generate_ints():
yield from range(9) # More than enough for retry loop

ints = generate_ints()

def fail_fail_win():
for i in ints:
if i < 2: # fail, fail
raise (ValueError(f"fail {i}"))
else: # win and return 3
return i + 1

# Without a third try raises first exception
with pytest.raises(ValueError) as e:
loaded_model._with_retry(fail_fail_win)

assert e.value.args[0] == "fail 0", "expected first exception 'fail 0'"


def test__with_retry_fail_fail_win(loaded_model, monkeypatch):
"""fn needs a few tries, logs, loops and succeeds"""

monkeypatch.setattr(loaded_model, "RETRY_COUNT", 6) # test needs at least 3 tries

def generate_ints():
yield from range(9) # More than enough for retry loop

Expand All @@ -748,8 +774,8 @@ def generate_ints():
def fail_fail_win():
for i in ints:
if i < 2: # fail, fail
assert 0
else: # win
raise (ValueError("fail, fail"))
else: # win and return 3
return i + 1

# Third try did not raise an exception. Returns 3.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,14 @@
)
TOK_CLASSIFICATION_RESULT = TokenClassificationResults(results=[FOX_CLASS, DOG_CLASS])

# NOTE: First test will test this separately
BOOTSTRAPPED_MODEL = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)

# Modules that already returns token classification for tests
@module(
"44d61711-c64b-4774-a39f-a9f40f1fcff0",
Expand Down Expand Up @@ -120,13 +128,7 @@ def test_bootstrap_run():

def test_bootstrap_run_with_threshold():
"""Check if we can bootstrap span classification models with overriden threshold"""
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)
token_classification_result = model.run(DOCUMENT, threshold=0.0)
token_classification_result = BOOTSTRAPPED_MODEL.run(DOCUMENT, threshold=0.0)
assert isinstance(token_classification_result, TokenClassificationResults)
assert (
len(token_classification_result.results) == 4
Expand Down Expand Up @@ -187,16 +189,17 @@ def test_bootstrap_run_with_token_classification_no_results():
assert len(token_classification_result.results) == 0


def test_bootstrap_run_empty():
"""Check if span classification model can run with empty string"""
token_classification_result = BOOTSTRAPPED_MODEL.run("")
assert isinstance(token_classification_result, TokenClassificationResults)
assert len(token_classification_result.results) == 0


def test_save_load_and_run_model():
"""Check if we can run a saved model successfully"""
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)
with tempfile.TemporaryDirectory() as model_dir:
model.save(model_dir)
BOOTSTRAPPED_MODEL.save(model_dir)
assert os.path.exists(os.path.join(model_dir, "config.yml"))
assert os.path.exists(os.path.join(model_dir, "tokenizer"))
assert os.path.exists(os.path.join(model_dir, "classification"))
Expand All @@ -216,14 +219,9 @@ def test_run_bidi_stream_model():
"""Check if model prediction works as expected for bi-directional stream"""

stream_input = data_model.DataStream.from_iterable(DOCUMENT)
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
stream_input
)

streaming_token_classification_result = model.run_bidi_stream(stream_input)
assert isinstance(streaming_token_classification_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_token_classification_result)
Expand Down Expand Up @@ -351,14 +349,10 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk():
works as expected for bi-directional stream"""
doc_stream = (DOCUMENT, " I am another sentence.")
stream_input = data_model.DataStream.from_iterable(doc_stream)
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)

streaming_token_classification_result = model.run_bidi_stream(stream_input)
streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
stream_input
)
assert isinstance(streaming_token_classification_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_token_classification_result)
Expand All @@ -385,22 +379,30 @@ def test_run_bidi_stream_with_multiple_spans_in_chunk():
assert count == expected_number_of_sentences


def test_run_bidi_stream_empty():
"""Check if span classification model can run with empty string for streaming"""
stream_input = data_model.DataStream.from_iterable("")
streaming_token_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
stream_input
)
assert isinstance(streaming_token_classification_result, Iterable)
# Convert to list to more easily check outputs
result_list = list(streaming_token_classification_result)
assert len(result_list) == 1
assert result_list[0].results == []
assert result_list[0].processed_index == 0


def test_run_stream_vs_no_stream():
"""Check if model prediction on stream with multiple sentences/spans
works as expected for bi-directional stream and gives expected span results
as non-stream"""
multiple_sentences = (
"The dragon hoarded gold. The cow ate grass. What is happening? What a day!"
)
model = FilteredSpanClassification.bootstrap(
lang="en",
tokenizer=SENTENCE_TOKENIZER,
classifier=BOOTSTRAPPED_SEQ_CLASS_MODEL,
default_threshold=0.5,
)

# Non-stream run
nonstream_classification_result = model.run(multiple_sentences)
nonstream_classification_result = BOOTSTRAPPED_MODEL.run(multiple_sentences)
assert len(nonstream_classification_result.results) == 4
assert nonstream_classification_result.results[0].word == "The dragon hoarded gold."
assert nonstream_classification_result.results[0].start == 0
Expand All @@ -411,7 +413,7 @@ def test_run_stream_vs_no_stream():

# Char-based stream
stream_input = data_model.DataStream.from_iterable(multiple_sentences)
stream_classification_result = model.run_bidi_stream(stream_input)
stream_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(stream_input)
# Convert to list to more easily check outputs
result_list = list(stream_classification_result)
assert len(result_list) == 4 # one per sentence
Expand All @@ -422,7 +424,9 @@ def test_run_stream_vs_no_stream():

# Chunk-based stream
chunk_stream_input = data_model.DataStream.from_iterable((multiple_sentences,))
chunk_stream_classification_result = model.run_bidi_stream(chunk_stream_input)
chunk_stream_classification_result = BOOTSTRAPPED_MODEL.run_bidi_stream(
chunk_stream_input
)
result_list = list(chunk_stream_classification_result)
assert len(result_list) == 4 # one per sentence
assert result_list[0].processed_index == 24
Expand Down

0 comments on commit 5865ed4

Please sign in to comment.