diff --git a/computer-use-demo/computer_use_demo/loop.py b/computer-use-demo/computer_use_demo/loop.py index 2feffa81..9db525a3 100644 --- a/computer-use-demo/computer_use_demo/loop.py +++ b/computer-use-demo/computer_use_demo/loop.py @@ -17,21 +17,22 @@ APIResponseValidationError, APIStatusError, ) -from anthropic.types import ( - ToolResultBlockParam, -) from anthropic.types.beta import ( - BetaContentBlock, + BetaCacheControlEphemeralParam, BetaContentBlockParam, BetaImageBlockParam, + BetaMessage, BetaMessageParam, + BetaTextBlock, BetaTextBlockParam, BetaToolResultBlockParam, + BetaToolUseBlockParam, ) from .tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult -BETA_FLAG = "computer-use-2024-10-22" +COMPUTER_USE_BETA_FLAG = "computer-use-2024-10-22" +PROMPT_CACHING_BETA_FLAG = "prompt-caching-2024-07-31" class APIProvider(StrEnum): @@ -75,7 +76,7 @@ async def sampling_loop( provider: APIProvider, system_prompt_suffix: str, messages: list[BetaMessageParam], - output_callback: Callable[[BetaContentBlock], None], + output_callback: Callable[[BetaContentBlockParam], None], tool_output_callback: Callable[[ToolResult, str], None], api_response_callback: Callable[ [httpx.Request, httpx.Response | object | None, Exception | None], None @@ -92,21 +93,37 @@ async def sampling_loop( BashTool(), EditTool(), ) - system = ( - f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}" + system = BetaTextBlockParam( + type="text", + text=f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}", ) while True: - if only_n_most_recent_images: - _maybe_filter_to_n_most_recent_images(messages, only_n_most_recent_images) - + enable_prompt_caching = False + betas = [COMPUTER_USE_BETA_FLAG] + image_truncation_threshold = 10 if provider == APIProvider.ANTHROPIC: client = Anthropic(api_key=api_key) + enable_prompt_caching = True elif provider == APIProvider.VERTEX: client = AnthropicVertex() elif provider == APIProvider.BEDROCK: client = AnthropicBedrock() + if enable_prompt_caching: + betas.append(PROMPT_CACHING_BETA_FLAG) + _inject_prompt_caching(messages) + # Is it ever worth it to bust the cache with prompt caching? + image_truncation_threshold = 50 + system["cache_control"] = {"type": "ephemeral"} + + if only_n_most_recent_images: + _maybe_filter_to_n_most_recent_images( + messages, + only_n_most_recent_images, + min_removal_threshold=image_truncation_threshold, + ) + # Call the API # we use raw_response to provide debug information to streamlit. Your # implementation may be able call the SDK directly with: @@ -116,9 +133,9 @@ async def sampling_loop( max_tokens=max_tokens, messages=messages, model=model, - system=system, + system=[system], tools=tool_collection.to_params(), - betas=[BETA_FLAG], + betas=betas, ) except (APIStatusError, APIResponseValidationError) as e: api_response_callback(e.request, e.response, e) @@ -133,25 +150,26 @@ async def sampling_loop( response = raw_response.parse() + response_params = _response_to_params(response) messages.append( { "role": "assistant", - "content": cast(list[BetaContentBlockParam], response.content), + "content": response_params, } ) tool_result_content: list[BetaToolResultBlockParam] = [] - for content_block in cast(list[BetaContentBlock], response.content): + for content_block in response_params: output_callback(content_block) - if content_block.type == "tool_use": + if content_block["type"] == "tool_use": result = await tool_collection.run( - name=content_block.name, - tool_input=cast(dict[str, Any], content_block.input), + name=content_block["name"], + tool_input=cast(dict[str, Any], content_block["input"]), ) tool_result_content.append( - _make_api_tool_result(result, content_block.id) + _make_api_tool_result(result, content_block["id"]) ) - tool_output_callback(result, content_block.id) + tool_output_callback(result, content_block["id"]) if not tool_result_content: return messages @@ -162,7 +180,7 @@ async def sampling_loop( def _maybe_filter_to_n_most_recent_images( messages: list[BetaMessageParam], images_to_keep: int, - min_removal_threshold: int = 10, + min_removal_threshold: int, ): """ With the assumption that images are screenshots that are of diminishing value as @@ -174,7 +192,7 @@ def _maybe_filter_to_n_most_recent_images( return messages tool_result_blocks = cast( - list[ToolResultBlockParam], + list[BetaToolResultBlockParam], [ item for message in messages @@ -208,6 +226,42 @@ def _maybe_filter_to_n_most_recent_images( tool_result["content"] = new_content +def _response_to_params( + response: BetaMessage, +) -> list[BetaTextBlockParam | BetaToolUseBlockParam]: + res: list[BetaTextBlockParam | BetaToolUseBlockParam] = [] + for block in response.content: + if isinstance(block, BetaTextBlock): + res.append({"type": "text", "text": block.text}) + else: + res.append(cast(BetaToolUseBlockParam, block.model_dump())) + return res + + +def _inject_prompt_caching( + messages: list[BetaMessageParam], +): + """ + Set cache breakpoints for the 3 most recent turns + one cache breakpoint is left for tools/system prompt, to be shared across sessions + """ + + breakpoints_remaining = 3 + for message in reversed(messages): + if message["role"] == "user" and isinstance( + content := message["content"], list + ): + if breakpoints_remaining: + breakpoints_remaining -= 1 + content[-1]["cache_control"] = BetaCacheControlEphemeralParam( + {"type": "ephemeral"} + ) + else: + content[-1].pop("cache_control", None) + # we'll only every have one extra turn per loop + break + + def _make_api_tool_result( result: ToolResult, tool_use_id: str ) -> BetaToolResultBlockParam: diff --git a/computer-use-demo/computer_use_demo/streamlit.py b/computer-use-demo/computer_use_demo/streamlit.py index 7db927e0..45318ed9 100644 --- a/computer-use-demo/computer_use_demo/streamlit.py +++ b/computer-use-demo/computer_use_demo/streamlit.py @@ -16,11 +16,10 @@ import httpx import streamlit as st from anthropic import RateLimitError -from anthropic.types import ( - TextBlock, +from anthropic.types.beta import ( + BetaContentBlockParam, + BetaTextBlockParam, ) -from anthropic.types.beta import BetaTextBlock, BetaToolUseBlock -from anthropic.types.tool_use_block import ToolUseBlock from streamlit.delta_generator import DeltaGenerator from computer_use_demo.loop import ( @@ -184,7 +183,7 @@ def _reset_api_provider(): else: _render_message( message["role"], - cast(BetaTextBlock | BetaToolUseBlock, block), + cast(BetaContentBlockParam | ToolResult, block), ) # render past http exchanges @@ -196,7 +195,7 @@ def _reset_api_provider(): st.session_state.messages.append( { "role": Sender.USER, - "content": [TextBlock(type="text", text=new_message)], + "content": [BetaTextBlockParam(type="text", text=new_message)], } ) _render_message(Sender.USER, new_message) @@ -345,15 +344,11 @@ def _render_error(error: Exception): def _render_message( sender: Sender, - message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, + message: str | BetaContentBlockParam | ToolResult, ): """Convert input from the user or output from the agent to a streamlit message.""" # streamlit's hotreloading breaks isinstance checks, so we need to check for class names - is_tool_result = not isinstance(message, str) and ( - isinstance(message, ToolResult) - or message.__class__.__name__ == "ToolResult" - or message.__class__.__name__ == "CLIResult" - ) + is_tool_result = not isinstance(message, str | dict) if not message or ( is_tool_result and st.session_state.hide_images @@ -373,10 +368,14 @@ def _render_message( st.error(message.error) if message.base64_image and not st.session_state.hide_images: st.image(base64.b64decode(message.base64_image)) - elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock): - st.write(message.text) - elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock): - st.code(f"Tool Use: {message.name}\nInput: {message.input}") + elif isinstance(message, dict): + if message["type"] == "text": + st.write(message["text"]) + elif message["type"] == "tool_use": + st.code(f'Tool Use: {message["name"]}\nInput: {message["input"]}') + else: + # only expected return types are text and tool_use + raise Exception(f'Unexpected response type {message["type"]}') else: st.markdown(message) diff --git a/computer-use-demo/tests/loop_test.py b/computer-use-demo/tests/loop_test.py index 4985dbee..9572c460 100644 --- a/computer-use-demo/tests/loop_test.py +++ b/computer-use-demo/tests/loop_test.py @@ -1,7 +1,7 @@ from unittest import mock from anthropic.types import TextBlock, ToolUseBlock -from anthropic.types.beta import BetaMessage, BetaMessageParam +from anthropic.types.beta import BetaMessage, BetaMessageParam, BetaTextBlockParam from computer_use_demo.loop import APIProvider, sampling_loop @@ -58,7 +58,9 @@ async def test_loop(): tool_collection.run.assert_called_once_with( name="computer", tool_input={"action": "test"} ) - output_callback.assert_called_with(TextBlock(text="Done!", type="text")) + output_callback.assert_called_with( + BetaTextBlockParam(text="Done!", type="text") + ) assert output_callback.call_count == 3 assert tool_output_callback.call_count == 1 assert api_response_callback.call_count == 2 diff --git a/computer-use-demo/tests/streamlit_test.py b/computer-use-demo/tests/streamlit_test.py index 25cd586b..d9a42936 100644 --- a/computer-use-demo/tests/streamlit_test.py +++ b/computer-use-demo/tests/streamlit_test.py @@ -1,9 +1,10 @@ from unittest import mock import pytest +from anthropic.types import TextBlockParam from streamlit.testing.v1 import AppTest -from computer_use_demo.streamlit import Sender, TextBlock +from computer_use_demo.streamlit import Sender @pytest.fixture @@ -18,6 +19,9 @@ def test_streamlit(streamlit_app: AppTest): streamlit_app.chat_input[0].set_value("Hello").run() assert patch.called assert patch.call_args.kwargs["messages"] == [ - {"role": Sender.USER, "content": [TextBlock(text="Hello", type="text")]} + { + "role": Sender.USER, + "content": [TextBlockParam(text="Hello", type="text")], + } ] assert not streamlit_app.exception