Skip to content

Commit

Permalink
Merge pull request #21 from caikit/main
Browse files Browse the repository at this point in the history
[pull] main from caikit:main
  • Loading branch information
dtrifiro authored May 17, 2024
2 parents af96c62 + 769812f commit c92b98d
Show file tree
Hide file tree
Showing 4 changed files with 233 additions and 48 deletions.
20 changes: 10 additions & 10 deletions .github/workflows/build-image.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@ jobs:
name: Build Image
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v4
with:
python-version: 3.9
- name: Setup tox
run: |
pip install -U pip wheel
pip install tox
- uses: actions/checkout@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Build image
run: |
docker build -t caikit-nlp:latest .
uses: docker/build-push-action@v5
with:
context: .
tags: "caikit-nlp:latest"
load: true
cache-from: type=gha
cache-to: type=gha,mode=max
8 changes: 6 additions & 2 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ COPY pyproject.toml .
COPY tox.ini .
COPY caikit_nlp caikit_nlp
# .git is required for setuptools-scm get the version
RUN --mount=source=.git,target=.git,type=bind tox -e build
RUN --mount=source=.git,target=.git,type=bind \
--mount=type=cache,target=/root/.cache/pip \
tox -e build


FROM base as deploy
Expand All @@ -26,7 +28,9 @@ ENV VIRTUAL_ENV=/opt/caikit
ENV PATH="$VIRTUAL_ENV/bin:$PATH"

COPY --from=builder /build/dist/caikit_nlp*.whl /tmp/
RUN pip install --no-cache /tmp/caikit_nlp*.whl && rm /tmp/caikit_nlp*.whl
RUN --mount=type=cache,target=/root/.cache/pip \
pip install /tmp/caikit_nlp*.whl && \
rm /tmp/caikit_nlp*.whl

COPY LICENSE /opt/caikit/
COPY README.md /opt/caikit/
Expand Down
121 changes: 85 additions & 36 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,15 @@
# Standard
from typing import Iterable

# Third Party
import grpc

