Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sync release with main #27

Merged
merged 35 commits into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
6ecc54a
:thread: Add timeout configuration for TGIS streaming request as an e…
gkumbhat May 16, 2024
40fb403
:sparkles: Add tgis req timeout as configurable parameter
gkumbhat May 17, 2024
48216eb
:white_check_mark: Fix tgis client fixture for acceting kwargs
gkumbhat May 17, 2024
401510d
:art: Fix formatting
gkumbhat May 17, 2024
fc5ebda
:bug::white_check_mark: Fix fixture for tgis tests
gkumbhat May 17, 2024
98578c9
Merge pull request #358 from caikit/add_tgis_timeout
gkumbhat May 17, 2024
2dd91d2
Expose model information for embeddings service
flaviabeo May 3, 2024
756b1e1
Bump lower caikit version
flaviabeo May 29, 2024
75cd9ea
Merge pull request #353 from flaviabeo/embeddings_model_info
gkumbhat May 31, 2024
7e328c7
added logging around tgis timout config setting
Jun 4, 2024
f4496b9
Update caikit_nlp/toolkit/text_generation/tgis_utils.py
swith004 Jun 4, 2024
7211497
Update caikit_nlp/toolkit/text_generation/tgis_utils.py
swith004 Jun 4, 2024
b710556
fixed formatting
Jun 4, 2024
4bf53fd
Merge pull request #361 from swith004/server-logging-931
gkumbhat Jun 5, 2024
e278915
add get_route_info
mynhardtburger Jun 6, 2024
e473c33
lazily create model_connection and _client
mynhardtburger Jun 6, 2024
5578b7a
lazy load model_connection and tgis client for peft
mynhardtburger Jun 6, 2024
71371bd
remove commented out code
mynhardtburger Jun 6, 2024
21192c0
Address review comments
mynhardtburger Jun 7, 2024
1c23527
Expand test_get_route_info
mynhardtburger Jun 7, 2024
a4e8539
Lazily create generation client
mynhardtburger Jun 7, 2024
be72a79
Update minimum caikit-tgis-backend version
mynhardtburger Jun 7, 2024
d5893d9
Add debug logs
mynhardtburger Jun 7, 2024
9848943
Linting
mynhardtburger Jun 7, 2024
5df0533
linting
mynhardtburger Jun 7, 2024
6bae66a
Update caikit_nlp/toolkit/text_generation/tgis_utils.py
mynhardtburger Jun 7, 2024
46ae073
review comments
mynhardtburger Jun 7, 2024
527b455
remove unreachable code
mynhardtburger Jun 7, 2024
3b2f8fd
Merge pull request #363 from mynhardtburger/lazy-model-connection
gabe-l-hart Jun 7, 2024
cbc6b33
RouteInfoFromBackend: Forward get_route_info and ROUTE_INFO_HEADER_KE…
gabe-l-hart Jun 18, 2024
1818a30
RouteInfoFromBackend: Bump caikit-tgis-backend
gabe-l-hart Jun 19, 2024
0e23e10
RouteInfoFromBackend: Bump caikit for context registration in backend
gabe-l-hart Jun 19, 2024
ff7f056
RouteInfoFromBackend: Remove unused imports
gabe-l-hart Jun 19, 2024
4ce8435
Merge pull request #364 from gabe-l-hart/RouteInfoFromBackend
gabe-l-hart Jun 19, 2024
0719637
sync with upstream @ v0.4.15
dtrifiro Jun 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Standard
from collections.abc import Sized
from enum import Enum, auto
from typing import Callable, Dict, List, NamedTuple, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, TypeVar, Union
import importlib
import os
import time
Expand Down Expand Up @@ -178,6 +178,20 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":

return cls(model)

