Skip to content

Commit

Permalink
Merge pull request #26 from opendatahub-io/sync-with-upstream @ v0.4.15
Browse files Browse the repository at this point in the history
sync with upstream @ v0.4.15
  • Loading branch information
dtrifiro authored Jun 21, 2024
2 parents 3fea23c + 0719637 commit dc75372
Show file tree
Hide file tree
Showing 9 changed files with 225 additions and 34 deletions.
16 changes: 15 additions & 1 deletion caikit_nlp/modules/text_embedding/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# Standard
from collections.abc import Sized
from enum import Enum, auto
from typing import Callable, Dict, List, NamedTuple, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, TypeVar, Union
import importlib
import os
import time
Expand Down Expand Up @@ -178,6 +178,20 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule":

return cls(model)

@property
def public_model_info(cls) -> Dict[str, Any]: # pylint: disable=no-self-argument
"""Helper property to return public metadata about a specific Model. This
function is separate from `metdata` as that contains the entire ModelConfig
which might not want to be shared/exposed.
Returns:
Dict[str, str]: A dictionary of this models's public metadata
"""
return {
"max_seq_length": cls.model.max_seq_length,
"sentence_embedding_dimension": cls.model.get_sentence_embedding_dimension(),
}

@classmethod
def _get_ipex(cls, ipex_flag):
"""Get IPEX optimization library if enabled and available, else return False
Expand Down
48 changes: 40 additions & 8 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
"""This file contains a distributed backend implementation for leveraging the PEFT-trained
prompt vectors in TGIS generation requests.
"""

# Standard
from functools import cached_property
from typing import Iterable, List, Optional, Tuple, Union
import os

Expand All @@ -32,6 +34,7 @@
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend import TGISBackend
import alog

Expand Down Expand Up @@ -68,15 +71,14 @@ def __init__(
prompt_artifacts: Optional[List[str]] = None,
) -> None:
super().__init__()
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
self._client = None

self._tgis_backend = tgis_backend
if enable_backend:
error.type_check(
"<NLP33971947E>", TGISBackend, tgis_backend=self._tgis_backend
)
# get_client will also launch a local TGIS process and get the model
# loaded when using the local TGIS backend
self._client = tgis_backend.get_client(base_model_name)

# Tell the backend to load all of the available prompt files
if prompt_artifacts:
Expand Down Expand Up @@ -107,6 +109,14 @@ def __del__(self):
if tgis_backend and prompt_cache_id and model_id:
tgis_backend.unload_prompt_artifacts(model_id, prompt_cache_id)

@cached_property
def _client(self):
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
if self._tgis_backend:
return self._tgis_backend.get_client(self.base_model_name)

@classmethod
def load(cls, model_path: str, load_backend: BackendBase) -> "PeftPromptTuningTGIS":
"""Load a TGIS Peft Prompt Tuning distributed module. Note that we do not
Expand Down Expand Up @@ -182,7 +192,7 @@ def save(self, model_path: str):
)

