From 6bcaf5b7db34accbc0334e8023bcab8c0d781ffc Mon Sep 17 00:00:00 2001 From: Jack Collins <6640905+jackmpcollins@users.noreply.github.com> Date: Thu, 16 May 2024 00:39:26 -0700 Subject: [PATCH] Add tests for usage --- tests/chat_model/test_anthropic_chat_model.py | 28 ++++++++++++++++++- tests/chat_model/test_openai_chat_model.py | 27 ++++++++++++++++++ 2 files changed, 54 insertions(+), 1 deletion(-) diff --git a/tests/chat_model/test_anthropic_chat_model.py b/tests/chat_model/test_anthropic_chat_model.py index 3ab918fc..093329f7 100644 --- a/tests/chat_model/test_anthropic_chat_model.py +++ b/tests/chat_model/test_anthropic_chat_model.py @@ -2,12 +2,13 @@ from magentic.chat_model.anthropic_chat_model import AnthropicChatModel from magentic.chat_model.base import StructuredOutputError -from magentic.chat_model.message import UserMessage +from magentic.chat_model.message import Usage, UserMessage from magentic.function_call import ( AsyncParallelFunctionCall, FunctionCall, ParallelFunctionCall, ) +from magentic.streaming import AsyncStreamedStr, StreamedStr @pytest.mark.parametrize( @@ -28,6 +29,18 @@ def test_anthropic_chat_model_complete(prompt, output_types, expected_output_typ assert isinstance(message.content, expected_output_type) +@pytest.mark.anthropic +def test_anthropic_chat_model_complete_usage(): + chat_model = AnthropicChatModel("claude-3-haiku-20240307") + message = chat_model.complete( + messages=[UserMessage("Say hello!")], output_types=[StreamedStr] + ) + str(message.content) # Finish the stream + assert isinstance(message.usage, Usage) + assert message.usage.input_tokens > 0 + assert message.usage.output_tokens > 0 + + @pytest.mark.anthropic def test_anthropic_chat_model_complete_raise_structured_output_error(): chat_model = AnthropicChatModel("claude-3-haiku-20240307") @@ -100,6 +113,19 @@ async def test_anthropic_chat_model_acomplete( assert isinstance(message.content, expected_output_type) +@pytest.mark.asyncio +@pytest.mark.anthropic +async def test_anthropic_chat_model_acomplete_usage(): + chat_model = AnthropicChatModel("claude-3-haiku-20240307") + message = await chat_model.acomplete( + messages=[UserMessage("Say hello!")], output_types=[AsyncStreamedStr] + ) + await message.content.to_string() # Finish the stream + assert isinstance(message.usage, Usage) + assert message.usage.input_tokens > 0 + assert message.usage.output_tokens > 0 + + @pytest.mark.asyncio @pytest.mark.anthropic async def test_anthropic_chat_model_acomplete_function_call(): diff --git a/tests/chat_model/test_openai_chat_model.py b/tests/chat_model/test_openai_chat_model.py index 4a431b3c..d530e54d 100644 --- a/tests/chat_model/test_openai_chat_model.py +++ b/tests/chat_model/test_openai_chat_model.py @@ -10,6 +10,7 @@ FunctionResultMessage, Message, SystemMessage, + Usage, UserMessage, ) from magentic.chat_model.openai_chat_model import ( @@ -17,6 +18,7 @@ message_to_openai_message, ) from magentic.function_call import FunctionCall, ParallelFunctionCall +from magentic.streaming import AsyncStreamedStr, StreamedStr def plus(a: int, b: int) -> int: @@ -141,6 +143,18 @@ def test_openai_chat_model_complete_seed(): assert message1.content == message2.content +@pytest.mark.openai +def test_openai_chat_model_complete_usage(): + chat_model = OpenaiChatModel("gpt-3.5-turbo") + message = chat_model.complete( + messages=[UserMessage("Say hello!")], output_types=[StreamedStr] + ) + str(message.content) # Finish the stream + assert isinstance(message.usage, Usage) + assert message.usage.input_tokens > 0 + assert message.usage.output_tokens > 0 + + @pytest.mark.openai def test_openai_chat_model_complete_no_structured_output_error(): chat_model = OpenaiChatModel("gpt-3.5-turbo") @@ -152,3 +166,16 @@ def test_openai_chat_model_complete_no_structured_output_error(): output_types=[int, bool], ) assert isinstance(message.content, int | bool) + + +@pytest.mark.asyncio +@pytest.mark.openai +async def test_openai_chat_model_acomplete_usage(): + chat_model = OpenaiChatModel("gpt-3.5-turbo") + message = await chat_model.acomplete( + messages=[UserMessage("Say hello!")], output_types=[AsyncStreamedStr] + ) + await message.content.to_string() # Finish the stream + assert isinstance(message.usage, Usage) + assert message.usage.input_tokens > 0 + assert message.usage.output_tokens > 0