@property
def public_model_info(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument
"""Helper property to return public metadata about a specific Model. This
function is separate from `metdata` as that contains the entire ModelConfig
which might not want to be shared/exposed.

Returns:
Dict[str, str]: A dictionary of this models's public metadata
"""
return {
"max_seq_length": cls.model.max_seq_length,
"sentence_embedding_dimension": cls.model.get_sentence_embedding_dimension(),
}

@classmethod
def _get_ipex(cls, ipex_flag):
"""Get IPEX optimization library if enabled and available, else return False
Expand Down
48 changes: 40 additions & 8 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
"""This file contains a distributed backend implementation for leveraging the PEFT-trained
prompt vectors in TGIS generation requests.
"""

# Standard
from functools import cached_property
from typing import Iterable, List, Optional, Tuple, Union
import os

Expand All @@ -32,6 +34,7 @@
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend import TGISBackend
import alog

Expand Down Expand Up @@ -68,15 +71,14 @@ def __init__(
prompt_artifacts: Optional[List[str]] = None,
) -> None:
super().__init__()
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
self._client = None

self._tgis_backend = tgis_backend
if enable_backend:
error.type_check(
"<NLP33971947E>", TGISBackend, tgis_backend=self._tgis_backend
)
# get_client will also launch a local TGIS process and get the model
# loaded when using the local TGIS backend
self._client = tgis_backend.get_client(base_model_name)

# Tell the backend to load all of the available prompt files
if prompt_artifacts:
Expand Down Expand Up @@ -107,6 +109,14 @@ def __del__(self):
if tgis_backend and prompt_cache_id and model_id:
tgis_backend.unload_prompt_artifacts(model_id, prompt_cache_id)

@cached_property
def _client(self):
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
if self._tgis_backend:
return self._tgis_backend.get_client(self.base_model_name)

@classmethod
def load(cls, model_path: str, load_backend: BackendBase) -> "PeftPromptTuningTGIS":
"""Load a TGIS Peft Prompt Tuning distributed module. Note that we do not
Expand Down Expand Up @@ -182,7 +192,7 @@ def save(self, model_path: str):
)

