diff --git a/README.md b/README.md index fef8994c..5fc419c6 100644 --- a/README.md +++ b/README.md @@ -215,6 +215,7 @@ for hero in create_superhero_team("The Food Dudes"): ### Additional Features - The `@prompt` decorator can also be used with `async` function definitions, which enables making concurrent queries to the LLM. +- The `functions` argument to `@prompt` can contain async/coroutine functions. When the corresponding `FunctionCall` objects are called the result must be awaited. - The `Annotated` type annotation can be used to provide descriptions and other metadata for function parameters. See [the pydantic documentation on using `Field` to describe function arguments](https://docs.pydantic.dev/latest/usage/validation_decorator/#using-field-to-describe-function-arguments). - The `@prompt` and `@prompt_chain` decorators also accept a `model` argument. You can pass an instance of `OpenaiChatModel` (from `magentic.chat_model.openai_chat_model`) to use GPT4 or configure a different temperature. diff --git a/tests/test_function_call.py b/tests/test_function_call.py index 0d170d10..0bddc4b0 100644 --- a/tests/test_function_call.py +++ b/tests/test_function_call.py @@ -1,3 +1,5 @@ +import inspect + import pytest from magentic.function_call import FunctionCall @@ -47,3 +49,14 @@ def test_function_call_eq(left, right, equal): ) def test_function_call_arguments(function_call, arguments): assert function_call.arguments == arguments + + +@pytest.mark.asyncio +async def test_function_call_async_function(): + async def async_plus(a: int, b: int) -> int: + return a + b + + function_call = FunctionCall(async_plus, a=1, b=2) + result = function_call() + assert inspect.isawaitable(result) + assert await result == 3 diff --git a/tests/test_prompt_function.py b/tests/test_prompt_function.py index 7b8a70e4..83b28498 100644 --- a/tests/test_prompt_function.py +++ b/tests/test_prompt_function.py @@ -1,6 +1,7 @@ """Tests for PromptFunction.""" from inspect import getdoc +from typing import Awaitable import pytest from pydantic import BaseModel @@ -154,3 +155,23 @@ async def sum_populations(country_one: str, country_two: str) -> FunctionCall[in assert isinstance(output, FunctionCall) func_result = output() assert isinstance(func_result, int) + + +@pytest.mark.asyncio +@pytest.mark.openai +async def test_async_decorator_return_async_function_call(): + async def async_plus(a: int, b: int) -> int: + return a + b + + @prompt( + "Sum the populations of {country_one} and {country_two}.", + functions=[async_plus], + ) + async def sum_populations( + country_one: str, country_two: str + ) -> FunctionCall[Awaitable[int]]: + ... + + output = await sum_populations("Ireland", "UK") + assert isinstance(output, FunctionCall) + assert isinstance(await output(), int)