Skip to content

Commit

Permalink
fix: minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik committed Jan 9, 2025
1 parent b5f1845 commit 3e6a221
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 44 deletions.
28 changes: 14 additions & 14 deletions aidial_adapter_bedrock/llm/consumer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

class Consumer(ContextManager["Consumer"], ABC):
@abstractmethod
def set_tools_emulator(self, emulator: ToolsEmulator):
def set_tools_emulator(self, tools_emulator: ToolsEmulator):
pass

@abstractmethod
Expand All @@ -49,20 +49,20 @@ def get_usage(self) -> TokenUsage:

@abstractmethod
def set_discarded_messages(
self, discarded_messages: DiscardedMessages | None
self, discarded_messages: Optional[DiscardedMessages]
):
pass

@abstractmethod
def get_discarded_messages(self) -> DiscardedMessages | None:
def get_discarded_messages(self) -> Optional[DiscardedMessages]:
pass

@abstractmethod
def create_function_tool_call(self, call: ToolCall):
def create_function_tool_call(self, tool_call: ToolCall):
pass

@abstractmethod
def create_function_call(self, call: FunctionCall):
def create_function_call(self, function_call: FunctionCall):
pass

@property
Expand Down Expand Up @@ -124,8 +124,8 @@ def __exit__(
self._choice.close()
return False

def set_tools_emulator(self, emulator: ToolsEmulator):
self.tools_emulator = emulator
def set_tools_emulator(self, tools_emulator: ToolsEmulator):
self.tools_emulator = tools_emulator

def _process_content(
self, content: str | None, finish_reason: FinishReason | None = None
Expand Down Expand Up @@ -176,17 +176,17 @@ def set_discarded_messages(
):
self.discarded_messages = discarded_messages

def create_function_tool_call(self, call: ToolCall):
def create_function_tool_call(self, tool_call: ToolCall):
self.choice.create_function_tool_call(
id=call.id,
name=call.function.name,
arguments=call.function.arguments,
id=tool_call.id,
name=tool_call.function.name,
arguments=tool_call.function.arguments,
)

def create_function_call(self, call: FunctionCall):
def create_function_call(self, function_call: FunctionCall):
self.choice.create_function_call(
name=call.name,
arguments=call.arguments,
name=function_call.name,
arguments=function_call.arguments,
)

@property
Expand Down
19 changes: 9 additions & 10 deletions aidial_adapter_bedrock/llm/converse/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,13 @@ async def create(
if tools_support == ToolsSupport.NON_STREAMING_ONLY
else ConverseAdapter
)
return replicator_decorator()(
cls(
deployment=self.deployment,
bedrock=await Bedrock.acreate(self.aws_client_config),
storage=create_file_storage(self.api_key),
input_tokenizer_factory=default_converse_tokenizer_factory,
support_tools=tools_support != ToolsSupport.NONE,
supported_image_types=supported_image_types or [],
supported_document_types=supported_document_types or [],
)
model = cls(
deployment=self.deployment,
bedrock=await Bedrock.acreate(self.aws_client_config),
storage=create_file_storage(self.api_key),
input_tokenizer_factory=default_converse_tokenizer_factory,
support_tools=tools_support != ToolsSupport.NONE,
supported_image_types=supported_image_types or [],
supported_document_types=supported_document_types or [],
)
return replicator_decorator()(model)
8 changes: 4 additions & 4 deletions aidial_adapter_bedrock/llm/converse/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ async def process_streaming(
match params.tools_mode:
case ToolsMode.TOOLS:
consumer.create_function_tool_call(
call=DialToolCall(
tool_call=DialToolCall(
type="function",
id=current_tool_use["toolUseId"],
index=None,
Expand All @@ -75,7 +75,7 @@ async def process_streaming(
# ignoring multiple function calls in one response
if not consumer.has_function_call:
consumer.create_function_call(
call=DialFunctionCall(
function_call=DialFunctionCall(
name=current_tool_use["name"],
arguments=current_tool_use["input"],
)
Expand Down Expand Up @@ -107,7 +107,7 @@ def process_non_streaming(
match params.tools_mode:
case ToolsMode.TOOLS:
consumer.create_function_tool_call(
call=DialToolCall(
tool_call=DialToolCall(
type="function",
id=content_block["toolUse"]["toolUseId"],
index=None,
Expand All @@ -123,7 +123,7 @@ def process_non_streaming(
# ignoring multiple function calls in one response
if not consumer.has_function_call:
consumer.create_function_call(
call=DialFunctionCall(
function_call=DialFunctionCall(
name=content_block["toolUse"]["name"],
arguments=json.dumps(
content_block["toolUse"]["input"]
Expand Down
30 changes: 14 additions & 16 deletions aidial_adapter_bedrock/llm/model/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,25 +105,23 @@ async def get_bedrock_adapter(
ChatCompletionDeployment.STABILITY_STABLE_IMAGE_CORE_V1
| ChatCompletionDeployment.STABILITY_STABLE_IMAGE_ULTRA_V1
):
return replicator_decorator()(
StabilityV2Adapter.create(
await Bedrock.acreate(aws_client_config),
model,
api_key,
image_to_image_supported=False,
)
model = StabilityV2Adapter.create(
await Bedrock.acreate(aws_client_config),
model,
api_key,
image_to_image_supported=False,
)
return replicator_decorator()(model)
case ChatCompletionDeployment.STABILITY_STABLE_DIFFUSION_3_LARGE_V1:
return replicator_decorator()(
StabilityV2Adapter.create(
await Bedrock.acreate(aws_client_config),
model,
api_key,
image_to_image_supported=True,
image_width_constraints=(640, 1536),
image_height_constraints=(640, 1536),
)
model = StabilityV2Adapter.create(
await Bedrock.acreate(aws_client_config),
model,
api_key,
image_to_image_supported=True,
image_width_constraints=(640, 1536),
image_height_constraints=(640, 1536),
)
return replicator_decorator()(model)
case ChatCompletionDeployment.AMAZON_TITAN_TG1_LARGE:
return amazon.create_adapter(
await Bedrock.acreate(aws_client_config), model
Expand Down

0 comments on commit 3e6a221

Please sign in to comment.