diff --git a/private_gpt/components/llm/custom/sagemaker.py b/private_gpt/components/llm/custom/sagemaker.py index 9286de6166..08f56a955e 100644 --- a/private_gpt/components/llm/custom/sagemaker.py +++ b/private_gpt/components/llm/custom/sagemaker.py @@ -21,8 +21,6 @@ ) if TYPE_CHECKING: - from collections.abc import Callable - from llama_index.callbacks import CallbackManager from llama_index.llms import ( CompletionResponseGen, @@ -113,10 +111,10 @@ class SagemakerLLM(CustomLLM): context_window: int = Field( description="The maximum number of context tokens for the model." ) - messages_to_prompt: Callable[..., str] = Field( + messages_to_prompt: Any = Field( description="The function to convert messages to a prompt.", exclude=True ) - completion_to_prompt: Callable[..., str] = Field( + completion_to_prompt: Any = Field( description="The function to convert a completion to a prompt.", exclude=True ) generate_kwargs: dict[str, Any] = Field( diff --git a/private_gpt/components/llm/llm_component.py b/private_gpt/components/llm/llm_component.py index 2c32897c60..cbd71ce1f5 100644 --- a/private_gpt/components/llm/llm_component.py +++ b/private_gpt/components/llm/llm_component.py @@ -37,6 +37,8 @@ def __init__(self) -> None: self.llm = SagemakerLLM( endpoint_name=settings.sagemaker.endpoint_name, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, ) case "openai": from llama_index.llms import OpenAI