Skip to content

Commit

Permalink
fixify
Browse files Browse the repository at this point in the history
  • Loading branch information
nsmccandlish committed Oct 23, 2024
1 parent 5f61360 commit 92300f7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 16 deletions.
18 changes: 8 additions & 10 deletions computer-use-demo/computer_use_demo/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ async def sampling_loop(
BashTool(),
EditTool(),
)
system = (
f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}"
system = BetaTextBlockParam(
type="text",
text=f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}",
)

while True:
Expand All @@ -103,6 +104,7 @@ async def sampling_loop(
_inject_prompt_caching(messages)
# Is it ever worth it to bust the cache with prompt caching?
image_truncation_threshold = 50
system["cache_control"] = {"type": "ephemeral"}

if only_n_most_recent_images:
_maybe_filter_to_n_most_recent_images(
Expand All @@ -119,10 +121,8 @@ async def sampling_loop(
max_tokens=max_tokens,
messages=messages,
model=model,
system=system,
tools=tool_collection.to_params(
enable_prompt_caching=enable_prompt_caching
),
system=[system],
tools=tool_collection.to_params(),
betas=betas,
)

Expand Down Expand Up @@ -223,10 +223,8 @@ def _inject_prompt_caching(
messages: list[BetaMessageParam],
):
"""
With the assumption that images are screenshots that are of diminishing value as
the conversation progresses, remove all but the final `images_to_keep` tool_result
images in place, with a chunk of min_removal_threshold to reduce the amount we
break the implicit prompt cache.
Set cache breakpoints for the 3 most recent turns
one cache breakpoint is left for tools/system prompt, to be shared across sessions
"""

breakpoints_remaining = 3
Expand Down
8 changes: 2 additions & 6 deletions computer-use-demo/computer_use_demo/tools/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import Any

from anthropic.types.beta import BetaCacheControlEphemeralParam, BetaToolUnionParam
from anthropic.types.beta import BetaToolUnionParam

from .base import (
BaseAnthropicTool,
Expand All @@ -21,12 +21,8 @@ def __init__(self, *tools: BaseAnthropicTool):

def to_params(
self,
enable_prompt_caching: bool
) -> list[BetaToolUnionParam]:
tools = [tool.to_params() for tool in self.tools]
if enable_prompt_caching:
tools[-1]["cache_control"] = BetaCacheControlEphemeralParam({"type": "ephemeral"})
return tools
return [tool.to_params() for tool in self.tools]

async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
tool = self.tool_map.get(name)
Expand Down

0 comments on commit 92300f7

Please sign in to comment.