Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
riprasad committed Aug 3, 2024
2 parents d0a1238 + 79ac58e commit 0c4b65c
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 17 deletions.
16 changes: 8 additions & 8 deletions caikit_nlp/modules/text_generation/peft_tgis_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,10 +213,10 @@ def run(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand Down Expand Up @@ -280,10 +280,10 @@ def run_stream_out(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing against the model running in TGIS
Expand Down
16 changes: 8 additions & 8 deletions caikit_nlp/modules/text_generation/text_generation_tgis.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,10 @@ def run(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> GeneratedTextResult:
f"""Run inference against the model running in TGIS.
Expand Down Expand Up @@ -296,10 +296,10 @@ def run_stream_out(
seed: Optional[np.uint64] = None,
preserve_input_text: bool = False,
input_tokens: bool = False,
generated_tokens: bool = True,
token_logprobs: bool = True,
token_ranks: bool = True,
include_stop_sequence: Optional[bool] = True,
generated_tokens: bool = False,
token_logprobs: bool = False,
token_ranks: bool = False,
include_stop_sequence: Optional[bool] = None,
context: Optional[RuntimeServerContextType] = None,
) -> Iterable[GeneratedTextStreamResult]:
f"""Run output stream inferencing for text generation module.
Expand Down
22 changes: 21 additions & 1 deletion caikit_nlp/toolkit/text_generation/tgis_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
# HTTP Header / gRPC Metadata key used to identify a route override
# (forwarded for API compatibility)
ROUTE_INFO_HEADER_KEY = TGISBackend.ROUTE_INFO_HEADER_KEY
INACTIVE_RPC_CONN_ERR_MESSAGE = "The underlying TCP connection is closed"
get_route_info = TGISBackend.get_route_info


Expand Down Expand Up @@ -144,7 +145,10 @@ def validate_inf_params(
error.type_check("<NLP65883540E>", bool, token_logprobs=token_logprobs)
error.type_check("<NLP65883541E>", bool, token_ranks=token_ranks)
error.type_check(
"<NLP65883542E>", bool, include_stop_sequence=include_stop_sequence
"<NLP65883542E>",
bool,
allow_none=True,
include_stop_sequence=include_stop_sequence,
)
error.type_check("<NLP85452188E>", str, allow_none=True, eos_token=eos_token)
error.type_check(
Expand Down Expand Up @@ -471,6 +475,14 @@ def unary_generate(
batch_response = self.tgis_client.Generate(
request, timeout=self.tgis_req_timeout
)
except grpc._channel._InactiveRpcError as err:
log.error("<NLP30829218E>", err.details)
caikit_status_code = GRPC_TO_CAIKIT_CORE_STATUS.get(
err.code(), CaikitCoreStatusCode.UNKNOWN
)
raise CaikitCoreException(
caikit_status_code, INACTIVE_RPC_CONN_ERR_MESSAGE
) from err
except grpc.RpcError as err:
raise_caikit_core_exception(err)

Expand Down Expand Up @@ -650,6 +662,14 @@ def stream_generate(
input_tokens=input_token_list,
details=details,
)
except grpc._channel._InactiveRpcError as err:
log.error("<NLP11829118E>", err.details)
caikit_status_code = GRPC_TO_CAIKIT_CORE_STATUS.get(
err.code(), CaikitCoreStatusCode.UNKNOWN
)
raise CaikitCoreException(
caikit_status_code, INACTIVE_RPC_CONN_ERR_MESSAGE
) from err
except grpc.RpcError as err:
raise_caikit_core_exception(err)

Expand Down

0 comments on commit 0c4b65c

Please sign in to comment.