# pylint: disable=duplicate-code
@TextGenerationTask.taskmethod()
@TextGenerationTask.taskmethod(context_arg="context")
def run(
self,
text: str,
Expand All @@ -206,6 +216,7 @@ def run(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand All @@ -221,6 +232,8 @@ def run(
self.enable_backend,
"Backend must be configured and loaded with this module before executing `run` call.",
)
self._register_model_connection_with_context(context)

verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.unary_generate(
text=verbalized_text,
Expand All @@ -244,7 +257,7 @@ def run(
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
@TextGenerationTask.taskmethod(output_streaming=True, context_arg="context")
def run_stream_out(
self,
text: str,
Expand All @@ -268,6 +281,7 @@ def run_stream_out(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing against the model running in TGIS
Expand All @@ -283,6 +297,9 @@ def run_stream_out(
"Backend must be configured and loaded with this module \
before executing `run_stream_out` call.",
)

self._register_model_connection_with_context(context)

verbalized_text = render_verbalizer(self.verbalizer, {"input": text})
return self.tgis_generation_client.stream_generate(
text=verbalized_text,
Expand All @@ -306,10 +323,11 @@ def run_stream_out(
stop_sequences=stop_sequences,
)

@TokenizationTask.taskmethod()
@TokenizationTask.taskmethod(context_arg="context")
def run_tokenizer(
self,
text: str,
context: Optional[RuntimeServerContextType] = None,
) -> TokenizationResults:
"""Run tokenization task against the model running in TGIS.
Expand All @@ -320,6 +338,20 @@ def run_tokenizer(
TokenizationResults
The token count
"""

self._register_model_connection_with_context(context)

return self.tgis_generation_client.unary_tokenize(
text=text,
)

def _register_model_connection_with_context(
self, context: Optional[RuntimeServerContextType]
):
"""
Register a remote model connection with the configured TGISBackend if there is
a context override provided.
"""
if self._tgis_backend:
self._tgis_backend.handle_runtime_context(self.base_model_name, context)
self._model_loaded = True
56 changes: 41 additions & 15 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@


# Standard
from functools import cached_property
from typing import Iterable, List, Optional, Tuple, Union
import os

Expand All @@ -30,6 +31,7 @@
TokenizationResults,
)
from caikit.interfaces.nlp.tasks import TextGenerationTask, TokenizationTask
from caikit.interfaces.runtime.data_model import RuntimeServerContextType
from caikit_tgis_backend import TGISBackend
import alog

Expand Down Expand Up @@ -86,28 +88,33 @@ def __init__(
# Set _model_loaded as False by default. This will only get set to True if
# we enable the tgis_backend and we are able to fetch the client successfully.
self._model_loaded = False
# Configure the internal client
# NOTE: This is made optional for the cases where we do not need to execute `.run` function
# for example, bootstrapping a model to caikit format and saving.
self._client = None
if tgis_backend:
self._client = tgis_backend.get_client(model_name)
# mark that the model is loaded so that we can unload it later
self._model_loaded = True
self.tgis_backend = tgis_backend

self._tgis_backend = tgis_backend
self._bos_token = bos_token
self._sep_token = sep_token
self._eos_token = eos_token
self._pad_token = pad_token
self.tgis_generation_client = TGISGenerationClient(
self.model_name, self._eos_token, self._client, self.PRODUCER_ID
)

def __del__(self):
# nothing to unload if we didn't finish loading
if self._model_loaded and self.tgis_backend:
self.tgis_backend.unload_model(self.model_name)
if self._model_loaded and self._tgis_backend:
self._tgis_backend.unload_model(self.model_name)

@cached_property
def _client(self):
# Lazily configure/create the internal tgis backend client
if self._tgis_backend:
return self._tgis_backend.get_client(self.model_name)

@cached_property
def tgis_generation_client(self):
# Lazily create the generation client
# This in turn calls self._client which also lazily gets the tgis backend client
return TGISGenerationClient(
self.model_name, self._eos_token, self._client, self.PRODUCER_ID
)

@classmethod
def bootstrap(cls, model_path: str, load_backend: Union[BackendBase, None] = None):
Expand Down Expand Up @@ -207,7 +214,7 @@ def save(self, model_path: str):
)

# pylint: disable=duplicate-code
@TextGenerationTask.taskmethod()
@TextGenerationTask.taskmethod(context_arg="context")
def run(
self,
text: str,
Expand All @@ -231,6 +238,7 @@ def run(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand All @@ -240,6 +248,8 @@ def run(
GeneratedTextResult
Generated text result produced by TGIS.
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.unary_generate(
text=text,
Expand All @@ -263,7 +273,7 @@ def run(
stop_sequences=stop_sequences,
)

@TextGenerationTask.taskmethod(output_streaming=True)
@TextGenerationTask.taskmethod(output_streaming=True, context_arg="context")
def run_stream_out(
self,
text: str,
Expand All @@ -287,6 +297,7 @@ def run_stream_out(
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing for text generation module.
Expand All @@ -295,6 +306,7 @@ def run_stream_out(
Returns:
Iterable[GeneratedTextStreamResult]
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.stream_generate(
Expand All @@ -319,10 +331,11 @@ def run_stream_out(
stop_sequences=stop_sequences,
)

@TokenizationTask.taskmethod()
@TokenizationTask.taskmethod(context_arg="context")
def run_tokenizer(
self,
text: str,
context: Optional[RuntimeServerContextType] = None,
) -> TokenizationResults:
"""Run tokenization task against the model running in TGIS.
Expand All @@ -333,7 +346,20 @@ def run_tokenizer(
TokenizationResults
The token count
"""
self._register_model_connection_with_context(context)

if self._model_loaded:
return self.tgis_generation_client.unary_tokenize(
text=text,
)

def _register_model_connection_with_context(
self, context: Optional[RuntimeServerContextType]
):
"""
Register a remote model connection with the configured TGISBackend if there is
a context override provided.
"""
if self._tgis_backend:
self._tgis_backend.handle_runtime_context(self.model_name, context)
self._model_loaded = True
25 changes: 23 additions & 2 deletions caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This file is for helper functions related to TGIS.
"""
"""This file is for helper functions related to TGIS."""

# Standard
from typing import Iterable

Expand All @@ -33,6 +33,7 @@
TokenizationResults,
TokenStreamDetails,
)
from caikit_tgis_backend import TGISBackend
from caikit_tgis_backend.protobufs import generation_pb2
import alog

Expand Down Expand Up @@ -84,6 +85,11 @@
grpc.StatusCode.UNAUTHENTICATED: CaikitCoreStatusCode.UNAUTHORIZED,
}

# HTTP Header / gRPC Metadata key used to identify a route override
# (forwarded for API compatibility)
ROUTE_INFO_HEADER_KEY = TGISBackend.ROUTE_INFO_HEADER_KEY
get_route_info = TGISBackend.get_route_info


def raise_caikit_core_exception(rpc_error: grpc.RpcError):
"""Helper to wrap logic of converting from grpc.RpcError ->
Expand Down Expand Up @@ -329,6 +335,21 @@ def __init__(

self.tgis_req_timeout = get_config().tgis_request_timeout

if (
not self.tgis_req_timeout
or not isinstance(self.tgis_req_timeout, int)
or self.tgis_req_timeout <= 0
):
log.debug("<RUN57106697I>", "TGIS timeout not set")
self.tgis_req_timeout = None

else:
log.debug(
"<RUN57106696T>",
"Setting TGIS timeout value to %d",
self.tgis_req_timeout,
)

def unary_generate(
self,
text,
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ classifiers=[
"License :: OSI Approved :: Apache Software License"
]
dependencies = [
"caikit[runtime-grpc,runtime-http]>=0.26.17,<0.27.0",
"caikit-tgis-backend>=0.1.27,<0.2.0",
"caikit[runtime-grpc,runtime-http]>=0.26.34,<0.27.0",
"caikit-tgis-backend>=0.1.34,<0.2.0",
# TODO: loosen dependencies
"grpcio>=1.62.2", # explicitly pin grpc dependencies to a recent version to avoid pip backtracking
"grpcio-reflection>=1.62.2",
Expand Down
Loading

0 comments on commit dc75372

Please sign in to comment.