diff --git a/aidial_adapter_bedrock/llm/consumer.py b/aidial_adapter_bedrock/llm/consumer.py index d535b9b..eacb6f9 100644 --- a/aidial_adapter_bedrock/llm/consumer.py +++ b/aidial_adapter_bedrock/llm/consumer.py @@ -24,7 +24,7 @@ class Consumer(ContextManager["Consumer"], ABC): @abstractmethod - def set_tools_emulator(self, emulator: ToolsEmulator): + def set_tools_emulator(self, tools_emulator: ToolsEmulator): pass @abstractmethod @@ -49,20 +49,20 @@ def get_usage(self) -> TokenUsage: @abstractmethod def set_discarded_messages( - self, discarded_messages: DiscardedMessages | None + self, discarded_messages: Optional[DiscardedMessages] ): pass @abstractmethod - def get_discarded_messages(self) -> DiscardedMessages | None: + def get_discarded_messages(self) -> Optional[DiscardedMessages]: pass @abstractmethod - def create_function_tool_call(self, call: ToolCall): + def create_function_tool_call(self, tool_call: ToolCall): pass @abstractmethod - def create_function_call(self, call: FunctionCall): + def create_function_call(self, function_call: FunctionCall): pass @property @@ -124,8 +124,8 @@ def __exit__( self._choice.close() return False - def set_tools_emulator(self, emulator: ToolsEmulator): - self.tools_emulator = emulator + def set_tools_emulator(self, tools_emulator: ToolsEmulator): + self.tools_emulator = tools_emulator def _process_content( self, content: str | None, finish_reason: FinishReason | None = None @@ -176,17 +176,17 @@ def set_discarded_messages( ): self.discarded_messages = discarded_messages - def create_function_tool_call(self, call: ToolCall): + def create_function_tool_call(self, tool_call: ToolCall): self.choice.create_function_tool_call( - id=call.id, - name=call.function.name, - arguments=call.function.arguments, + id=tool_call.id, + name=tool_call.function.name, + arguments=tool_call.function.arguments, ) - def create_function_call(self, call: FunctionCall): + def create_function_call(self, function_call: FunctionCall): self.choice.create_function_call( - name=call.name, - arguments=call.arguments, + name=function_call.name, + arguments=function_call.arguments, ) @property diff --git a/aidial_adapter_bedrock/llm/converse/factory.py b/aidial_adapter_bedrock/llm/converse/factory.py index f0c62ee..e85ee8c 100644 --- a/aidial_adapter_bedrock/llm/converse/factory.py +++ b/aidial_adapter_bedrock/llm/converse/factory.py @@ -43,14 +43,13 @@ async def create( if tools_support == ToolsSupport.NON_STREAMING_ONLY else ConverseAdapter ) - return replicator_decorator()( - cls( - deployment=self.deployment, - bedrock=await Bedrock.acreate(self.aws_client_config), - storage=create_file_storage(self.api_key), - input_tokenizer_factory=default_converse_tokenizer_factory, - support_tools=tools_support != ToolsSupport.NONE, - supported_image_types=supported_image_types or [], - supported_document_types=supported_document_types or [], - ) + model = cls( + deployment=self.deployment, + bedrock=await Bedrock.acreate(self.aws_client_config), + storage=create_file_storage(self.api_key), + input_tokenizer_factory=default_converse_tokenizer_factory, + support_tools=tools_support != ToolsSupport.NONE, + supported_image_types=supported_image_types or [], + supported_document_types=supported_document_types or [], ) + return replicator_decorator()(model) diff --git a/aidial_adapter_bedrock/llm/converse/output.py b/aidial_adapter_bedrock/llm/converse/output.py index 5f78e1e..83d0781 100644 --- a/aidial_adapter_bedrock/llm/converse/output.py +++ b/aidial_adapter_bedrock/llm/converse/output.py @@ -61,7 +61,7 @@ async def process_streaming( match params.tools_mode: case ToolsMode.TOOLS: consumer.create_function_tool_call( - call=DialToolCall( + tool_call=DialToolCall( type="function", id=current_tool_use["toolUseId"], index=None, @@ -75,7 +75,7 @@ async def process_streaming( # ignoring multiple function calls in one response if not consumer.has_function_call: consumer.create_function_call( - call=DialFunctionCall( + function_call=DialFunctionCall( name=current_tool_use["name"], arguments=current_tool_use["input"], ) @@ -107,7 +107,7 @@ def process_non_streaming( match params.tools_mode: case ToolsMode.TOOLS: consumer.create_function_tool_call( - call=DialToolCall( + tool_call=DialToolCall( type="function", id=content_block["toolUse"]["toolUseId"], index=None, @@ -123,7 +123,7 @@ def process_non_streaming( # ignoring multiple function calls in one response if not consumer.has_function_call: consumer.create_function_call( - call=DialFunctionCall( + function_call=DialFunctionCall( name=content_block["toolUse"]["name"], arguments=json.dumps( content_block["toolUse"]["input"] diff --git a/aidial_adapter_bedrock/llm/model/adapter.py b/aidial_adapter_bedrock/llm/model/adapter.py index 96a652e..0d8a6a9 100644 --- a/aidial_adapter_bedrock/llm/model/adapter.py +++ b/aidial_adapter_bedrock/llm/model/adapter.py @@ -105,25 +105,23 @@ async def get_bedrock_adapter( ChatCompletionDeployment.STABILITY_STABLE_IMAGE_CORE_V1 | ChatCompletionDeployment.STABILITY_STABLE_IMAGE_ULTRA_V1 ): - return replicator_decorator()( - StabilityV2Adapter.create( - await Bedrock.acreate(aws_client_config), - model, - api_key, - image_to_image_supported=False, - ) + model = StabilityV2Adapter.create( + await Bedrock.acreate(aws_client_config), + model, + api_key, + image_to_image_supported=False, ) + return replicator_decorator()(model) case ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_3_LARGE_V1: - return replicator_decorator()( - StabilityV2Adapter.create( - await Bedrock.acreate(aws_client_config), - model, - api_key, - image_to_image_supported=True, - image_width_constraints=(640, 1536), - image_height_constraints=(640, 1536), - ) + model = StabilityV2Adapter.create( + await Bedrock.acreate(aws_client_config), + model, + api_key, + image_to_image_supported=True, + image_width_constraints=(640, 1536), + image_height_constraints=(640, 1536), ) + return replicator_decorator()(model) case ChatCompletionDeployment.AMAZON_TITAN_TG1_LARGE: return amazon.create_adapter( await Bedrock.acreate(aws_client_config), model