diff --git a/README.md b/README.md index 5030b841..ef214c98 100644 --- a/README.md +++ b/README.md @@ -179,6 +179,29 @@ for country, streamed_str in zip(countries, streamed_strs): # 24.67s : Chile - 2186 chars ``` +#### Object Streaming + +Structured outputs can also be streamed from the LLM by using the return type annotation `Iterable` (or `AsyncIterable`). This allows each item to be processed while the next one is being generated. See the example in [examples/quiz](examples/quiz/) for how this can be used to improve user experience by quickly displaying/using the first item returned. + +```python +from collections.abc import Iterable +from time import time + + +@prompt("Create a Superhero team named {name}.") +def create_superhero_team(name: str) -> Iterable[Superhero]: + ... + + +start_time = time() +for hero in create_superhero_team("The Food Dudes"): + print(f"{time() - start_time:.2f}s : {hero}") + +# 2.23s : name='Pizza Man' age=30 power='Can shoot pizza slices from his hands' enemies=['The Hungry Horde', 'The Junk Food Gang'] +# 4.03s : name='Captain Carrot' age=35 power='Super strength and agility from eating carrots' enemies=['The Sugar Squad', 'The Greasy Gang'] +# 6.05s : name='Ice Cream Girl' age=25 power='Can create ice cream out of thin air' enemies=['The Hot Sauce Squad', 'The Healthy Eaters'] +``` + ### Additional Features - The `@prompt` decorator can also be used with `async` function definitions, which enables making concurrent queries to the LLM. diff --git a/examples/quiz/quiz.py b/examples/quiz/0_quiz.py similarity index 86% rename from examples/quiz/quiz.py rename to examples/quiz/0_quiz.py index c47355f3..515ebba0 100644 --- a/examples/quiz/quiz.py +++ b/examples/quiz/0_quiz.py @@ -8,13 +8,13 @@ Run this example within this directory with: ```sh -poetry run python quiz.py +poetry run python 0_quiz.py ``` or if you have installed magentic with pip: ```sh -python quiz.py +python 0_quiz.py ``` --- @@ -22,7 +22,6 @@ Example run: ``` -% poetry run python quiz.py Enter a topic for a quiz: pizza Enter the number of questions: 3 @@ -43,9 +42,9 @@ Quiz complete! You scored: 66% -"Hey pizza enthusiast! Congrats on scoring 66/100 on the pizza quiz! You may not have -aced it, but hey, you've still got a slice of the pie! Keep up the cheesy spirit and -remember, there's always room for improvement... and extra toppings! 🍕🎉" +Hey pizza enthusiast! Congrats on scoring 66/100 on the pizza quiz! You may not have + aced it, but hey, you've still got a slice of the pie! Keep up the cheesy spirit and + remember, there's always room for improvement... and extra toppings! 🍕🎉 ``` """ diff --git a/examples/quiz/quiz_async.py b/examples/quiz/1_quiz_async.py similarity index 93% rename from examples/quiz/quiz_async.py rename to examples/quiz/1_quiz_async.py index fec45eea..83d9c87c 100644 --- a/examples/quiz/quiz_async.py +++ b/examples/quiz/1_quiz_async.py @@ -1,6 +1,6 @@ """A simple quiz game. -This example builds on the `quiz.py` example to demonstrate how using asyncio can greatly speed up queries. The quiz +This example builds on the `0_quiz.py` example to demonstrate how using asyncio can greatly speed up queries. The quiz questions are now generated concurrently which means the quiz starts much more quickly after the user has entered the topic and number of questions. However since the questions are generated independently there is more likelihood of duplicates - increasing the model temperature can help with this. @@ -10,13 +10,13 @@ Run this example within this directory with: ```sh -poetry run python quiz.py +poetry run python 1_quiz_async.py ``` or if you have installed magentic with pip: ```sh -python quiz.py +python 1_quiz_async.py ``` --- @@ -24,7 +24,6 @@ Example run: ``` -% poetry run python examples/quiz/quiz_async.py Enter a topic for a quiz: France Enter the number of questions: 3 diff --git a/examples/quiz/2_quiz_streamed.py b/examples/quiz/2_quiz_streamed.py new file mode 100644 index 00000000..88bb48a1 --- /dev/null +++ b/examples/quiz/2_quiz_streamed.py @@ -0,0 +1,110 @@ +"""A simple quiz game. + +This example improves on the `1_quiz.py` example by using streaming to generate the questions. In `1_quiz.py` the +questions were generated concurrently which allowed the quiz to start quickly but meant there was a chance of duplicate +questions being generated. In this example the questions are streamed which allows us to show the first question to the +user as soon as it is ready, while still making a single query to the LLM which avoids generating duplicate questions. + +The only change from `0_quiz.py` is the return type annotations of the `generate_questions` function changing from +`list[Question]` to `Iterable[Question]`. This allows us to iterate through the questions as they are generated. + +--- + +Run this example within this directory with: + +```sh +poetry run python 2_quiz_streamed.py +``` + +or if you have installed magentic with pip: + +```sh +python 2_quiz_streamed.py +``` + +--- + +Example run: + +``` +Enter a topic for a quiz: NASA +Enter the number of questions: 3 + +1 / 3 +Q: When was NASA founded? +A: 1958 +Correct! The answer is: 1958 + +2 / 3 +Q: Who was the first person to walk on the moon? +A: Neil Armstrong +Correct! The answer is: Neil Armstrong + +3 / 3 +Q: What is the largest planet in our solar system? +A: Jupyter +Incorrect! The correct answer is: Jupiter + +Quiz complete! You scored: 66% + +Congratulations on your stellar performance! You may not have reached the moon, + but you definitely rocked that NASA quiz with a score of 66/100! Remember, + even astronauts have their off days. Keep reaching for the stars, and who knows, + maybe next time you'll be the one discovering a new galaxy! + Keep up the astronomical work! 🚀🌟 +``` + +""" + +from collections.abc import Iterable + +from pydantic import BaseModel + +from magentic import prompt + + +class Question(BaseModel): + question: str + answer: str + + +@prompt("Generate {num} quiz questions about {topic}") +def generate_questions(topic: str, num: int) -> Iterable[Question]: + ... + + +@prompt("""Return true if the user's answer is correct. +Question: {question.question} +Answer: {question.answer} +User Answer: {user_answer}""") +def is_answer_correct(question: Question, user_answer: str) -> bool: + ... + + +@prompt( + "Create a short and funny message of celebration or encouragment for someone who" + " scored {score}/100 on a quiz about {topic}." +) +def create_encouragement_message(score: int, topic: str) -> str: + ... + + +topic = input("Enter a topic for a quiz: ") +num_questions = int(input("Enter the number of questions: ")) +questions = generate_questions(topic, num_questions) + +user_points = 0 +for num, question in enumerate(questions, start=1): + print(f"\n{num} / {num_questions}") + print(f"Q: {question.question}") + user_answer = input("A: ") + + if is_answer_correct(question, user_answer): + print(f"Correct! The answer is: {question.answer}") + user_points += 1 + else: + print(f"Incorrect! The correct answer is: {question.answer}") + +score = 100 * user_points // num_questions +print(f"\nQuiz complete! You scored: {score}%\n") +print(create_encouragement_message(score, topic)) diff --git a/src/magentic/__init__.py b/src/magentic/__init__.py index 0b8cb24d..62d4df68 100644 --- a/src/magentic/__init__.py +++ b/src/magentic/__init__.py @@ -1,7 +1,7 @@ from magentic.function_call import FunctionCall from magentic.prompt_chain import prompt_chain from magentic.prompt_function import prompt -from magentic.streamed_str import AsyncStreamedStr, StreamedStr +from magentic.streaming import AsyncStreamedStr, StreamedStr __all__ = [ "FunctionCall", diff --git a/src/magentic/chat_model/openai_chat_model.py b/src/magentic/chat_model/openai_chat_model.py index 10ee193d..778fe551 100644 --- a/src/magentic/chat_model/openai_chat_model.py +++ b/src/magentic/chat_model/openai_chat_model.py @@ -1,18 +1,10 @@ import inspect import json -import textwrap +import typing from abc import ABC, abstractmethod +from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterable, Iterator from enum import Enum -from typing import ( - Any, - AsyncIterator, - Callable, - Generic, - Iterable, - Iterator, - TypeVar, - cast, -) +from typing import Any, Generic, TypeVar, cast, get_args, get_origin import openai from pydantic import BaseModel, TypeAdapter, ValidationError, create_model @@ -25,8 +17,13 @@ UserMessage, ) from magentic.function_call import FunctionCall -from magentic.streamed_str import AsyncStreamedStr, StreamedStr -from magentic.typing import is_origin_subclass, name_type +from magentic.streaming import ( + AsyncStreamedStr, + StreamedStr, + aiter_streamed_json_array, + iter_streamed_json_array, +) +from magentic.typing import is_origin_abstract, is_origin_subclass, name_type class StructuredOutputError(Exception): @@ -58,12 +55,21 @@ def dict(self) -> dict[str, Any]: return schema @abstractmethod - def parse_args(self, arguments: str) -> T: + def parse_args(self, arguments: Iterable[str]) -> T: ... - def parse_args_to_message(self, arguments: str) -> AssistantMessage[T]: + async def aparse_args(self, arguments: AsyncIterable[str]) -> T: + # TODO: Convert AsyncIterable to lazy Iterable rather than list + return self.parse_args([arg async for arg in arguments]) + + def parse_args_to_message(self, arguments: Iterable[str]) -> AssistantMessage[T]: return AssistantMessage(self.parse_args(arguments)) + async def aparse_args_to_message( + self, arguments: AsyncIterable[str] + ) -> AssistantMessage[T]: + return AssistantMessage(await self.aparse_args(arguments)) + @abstractmethod def serialize_args(self, value: T) -> str: ... @@ -90,13 +96,89 @@ def parameters(self) -> dict[str, Any]: model_schema.pop("description", None) return model_schema - def parse_args(self, arguments: str) -> T: - return self._model.model_validate_json(arguments).value + def parse_args(self, arguments: Iterable[str]) -> T: + return self._model.model_validate_json("".join(arguments)).value def serialize_args(self, value: T) -> str: return self._model(value=value).model_dump_json() +IterableT = TypeVar("IterableT", bound=Iterable[Any]) + + +class IterableFunctionSchema(BaseFunctionSchema[IterableT], Generic[IterableT]): + def __init__(self, output_type: type[IterableT]): + self._output_type = output_type + self._item_type_adapter = TypeAdapter(get_args(output_type)[0]) + # https://github.com/python/mypy/issues/14458 + self._model = Output[output_type] # type: ignore[valid-type] + + @property + def name(self) -> str: + return f"return_{name_type(self._output_type)}" + + @property + def parameters(self) -> dict[str, Any]: + model_schema = self._model.model_json_schema().copy() + model_schema.pop("title", None) + model_schema.pop("description", None) + return model_schema + + def parse_args(self, arguments: Iterable[str]) -> IterableT: + iter_items = ( + self._item_type_adapter.validate_json(item) + for item in iter_streamed_json_array(arguments) + ) + return self._model.model_validate({"value": iter_items}).value + + def serialize_args(self, value: IterableT) -> str: + return self._model(value=value).model_dump_json() + + +AsyncIterableT = TypeVar("AsyncIterableT", bound=AsyncIterable[Any]) + + +class AsyncIterableFunctionSchema( + BaseFunctionSchema[AsyncIterableT], Generic[AsyncIterableT] +): + def __init__(self, output_type: type[AsyncIterableT]): + self._output_type = output_type + self._item_type_adapter = TypeAdapter(get_args(output_type)[0]) + # Convert to list so pydantic can handle for schema generation + # But keep the type hint using AsyncIterableT for type checking + self._model: type[Output[AsyncIterableT]] = Output[list[get_args(output_type)[0]]] # type: ignore + + @property + def name(self) -> str: + return f"return_{name_type(self._output_type)}" + + @property + def parameters(self) -> dict[str, Any]: + model_schema = self._model.model_json_schema().copy() + model_schema.pop("title", None) + model_schema.pop("description", None) + return model_schema + + def parse_args(self, arguments: Iterable[str]) -> AsyncIterableT: + raise NotImplementedError() + + async def aparse_args(self, arguments: AsyncIterable[str]) -> AsyncIterableT: + aiter_items = ( + self._item_type_adapter.validate_json(item) + async for item in aiter_streamed_json_array(arguments) + ) + if (get_origin(self._output_type) or self._output_type) in ( + typing.AsyncIterable, + typing.AsyncIterator, + ) or is_origin_abstract(self._output_type): + return cast(AsyncIterableT, aiter_items) + + raise NotImplementedError() + + def serialize_args(self, value: AsyncIterableT) -> str: + return self._model(value=value).model_dump_json() + + class DictFunctionSchema(BaseFunctionSchema[T], Generic[T]): def __init__(self, output_type: type[T]): self._output_type = output_type @@ -112,8 +194,8 @@ def parameters(self) -> dict[str, Any]: model_schema["properties"] = model_schema.get("properties", {}) return model_schema - def parse_args(self, arguments: str) -> T: - return self._type_adapter.validate_json(arguments) + def parse_args(self, arguments: Iterable[str]) -> T: + return self._type_adapter.validate_json("".join(arguments)) def serialize_args(self, value: T) -> str: return self._type_adapter.dump_json(value).decode() @@ -137,8 +219,8 @@ def parameters(self) -> dict[str, Any]: model_schema.pop("description", None) return model_schema - def parse_args(self, arguments: str) -> BaseModelT: - return self._model.model_validate_json(arguments) + def parse_args(self, arguments: Iterable[str]) -> BaseModelT: + return self._model.model_validate_json("".join(arguments)) def serialize_args(self, value: BaseModelT) -> str: return value.model_dump_json() @@ -171,24 +253,41 @@ def parameters(self) -> dict[str, Any]: schema.pop("title", None) return schema - def parse_args(self, arguments: str) -> FunctionCall[T]: - args = self._model.model_validate_json(arguments).model_dump(exclude_unset=True) + def parse_args(self, arguments: Iterable[str]) -> FunctionCall[T]: + args = self._model.model_validate_json("".join(arguments)).model_dump( + exclude_unset=True + ) return FunctionCall(self._func, **args) - def parse_args_to_message(self, arguments: str) -> FunctionCallMessage[T]: + async def aparse_args(self, arguments: AsyncIterable[str]) -> FunctionCall[T]: + return self.parse_args([arg async for arg in arguments]) + + def parse_args_to_message(self, arguments: Iterable[str]) -> FunctionCallMessage[T]: return FunctionCallMessage(self.parse_args(arguments)) + async def aparse_args_to_message( + self, arguments: AsyncIterable[str] + ) -> FunctionCallMessage[T]: + return FunctionCallMessage(await self.aparse_args(arguments)) + def serialize_args(self, value: FunctionCall[T]) -> str: return json.dumps(value.arguments) -def function_schema_for_type(type_: type[T]) -> BaseFunctionSchema[T]: +# TODO: Add type hints here. Possibly use `functools.singledispatch` instead. +def function_schema_for_type(type_: type[Any]) -> BaseFunctionSchema[Any]: """Create a FunctionSchema for the given type.""" if is_origin_subclass(type_, BaseModel): - return BaseModelFunctionSchema(type_) # type: ignore[return-value] + return BaseModelFunctionSchema(type_) if is_origin_subclass(type_, dict): - return DictFunctionSchema(type_) # type: ignore[arg-type] + return DictFunctionSchema(type_) + + if is_origin_subclass(type_, Iterable): + return IterableFunctionSchema(type_) + + if is_origin_subclass(type_, AsyncIterable): + return AsyncIterableFunctionSchema(type_) return AnyFunctionSchema(type_) @@ -259,12 +358,12 @@ def complete( function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ function_schema_for_type(type_) for type_ in output_types - if not issubclass(type_, (str, StreamedStr)) + if not is_origin_subclass(type_, (str, StreamedStr)) ] - str_in_output_types = any(issubclass(cls, str) for cls in output_types) + str_in_output_types = any(is_origin_subclass(cls, str) for cls in output_types) streamed_str_in_output_types = any( - issubclass(cls, StreamedStr) for cls in output_types + is_origin_subclass(cls, StreamedStr) for cls in output_types ) allow_string_output = str_in_output_types or streamed_str_in_output_types @@ -295,19 +394,16 @@ def complete( } function_name = first_chunk_delta["function_call"]["name"] function_schema = function_schema_by_name[function_name] - function_call_args = "".join( - chunk["choices"][0]["delta"]["function_call"]["arguments"] - for chunk in response - if chunk["choices"][0]["delta"] - ) try: - message = function_schema.parse_args_to_message(function_call_args) + message = function_schema.parse_args_to_message( + chunk["choices"][0]["delta"]["function_call"]["arguments"] + for chunk in response + if chunk["choices"][0]["delta"] + ) except ValidationError as e: raise StructuredOutputError( - "Failed to parse model output" - f" {textwrap.shorten(function_call_args, 100)!r}." - " You may need to update your prompt to encourage the model to" - " return a specific type." + "Failed to parse model output. You may need to update your prompt" + " to encourage the model to return a specific type." ) from e return message @@ -339,12 +435,12 @@ async def acomplete( function_schemas = [FunctionCallFunctionSchema(f) for f in functions or []] + [ function_schema_for_type(type_) for type_ in output_types - if not issubclass(type_, (str, AsyncStreamedStr)) + if not is_origin_subclass(type_, (str, AsyncStreamedStr)) ] - str_in_output_types = any(issubclass(cls, str) for cls in output_types) + str_in_output_types = any(is_origin_subclass(cls, str) for cls in output_types) async_streamed_str_in_output_types = any( - issubclass(cls, AsyncStreamedStr) for cls in output_types + is_origin_subclass(cls, AsyncStreamedStr) for cls in output_types ) allow_string_output = str_in_output_types or async_streamed_str_in_output_types @@ -375,21 +471,16 @@ async def acomplete( } function_name = first_chunk_delta["function_call"]["name"] function_schema = function_schema_by_name[function_name] - function_call_args = "".join( - [ + try: + message = await function_schema.aparse_args_to_message( chunk["choices"][0]["delta"]["function_call"]["arguments"] async for chunk in response if chunk["choices"][0]["delta"] - ] - ) - try: - message = function_schema.parse_args_to_message(function_call_args) + ) except ValidationError as e: raise StructuredOutputError( - "Failed to parse model output" - f" {textwrap.shorten(function_call_args, 100)!r}." - " You may need to update your prompt to encourage the model to" - " return a specific type." + "Failed to parse model output. You may need to update your prompt" + " to encourage the model to return a specific type." ) from e return message diff --git a/src/magentic/streamed_str.py b/src/magentic/streamed_str.py deleted file mode 100644 index 50927d92..00000000 --- a/src/magentic/streamed_str.py +++ /dev/null @@ -1,52 +0,0 @@ -from collections.abc import AsyncIterable, Iterable -from typing import AsyncIterator, Iterator, TypeVar - -T = TypeVar("T") - - -async def async_iter(iterable: Iterable[T]) -> AsyncIterator[T]: - """Get an AsyncIterator for an Iterable.""" - for item in iterable: - yield item - - -class StreamedStr(Iterable[str]): - """A string that is generated in chunks.""" - - def __init__(self, chunks: Iterable[str]): - self._chunks = chunks - self._cached_chunks: list[str] = [] - - def __iter__(self) -> Iterator[str]: - yield from self._cached_chunks - for chunk in self._chunks: - self._cached_chunks.append(chunk) - yield chunk - - def __str__(self) -> str: - return "".join(self) - - def to_string(self) -> str: - """Convert the streamed string to a string.""" - return str(self) - - -class AsyncStreamedStr(AsyncIterable[str]): - """Async version of `StreamedStr`.""" - - def __init__(self, chunks: AsyncIterable[str]): - self._chunks = chunks - self._cached_chunks: list[str] = [] - - async def __aiter__(self) -> AsyncIterator[str]: - # Cannot use `yield from` inside an async function - # https://peps.python.org/pep-0525/#asynchronous-yield-from - for chunk in self._cached_chunks: - yield chunk - async for chunk in self._chunks: - self._cached_chunks.append(chunk) - yield chunk - - async def to_string(self) -> str: - """Convert the streamed string to a string.""" - return "".join([item async for item in self]) diff --git a/src/magentic/streaming.py b/src/magentic/streaming.py new file mode 100644 index 00000000..e67fcd88 --- /dev/null +++ b/src/magentic/streaming.py @@ -0,0 +1,139 @@ +from collections.abc import AsyncIterable, Iterable +from dataclasses import dataclass +from itertools import chain, dropwhile +from typing import AsyncIterator, Iterator, TypeVar + +T = TypeVar("T") + + +async def async_iter(iterable: Iterable[T]) -> AsyncIterator[T]: + """Get an AsyncIterator for an Iterable.""" + for item in iterable: + yield item + + +@dataclass +class JsonArrayParserState: + array_level: int = 0 + object_level: int = 0 + in_string: bool = False + is_escaped: bool = False + is_element_separator: bool = False + + def update(self, char: str) -> None: + if self.in_string: + if char == '"' and not self.is_escaped: + self.in_string = False + elif char == '"': + self.in_string = True + elif char == ",": + if self.array_level == 1 and self.object_level == 0: + self.is_element_separator = True + return + elif char == "[": + self.array_level += 1 + elif char == "]": + self.array_level -= 1 + if self.array_level == 0: + self.is_element_separator = True + return + elif char == "{": + self.object_level += 1 + elif char == "}": + self.object_level -= 1 + elif char == "\\": + self.is_escaped = not self.is_escaped + else: + self.is_escaped = False + self.is_element_separator = False + + +def iter_streamed_json_array(chunks: Iterable[str]) -> Iterable[str]: + """Convert a streamed JSON array into an iterable of JSON object strings. + + This ignores all characters before the start of the first array i.e. the first "[" + """ + iter_chars: Iterator[str] = chain.from_iterable(chunks) + parser_state = JsonArrayParserState() + + iter_chars = dropwhile(lambda x: x != "[", iter_chars) + parser_state.update(next(iter_chars)) + + item_chars: list[str] = [] + for char in iter_chars: + parser_state.update(char) + if parser_state.is_element_separator: + if item_chars: + yield "".join(item_chars).strip() + item_chars = [] + else: + item_chars.append(char) + + +async def aiter_streamed_json_array(chunks: AsyncIterable[str]) -> AsyncIterable[str]: + """Async version of `iter_streamed_json_array`.""" + + async def chars_generator() -> AsyncIterable[str]: + async for chunk in chunks: + for char in chunk: + yield char + + iter_chars = chars_generator() + parser_state = JsonArrayParserState() + + async for char in iter_chars: + if char == "[": + break + parser_state.update("[") + + item_chars: list[str] = [] + async for char in iter_chars: + parser_state.update(char) + if parser_state.is_element_separator: + if item_chars: + yield "".join(item_chars).strip() + item_chars = [] + else: + item_chars.append(char) + + +class StreamedStr(Iterable[str]): + """A string that is generated in chunks.""" + + def __init__(self, chunks: Iterable[str]): + self._chunks = chunks + self._cached_chunks: list[str] = [] + + def __iter__(self) -> Iterator[str]: + yield from self._cached_chunks + for chunk in self._chunks: + self._cached_chunks.append(chunk) + yield chunk + + def __str__(self) -> str: + return "".join(self) + + def to_string(self) -> str: + """Convert the streamed string to a string.""" + return str(self) + + +class AsyncStreamedStr(AsyncIterable[str]): + """Async version of `StreamedStr`.""" + + def __init__(self, chunks: AsyncIterable[str]): + self._chunks = chunks + self._cached_chunks: list[str] = [] + + async def __aiter__(self) -> AsyncIterator[str]: + # Cannot use `yield from` inside an async function + # https://peps.python.org/pep-0525/#asynchronous-yield-from + for chunk in self._cached_chunks: + yield chunk + async for chunk in self._chunks: + self._cached_chunks.append(chunk) + yield chunk + + async def to_string(self) -> str: + """Convert the streamed string to a string.""" + return "".join([item async for item in self]) diff --git a/src/magentic/typing.py b/src/magentic/typing.py index 4187461e..381c916d 100644 --- a/src/magentic/typing.py +++ b/src/magentic/typing.py @@ -1,3 +1,4 @@ +import inspect import types from typing import ( Any, @@ -25,6 +26,11 @@ def split_union_type(type_: TypeT) -> Sequence[TypeT]: return get_args(type_) if is_union_type(type_) else [type_] +def is_origin_abstract(type_: type) -> bool: + """Return true if the unsubscripted type is an abstract base class (ABC).""" + return inspect.isabstract(get_origin(type_) or type_) + + def is_origin_subclass( type_: type, cls_or_tuple: TypeT | tuple[TypeT, ...] ) -> TypeGuard[TypeT]: @@ -52,9 +58,11 @@ def name_type(type_: type) -> str: return f"dict_of_{name_type(key_type)}_to_{name_type(value_type)}" if name := getattr(type_, "__name__", None): + assert isinstance(name, str) + if len(args) == 1: - return f"{name}_of_{name_type(args[0])}" + return f"{name.lower()}_of_{name_type(args[0])}" - return type_.__name__.lower() + return name.lower() raise ValueError(f"Unable to name type {type_}") diff --git a/tests/test_openai_chat_model.py b/tests/test_openai_chat_model.py index be6144db..f8f1f38d 100644 --- a/tests/test_openai_chat_model.py +++ b/tests/test_openai_chat_model.py @@ -1,17 +1,22 @@ +import collections.abc import json +import typing from collections import OrderedDict -from typing import Annotated, Any +from typing import Annotated, Any, get_origin import pytest from pydantic import BaseModel, Field from magentic.chat_model.openai_chat_model import ( AnyFunctionSchema, + AsyncIterableFunctionSchema, BaseModelFunctionSchema, DictFunctionSchema, FunctionCallFunctionSchema, + IterableFunctionSchema, ) from magentic.function_call import FunctionCall +from magentic.streaming import async_iter from magentic.typing import is_origin_subclass @@ -152,6 +157,15 @@ def test_any_function_schema_parse_args(type_, args_str, expected_args): assert parsed_args == expected_args +@pytest.mark.parametrize( + ["type_", "args_str", "expected_args"], any_function_schema_args_test_cases +) +@pytest.mark.asyncio +async def test_any_function_schema_aparse_args(type_, args_str, expected_args): + parsed_args = await AnyFunctionSchema(type_).aparse_args(async_iter(args_str)) + assert parsed_args == expected_args + + @pytest.mark.parametrize( ["type_", "expected_args_str", "args"], any_function_schema_args_test_cases ) @@ -160,6 +174,165 @@ def test_any_function_schema_serialize_args(type_, expected_args_str, args): assert json.loads(serialized_args) == json.loads(expected_args_str) +@pytest.mark.parametrize( + ["type_", "json_schema"], + [ + ( + list[str], + { + "name": "return_list_of_str", + "parameters": { + "properties": { + "value": { + "title": "Value", + "type": "array", + "items": {"type": "string"}, + } + }, + "required": ["value"], + "type": "object", + }, + }, + ), + ( + typing.Iterable[str], + { + "name": "return_iterable_of_str", + "parameters": { + "properties": { + "value": { + "title": "Value", + "type": "array", + "items": {"type": "string"}, + } + }, + "required": ["value"], + "type": "object", + }, + }, + ), + ( + collections.abc.Iterable[str], + { + "name": "return_iterable_of_str", + "parameters": { + "properties": { + "value": { + "title": "Value", + "type": "array", + "items": {"type": "string"}, + } + }, + "required": ["value"], + "type": "object", + }, + }, + ), + ], +) +def test_iterable_function_schema(type_, json_schema): + function_schema = IterableFunctionSchema(type_) + assert function_schema.dict() == json_schema + + +iterable_function_schema_args_test_cases = [ + (list[str], '{"value": ["One", "Two"]}', ["One", "Two"]), + (typing.Iterable[str], '{"value": ["One", "Two"]}', ["One", "Two"]), + (collections.abc.Iterable[str], '{"value": ["One", "Two"]}', ["One", "Two"]), +] + + +@pytest.mark.parametrize( + ["type_", "args_str", "expected_args"], iterable_function_schema_args_test_cases +) +def test_iterable_function_schema_parse_args(type_, args_str, expected_args): + parsed_args = IterableFunctionSchema(type_).parse_args(args_str) + assert isinstance(parsed_args, get_origin(type_)) + assert list(parsed_args) == expected_args + + +@pytest.mark.parametrize( + ["type_", "expected_args_str", "args"], iterable_function_schema_args_test_cases +) +def test_iterable_function_schema_serialize_args(type_, expected_args_str, args): + serialized_args = IterableFunctionSchema(type_).serialize_args(args) + assert json.loads(serialized_args) == json.loads(expected_args_str) + + +@pytest.mark.parametrize( + ["type_", "json_schema"], + [ + ( + typing.AsyncIterable[str], + { + "name": "return_asynciterable_of_str", + "parameters": { + "properties": { + "value": { + "title": "Value", + "type": "array", + "items": {"type": "string"}, + } + }, + "required": ["value"], + "type": "object", + }, + }, + ), + ( + collections.abc.AsyncIterable[str], + { + "name": "return_asynciterable_of_str", + "parameters": { + "properties": { + "value": { + "title": "Value", + "type": "array", + "items": {"type": "string"}, + } + }, + "required": ["value"], + "type": "object", + }, + }, + ), + ], +) +def test_async_iterable_function_schema(type_, json_schema): + function_schema = AsyncIterableFunctionSchema(type_) + assert function_schema.dict() == json_schema + + +async_iterable_function_schema_args_test_cases = [ + (typing.AsyncIterable[str], '{"value": ["One", "Two"]}', ["One", "Two"]), + (collections.abc.AsyncIterable[str], '{"value": ["One", "Two"]}', ["One", "Two"]), +] + + +@pytest.mark.parametrize( + ["type_", "args_str", "expected_args"], + async_iterable_function_schema_args_test_cases, +) +@pytest.mark.asyncio +async def test_async_iterable_function_schema_aparse_args( + type_, args_str, expected_args +): + parsed_args = await AsyncIterableFunctionSchema(type_).aparse_args( + async_iter(args_str) + ) + assert isinstance(parsed_args, get_origin(type_)) + assert [arg async for arg in parsed_args] == expected_args + + +@pytest.mark.parametrize( + ["type_", "expected_args_str", "args"], + async_iterable_function_schema_args_test_cases, +) +def test_async_iterable_function_schema_serialize_args(type_, expected_args_str, args): + serialized_args = AsyncIterableFunctionSchema(type_).serialize_args(args) + assert json.loads(serialized_args) == json.loads(expected_args_str) + + class User(BaseModel): name: str age: int diff --git a/tests/test_prompt_function.py b/tests/test_prompt_function.py index df95e918..7b8a70e4 100644 --- a/tests/test_prompt_function.py +++ b/tests/test_prompt_function.py @@ -8,7 +8,7 @@ from magentic.chat_model.openai_chat_model import StructuredOutputError from magentic.function_call import FunctionCall from magentic.prompt_function import AsyncPromptFunction, PromptFunction, prompt -from magentic.streamed_str import AsyncStreamedStr, StreamedStr +from magentic.streaming import AsyncStreamedStr, StreamedStr @pytest.mark.openai diff --git a/tests/test_streamed_str.py b/tests/test_streaming.py similarity index 59% rename from tests/test_streamed_str.py rename to tests/test_streaming.py index 9bb61e1a..ab284bd5 100644 --- a/tests/test_streamed_str.py +++ b/tests/test_streaming.py @@ -3,7 +3,11 @@ import pytest from magentic import AsyncStreamedStr, StreamedStr -from magentic.streamed_str import async_iter +from magentic.streaming import ( + aiter_streamed_json_array, + async_iter, + iter_streamed_json_array, +) @pytest.mark.asyncio @@ -13,6 +17,27 @@ async def test_async_iter(): assert [chunk async for chunk in output] == ["Hello", " World"] +iter_streamed_json_array_test_cases = [ + (["[]"], []), + (['["He', 'llo", ', '"Wo', 'rld"]'], ['"Hello"', '"World"']), + (["[1, 2, 3]"], ["1", "2", "3"]), + (["[1, ", "2, 3]"], ["1", "2", "3"]), + (['[{"a": 1}, {2: "b"}]'], ['{"a": 1}', '{2: "b"}']), + (["{\n", '"value', '":', " [", "1, ", "2, 3", "]"], ["1", "2", "3"]), +] + + +@pytest.mark.parametrize(["input", "expected"], iter_streamed_json_array_test_cases) +def test_iter_streamed_json_array(input, expected): + assert list(iter_streamed_json_array(iter(input))) == expected + + +@pytest.mark.parametrize(["input", "expected"], iter_streamed_json_array_test_cases) +@pytest.mark.asyncio +async def test_aiter_streamed_json_array(input, expected): + assert [x async for x in aiter_streamed_json_array(async_iter(input))] == expected + + def test_streamed_str_iter(): iter_chunks = iter(["Hello", " World"]) streamed_str = StreamedStr(iter_chunks) diff --git a/tests/test_typing.py b/tests/test_typing.py index b21f83ec..33207100 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -1,9 +1,15 @@ +import typing from types import NoneType from typing import Any import pytest -from magentic.typing import is_origin_subclass, name_type, split_union_type +from magentic.typing import ( + is_origin_abstract, + is_origin_subclass, + name_type, + split_union_type, +) @pytest.mark.parametrize( @@ -18,6 +24,18 @@ def test_split_union_type(type_, expected_types): assert [t.__name__ for t in split_union_type(type_)] == expected_types +@pytest.mark.parametrize( + ["type_", "expected_result"], + [ + (str, False), + (list[str], False), + (typing.Iterable[str], True), + ], +) +def test_is_origin_abstract(type_, expected_result): + assert is_origin_abstract(type_) == expected_result + + @pytest.mark.parametrize( ["type_", "cls_or_tuple", "expected_result"], [ @@ -43,6 +61,7 @@ def test_is_origin_subclass(type_, cls_or_tuple, expected_result): (list[str | bool], "list_of_str_or_bool"), (list[str] | bool, "list_of_str_or_bool"), (dict[str, int], "dict_of_str_to_int"), + (typing.Iterable[str], "iterable_of_str"), ], ) def test_name_type(type_, expected_name):