# pylint: disable=duplicate-code
@TextGenerationTask.taskmethod()
@TextGenerationTask.taskmethod(context_arg="context")
def run(
self,
text: str,
Expand All @@ -206,6 +216,7 @@ def run(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand All @@ -221,6 +232,8 @@ def run(
self.enable_backend,
"Backend must be configured and loaded with this module before executing `run` call.",
)
self._register_model_connection_with_context(context)

verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.unary_generate(
text=verbalized_text,
Expand All @@ -244,7 +257,7 @@ def run(
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
@TextGenerationTask.taskmethod(output_streaming=True, context_arg="context")
def run_stream_out(
self,
text: str,
Expand All @@ -268,6 +281,7 @@ def run_stream_out(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing against the model running in TGIS
Expand All @@ -283,6 +297,9 @@ def run_stream_out(
"Backend must be configured and loaded with this module \
before executing `run_stream_out` call.",
)

self._register_model_connection_with_context(context)

verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.stream_generate(
text=verbalized_text,
Expand All @@ -306,10 +323,11 @@ def run_stream_out(
stop_sequences=stop_sequences,
)

@TokenizationTask.taskmethod()
@TokenizationTask.taskmethod(context_arg="context")
def run_tokenizer(
self,
text: str,
context: Optional[RuntimeServerContextType] = None,
) -> TokenizationResults:
"""Run tokenization task against the model running in TGIS.
Expand All @@ -320,6 +338,20 @@ def run_tokenizer(
TokenizationResults
The token count
"""

self._register_model_connection_with_context(context)

return self.tgis_generation_client.unary_tokenize(
text=text,
)

def _register_model_connection_with_context(
self, context: Optional[RuntimeServerContextType]
):
"""
Register a remote model connection with the configured TGISBackend if there is
a context override provided.
"""
if self._tgis_backend:
self._tgis_backend.handle_runtime_context(self.base_model_name, context)
self._model_loaded = True
56 changes: 41 additions & 15 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


# Standard
from functools import cached_property
from typing import Iterable, List, Optional, Tuple, Union
import os

Expand All @@ -30,6 +31,7 @@
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend import TGISBackend
import alog

Expand Down Expand Up @@ -86,28 +88,33 @@ def __init__(
# Set _model_loaded as False by default. This will only get set to True if
# we enable the tgis_backend and we are able to fetch the client successfully.
self._model_loaded = False
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
self._client = None
if tgis_backend:
self._client = tgis_backend.get_client(model_name)
# mark that the model is loaded so that we can unload it later
self._model_loaded = True
self.tgis_backend = tgis_backend

self._tgis_backend = tgis_backend
self._bos_token = bos_token
self._sep_token = sep_token
self._eos_token = eos_token
self._pad_token = pad_token
self.tgis_generation_client = TGISGenerationClient(
self.model_name, self._eos_token, self._client, self.PRODUCER_ID
)

def __del__(self):
# nothing to unload if we didn't finish loading
if self._model_loaded and self.tgis_backend:
self.tgis_backend.unload_model(self.model_name)
if self._model_loaded and self._tgis_backend:
self._tgis_backend.unload_model(self.model_name)

@cached_property
def _client(self):
# Lazily configure/create the internal tgis backend client
if self._tgis_backend:
return self._tgis_backend.get_client(self.model_name)

@cached_property
def tgis_generation_client(self):
# Lazily create the generation client
# This in turn calls self._client which also lazily gets the tgis backend client
return TGISGenerationClient(
self.model_name, self._eos_token, self._client, self.PRODUCER_ID
)

@classmethod
def bootstrap(cls, model_path: str, load_backend: Union[BackendBase, None] = None):
Expand Down Expand Up @@ -207,7 +214,7 @@ def save(self, model_path: str):
)

# pylint: disable=duplicate-code
@TextGenerationTask.taskmethod()
@TextGenerationTask.taskmethod(context_arg="context")
def run(
self,
text: str,
Expand All @@ -231,6 +238,7 @@ def run(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand All @@ -240,6 +248,8 @@ def run(
GeneratedTextResult
Generated text result produced by TGIS.
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.unary_generate(
text=text,
Expand All @@ -263,7 +273,7 @@ def run(
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
@TextGenerationTask.taskmethod(output_streaming=True, context_arg="context")
def run_stream_out(
self,
text: str,
Expand All @@ -287,6 +297,7 @@ def run_stream_out(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing for text generation module.
Expand All @@ -295,6 +306,7 @@ def run_stream_out(
Returns:
Iterable[GeneratedTextStreamResult]
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.stream_generate(
Expand All @@ -319,10 +331,11 @@ def run_stream_out(
stop_sequences=stop_sequences,
)

@TokenizationTask.taskmethod()
@TokenizationTask.taskmethod(context_arg="context")
def run_tokenizer(
self,
text: str,
context: Optional[RuntimeServerContextType] = None,
) -> TokenizationResults:
"""Run tokenization task against the model running in TGIS.
Expand All @@ -333,7 +346,20 @@ def run_tokenizer(
TokenizationResults
The token count
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.unary_tokenize(
text=text,
)

def _register_model_connection_with_context(
self, context: Optional[RuntimeServerContextType]
):
"""
Register a remote model connection with the configured TGISBackend if there is
a context override provided.
"""
if self._tgis_backend:
self._tgis_backend.handle_runtime_context(self.model_name, context)
self._model_loaded = True
25 changes: 23 additions & 2 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is for helper functions related to TGIS.
"""
"""This file is for helper functions related to TGIS."""

# Standard
from typing import Iterable

Expand All @@ -33,6 +33,7 @@
TokenizationResults,
TokenStreamDetails,
)
from caikit_tgis_backend import TGISBackend
from caikit_tgis_backend.protobufs import generation_pb2
import alog

Expand Down Expand Up @@ -84,6 +85,11 @@
grpc.StatusCode.UNAUTHENTICATED: CaikitCoreStatusCode.UNAUTHORIZED,
}

# HTTP Header / gRPC Metadata key used to identify a route override
# (forwarded for API compatibility)
ROUTE_INFO_HEADER_KEY = TGISBackend.ROUTE_INFO_HEADER_KEY
get_route_info = TGISBackend.get_route_info


def raise_caikit_core_exception(rpc_error: grpc.RpcError):
"""Helper to wrap logic of converting from grpc.RpcError ->
Expand Down Expand Up @@ -329,6 +335,21 @@ def __init__(

self.tgis_req_timeout = get_config().tgis_request_timeout

if (
not self.tgis_req_timeout
or not isinstance(self.tgis_req_timeout, int)
or self.tgis_req_timeout <= 0
):
log.debug("<RUN57106697I>", "TGIS timeout not set")
self.tgis_req_timeout = None

else:
log.debug(
"<RUN57106696T>",
"Setting TGIS timeout value to %d",
self.tgis_req_timeout,
)

def unary_generate(
self,
text,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.26.17,<0.27.0",
"caikit-tgis-backend>=0.1.27,<0.2.0",
"caikit[runtime-grpc,runtime-http]>=0.26.34,<0.27.0",
"caikit-tgis-backend>=0.1.34,<0.2.0",
# TODO: loosen dependencies
"grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking
"grpcio-reflection>=1.62.2",
Expand Down
Loading
Loading