From 3adf6a0dcd224527223ffdec30eed6ae363e6a14 Mon Sep 17 00:00:00 2001 From: nsmccandlish Date: Tue, 22 Oct 2024 17:11:15 -0700 Subject: [PATCH 1/4] enable prompt caching, move everything to be param shapes --- computer-use-demo/computer_use_demo/loop.py | 83 ++++++++++++++----- .../computer_use_demo/streamlit.py | 31 ++++--- .../computer_use_demo/tools/collection.py | 8 +- computer-use-demo/tests/streamlit_test.py | 5 +- 4 files changed, 87 insertions(+), 40 deletions(-) diff --git a/computer-use-demo/computer_use_demo/loop.py b/computer-use-demo/computer_use_demo/loop.py index bb959e4b..df0b4046 100644 --- a/computer-use-demo/computer_use_demo/loop.py +++ b/computer-use-demo/computer_use_demo/loop.py @@ -9,9 +9,6 @@ from typing import Any, cast from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse -from anthropic.types import ( - ToolResultBlockParam, -) from anthropic.types.beta import ( BetaContentBlock, BetaContentBlockParam, @@ -20,6 +17,10 @@ BetaMessageParam, BetaTextBlockParam, BetaToolResultBlockParam, + BetaToolUseBlockParam, + BetaTextBlock, + BetaToolUseBlock, + BetaCacheControlEphemeralParam, ) from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult @@ -68,7 +69,7 @@ async def sampling_loop( provider: APIProvider, system_prompt_suffix: str, messages: list[BetaMessageParam], - output_callback: Callable[[BetaContentBlock], None], + output_callback: Callable[[BetaContentBlockParam], None], tool_output_callback: Callable[[ToolResult, str], None], api_response_callback: Callable[[APIResponse[BetaMessage]], None], api_key: str, @@ -88,16 +89,27 @@ async def sampling_loop( ) while True: - if only_n_most_recent_images: - _maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images) + enable_prompt_caching = False + betas=["computer-use-2024-10-22"] + image_truncation_threshold = 10 if provider == APIProvider.ANTHROPIC: client = Anthropic(api_key=api_key) + enable_prompt_caching = True elif provider == APIProvider.VERTEX: client = AnthropicVertex() elif provider == APIProvider.BEDROCK: client = AnthropicBedrock() + + if enable_prompt_caching: + betas.append("prompt-caching-2024-07-31") + _inject_prompt_caching(messages) + # Is it ever worth it to bust the cache with prompt caching? + image_truncation_threshold = 50 + if only_n_most_recent_images: + _maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images, min_removal_threshold=image_truncation_threshold) + # Call the API # we use raw_response to provide debug information to streamlit. Your # implementation may be able call the SDK directly with: @@ -107,33 +119,35 @@ async def sampling_loop( messages=messages, model=model, system=system, - tools=tool_collection.to_params(), - betas=["computer-use-2024-10-22"], + tools=tool_collection.to_params(enable_prompt_caching=enable_prompt_caching), + betas=betas ) + + raw_response = cast(APIResponse[BetaMessage], raw_response) - api_response_callback(cast(APIResponse[BetaMessage], raw_response)) - + api_response_callback(raw_response) response = raw_response.parse() + response_params = _response_to_params(response) messages.append( { "role": "assistant", - "content": cast(list[BetaContentBlockParam], response.content), + "content": response_params, } ) tool_result_content: list[BetaToolResultBlockParam] = [] - for content_block in cast(list[BetaContentBlock], response.content): + for content_block in response_params: output_callback(content_block) - if content_block.type == "tool_use": + if content_block["type"] == "tool_use": result = await tool_collection.run( - name=content_block.name, - tool_input=cast(dict[str, Any], content_block.input), + name=content_block["name"], + tool_input=cast(dict[str, Any], content_block["input"]), ) tool_result_content.append( - _make_api_tool_result(result, content_block.id) + _make_api_tool_result(result, content_block["id"]) ) - tool_output_callback(result, content_block.id) + tool_output_callback(result, content_block["id"]) if not tool_result_content: return messages @@ -144,7 +158,7 @@ async def sampling_loop( def _maybe_filter_to_n_most_recent_images( messages: list[BetaMessageParam], images_to_keep: int, - min_removal_threshold: int = 10, + min_removal_threshold: int, ): """ With the assumption that images are screenshots that are of diminishing value as @@ -156,7 +170,7 @@ def _maybe_filter_to_n_most_recent_images( return messages tool_result_blocks = cast( - list[ToolResultBlockParam], + list[BetaToolResultBlockParam], [ item for message in messages @@ -189,6 +203,37 @@ def _maybe_filter_to_n_most_recent_images( new_content.append(content) tool_result["content"] = new_content +def _response_to_params(response: BetaMessage) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: + res: list[BetaTextBlockParam | BetaToolUseBlockParam] = [] + for block in response.content: + if isinstance(block, BetaTextBlock): + res.append({"type": "text", "text": block.text}) + else: + res.append(cast(BetaToolUseBlockParam, block.model_dump())) + return res + +def _inject_prompt_caching( + messages: list[BetaMessageParam], +): + """ + With the assumption that images are screenshots that are of diminishing value as + the conversation progresses, remove all but the final `images_to_keep` tool_result + images in place, with a chunk of min_removal_threshold to reduce the amount we + break the implicit prompt cache. + """ + + breakpoints_remaining = 3 + for message in reversed(messages): + if message["role"] == "user" and isinstance(content := message["content"], list): + if breakpoints_remaining: + breakpoints_remaining -= 1 + content[-1]["cache_control"] = BetaCacheControlEphemeralParam({"type": "ephemeral"}) + else: + content[-1].pop("cache_control", None) + # we'll only every have one extra turn per loop + break + + def _make_api_tool_result( result: ToolResult, tool_use_id: str diff --git a/computer-use-demo/computer_use_demo/streamlit.py b/computer-use-demo/computer_use_demo/streamlit.py index 6750029c..0f29fb76 100644 --- a/computer-use-demo/computer_use_demo/streamlit.py +++ b/computer-use-demo/computer_use_demo/streamlit.py @@ -12,13 +12,10 @@ from pathlib import PosixPath from typing import cast +from anthropic.types.beta.beta_tool_use_block_param import BetaToolUseBlockParam import streamlit as st from anthropic import APIResponse -from anthropic.types import ( - TextBlock, -) -from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock -from anthropic.types.tool_use_block import ToolUseBlock +from anthropic.types.beta import BetaContentBlockParam, BetaMessage, BetaTextBlockParam, BetaToolResultBlockParam, BetaToolUseBlock from streamlit.delta_generator import DeltaGenerator from computer_use_demo.loop import ( @@ -182,7 +179,7 @@ def _reset_api_provider(): else: _render_message( message["role"], - cast(BetaTextBlock | BetaToolUseBlock, block), + cast(BetaContentBlockParam | ToolResult, block), ) # render past http exchanges @@ -194,7 +191,7 @@ def _reset_api_provider(): st.session_state.messages.append( { "role": Sender.USER, - "content": [TextBlock(type="text", text=new_message)], + "content": [BetaTextBlockParam(type="text", text=new_message)], } ) _render_message(Sender.USER, new_message) @@ -317,15 +314,11 @@ def _render_api_response( def _render_message( sender: Sender, - message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, + message: str | BetaContentBlockParam | ToolResult, ): """Convert input from the user or output from the agent to a streamlit message.""" # streamlit's hotreloading breaks isinstance checks, so we need to check for class names - is_tool_result = not isinstance(message, str) and ( - isinstance(message, ToolResult) - or message.__class__.__name__ == "ToolResult" - or message.__class__.__name__ == "CLIResult" - ) + is_tool_result = not isinstance(message, str | dict) if not message or ( is_tool_result and st.session_state.hide_images @@ -345,10 +338,14 @@ def _render_message( st.error(message.error) if message.base64_image and not st.session_state.hide_images: st.image(base64.b64decode(message.base64_image)) - elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock): - st.write(message.text) - elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock): - st.code(f"Tool Use: {message.name}\nInput: {message.input}") + elif isinstance(message, dict): + if message["type"] == "text": + st.write(message["text"]) + elif message["type"] == "tool_use": + st.code(f'Tool Use: {message["name"]}\nInput: {message["input"]}') + else: + # only expected return types are text and tool_use + raise Exception(f'Unexpected response type {message["type"]}') else: st.markdown(message) diff --git a/computer-use-demo/computer_use_demo/tools/collection.py b/computer-use-demo/computer_use_demo/tools/collection.py index c4e8c95c..72c05ecb 100644 --- a/computer-use-demo/computer_use_demo/tools/collection.py +++ b/computer-use-demo/computer_use_demo/tools/collection.py @@ -2,7 +2,7 @@ from typing import Any -from anthropic.types.beta import BetaToolUnionParam +from anthropic.types.beta import BetaCacheControlEphemeralParam, BetaToolUnionParam from .base import ( BaseAnthropicTool, @@ -21,8 +21,12 @@ def __init__(self, *tools: BaseAnthropicTool): def to_params( self, + enable_prompt_caching: bool ) -> list[BetaToolUnionParam]: - return [tool.to_params() for tool in self.tools] + tools = [tool.to_params() for tool in self.tools] + if enable_prompt_caching: + tools[-1]["cache_control"] = BetaCacheControlEphemeralParam({"type": "ephemeral"}) + return tools async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult: tool = self.tool_map.get(name) diff --git a/computer-use-demo/tests/streamlit_test.py b/computer-use-demo/tests/streamlit_test.py index 25cd586b..7a76c1fa 100644 --- a/computer-use-demo/tests/streamlit_test.py +++ b/computer-use-demo/tests/streamlit_test.py @@ -1,9 +1,10 @@ from unittest import mock +from anthropic.types import TextBlockParam import pytest from streamlit.testing.v1 import AppTest -from computer_use_demo.streamlit import Sender, TextBlock +from computer_use_demo.streamlit import Sender @pytest.fixture @@ -18,6 +19,6 @@ def test_streamlit(streamlit_app: AppTest): streamlit_app.chat_input[0].set_value("Hello").run() assert patch.called assert patch.call_args.kwargs["messages"] == [ - {"role": Sender.USER, "content": [TextBlock(text="Hello", type="text")]} + {"role": Sender.USER, "content": [TextBlockParam(text="Hello", type="text")]} ] assert not streamlit_app.exception From 18e497fbaf500ce6c0c6ce185c2e27c6ca79b264 Mon Sep 17 00:00:00 2001 From: nsmccandlish Date: Wed, 23 Oct 2024 11:31:27 -0700 Subject: [PATCH 2/4] fix test --- computer-use-demo/tests/loop_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/computer-use-demo/tests/loop_test.py b/computer-use-demo/tests/loop_test.py index 4985dbee..9572c460 100644 --- a/computer-use-demo/tests/loop_test.py +++ b/computer-use-demo/tests/loop_test.py @@ -1,7 +1,7 @@ from unittest import mock from anthropic.types import TextBlock, ToolUseBlock -from anthropic.types.beta import BetaMessage, BetaMessageParam +from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaTextBlockParam from computer_use_demo.loop import APIProvider, sampling_loop @@ -58,7 +58,9 @@ async def test_loop(): tool_collection.run.assert_called_once_with( name="computer", tool_input={"action": "test"} ) - output_callback.assert_called_with(TextBlock(text="Done!", type="text")) + output_callback.assert_called_with( + BetaTextBlockParam(text="Done!", type="text") + ) assert output_callback.call_count == 3 assert tool_output_callback.call_count == 1 assert api_response_callback.call_count == 2 From 5f61360dc25104248e8107771b00a5ce69875e61 Mon Sep 17 00:00:00 2001 From: nsmccandlish Date: Wed, 23 Oct 2024 11:32:26 -0700 Subject: [PATCH 3/4] ruff? --- computer-use-demo/computer_use_demo/loop.py | 44 ++++++++++++------- .../computer_use_demo/streamlit.py | 7 ++- computer-use-demo/tests/streamlit_test.py | 7 ++- 3 files changed, 37 insertions(+), 21 deletions(-) diff --git a/computer-use-demo/computer_use_demo/loop.py b/computer-use-demo/computer_use_demo/loop.py index df0b4046..73040e7f 100644 --- a/computer-use-demo/computer_use_demo/loop.py +++ b/computer-use-demo/computer_use_demo/loop.py @@ -10,17 +10,15 @@ from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse from anthropic.types.beta import ( - BetaContentBlock, + BetaCacheControlEphemeralParam, BetaContentBlockParam, BetaImageBlockParam, BetaMessage, BetaMessageParam, + BetaTextBlock, BetaTextBlockParam, BetaToolResultBlockParam, BetaToolUseBlockParam, - BetaTextBlock, - BetaToolUseBlock, - BetaCacheControlEphemeralParam, ) from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult @@ -89,9 +87,8 @@ async def sampling_loop( ) while True: - enable_prompt_caching = False - betas=["computer-use-2024-10-22"] + betas = ["computer-use-2024-10-22"] image_truncation_threshold = 10 if provider == APIProvider.ANTHROPIC: client = Anthropic(api_key=api_key) @@ -100,7 +97,7 @@ async def sampling_loop( client = AnthropicVertex() elif provider == APIProvider.BEDROCK: client = AnthropicBedrock() - + if enable_prompt_caching: betas.append("prompt-caching-2024-07-31") _inject_prompt_caching(messages) @@ -108,8 +105,12 @@ async def sampling_loop( image_truncation_threshold = 50 if only_n_most_recent_images: - _maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images, min_removal_threshold=image_truncation_threshold) - + _maybe_filter_to_n_most_recent_images( + messages, + only_n_most_recent_images, + min_removal_threshold=image_truncation_threshold, + ) + # Call the API # we use raw_response to provide debug information to streamlit. Your # implementation may be able call the SDK directly with: @@ -119,10 +120,12 @@ async def sampling_loop( messages=messages, model=model, system=system, - tools=tool_collection.to_params(enable_prompt_caching=enable_prompt_caching), - betas=betas + tools=tool_collection.to_params( + enable_prompt_caching=enable_prompt_caching + ), + betas=betas, ) - + raw_response = cast(APIResponse[BetaMessage], raw_response) api_response_callback(raw_response) @@ -203,7 +206,10 @@ def _maybe_filter_to_n_most_recent_images( new_content.append(content) tool_result["content"] = new_content -def _response_to_params(response: BetaMessage) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: + +def _response_to_params( + response: BetaMessage, +) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: res: list[BetaTextBlockParam | BetaToolUseBlockParam] = [] for block in response.content: if isinstance(block, BetaTextBlock): @@ -212,6 +218,7 @@ def _response_to_params(response: BetaMessage) -> list[BetaTextBlockParam | Beta res.append(cast(BetaToolUseBlockParam, block.model_dump())) return res + def _inject_prompt_caching( messages: list[BetaMessageParam], ): @@ -221,20 +228,23 @@ def _inject_prompt_caching( images in place, with a chunk of min_removal_threshold to reduce the amount we break the implicit prompt cache. """ - + breakpoints_remaining = 3 for message in reversed(messages): - if message["role"] == "user" and isinstance(content := message["content"], list): + if message["role"] == "user" and isinstance( + content := message["content"], list + ): if breakpoints_remaining: breakpoints_remaining -= 1 - content[-1]["cache_control"] = BetaCacheControlEphemeralParam({"type": "ephemeral"}) + content[-1]["cache_control"] = BetaCacheControlEphemeralParam( + {"type": "ephemeral"} + ) else: content[-1].pop("cache_control", None) # we'll only every have one extra turn per loop break - def _make_api_tool_result( result: ToolResult, tool_use_id: str ) -> BetaToolResultBlockParam: diff --git a/computer-use-demo/computer_use_demo/streamlit.py b/computer-use-demo/computer_use_demo/streamlit.py index 0f29fb76..c0350dd2 100644 --- a/computer-use-demo/computer_use_demo/streamlit.py +++ b/computer-use-demo/computer_use_demo/streamlit.py @@ -12,10 +12,13 @@ from pathlib import PosixPath from typing import cast -from anthropic.types.beta.beta_tool_use_block_param import BetaToolUseBlockParam import streamlit as st from anthropic import APIResponse -from anthropic.types.beta import BetaContentBlockParam, BetaMessage, BetaTextBlockParam, BetaToolResultBlockParam, BetaToolUseBlock +from anthropic.types.beta import ( + BetaContentBlockParam, + BetaMessage, + BetaTextBlockParam, +) from streamlit.delta_generator import DeltaGenerator from computer_use_demo.loop import ( diff --git a/computer-use-demo/tests/streamlit_test.py b/computer-use-demo/tests/streamlit_test.py index 7a76c1fa..d9a42936 100644 --- a/computer-use-demo/tests/streamlit_test.py +++ b/computer-use-demo/tests/streamlit_test.py @@ -1,7 +1,7 @@ from unittest import mock -from anthropic.types import TextBlockParam import pytest +from anthropic.types import TextBlockParam from streamlit.testing.v1 import AppTest from computer_use_demo.streamlit import Sender @@ -19,6 +19,9 @@ def test_streamlit(streamlit_app: AppTest): streamlit_app.chat_input[0].set_value("Hello").run() assert patch.called assert patch.call_args.kwargs["messages"] == [ - {"role": Sender.USER, "content": [TextBlockParam(text="Hello", type="text")]} + { + "role": Sender.USER, + "content": [TextBlockParam(text="Hello", type="text")], + } ] assert not streamlit_app.exception From 92300f77f84f7f6fa47a77198126fdc9e83dbe5a Mon Sep 17 00:00:00 2001 From: nsmccandlish Date: Wed, 23 Oct 2024 14:58:03 -0700 Subject: [PATCH 4/4] fixify --- computer-use-demo/computer_use_demo/loop.py | 18 ++++++++---------- .../computer_use_demo/tools/collection.py | 8 ++------ 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/computer-use-demo/computer_use_demo/loop.py b/computer-use-demo/computer_use_demo/loop.py index 73040e7f..dfbb9661 100644 --- a/computer-use-demo/computer_use_demo/loop.py +++ b/computer-use-demo/computer_use_demo/loop.py @@ -82,8 +82,9 @@ async def sampling_loop( BashTool(), EditTool(), ) - system = ( - f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}" + system = BetaTextBlockParam( + type="text", + text=f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}", ) while True: @@ -103,6 +104,7 @@ async def sampling_loop( _inject_prompt_caching(messages) # Is it ever worth it to bust the cache with prompt caching? image_truncation_threshold = 50 + system["cache_control"] = {"type": "ephemeral"} if only_n_most_recent_images: _maybe_filter_to_n_most_recent_images( @@ -119,10 +121,8 @@ async def sampling_loop( max_tokens=max_tokens, messages=messages, model=model, - system=system, - tools=tool_collection.to_params( - enable_prompt_caching=enable_prompt_caching - ), + system=[system], + tools=tool_collection.to_params(), betas=betas, ) @@ -223,10 +223,8 @@ def _inject_prompt_caching( messages: list[BetaMessageParam], ): """ - With the assumption that images are screenshots that are of diminishing value as - the conversation progresses, remove all but the final `images_to_keep` tool_result - images in place, with a chunk of min_removal_threshold to reduce the amount we - break the implicit prompt cache. + Set cache breakpoints for the 3 most recent turns + one cache breakpoint is left for tools/system prompt, to be shared across sessions """ breakpoints_remaining = 3 diff --git a/computer-use-demo/computer_use_demo/tools/collection.py b/computer-use-demo/computer_use_demo/tools/collection.py index 72c05ecb..c4e8c95c 100644 --- a/computer-use-demo/computer_use_demo/tools/collection.py +++ b/computer-use-demo/computer_use_demo/tools/collection.py @@ -2,7 +2,7 @@ from typing import Any -from anthropic.types.beta import BetaCacheControlEphemeralParam, BetaToolUnionParam +from anthropic.types.beta import BetaToolUnionParam from .base import ( BaseAnthropicTool, @@ -21,12 +21,8 @@ def __init__(self, *tools: BaseAnthropicTool): def to_params( self, - enable_prompt_caching: bool ) -> list[BetaToolUnionParam]: - tools = [tool.to_params() for tool in self.tools] - if enable_prompt_caching: - tools[-1]["cache_control"] = BetaCacheControlEphemeralParam({"type": "ephemeral"}) - return tools + return [tool.to_params() for tool in self.tools] async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult: tool = self.tool_map.get(name)