From 6ecc54a36e03e312f777d30774e0223143c79e2e Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Wed, 15 May 2024 19:09:09 -0500 Subject: [PATCH 01/29] :thread: Add timeout configuration for TGIS streaming request as an experiment Signed-off-by: gkumbhat --- caikit_nlp/toolkit/text_generation/tgis_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 6c003c9d..7bdea64d 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -576,7 +576,7 @@ def stream_generate( # stream GenerationResponse try: - stream_response = self.tgis_client.GenerateStream(request) + stream_response = self.tgis_client.GenerateStream(request, timeout=60) # set timeout to 60s, TODO: Make this configurable for stream_part in stream_response: details = TokenStreamDetails( From 40fb403166291980437f03e7e378a6ba3dd40f65 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Thu, 16 May 2024 22:49:02 -0500 Subject: [PATCH 02/29] :sparkles: Add tgis req timeout as configurable parameter Signed-off-by: gkumbhat --- caikit_nlp/config/config.yml | 4 ++++ caikit_nlp/toolkit/text_generation/tgis_utils.py | 9 ++++++--- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/caikit_nlp/config/config.yml b/caikit_nlp/config/config.yml index f6e7f026..ec18b4e7 100644 --- a/caikit_nlp/config/config.yml +++ b/caikit_nlp/config/config.yml @@ -56,3 +56,7 @@ embedding: runtime: library: caikit_nlp + + +# Configure request timeout for TGIS backend (in seconds) +tgis_request_timeout: 60 diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 7bdea64d..e2216378 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -20,6 +20,7 @@ import grpc # First Party +from caikit import get_config from caikit.core.exceptions import error_handler from caikit.core.exceptions.caikit_core_exception import ( CaikitCoreException, @@ -326,6 +327,8 @@ def __init__( self.producer_id = producer_id self.prefix_id = prefix_id + self.tgis_req_timeout = get_config().tgis_request_timeout + def unary_generate( self, text, @@ -432,7 +435,7 @@ def unary_generate( # Currently, we send a batch request of len(x)==1, so we expect one response back with alog.ContextTimer(log.trace, "TGIS request duration: "): try: - batch_response = self.tgis_client.Generate(request) + batch_response = self.tgis_client.Generate(request, timeout=self.tgis_req_timeout) except grpc.RpcError as err: raise_caikit_core_exception(err) @@ -576,7 +579,7 @@ def stream_generate( # stream GenerationResponse try: - stream_response = self.tgis_client.GenerateStream(request, timeout=60) # set timeout to 60s, TODO: Make this configurable + stream_response = self.tgis_client.GenerateStream(request, timeout=self.tgis_req_timeout) for stream_part in stream_response: details = TokenStreamDetails( @@ -645,7 +648,7 @@ def unary_tokenize( # Currently, we send a batch request of len(x)==1, so we expect one response back with alog.ContextTimer(log.trace, "TGIS request duration: "): try: - batch_response = self.tgis_client.Tokenize(request) + batch_response = self.tgis_client.Tokenize(request, timeout=self.tgis_req_timeout) except grpc.RpcError as err: raise_caikit_core_exception(err) From 48216ebc7a54439bf87d0fef556e4e9252246920 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Thu, 16 May 2024 23:21:58 -0500 Subject: [PATCH 03/29] :white_check_mark: Fix tgis client fixture for acceting kwargs Signed-off-by: gkumbhat --- tests/fixtures/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index e7440e83..6f643476 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -198,13 +198,13 @@ class StubTGISClient: def __init__(self, base_model_name): pass - def Generate(self, request): + def Generate(self, request, **kwargs): return StubTGISClient.unary_generate(request) - def GenerateStream(self, request): + def GenerateStream(self, request, **kwargs): return StubTGISClient.stream_generate(request) - def Tokenize(self, request): + def Tokenize(self, request, **kwargs): return StubTGISClient.tokenize(request) @staticmethod From 401510de8d82e7060e9fa35fe3785097fa32f9f1 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Fri, 17 May 2024 08:04:47 -0500 Subject: [PATCH 04/29] :art: Fix formatting Signed-off-by: gkumbhat --- caikit_nlp/config/config.yml | 1 - caikit_nlp/toolkit/text_generation/tgis_utils.py | 12 +++++++++--- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/caikit_nlp/config/config.yml b/caikit_nlp/config/config.yml index ec18b4e7..6f440a22 100644 --- a/caikit_nlp/config/config.yml +++ b/caikit_nlp/config/config.yml @@ -57,6 +57,5 @@ embedding: runtime: library: caikit_nlp - # Configure request timeout for TGIS backend (in seconds) tgis_request_timeout: 60 diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index e2216378..8dd10d43 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -435,7 +435,9 @@ def unary_generate( # Currently, we send a batch request of len(x)==1, so we expect one response back with alog.ContextTimer(log.trace, "TGIS request duration: "): try: - batch_response = self.tgis_client.Generate(request, timeout=self.tgis_req_timeout) + batch_response = self.tgis_client.Generate( + request, timeout=self.tgis_req_timeout + ) except grpc.RpcError as err: raise_caikit_core_exception(err) @@ -579,7 +581,9 @@ def stream_generate( # stream GenerationResponse try: - stream_response = self.tgis_client.GenerateStream(request, timeout=self.tgis_req_timeout) + stream_response = self.tgis_client.GenerateStream( + request, timeout=self.tgis_req_timeout + ) for stream_part in stream_response: details = TokenStreamDetails( @@ -648,7 +652,9 @@ def unary_tokenize( # Currently, we send a batch request of len(x)==1, so we expect one response back with alog.ContextTimer(log.trace, "TGIS request duration: "): try: - batch_response = self.tgis_client.Tokenize(request, timeout=self.tgis_req_timeout) + batch_response = self.tgis_client.Tokenize( + request, timeout=self.tgis_req_timeout + ) except grpc.RpcError as err: raise_caikit_core_exception(err) From fc5ebdaff1532458c0240ca733a7d6fb66ba6748 Mon Sep 17 00:00:00 2001 From: gkumbhat Date: Fri, 17 May 2024 08:17:47 -0500 Subject: [PATCH 05/29] :bug::white_check_mark: Fix fixture for tgis tests Signed-off-by: gkumbhat --- tests/fixtures/__init__.py | 6 +++--- tests/toolkit/text_generation/test_tgis_utils.py | 9 +++------ 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 6f643476..1d5e5d2b 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -208,7 +208,7 @@ def Tokenize(self, request, **kwargs): return StubTGISClient.tokenize(request) @staticmethod - def unary_generate(request): + def unary_generate(request, **kwargs): fake_response = mock.Mock() fake_result = mock.Mock() fake_result.stop_reason = 5 @@ -229,7 +229,7 @@ def unary_generate(request): return fake_response @staticmethod - def stream_generate(request): + def stream_generate(request, **kwargs): fake_stream = mock.Mock() fake_stream.stop_reason = 5 fake_stream.generated_token_count = 1 @@ -250,7 +250,7 @@ def stream_generate(request): yield fake_stream @staticmethod - def tokenize(request): + def tokenize(request, **kwargs): fake_response = mock.Mock() fake_result = mock.Mock() fake_result.token_count = 1 diff --git a/tests/toolkit/text_generation/test_tgis_utils.py b/tests/toolkit/text_generation/test_tgis_utils.py index 7c1093a7..e6339108 100644 --- a/tests/toolkit/text_generation/test_tgis_utils.py +++ b/tests/toolkit/text_generation/test_tgis_utils.py @@ -54,22 +54,19 @@ def _maybe_raise(self, error_type: Type[grpc.RpcError], *args): ) def Generate( - self, - request: generation_pb2.BatchedGenerationRequest, + self, request: generation_pb2.BatchedGenerationRequest, **kwargs ) -> generation_pb2.BatchedGenerationResponse: self._maybe_raise(grpc._channel._InactiveRpcError) return generation_pb2.BatchedGenerationResponse() def GenerateStream( - self, - request: generation_pb2.SingleGenerationRequest, + self, request: generation_pb2.SingleGenerationRequest, **kwargs ) -> Iterable[generation_pb2.GenerationResponse]: self._maybe_raise(grpc._channel._MultiThreadedRendezvous, None, None, None) yield generation_pb2.GenerationResponse() def Tokenize( - self, - request: generation_pb2.BatchedTokenizeRequest, + self, request: generation_pb2.BatchedTokenizeRequest, **kwargs ) -> generation_pb2.BatchedTokenizeResponse: self._maybe_raise(grpc._channel._InactiveRpcError) return generation_pb2.BatchedTokenizeResponse() From 2dd91d2a34b8ae5df751a764b32114d1e50e31ce Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Fri, 3 May 2024 08:45:50 -0300 Subject: [PATCH 06/29] Expose model information for embeddings service Signed-off-by: Flavia Beo --- caikit_nlp/modules/text_embedding/embedding.py | 16 +++++++++++++++- tests/modules/text_embedding/test_embedding.py | 15 +++++++++++++++ 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index d4447ebe..c1eb7173 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -15,7 +15,7 @@ # Standard from collections.abc import Sized from enum import Enum, auto -from typing import Callable, Dict, List, NamedTuple, Optional, TypeVar, Union +from typing import Any, Callable, Dict, List, NamedTuple, Optional, TypeVar, Union import importlib import os import time @@ -178,6 +178,20 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": return cls(model) + @property + def public_model_info(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument + """Helper property to return public metadata about a specific Model. This + function is separate from `metdata` as that contains the entire ModelConfig + which might not want to be shared/exposed. + + Returns: + Dict[str, str]: A dictionary of this models's public metadata + """ + return { + "max_seq_length": cls.model.max_seq_length, + "sentence_embedding_dimension": cls.model.get_sentence_embedding_dimension(), + } + @classmethod def _get_ipex(cls, ipex_flag): """Get IPEX optimization library if enabled and available, else return False diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index e625588b..d630b747 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -197,6 +197,21 @@ def test_save_load_and_run(): _assert_is_expected_embedding_result(result) +def test_public_model_info(): + """Check if we can get model info successfully""" + model_id = "model_id" + with tempfile.TemporaryDirectory(suffix="-1st") as model_dir: + model_path = os.path.join(model_dir, model_id) + BOOTSTRAPPED_MODEL.save(model_path) + new_model = EmbeddingModule.load(model_path) + + result = new_model.public_model_info + assert "max_seq_length" in result + assert "sentence_embedding_dimension" in result + assert type(result["max_seq_length"]) is int + assert type(result["sentence_embedding_dimension"]) is int + + @pytest.mark.parametrize( "model_path", ["", " ", " " * 100], ids=["empty", "space", "spaces"] ) From 756b1e191f125ecee6f46ef94667a13132f3e0f4 Mon Sep 17 00:00:00 2001 From: Flavia Beo Date: Wed, 29 May 2024 13:53:50 -0300 Subject: [PATCH 07/29] Bump lower caikit version Signed-off-by: Flavia Beo --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bcb82db0..3fc25ee2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.26.17,<0.27.0", + "caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0", "caikit-tgis-backend>=0.1.27,<0.2.0", # TODO: loosen dependencies "accelerate>=0.22.0", From 7e328c7d74216748d8b9be9f882d5ce522f5f9f1 Mon Sep 17 00:00:00 2001 From: Shonda-Adena-Witherspoon Date: Tue, 4 Jun 2024 14:32:20 -0500 Subject: [PATCH 08/29] added logging around tgis timout config setting Signed-off-by: Shonda-Adena-Witherspoon --- caikit_nlp/toolkit/text_generation/tgis_utils.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 8dd10d43..25026a65 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -329,6 +329,21 @@ def __init__( self.tgis_req_timeout = get_config().tgis_request_timeout + if ( + not self.tgis_req_timeout + or not isinstance(self.tgis_req_timeout, int) + or self.tgis_req_timeout <= 0 + ): + log.info("", "TGIS timeout not set") + self.tgis_req_timeout = None + + else: + log.info( + "", + "Setting TGIS timeout value to %d", + self.tgis_req_timeout, + ) + def unary_generate( self, text, From f4496b95b268d8326b9a3a6646704827ea67eb2a Mon Sep 17 00:00:00 2001 From: swith004 Date: Tue, 4 Jun 2024 18:03:19 -0400 Subject: [PATCH 09/29] Update caikit_nlp/toolkit/text_generation/tgis_utils.py change logging to debug Co-authored-by: Gaurav Kumbhat Signed-off-by: swith004 Signed-off-by: Shonda-Adena-Witherspoon --- caikit_nlp/toolkit/text_generation/tgis_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 25026a65..57d1b2e8 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -334,7 +334,7 @@ def __init__( or not isinstance(self.tgis_req_timeout, int) or self.tgis_req_timeout <= 0 ): - log.info("", "TGIS timeout not set") + log.debug("", "TGIS timeout not set") self.tgis_req_timeout = None else: From 7211497c771946dc90d0a46ed559b879013fe45d Mon Sep 17 00:00:00 2001 From: swith004 Date: Tue, 4 Jun 2024 18:03:42 -0400 Subject: [PATCH 10/29] Update caikit_nlp/toolkit/text_generation/tgis_utils.py change logging to debug Co-authored-by: Gaurav Kumbhat Signed-off-by: swith004 Signed-off-by: Shonda-Adena-Witherspoon --- caikit_nlp/toolkit/text_generation/tgis_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 57d1b2e8..f6185540 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -338,7 +338,7 @@ def __init__( self.tgis_req_timeout = None else: - log.info( + log.debug "", "Setting TGIS timeout value to %d", self.tgis_req_timeout, From b710556a587ca79f261879017764b74718b1e7bc Mon Sep 17 00:00:00 2001 From: Shonda-Adena-Witherspoon Date: Tue, 4 Jun 2024 17:12:37 -0500 Subject: [PATCH 11/29] fixed formatting Signed-off-by: Shonda-Adena-Witherspoon --- caikit_nlp/toolkit/text_generation/tgis_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index f6185540..d10a0531 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -338,7 +338,7 @@ def __init__( self.tgis_req_timeout = None else: - log.debug + log.debug( "", "Setting TGIS timeout value to %d", self.tgis_req_timeout, From e278915de20a6c6301971210805e3517694023d6 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Thu, 6 Jun 2024 10:56:01 -0400 Subject: [PATCH 12/29] add get_route_info Signed-off-by: Mynhardt Burger --- .../toolkit/text_generation/tgis_utils.py | 36 ++++++++++++++++-- .../text_generation/test_tgis_utils.py | 37 +++++++++++++++++++ 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index d10a0531..96598864 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -"""This file is for helper functions related to TGIS. -""" +"""This file is for helper functions related to TGIS.""" + # Standard -from typing import Iterable +from typing import Iterable, Optional, Tuple # Third Party +import fastapi import grpc # First Party @@ -33,6 +34,7 @@ TokenizationResults, TokenStreamDetails, ) +from caikit.interfaces.runtime.data_model import RuntimeServerContextType from caikit_tgis_backend.protobufs import generation_pb2 import alog @@ -683,3 +685,31 @@ def unary_tokenize( return TokenizationResults( token_count=response.token_count, ) + + +def get_route_info( + context: Optional[RuntimeServerContextType], +) -> Tuple[bool, Optional[str]]: + """ + Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in the headers/metadata. + + Otherwise returns a tuple `(False, None)` if "x-route-info" was not found in the context or if context is None. + """ + if context is None: + return False, None + + if isinstance(context, grpc.ServicerContext): + route_info = dict(context.invocation_metadata()).get("x-route-info") + if route_info: + return True, route_info + elif isinstance(context, fastapi.Request): + route_info = context.headers.get("x-route-info") + if route_info: + return True, route_info + else: + error.log_raise( + "", + ValueError(f"context is of an unsupported type: {type(context)}"), + ) + + return False, None diff --git a/tests/toolkit/text_generation/test_tgis_utils.py b/tests/toolkit/text_generation/test_tgis_utils.py index e6339108..cb9822ab 100644 --- a/tests/toolkit/text_generation/test_tgis_utils.py +++ b/tests/toolkit/text_generation/test_tgis_utils.py @@ -14,10 +14,12 @@ """ Tests for tgis_utils """ + # Standard from typing import Iterable, Optional, Type # Third Party +import fastapi import grpc import grpc._channel import pytest @@ -25,6 +27,7 @@ # First Party from caikit.core.data_model import ProducerId from caikit.core.exceptions.caikit_core_exception import CaikitCoreException +from caikit.interfaces.runtime.data_model import RuntimeServerContextType from caikit_tgis_backend.protobufs import generation_pb2 # Local @@ -127,3 +130,37 @@ def test_TGISGenerationClient_rpc_errors(status_code, method): ) rpc_err = context.value.__context__ assert isinstance(rpc_err, grpc.RpcError) + + +@pytest.mark.parametrize( + argnames=["context", "ok", "route_info"], + argvalues=[ + ( + fastapi.Request( + {"type": "http", "headers": [(b"x-route-info", b"sometext")]} + ), + True, + "sometext", + ), + ( + fastapi.Request( + {"type": "http", "headers": [(b"route-info", b"sometext")]} + ), + False, + None, + ), + ("should raise ValueError", False, None), + (None, False, None), + # Uncertain how to create a grpc.ServicerContext object + ], +) +def test_get_route_info( + context: RuntimeServerContextType, ok: bool, route_info: Optional[str] +): + if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))): + with pytest.raises(ValueError): + tgis_utils.get_route_info(context) + else: + actual_ok, actual_route_info = tgis_utils.get_route_info(context) + assert actual_ok == ok + assert actual_route_info == route_info From e473c3302bcd2e890ae5b170cf6f6580d7ac3686 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Thu, 6 Jun 2024 10:57:04 -0400 Subject: [PATCH 13/29] lazily create model_connection and _client Signed-off-by: Mynhardt Burger --- .../text_generation/text_generation_tgis.py | 47 +++++++++++++++---- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 7c55cc2c..da6a2482 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -14,6 +14,7 @@ # Standard +from functools import cached_property from typing import Iterable, List, Optional, Tuple, Union import os @@ -30,6 +31,7 @@ TokenizationResults, ) from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask +from caikit.interfaces.runtime.data_model import RuntimeServerContextType from caikit_tgis_backend import TGISBackend import alog @@ -43,6 +45,7 @@ from ...toolkit.text_generation.tgis_utils import ( GENERATE_FUNCTION_TGIS_ARGS, TGISGenerationClient, + get_route_info, ) from .text_generation_local import TextGeneration @@ -86,14 +89,7 @@ def __init__( # Set _model_loaded as False by default. This will only get set to True if # we enable the tgis_backend and we are able to fetch the client successfully. self._model_loaded = False - # Configure the internal client - # NOTE: This is made optional for the cases where we do not need to execute `.run` function - # for example, bootstrapping a model to caikit format and saving. - self._client = None if tgis_backend: - self._client = tgis_backend.get_client(model_name) - # mark that the model is loaded so that we can unload it later - self._model_loaded = True self.tgis_backend = tgis_backend self._bos_token = bos_token @@ -109,6 +105,14 @@ def __del__(self): if self._model_loaded and self.tgis_backend: self.tgis_backend.unload_model(self.model_name) + @cached_property + def _client(self): + # Configure the internal client + # NOTE: This is made optional for the cases where we do not need to execute `.run` function + # for example, bootstrapping a model to caikit format and saving. + if hasattr(self, "tgis_backend") and self.tgis_backend: + return self.tgis_backend.get_client(self.model_name) + @classmethod def bootstrap(cls, model_path: str, load_backend: Union[BackendBase, None] = None): """Function to bootstrap a pre-trained transformers model and @@ -207,7 +211,7 @@ def save(self, model_path: str): ) # pylint: disable=duplicate-code - @TextGenerationTask.taskmethod() + @TextGenerationTask.taskmethod(context_arg="context") def run( self, text: str, @@ -231,6 +235,7 @@ def run( generated_tokens: bool = True, token_logprobs: bool = True, token_ranks: bool = True, + context: Optional[RuntimeServerContextType] = None, ) -> GeneratedTextResult: f"""Run inference against the model running in TGIS. @@ -240,6 +245,9 @@ def run( GeneratedTextResult Generated text result produced by TGIS. """ + if self.tgis_backend: + self._register_model_connection_with_context(context) + if self._model_loaded: return self.tgis_generation_client.unary_generate( text=text, @@ -263,7 +271,7 @@ def run( stop_sequences=stop_sequences, ) - @TextGenerationTask.taskmethod(output_streaming=True) + @TextGenerationTask.taskmethod(output_streaming=True, context_arg="context") def run_stream_out( self, text: str, @@ -287,6 +295,7 @@ def run_stream_out( generated_tokens: bool = True, token_logprobs: bool = True, token_ranks: bool = True, + context: Optional[RuntimeServerContextType] = None, ) -> Iterable[GeneratedTextStreamResult]: f"""Run output stream inferencing for text generation module. @@ -295,6 +304,8 @@ def run_stream_out( Returns: Iterable[GeneratedTextStreamResult] """ + if self.tgis_backend: + self._register_model_connection_with_context(context) if self._model_loaded: return self.tgis_generation_client.stream_generate( @@ -319,10 +330,11 @@ def run_stream_out( stop_sequences=stop_sequences, ) - @TokenizationTask.taskmethod() + @TokenizationTask.taskmethod(context_arg="context") def run_tokenizer( self, text: str, + context: Optional[RuntimeServerContextType] = None, ) -> TokenizationResults: """Run tokenization task against the model running in TGIS. @@ -333,7 +345,22 @@ def run_tokenizer( TokenizationResults The token count """ + if self.tgis_backend: + self._register_model_connection_with_context(context) + if self._model_loaded: return self.tgis_generation_client.unary_tokenize( text=text, ) + + def _register_model_connection_with_context( + self, context: Optional[RuntimeServerContextType] + ): + ok, route_info = get_route_info(context) + if ok: + self.tgis_backend.register_model_connection( + self.model_name, {"hostname": route_info} + ) + else: + self.tgis_backend.register_model_connection(self.model_name) + self._model_loaded = True From 5578b7a0619ac7e6ce415266ebd08684d954bcb0 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Thu, 6 Jun 2024 11:28:59 -0400 Subject: [PATCH 14/29] lazy load model_connection and tgis client for peft Signed-off-by: Mynhardt Burger --- .../text_generation/peft_tgis_remote.py | 53 ++++++++++++++++--- 1 file changed, 45 insertions(+), 8 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 6a87ab45..661c8f0a 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -14,7 +14,9 @@ """This file contains a distributed backend implementation for leveraging the PEFT-trained prompt vectors in TGIS generation requests. """ + # Standard +from functools import cached_property from typing import Iterable, List, Optional, Tuple, Union import os @@ -32,6 +34,7 @@ TokenizationResults, ) from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask +from caikit.interfaces.runtime.data_model import RuntimeServerContextType from caikit_tgis_backend import TGISBackend import alog @@ -40,6 +43,7 @@ from ...toolkit.text_generation.tgis_utils import ( GENERATE_FUNCTION_TGIS_ARGS, TGISGenerationClient, + get_route_info, ) from ...toolkit.verbalizer_utils import render_verbalizer from . import PeftPromptTuning @@ -68,15 +72,15 @@ def __init__( prompt_artifacts: Optional[List[str]] = None, ) -> None: super().__init__() - # Configure the internal client - # NOTE: This is made optional for the cases where we do not need to execute `.run` function - # for example, bootstrapping a model to caikit format and saving. - self._client = None + + # self._client = None self._tgis_backend = tgis_backend if enable_backend: + error.type_check( + "", TGISBackend, tgis_backend=self._tgis_backend + ) # get_client will also launch a local TGIS process and get the model # loaded when using the local TGIS backend - self._client = tgis_backend.get_client(base_model_name) # Tell the backend to load all of the available prompt files if prompt_artifacts: @@ -107,6 +111,14 @@ def __del__(self): if tgis_backend and prompt_cache_id and model_id: tgis_backend.unload_prompt_artifacts(model_id, prompt_cache_id) + @cached_property + def _client(self): + # Configure the internal client + # NOTE: This is made optional for the cases where we do not need to execute `.run` function + # for example, bootstrapping a model to caikit format and saving. + if hasattr(self, "tgis_backend") and self._tgis_backend: + return self._tgis_backend.get_client(self.base_model_name) + @classmethod def load(cls, model_path: str, load_backend: BackendBase) -> "PeftPromptTuningTGIS": """Load a TGIS Peft Prompt Tuning distributed module. Note that we do not @@ -182,7 +194,7 @@ def save(self, model_path: str): ) # pylint: disable=duplicate-code - @TextGenerationTask.taskmethod() + @TextGenerationTask.taskmethod(context_arg="context") def run( self, text: str, @@ -206,6 +218,7 @@ def run( generated_tokens: bool = True, token_logprobs: bool = True, token_ranks: bool = True, + context: Optional[RuntimeServerContextType] = None, ) -> GeneratedTextResult: f"""Run inference against the model running in TGIS. @@ -221,6 +234,9 @@ def run( self.enable_backend, "Backend must be configured and loaded with this module before executing `run` call.", ) + if self._tgis_backend: + self._register_model_connection_with_context(context) + verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.unary_generate( text=verbalized_text, @@ -244,7 +260,7 @@ def run( stop_sequences=stop_sequences, ) - @TextGenerationTask.taskmethod(output_streaming=True) + @TextGenerationTask.taskmethod(output_streaming=True, context_arg="context") def run_stream_out( self, text: str, @@ -268,6 +284,7 @@ def run_stream_out( generated_tokens: bool = True, token_logprobs: bool = True, token_ranks: bool = True, + context: Optional[RuntimeServerContextType] = None, ) -> Iterable[GeneratedTextStreamResult]: f"""Run output stream inferencing against the model running in TGIS @@ -283,6 +300,10 @@ def run_stream_out( "Backend must be configured and loaded with this module \ before executing `run_stream_out` call.", ) + + if self._tgis_backend: + self._register_model_connection_with_context(context) + verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.stream_generate( text=verbalized_text, @@ -306,10 +327,11 @@ def run_stream_out( stop_sequences=stop_sequences, ) - @TokenizationTask.taskmethod() + @TokenizationTask.taskmethod(context_arg="context") def run_tokenizer( self, text: str, + context: Optional[RuntimeServerContextType] = None, ) -> TokenizationResults: """Run tokenization task against the model running in TGIS. @@ -320,6 +342,21 @@ def run_tokenizer( TokenizationResults The token count """ + if self._tgis_backend: + self._register_model_connection_with_context(context) + return self.tgis_generation_client.unary_tokenize( text=text, ) + + def _register_model_connection_with_context( + self, context: Optional[RuntimeServerContextType] + ): + ok, route_info = get_route_info(context) + if ok: + self._tgis_backend.register_model_connection( + self.base_model_name, {"hostname": route_info} + ) + else: + self._tgis_backend.register_model_connection(self.base_model_name) + self._model_loaded = True From 71371bd2cc6961380f7a59f4aed3896d3453e9e7 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Thu, 6 Jun 2024 11:44:14 -0400 Subject: [PATCH 15/29] remove commented out code Signed-off-by: Mynhardt Burger --- caikit_nlp/modules/text_generation/peft_tgis_remote.py | 1 - 1 file changed, 1 deletion(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 661c8f0a..2d19e8e4 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -73,7 +73,6 @@ def __init__( ) -> None: super().__init__() - # self._client = None self._tgis_backend = tgis_backend if enable_backend: error.type_check( From 21192c06db76916c86a843e522a3718fffff8805 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Fri, 7 Jun 2024 09:52:20 -0400 Subject: [PATCH 16/29] Address review comments Signed-off-by: Mynhardt Burger --- .../text_generation/peft_tgis_remote.py | 7 +++-- .../text_generation/text_generation_tgis.py | 28 ++++++++----------- .../toolkit/text_generation/tgis_utils.py | 6 ++-- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 2d19e8e4..a22d95a9 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -115,7 +115,7 @@ def _client(self): # Configure the internal client # NOTE: This is made optional for the cases where we do not need to execute `.run` function # for example, bootstrapping a model to caikit format and saving. - if hasattr(self, "tgis_backend") and self._tgis_backend: + if self._tgis_backend: return self._tgis_backend.get_client(self.base_model_name) @classmethod @@ -351,10 +351,13 @@ def run_tokenizer( def _register_model_connection_with_context( self, context: Optional[RuntimeServerContextType] ): + """ + Register a model connection with the configured TGISBackend. + """ ok, route_info = get_route_info(context) if ok: self._tgis_backend.register_model_connection( - self.base_model_name, {"hostname": route_info} + self.base_model_name, {"hostname": route_info}, fill_with_defaults=True ) else: self._tgis_backend.register_model_connection(self.base_model_name) diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index da6a2482..4bfbf74f 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -92,26 +92,22 @@ def __init__( if tgis_backend: self.tgis_backend = tgis_backend + self._tgis_backend = tgis_backend self._bos_token = bos_token self._sep_token = sep_token self._eos_token = eos_token self._pad_token = pad_token - self.tgis_generation_client = TGISGenerationClient( - self.model_name, self._eos_token, self._client, self.PRODUCER_ID - ) def __del__(self): # nothing to unload if we didn't finish loading - if self._model_loaded and self.tgis_backend: - self.tgis_backend.unload_model(self.model_name) + if self._model_loaded and self._tgis_backend: + self._tgis_backend.unload_model(self.model_name) @cached_property def _client(self): - # Configure the internal client - # NOTE: This is made optional for the cases where we do not need to execute `.run` function - # for example, bootstrapping a model to caikit format and saving. - if hasattr(self, "tgis_backend") and self.tgis_backend: - return self.tgis_backend.get_client(self.model_name) + # Lazily configure/create the internal tgis backend client + if self._tgis_backend: + return self._tgis_backend.get_client(self.model_name) @classmethod def bootstrap(cls, model_path: str, load_backend: Union[BackendBase, None] = None): @@ -245,7 +241,7 @@ def run( GeneratedTextResult Generated text result produced by TGIS. """ - if self.tgis_backend: + if self._tgis_backend: self._register_model_connection_with_context(context) if self._model_loaded: @@ -304,7 +300,7 @@ def run_stream_out( Returns: Iterable[GeneratedTextStreamResult] """ - if self.tgis_backend: + if self._tgis_backend: self._register_model_connection_with_context(context) if self._model_loaded: @@ -345,7 +341,7 @@ def run_tokenizer( TokenizationResults The token count """ - if self.tgis_backend: + if self._tgis_backend: self._register_model_connection_with_context(context) if self._model_loaded: @@ -358,9 +354,9 @@ def _register_model_connection_with_context( ): ok, route_info = get_route_info(context) if ok: - self.tgis_backend.register_model_connection( - self.model_name, {"hostname": route_info} + self._tgis_backend.register_model_connection( + self.model_name, {"hostname": route_info}, fill_with_defaults=True ) else: - self.tgis_backend.register_model_connection(self.model_name) + self._tgis_backend.register_model_connection(self.model_name) self._model_loaded = True diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 96598864..b5d41d1f 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -86,6 +86,8 @@ grpc.StatusCode.UNAUTHENTICATED: CaikitCoreStatusCode.UNAUTHORIZED, } +ROUTE_INFO_KEY = "x-route-info" + def raise_caikit_core_exception(rpc_error: grpc.RpcError): """Helper to wrap logic of converting from grpc.RpcError -> @@ -699,11 +701,11 @@ def get_route_info( return False, None if isinstance(context, grpc.ServicerContext): - route_info = dict(context.invocation_metadata()).get("x-route-info") + route_info = dict(context.invocation_metadata()).get(ROUTE_INFO_KEY) if route_info: return True, route_info elif isinstance(context, fastapi.Request): - route_info = context.headers.get("x-route-info") + route_info = context.headers.get(ROUTE_INFO_KEY) if route_info: return True, route_info else: From 1c235275beec1fdb835d86c7252ac4c1609fbb99 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Fri, 7 Jun 2024 09:52:57 -0400 Subject: [PATCH 17/29] Expand test_get_route_info Signed-off-by: Mynhardt Burger --- tests/fixtures/__init__.py | 19 ++++++++++++++++--- .../text_generation/test_tgis_utils.py | 16 +++++++++++++++- 2 files changed, 31 insertions(+), 4 deletions(-) diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 1d5e5d2b..662f6832 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -1,8 +1,8 @@ -"""Helpful fixtures for configuring individual unit tests. -""" +"""Helpful fixtures for configuring individual unit tests.""" + # Standard from contextlib import contextmanager -from typing import Iterable, Optional +from typing import Iterable, Optional, Union from unittest import mock import json import os @@ -191,6 +191,7 @@ def requires_determinism(request): ### Common TGIS stub classes + # Helper stubs / mocks; we use these to patch caikit so that we don't actually # test the TGIS backend directly, and instead stub the client and inspect the # args that we pass to it. @@ -342,3 +343,15 @@ def temp_config(**overrides): with mock.patch.object(caikit.config.config, "_IMMUTABLE_CONFIG", local_config): yield local_config + + +class TestServicerContext: + """ + A dummy class for mimicking ServicerContext invocation metadata storage. + """ + + def __init__(self, metadata: dict[str, Union[str, bytes]]): + self.metadata = metadata + + def invocation_metadata(self): + return list(self.metadata.items()) diff --git a/tests/toolkit/text_generation/test_tgis_utils.py b/tests/toolkit/text_generation/test_tgis_utils.py index cb9822ab..beca3e9c 100644 --- a/tests/toolkit/text_generation/test_tgis_utils.py +++ b/tests/toolkit/text_generation/test_tgis_utils.py @@ -32,6 +32,7 @@ # Local from caikit_nlp.toolkit.text_generation import tgis_utils +from tests.fixtures import TestServicerContext ## Helpers ##################################################################### @@ -137,7 +138,10 @@ def test_TGISGenerationClient_rpc_errors(status_code, method): argvalues=[ ( fastapi.Request( - {"type": "http", "headers": [(b"x-route-info", b"sometext")]} + { + "type": "http", + "headers": [(tgis_utils.ROUTE_INFO_KEY.encode(), b"sometext")], + } ), True, "sometext", @@ -149,6 +153,16 @@ def test_TGISGenerationClient_rpc_errors(status_code, method): False, None, ), + ( + TestServicerContext({tgis_utils.ROUTE_INFO_KEY: "sometext"}), + True, + "sometext", + ), + ( + TestServicerContext({"route-info": "sometext"}), + False, + None, + ), ("should raise ValueError", False, None), (None, False, None), # Uncertain how to create a grpc.ServicerContext object From a4e8539dac33818e2626ce2c51d63063424a3de4 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Fri, 7 Jun 2024 09:53:56 -0400 Subject: [PATCH 18/29] Lazily create generation client Signed-off-by: Mynhardt Burger --- .../text_generation/text_generation_tgis.py | 8 ++++++ .../test_text_generation_tgis.py | 26 ++++++++++++++++--- 2 files changed, 31 insertions(+), 3 deletions(-) diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 4bfbf74f..fead58c5 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -109,6 +109,14 @@ def _client(self): if self._tgis_backend: return self._tgis_backend.get_client(self.model_name) + @cached_property + def tgis_generation_client(self): + # Lazily create the generation client + # This in turn calls self._client which also lazily gets the tgis backend client + return TGISGenerationClient( + self.model_name, self._eos_token, self._client, self.PRODUCER_ID + ) + @classmethod def bootstrap(cls, model_path: str, load_backend: Union[BackendBase, None] = None): """Function to bootstrap a pre-trained transformers model and diff --git a/tests/modules/text_generation/test_text_generation_tgis.py b/tests/modules/text_generation/test_text_generation_tgis.py index 741f1e27..d938688d 100644 --- a/tests/modules/text_generation/test_text_generation_tgis.py +++ b/tests/modules/text_generation/test_text_generation_tgis.py @@ -1,5 +1,5 @@ -"""Tests for text-generation module -""" +"""Tests for text-generation module""" + # Standard from unittest import mock import os @@ -12,6 +12,7 @@ # First Party from caikit.interfaces.nlp.data_model import GeneratedTextResult +from caikit_tgis_backend import TGISBackend import caikit # Local @@ -23,7 +24,6 @@ SEQ2SEQ_LM_MODEL, StubTGISBackend, StubTGISClient, - set_cpu_device, ) SAMPLE_TEXT = "Hello stub" @@ -152,6 +152,26 @@ def test_remote_tgis_only_model(): TextGenerationTGIS.load(model_dir, load_backend=tgis_backend) +def test_client_lazy_load(): + """ + Test that the TGISBackend client is lazy loaded + """ + model_name = "model-name" + tgis_backend = TGISBackend( + {"connection": {"hostname": "{model_id}.localhost:1234"}} + ) + model = TextGenerationTGIS(model_name, tgis_backend=tgis_backend) + + # No tgis_backend client and _model_loaded still False + assert "_client" not in model.__dict__ + assert not model.__dict__.get("_model_loaded", True) + + # Client gets created on accessing ._client + client = model._client + assert client is not None + assert client + + ### Output streaming tests ############################################################## From be72a79254664532d7be725d291ebb75749eb234 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Fri, 7 Jun 2024 11:28:22 -0400 Subject: [PATCH 19/29] Update minimum caikit-tgis-backend version Signed-off-by: Mynhardt Burger --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8fc7a8ae..4d0d7e44 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers=[ ] dependencies = [ "caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0", - "caikit-tgis-backend>=0.1.27,<0.2.0", + "caikit-tgis-backend>=0.1.33,<0.2.0", # TODO: loosen dependencies "grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking "grpcio-reflection>=1.62.2", From d5893d9914d31317e596890abe1ff4adec8644e1 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Fri, 7 Jun 2024 12:19:21 -0400 Subject: [PATCH 20/29] Add debug logs Signed-off-by: Mynhardt Burger --- caikit_nlp/modules/text_generation/peft_tgis_remote.py | 4 ++++ caikit_nlp/modules/text_generation/text_generation_tgis.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index a22d95a9..b7847025 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -356,6 +356,10 @@ def _register_model_connection_with_context( """ ok, route_info = get_route_info(context) if ok: + log.debug( + " Registering remote model connection with context override: 'hostname: %s'", + route_info, + ) self._tgis_backend.register_model_connection( self.base_model_name, {"hostname": route_info}, fill_with_defaults=True ) diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index fead58c5..c03d539d 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -362,6 +362,10 @@ def _register_model_connection_with_context( ): ok, route_info = get_route_info(context) if ok: + log.debug( + " Registering remote model connection with context override: 'hostname: %s'", + route_info, + ) self._tgis_backend.register_model_connection( self.model_name, {"hostname": route_info}, fill_with_defaults=True ) From 984894354dd3361534b4e6071313bdfb24d469a4 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Fri, 7 Jun 2024 12:44:24 -0400 Subject: [PATCH 21/29] Linting Signed-off-by: Mynhardt Burger --- caikit_nlp/modules/text_generation/peft_tgis_remote.py | 3 ++- caikit_nlp/modules/text_generation/text_generation_tgis.py | 3 ++- caikit_nlp/toolkit/text_generation/tgis_utils.py | 6 ++++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index b7847025..5e556b31 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -357,7 +357,8 @@ def _register_model_connection_with_context( ok, route_info = get_route_info(context) if ok: log.debug( - " Registering remote model connection with context override: 'hostname: %s'", + " Registering remote model connection with context " + "override: 'hostname: %s'", route_info, ) self._tgis_backend.register_model_connection( diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index c03d539d..99fe79e9 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -363,7 +363,8 @@ def _register_model_connection_with_context( ok, route_info = get_route_info(context) if ok: log.debug( - " Registering remote model connection with context override: 'hostname: %s'", + " Registering remote model connection with context " + "override: 'hostname: %s'", route_info, ) self._tgis_backend.register_model_connection( diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index b5d41d1f..6e235489 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -693,9 +693,11 @@ def get_route_info( context: Optional[RuntimeServerContextType], ) -> Tuple[bool, Optional[str]]: """ - Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in the headers/metadata. + Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in + the headers/metadata. - Otherwise returns a tuple `(False, None)` if "x-route-info" was not found in the context or if context is None. + Otherwise returns a tuple `(False, None)` if "x-route-info" was not found in the + context or if context is None. """ if context is None: return False, None From 5df0533c52bb3795098d3075e0c4e308329aab28 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Fri, 7 Jun 2024 12:50:46 -0400 Subject: [PATCH 22/29] linting Signed-off-by: Mynhardt Burger --- tests/modules/text_generation/test_text_generation_tgis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/modules/text_generation/test_text_generation_tgis.py b/tests/modules/text_generation/test_text_generation_tgis.py index d938688d..33fdbb39 100644 --- a/tests/modules/text_generation/test_text_generation_tgis.py +++ b/tests/modules/text_generation/test_text_generation_tgis.py @@ -19,6 +19,7 @@ from caikit_nlp.data_model import ExponentialDecayLengthPenalty, GenerationTrainRecord from caikit_nlp.modules.text_generation import TextGeneration, TextGenerationTGIS from caikit_nlp.resources.pretrained_model.hf_auto_seq2seq_lm import HFAutoSeq2SeqLM +from tests.fixtures import set_cpu_device # noqa from tests.fixtures import ( CAUSAL_LM_MODEL, SEQ2SEQ_LM_MODEL, From 6bae66a895ea95810067546a2b832a41e2755cea Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Fri, 7 Jun 2024 13:13:03 -0400 Subject: [PATCH 23/29] Update caikit_nlp/toolkit/text_generation/tgis_utils.py Co-authored-by: Gabe Goodhart Signed-off-by: Mynhardt Burger --- caikit_nlp/toolkit/text_generation/tgis_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 6e235489..1af2b372 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -86,7 +86,8 @@ grpc.StatusCode.UNAUTHENTICATED: CaikitCoreStatusCode.UNAUTHORIZED, } -ROUTE_INFO_KEY = "x-route-info" +# HTTP Header / gRPC Metadata key used to identify a route override +ROUTE_INFO_HEADER_KEY = "x-route-info" def raise_caikit_core_exception(rpc_error: grpc.RpcError): From 46ae07336a6354f3ef906c54cfff1551b65157b1 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Fri, 7 Jun 2024 13:29:28 -0400 Subject: [PATCH 24/29] review comments Signed-off-by: Mynhardt Burger --- .../text_generation/peft_tgis_remote.py | 39 +++++++++---------- .../text_generation/text_generation_tgis.py | 37 +++++++++--------- .../toolkit/text_generation/tgis_utils.py | 16 ++++---- .../text_generation/test_tgis_utils.py | 23 +++++------ 4 files changed, 54 insertions(+), 61 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 5e556b31..67921bd3 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -233,8 +233,7 @@ def run( self.enable_backend, "Backend must be configured and loaded with this module before executing `run` call.", ) - if self._tgis_backend: - self._register_model_connection_with_context(context) + self._register_model_connection_with_context(context) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.unary_generate( @@ -300,8 +299,7 @@ def run_stream_out( before executing `run_stream_out` call.", ) - if self._tgis_backend: - self._register_model_connection_with_context(context) + self._register_model_connection_with_context(context) verbalized_text = render_verbalizer(self.verbalizer, {"input": text}) return self.tgis_generation_client.stream_generate( @@ -341,8 +339,8 @@ def run_tokenizer( TokenizationResults The token count """ - if self._tgis_backend: - self._register_model_connection_with_context(context) + + self._register_model_connection_with_context(context) return self.tgis_generation_client.unary_tokenize( text=text, @@ -352,18 +350,19 @@ def _register_model_connection_with_context( self, context: Optional[RuntimeServerContextType] ): """ - Register a model connection with the configured TGISBackend. + Register a remote model connection with the configured TGISBackend if there is + a context override provided. """ - ok, route_info = get_route_info(context) - if ok: - log.debug( - " Registering remote model connection with context " - "override: 'hostname: %s'", - route_info, - ) - self._tgis_backend.register_model_connection( - self.base_model_name, {"hostname": route_info}, fill_with_defaults=True - ) - else: - self._tgis_backend.register_model_connection(self.base_model_name) - self._model_loaded = True + if self._tgis_backend: + if route_info := get_route_info(context): + log.debug( + " Registering remote model connection with context " + "override: 'hostname: %s'", + route_info, + ) + self._tgis_backend.register_model_connection( + self.base_model_name, + {"hostname": route_info}, + fill_with_defaults=True, + ) + self._model_loaded = True diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index 99fe79e9..b1c344ba 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -249,8 +249,7 @@ def run( GeneratedTextResult Generated text result produced by TGIS. """ - if self._tgis_backend: - self._register_model_connection_with_context(context) + self._register_model_connection_with_context(context) if self._model_loaded: return self.tgis_generation_client.unary_generate( @@ -308,8 +307,7 @@ def run_stream_out( Returns: Iterable[GeneratedTextStreamResult] """ - if self._tgis_backend: - self._register_model_connection_with_context(context) + self._register_model_connection_with_context(context) if self._model_loaded: return self.tgis_generation_client.stream_generate( @@ -349,8 +347,7 @@ def run_tokenizer( TokenizationResults The token count """ - if self._tgis_backend: - self._register_model_connection_with_context(context) + self._register_model_connection_with_context(context) if self._model_loaded: return self.tgis_generation_client.unary_tokenize( @@ -360,16 +357,18 @@ def run_tokenizer( def _register_model_connection_with_context( self, context: Optional[RuntimeServerContextType] ): - ok, route_info = get_route_info(context) - if ok: - log.debug( - " Registering remote model connection with context " - "override: 'hostname: %s'", - route_info, - ) - self._tgis_backend.register_model_connection( - self.model_name, {"hostname": route_info}, fill_with_defaults=True - ) - else: - self._tgis_backend.register_model_connection(self.model_name) - self._model_loaded = True + """ + Register a remote model connection with the configured TGISBackend if there is + a context override provided. + """ + if self._tgis_backend: + if route_info := get_route_info(context): + log.debug( + " Registering remote model connection with context " + "override: 'hostname: %s'", + route_info, + ) + self._tgis_backend.register_model_connection( + self.model_name, {"hostname": route_info}, fill_with_defaults=True + ) + self._model_loaded = True diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 1af2b372..e09aac0f 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -14,7 +14,7 @@ """This file is for helper functions related to TGIS.""" # Standard -from typing import Iterable, Optional, Tuple +from typing import Iterable, Optional # Third Party import fastapi @@ -692,7 +692,7 @@ def unary_tokenize( def get_route_info( context: Optional[RuntimeServerContextType], -) -> Tuple[bool, Optional[str]]: +) -> Optional[str]: """ Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in the headers/metadata. @@ -701,20 +701,20 @@ def get_route_info( context or if context is None. """ if context is None: - return False, None + return None if isinstance(context, grpc.ServicerContext): - route_info = dict(context.invocation_metadata()).get(ROUTE_INFO_KEY) + route_info = dict(context.invocation_metadata()).get(ROUTE_INFO_HEADER_KEY) if route_info: - return True, route_info + return route_info elif isinstance(context, fastapi.Request): - route_info = context.headers.get(ROUTE_INFO_KEY) + route_info = context.headers.get(ROUTE_INFO_HEADER_KEY) if route_info: - return True, route_info + return route_info else: error.log_raise( "", ValueError(f"context is of an unsupported type: {type(context)}"), ) - return False, None + return None diff --git a/tests/toolkit/text_generation/test_tgis_utils.py b/tests/toolkit/text_generation/test_tgis_utils.py index beca3e9c..2e97a5e8 100644 --- a/tests/toolkit/text_generation/test_tgis_utils.py +++ b/tests/toolkit/text_generation/test_tgis_utils.py @@ -134,47 +134,42 @@ def test_TGISGenerationClient_rpc_errors(status_code, method): @pytest.mark.parametrize( - argnames=["context", "ok", "route_info"], + argnames=["context", "route_info"], argvalues=[ ( fastapi.Request( { "type": "http", - "headers": [(tgis_utils.ROUTE_INFO_KEY.encode(), b"sometext")], + "headers": [ + (tgis_utils.ROUTE_INFO_HEADER_KEY.encode(), b"sometext") + ], } ), - True, "sometext", ), ( fastapi.Request( {"type": "http", "headers": [(b"route-info", b"sometext")]} ), - False, None, ), ( - TestServicerContext({tgis_utils.ROUTE_INFO_KEY: "sometext"}), - True, + TestServicerContext({tgis_utils.ROUTE_INFO_HEADER_KEY: "sometext"}), "sometext", ), ( TestServicerContext({"route-info": "sometext"}), - False, None, ), - ("should raise ValueError", False, None), - (None, False, None), + ("should raise ValueError", None), + (None, None), # Uncertain how to create a grpc.ServicerContext object ], ) -def test_get_route_info( - context: RuntimeServerContextType, ok: bool, route_info: Optional[str] -): +def test_get_route_info(context: RuntimeServerContextType, route_info: Optional[str]): if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))): with pytest.raises(ValueError): tgis_utils.get_route_info(context) else: - actual_ok, actual_route_info = tgis_utils.get_route_info(context) - assert actual_ok == ok + actual_route_info = tgis_utils.get_route_info(context) assert actual_route_info == route_info From 527b4555adde51e9a36a6a8a17865adc0cf5f332 Mon Sep 17 00:00:00 2001 From: Mynhardt Burger Date: Fri, 7 Jun 2024 13:36:51 -0400 Subject: [PATCH 25/29] remove unreachable code Signed-off-by: Mynhardt Burger --- caikit_nlp/toolkit/text_generation/tgis_utils.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index e09aac0f..627bea1e 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -716,5 +716,3 @@ def get_route_info( "", ValueError(f"context is of an unsupported type: {type(context)}"), ) - - return None From cbc6b33d88bd5b7484230fe367885b8c50a1f92b Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 18 Jun 2024 17:02:00 -0600 Subject: [PATCH 26/29] RouteInfoFromBackend: Forward get_route_info and ROUTE_INFO_HEADER_KEY from backend Signed-off-by: Gabe Goodhart --- .../text_generation/peft_tgis_remote.py | 13 +------- .../text_generation/text_generation_tgis.py | 11 +------ .../toolkit/text_generation/tgis_utils.py | 33 +++---------------- .../text_generation/test_tgis_utils.py | 5 ++- 4 files changed, 10 insertions(+), 52 deletions(-) diff --git a/caikit_nlp/modules/text_generation/peft_tgis_remote.py b/caikit_nlp/modules/text_generation/peft_tgis_remote.py index 67921bd3..c6067e15 100644 --- a/caikit_nlp/modules/text_generation/peft_tgis_remote.py +++ b/caikit_nlp/modules/text_generation/peft_tgis_remote.py @@ -43,7 +43,6 @@ from ...toolkit.text_generation.tgis_utils import ( GENERATE_FUNCTION_TGIS_ARGS, TGISGenerationClient, - get_route_info, ) from ...toolkit.verbalizer_utils import render_verbalizer from . import PeftPromptTuning @@ -354,15 +353,5 @@ def _register_model_connection_with_context( a context override provided. """ if self._tgis_backend: - if route_info := get_route_info(context): - log.debug( - " Registering remote model connection with context " - "override: 'hostname: %s'", - route_info, - ) - self._tgis_backend.register_model_connection( - self.base_model_name, - {"hostname": route_info}, - fill_with_defaults=True, - ) + self._tgis_backend.handle_runtime_context(self.base_model_name, context) self._model_loaded = True diff --git a/caikit_nlp/modules/text_generation/text_generation_tgis.py b/caikit_nlp/modules/text_generation/text_generation_tgis.py index b1c344ba..2ee32a7b 100644 --- a/caikit_nlp/modules/text_generation/text_generation_tgis.py +++ b/caikit_nlp/modules/text_generation/text_generation_tgis.py @@ -45,7 +45,6 @@ from ...toolkit.text_generation.tgis_utils import ( GENERATE_FUNCTION_TGIS_ARGS, TGISGenerationClient, - get_route_info, ) from .text_generation_local import TextGeneration @@ -362,13 +361,5 @@ def _register_model_connection_with_context( a context override provided. """ if self._tgis_backend: - if route_info := get_route_info(context): - log.debug( - " Registering remote model connection with context " - "override: 'hostname: %s'", - route_info, - ) - self._tgis_backend.register_model_connection( - self.model_name, {"hostname": route_info}, fill_with_defaults=True - ) + self._tgis_backend.handle_runtime_context(self.model_name, context) self._model_loaded = True diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index 627bea1e..a8e55c2b 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -35,6 +35,7 @@ TokenStreamDetails, ) from caikit.interfaces.runtime.data_model import RuntimeServerContextType +from caikit_tgis_backend import TGISBackend from caikit_tgis_backend.protobufs import generation_pb2 import alog @@ -87,7 +88,9 @@ } # HTTP Header / gRPC Metadata key used to identify a route override -ROUTE_INFO_HEADER_KEY = "x-route-info" +# (forwarded for API compatibility) +ROUTE_INFO_HEADER_KEY = TGISBackend.ROUTE_INFO_HEADER_KEY +get_route_info = TGISBackend.get_route_info def raise_caikit_core_exception(rpc_error: grpc.RpcError): @@ -688,31 +691,3 @@ def unary_tokenize( return TokenizationResults( token_count=response.token_count, ) - - -def get_route_info( - context: Optional[RuntimeServerContextType], -) -> Optional[str]: - """ - Returns a tuple `(True, x-route-info)` from context if "x-route-info" was found in - the headers/metadata. - - Otherwise returns a tuple `(False, None)` if "x-route-info" was not found in the - context or if context is None. - """ - if context is None: - return None - - if isinstance(context, grpc.ServicerContext): - route_info = dict(context.invocation_metadata()).get(ROUTE_INFO_HEADER_KEY) - if route_info: - return route_info - elif isinstance(context, fastapi.Request): - route_info = context.headers.get(ROUTE_INFO_HEADER_KEY) - if route_info: - return route_info - else: - error.log_raise( - "", - ValueError(f"context is of an unsupported type: {type(context)}"), - ) diff --git a/tests/toolkit/text_generation/test_tgis_utils.py b/tests/toolkit/text_generation/test_tgis_utils.py index 2e97a5e8..696a9acf 100644 --- a/tests/toolkit/text_generation/test_tgis_utils.py +++ b/tests/toolkit/text_generation/test_tgis_utils.py @@ -133,6 +133,9 @@ def test_TGISGenerationClient_rpc_errors(status_code, method): assert isinstance(rpc_err, grpc.RpcError) +# NOTE: This test is preserved in caikit-nlp despite being duplicated in +# caikit-tgis-backend so that we guarantee that the functionality is accessible +# in a version-compatible way here. @pytest.mark.parametrize( argnames=["context", "route_info"], argvalues=[ @@ -168,7 +171,7 @@ def test_TGISGenerationClient_rpc_errors(status_code, method): ) def test_get_route_info(context: RuntimeServerContextType, route_info: Optional[str]): if not isinstance(context, (fastapi.Request, grpc.ServicerContext, type(None))): - with pytest.raises(ValueError): + with pytest.raises(TypeError): tgis_utils.get_route_info(context) else: actual_route_info = tgis_utils.get_route_info(context) From 1818a306a28d7561a3da47af53d156a25f1a8c32 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 19 Jun 2024 15:29:38 -0600 Subject: [PATCH 27/29] RouteInfoFromBackend: Bump caikit-tgis-backend Signed-off-by: Gabe Goodhart --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4d0d7e44..872d4c3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ classifiers=[ ] dependencies = [ "caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0", - "caikit-tgis-backend>=0.1.33,<0.2.0", + "caikit-tgis-backend>=0.1.34,<0.2.0", # TODO: loosen dependencies "grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking "grpcio-reflection>=1.62.2", From 0e23e10bfbbf3b4e3ed99b89459ba15b4f881e06 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 19 Jun 2024 15:37:05 -0600 Subject: [PATCH 28/29] RouteInfoFromBackend: Bump caikit for context registration in backend Signed-off-by: Gabe Goodhart --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 872d4c3e..14be273f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers=[ "License :: OSI Approved :: Apache Software License" ] dependencies = [ - "caikit[runtime-grpc,runtime-http]>=0.26.27,<0.27.0", + "caikit[runtime-grpc,runtime-http]>=0.26.34,<0.27.0", "caikit-tgis-backend>=0.1.34,<0.2.0", # TODO: loosen dependencies "grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking From ff7f05682e56db424c048138d3d876e7ea7bbb30 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Wed, 19 Jun 2024 15:39:44 -0600 Subject: [PATCH 29/29] RouteInfoFromBackend: Remove unused imports Signed-off-by: Gabe Goodhart --- caikit_nlp/toolkit/text_generation/tgis_utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/caikit_nlp/toolkit/text_generation/tgis_utils.py b/caikit_nlp/toolkit/text_generation/tgis_utils.py index a8e55c2b..392455c0 100644 --- a/caikit_nlp/toolkit/text_generation/tgis_utils.py +++ b/caikit_nlp/toolkit/text_generation/tgis_utils.py @@ -14,10 +14,9 @@ """This file is for helper functions related to TGIS.""" # Standard -from typing import Iterable, Optional +from typing import Iterable # Third Party -import fastapi import grpc # First Party @@ -34,7 +33,6 @@ TokenizationResults, TokenStreamDetails, ) -from caikit.interfaces.runtime.data_model import RuntimeServerContextType from caikit_tgis_backend import TGISBackend from caikit_tgis_backend.protobufs import generation_pb2 import alog