Skip to content

Commit

Permalink
Add test for async/coroutine function with FunctionCall (#11)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmpcollins authored Sep 10, 2023
1 parent 81fadb2 commit 943019c
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ for hero in create_superhero_team("The Food Dudes"):
### Additional Features

- The `@prompt` decorator can also be used with `async` function definitions, which enables making concurrent queries to the LLM.
- The `functions` argument to `@prompt` can contain async/coroutine functions. When the corresponding `FunctionCall` objects are called the result must be awaited.
- The `Annotated` type annotation can be used to provide descriptions and other metadata for function parameters. See [the pydantic documentation on using `Field` to describe function arguments](https://docs.pydantic.dev/latest/usage/validation_decorator/#using-field-to-describe-function-arguments).
- The `@prompt` and `@prompt_chain` decorators also accept a `model` argument. You can pass an instance of `OpenaiChatModel` (from `magentic.chat_model.openai_chat_model`) to use GPT4 or configure a different temperature.

Expand Down
13 changes: 13 additions & 0 deletions tests/test_function_call.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import inspect

import pytest

from magentic.function_call import FunctionCall
Expand Down Expand Up @@ -47,3 +49,14 @@ def test_function_call_eq(left, right, equal):
)
def test_function_call_arguments(function_call, arguments):
assert function_call.arguments == arguments


@pytest.mark.asyncio
async def test_function_call_async_function():
async def async_plus(a: int, b: int) -> int:
return a + b

function_call = FunctionCall(async_plus, a=1, b=2)
result = function_call()
assert inspect.isawaitable(result)
assert await result == 3
21 changes: 21 additions & 0 deletions tests/test_prompt_function.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Tests for PromptFunction."""

from inspect import getdoc
from typing import Awaitable

import pytest
from pydantic import BaseModel
Expand Down Expand Up @@ -154,3 +155,23 @@ async def sum_populations(country_one: str, country_two: str) -> FunctionCall[in
assert isinstance(output, FunctionCall)
func_result = output()
assert isinstance(func_result, int)


@pytest.mark.asyncio
@pytest.mark.openai
async def test_async_decorator_return_async_function_call():
async def async_plus(a: int, b: int) -> int:
return a + b

@prompt(
"Sum the populations of {country_one} and {country_two}.",
functions=[async_plus],
)
async def sum_populations(
country_one: str, country_two: str
) -> FunctionCall[Awaitable[int]]:
...

output = await sum_populations("Ireland", "UK")
assert isinstance(output, FunctionCall)
assert isinstance(await output(), int)

0 comments on commit 943019c

Please sign in to comment.