# First Party
from caikit.core.exceptions import error_handler
from caikit.core.exceptions.caikit_core_exception import (
CaikitCoreException,
CaikitCoreStatusCode,
)
from caikit.interfaces.nlp.data_model import (
GeneratedTextResult,
GeneratedTextStreamResult,
Expand All @@ -41,20 +48,53 @@
Whether or not the source string should be contained in the generated output,
e.g., as a prefix.
input_tokens: bool
Whether or not to include list of input tokens.
Whether or not to include list of input tokens.
generated_tokens: bool
Whether or not to include list of individual generated tokens.
Whether or not to include list of individual generated tokens.
token_logprobs: bool
Whether or not to include logprob for each returned token.
Whether or not to include logprob for each returned token.
Applicable only if generated_tokens == true and/or input_tokens == true
token_ranks: bool
Whether or not to include rank of each returned token.
Applicable only if generated_tokens == true and/or input_tokens == true
Whether or not to include rank of each returned token.
Applicable only if generated_tokens == true and/or input_tokens == true
""".format(
GENERATE_FUNCTION_ARGS
)


# Mapping from grpc status codes to caikit status codes. There is not a 1:1
# mapping at the moment, so this conversion is lossy!
GRPC_TO_CAIKIT_CORE_STATUS = {
grpc.StatusCode.CANCELLED: CaikitCoreStatusCode.CONNECTION_ERROR,
grpc.StatusCode.UNKNOWN: CaikitCoreStatusCode.UNKNOWN,
grpc.StatusCode.INVALID_ARGUMENT: CaikitCoreStatusCode.INVALID_ARGUMENT,
grpc.StatusCode.DEADLINE_EXCEEDED: CaikitCoreStatusCode.CONNECTION_ERROR,
grpc.StatusCode.NOT_FOUND: CaikitCoreStatusCode.NOT_FOUND,
grpc.StatusCode.ALREADY_EXISTS: CaikitCoreStatusCode.INVALID_ARGUMENT,
grpc.StatusCode.PERMISSION_DENIED: CaikitCoreStatusCode.FORBIDDEN,
grpc.StatusCode.RESOURCE_EXHAUSTED: CaikitCoreStatusCode.INVALID_ARGUMENT,
grpc.StatusCode.FAILED_PRECONDITION: CaikitCoreStatusCode.INVALID_ARGUMENT,
grpc.StatusCode.ABORTED: CaikitCoreStatusCode.CONNECTION_ERROR,
grpc.StatusCode.OUT_OF_RANGE: CaikitCoreStatusCode.INVALID_ARGUMENT,
grpc.StatusCode.UNIMPLEMENTED: CaikitCoreStatusCode.UNKNOWN,
grpc.StatusCode.INTERNAL: CaikitCoreStatusCode.FATAL,
grpc.StatusCode.UNAVAILABLE: CaikitCoreStatusCode.CONNECTION_ERROR,
grpc.StatusCode.DATA_LOSS: CaikitCoreStatusCode.CONNECTION_ERROR,
grpc.StatusCode.UNAUTHENTICATED: CaikitCoreStatusCode.UNAUTHORIZED,
}


def raise_caikit_core_exception(rpc_error: grpc.RpcError):
"""Helper to wrap logic of converting from grpc.RpcError ->
CaikitCoreException
"""
caikit_status_code = GRPC_TO_CAIKIT_CORE_STATUS.get(
rpc_error.code(), CaikitCoreStatusCode.UNKNOWN
)
error_message = rpc_error.details() or f"Unknown RpcError: {rpc_error}"
raise CaikitCoreException(caikit_status_code, error_message) from rpc_error


def validate_inf_params(
text,
preserve_input_text,
Expand Down Expand Up @@ -391,7 +431,10 @@ def unary_generate(

# Currently, we send a batch request of len(x)==1, so we expect one response back
with alog.ContextTimer(log.trace, "TGIS request duration: "):
batch_response = self.tgis_client.Generate(request)
try:
batch_response = self.tgis_client.Generate(request)
except grpc.RpcError as err:
raise_caikit_core_exception(err)

error.value_check(
"<NLP38899018E>",
Expand Down Expand Up @@ -532,37 +575,40 @@ def stream_generate(
)

# stream GenerationResponse
stream_response = self.tgis_client.GenerateStream(request)

for stream_part in stream_response:
details = TokenStreamDetails(
finish_reason=stream_part.stop_reason,
generated_tokens=stream_part.generated_token_count,
seed=stream_part.seed,
input_token_count=stream_part.input_token_count,
)
token_list = []
if stream_part.tokens is not None:
for token in stream_part.tokens:
token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
try:
stream_response = self.tgis_client.GenerateStream(request)

for stream_part in stream_response:
details = TokenStreamDetails(
finish_reason=stream_part.stop_reason,
generated_tokens=stream_part.generated_token_count,
seed=stream_part.seed,
input_token_count=stream_part.input_token_count,
)
token_list = []
if stream_part.tokens is not None:
for token in stream_part.tokens:
token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
)
)
)
input_token_list = []
if stream_part.input_tokens is not None:
for token in stream_part.input_tokens:
input_token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
input_token_list = []
if stream_part.input_tokens is not None:
for token in stream_part.input_tokens:
input_token_list.append(
GeneratedToken(
text=token.text, logprob=token.logprob, rank=token.rank
)
)
)
yield GeneratedTextStreamResult(
generated_text=stream_part.text,
tokens=token_list,
input_tokens=input_token_list,
details=details,
)
yield GeneratedTextStreamResult(
generated_text=stream_part.text,
tokens=token_list,
input_tokens=input_token_list,
details=details,
)
except grpc.RpcError as err:
raise_caikit_core_exception(err)

def unary_tokenize(
self,
Expand Down Expand Up @@ -598,7 +644,10 @@ def unary_tokenize(

# Currently, we send a batch request of len(x)==1, so we expect one response back
with alog.ContextTimer(log.trace, "TGIS request duration: "):
batch_response = self.tgis_client.Tokenize(request)
try:
batch_response = self.tgis_client.Tokenize(request)
except grpc.RpcError as err:
raise_caikit_core_exception(err)

error.value_check(
"<NLP38899081E>",
Expand Down
132 changes: 132 additions & 0 deletions tests/toolkit/text_generation/test_tgis_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# Copyright The Caikit Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Tests for tgis_utils
"""
# Standard
from typing import Iterable, Optional, Type

# Third Party
import grpc
import grpc._channel
import pytest

# First Party
from caikit.core.data_model import ProducerId
from caikit.core.exceptions.caikit_core_exception import CaikitCoreException
from caikit_tgis_backend.protobufs import generation_pb2

# Local
from caikit_nlp.toolkit.text_generation import tgis_utils

## Helpers #####################################################################


class MockTgisClient:
"""Mock of a TGIS client that doesn't actually call anything"""

def __init__(
self,
status_code: Optional[grpc.StatusCode],
error_message: str = "Yikes",
):
self._status_code = status_code
self._error_message = error_message

def _maybe_raise(self, error_type: Type[grpc.RpcError], *args):
if self._status_code not in [None, grpc.StatusCode.OK]:
raise error_type(
grpc._channel._RPCState(
[], [], [], code=self._status_code, details=self._error_message
),
*args,
)

def Generate(
self,
request: generation_pb2.BatchedGenerationRequest,
) -> generation_pb2.BatchedGenerationResponse:
self._maybe_raise(grpc._channel._InactiveRpcError)
return generation_pb2.BatchedGenerationResponse()

def GenerateStream(
self,
request: generation_pb2.SingleGenerationRequest,
) -> Iterable[generation_pb2.GenerationResponse]:
self._maybe_raise(grpc._channel._MultiThreadedRendezvous, None, None, None)
yield generation_pb2.GenerationResponse()

def Tokenize(
self,
request: generation_pb2.BatchedTokenizeRequest,
) -> generation_pb2.BatchedTokenizeResponse:
self._maybe_raise(grpc._channel._InactiveRpcError)
return generation_pb2.BatchedTokenizeResponse()


## TGISGenerationClient ########################################################


@pytest.mark.parametrize(
"status_code",
[code for code in grpc.StatusCode if code != grpc.StatusCode.OK],
)
@pytest.mark.parametrize(
"method", ["unary_generate", "stream_generate", "unary_tokenize"]
)
def test_TGISGenerationClient_rpc_errors(status_code, method):
"""Test that raised errors in downstream RPCs are converted to
CaikitCoreException correctly
"""
tgis_client = MockTgisClient(status_code)
gen_client = tgis_utils.TGISGenerationClient(
"foo",
"bar",
tgis_client,
ProducerId("foobar"),
)
with pytest.raises(CaikitCoreException) as context:
kwargs = (
dict(
preserve_input_text=True,
input_tokens=True,
generated_tokens=True,
token_logprobs=True,
token_ranks=True,
max_new_tokens=20,
min_new_tokens=20,
truncate_input_tokens=True,
decoding_method="GREEDY",
top_k=None,
top_p=None,
typical_p=None,
temperature=None,
seed=None,
repetition_penalty=0.5,
max_time=None,
exponential_decay_length_penalty=None,
stop_sequences=["asdf"],
)
if method.endswith("_generate")
else dict()
)
res = getattr(gen_client, method)(text="foobar", **kwargs)
if method.startswith("stream_"):
next(res)

assert (
context.value.status_code == tgis_utils.GRPC_TO_CAIKIT_CORE_STATUS[status_code]
)
rpc_err = context.value.__context__
assert isinstance(rpc_err, grpc.RpcError)

0 comments on commit c92b98d

Please sign in to comment.