Skip to content

Commit

Permalink
Add usage for Anthropic model responses
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmpcollins committed May 16, 2024
1 parent 62612c2 commit 9ad92c2
Showing 1 changed file with 18 additions and 6 deletions.
24 changes: 18 additions & 6 deletions src/magentic/chat_model/anthropic_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
AssistantMessage,
Message,
SystemMessage,
Usage,
UserMessage,
_assistant_message_with_usage,
)
from magentic.function_call import (
AsyncParallelFunctionCall,
Expand All @@ -44,6 +46,7 @@
ToolsBetaMessageParam,
ToolUseBlock,
)
from anthropic.types.usage import Usage as AnthropicUsage
except ImportError as error:
msg = "To use AnthropicChatModel you must install the `anthropic` package using `pip install 'magentic[anthropic]'`."
raise ImportError(msg) from error
Expand Down Expand Up @@ -220,6 +223,15 @@ def _extract_system_message(
)


def _assistant_message(content: T, usage: AnthropicUsage) -> AssistantMessage[T]:
"""Create an AssistantMessage with the given content and Anthropic usage onject."""
_usage = Usage(
input_tokens=usage.input_tokens,
output_tokens=usage.output_tokens,
)
return _assistant_message_with_usage(content, usage_pointer=[_usage])


R = TypeVar("R")


Expand Down Expand Up @@ -345,19 +357,19 @@ def complete(
allow_string_output=allow_string_output,
streamed=streamed_str_in_output_types,
)
return AssistantMessage(str_content) # type: ignore[return-value]
return _assistant_message(str_content, response.usage) # type: ignore[return-value]

if last_content.type == "tool_use":
try:
if is_any_origin_subclass(output_types, ParallelFunctionCall):
content = ParallelFunctionCall(
parse_tool_calls(response, tool_schemas)
)
return AssistantMessage(content) # type: ignore[return-value]
return _assistant_message(content, response.usage) # type: ignore[return-value]
# Take only the first tool_call, silently ignore extra chunks
# TODO: Create generator here that raises error or warns if multiple tool_calls
content = next(parse_tool_calls(response, tool_schemas))
return AssistantMessage(content) # type: ignore[return-value]
return _assistant_message(content, response.usage) # type: ignore[return-value]
except ValidationError as e:
msg = (
"Failed to parse model output. You may need to update your prompt"
Expand Down Expand Up @@ -439,19 +451,19 @@ async def acomplete(
allow_string_output=allow_string_output,
streamed=async_streamed_str_in_output_types,
)
return AssistantMessage(str_content) # type: ignore[return-value]
return _assistant_message(str_content, response.usage) # type: ignore[return-value]

if last_content.type == "tool_use":
try:
if is_any_origin_subclass(output_types, AsyncParallelFunctionCall):
content = AsyncParallelFunctionCall(
aparse_tool_calls(response, tool_schemas)
)
return AssistantMessage(content) # type: ignore[return-value]
return _assistant_message(content, response.usage) # type: ignore[return-value]
# Take only the first tool_call, silently ignore extra chunks
# TODO: Create generator here that raises error or warns if multiple tool_calls
content = await anext(aparse_tool_calls(response, tool_schemas))
return AssistantMessage(content) # type: ignore[return-value]
return _assistant_message(content, response.usage) # type: ignore[return-value]
except ValidationError as e:
msg = (
"Failed to parse model output. You may need to update your prompt"
Expand Down

0 comments on commit 9ad92c2

Please sign in to comment.