Skip to content

Commit

Permalink
Use pydantic models for OpenAI inputs/outputs (#9)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmpcollins authored Sep 10, 2023
1 parent 3d517fe commit c056d17
Showing 1 changed file with 144 additions and 58 deletions.
202 changes: 144 additions & 58 deletions src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, AsyncIterator, Callable, Iterable, Iterator
from enum import Enum
from typing import Any, Generic, TypeVar, cast, get_args, get_origin
from typing import Any, Generic, Literal, TypeVar, cast, get_args, get_origin

import openai
from pydantic import BaseModel, TypeAdapter, ValidationError, create_model
Expand Down Expand Up @@ -299,43 +299,137 @@ class OpenaiMessageRole(Enum):
USER = "user"


def message_to_openai_message(message: Message[Any]) -> dict[str, Any]:
class OpenaiChatCompletionFunctionCall(BaseModel):
name: str | None = None
arguments: str

def get_name_or_raise(self) -> str:
"""Return the name, raising an error if it doesn't exist."""
assert self.name is not None
return self.name


class OpenaiChatCompletionDelta(BaseModel):
role: OpenaiMessageRole | None = None
content: str | None = None
function_call: OpenaiChatCompletionFunctionCall | None = None


class OpenaiChatCompletionChunkChoice(BaseModel):
delta: OpenaiChatCompletionDelta


class OpenaiChatCompletionChunk(BaseModel):
choices: list[OpenaiChatCompletionChunkChoice]


class OpenaiChatCompletionChoiceMessage(BaseModel):
role: OpenaiMessageRole
name: str | None = None
content: str | None
function_call: OpenaiChatCompletionFunctionCall | None = None


class OpenaiChatCompletionChoice(BaseModel):
message: OpenaiChatCompletionDelta


class OpenaiChatCompletion(BaseModel):
choices: list[OpenaiChatCompletionChoice]


def message_to_openai_message(
message: Message[Any],
) -> OpenaiChatCompletionChoiceMessage:
"""Convert a `Message` to an OpenAI message dict."""
if isinstance(message, UserMessage):
return {"role": OpenaiMessageRole.USER.value, "content": message.content}
return OpenaiChatCompletionChoiceMessage(
role=OpenaiMessageRole.USER, content=message.content
)

if isinstance(message, AssistantMessage):
if isinstance(message.content, str):
return {
"role": OpenaiMessageRole.ASSISTANT.value,
"content": message.content,
}
return OpenaiChatCompletionChoiceMessage(
role=OpenaiMessageRole.ASSISTANT, content=message.content
)

function_schema: BaseFunctionSchema[Any]
if isinstance(message.content, FunctionCall):
function_schema = FunctionCallFunctionSchema(message.content.function)
else:
function_schema = function_schema_for_type(type(message.content))

return {
"role": OpenaiMessageRole.ASSISTANT.value,
"content": None,
"function_call": {
"name": function_schema.name,
"arguments": function_schema.serialize_args(message.content),
},
}
return OpenaiChatCompletionChoiceMessage(
role=OpenaiMessageRole.ASSISTANT,
content=None,
function_call=OpenaiChatCompletionFunctionCall(
name=function_schema.name,
arguments=function_schema.serialize_args(message.content),
),
)

if isinstance(message, FunctionResultMessage):
return {
"role": OpenaiMessageRole.FUNCTION.value,
"name": FunctionCallFunctionSchema(message.function_call.function).name,
"content": json.dumps(message.content),
}
return OpenaiChatCompletionChoiceMessage(
role=OpenaiMessageRole.FUNCTION,
name=FunctionCallFunctionSchema(message.function_call.function).name,
content=json.dumps(message.content),
)

raise NotImplementedError(type(message))


def openai_chatcompletion_create(
model: str,
messages: Iterable[OpenaiChatCompletionChoiceMessage],
temperature: float | None = None,
functions: list[dict[str, Any]] | None = None,
function_call: Literal["auto", "none"] | dict[str, Any] | None = None,
) -> Iterator[OpenaiChatCompletionChunk]:
"""Type-annotated version of `openai.ChatCompletion.create`."""
# `openai.ChatCompletion.create` doesn't accept `None`
# so only pass function args if there are functions
kwargs: dict[str, Any] = {}
if functions:
kwargs["functions"] = functions
if function_call:
kwargs["function_call"] = function_call

response: Iterator[dict[str, Any]] = openai.ChatCompletion.create( # type: ignore[no-untyped-call]
model=model,
messages=[m.model_dump(mode="json", exclude_unset=True) for m in messages],
temperature=temperature,
stream=True,
**kwargs,
)
return (OpenaiChatCompletionChunk.model_validate(chunk) for chunk in response)


async def openai_chatcompletion_acreate(
model: str,
messages: Iterable[OpenaiChatCompletionChoiceMessage],
temperature: float | None = None,
functions: list[dict[str, Any]] | None = None,
function_call: Literal["auto", "none"] | dict[str, Any] | None = None,
) -> AsyncIterator[OpenaiChatCompletionChunk]:
"""Type-annotated version of `openai.ChatCompletion.acreate`."""
# `openai.ChatCompletion.create` doesn't accept `None`
# so only pass function args if there are functions
kwargs: dict[str, Any] = {}
if functions:
kwargs["functions"] = functions
if function_call:
kwargs["function_call"] = function_call

response: AsyncIterator[dict[str, Any]] = await openai.ChatCompletion.acreate( # type: ignore[no-untyped-call]
model=model,
messages=[m.model_dump(mode="json", exclude_unset=True) for m in messages],
temperature=temperature,
stream=True,
**kwargs,
)
return (OpenaiChatCompletionChunk.model_validate(chunk) async for chunk in response)


R = TypeVar("R")
FuncR = TypeVar("FuncR")

Expand Down Expand Up @@ -367,38 +461,34 @@ def complete(
)
allow_string_output = str_in_output_types or streamed_str_in_output_types

# `openai.ChatCompletion.create` doesn't accept `None`
# so only pass function args if there are functions
function_args: dict[str, Any] = {}
if function_schemas:
function_args["functions"] = [schema.dict() for schema in function_schemas]
if len(function_schemas) == 1 and not allow_string_output:
# Force the model to call the function
function_args["function_call"] = {"name": function_schemas[0].name}

response: Iterator[dict[str, Any]] = openai.ChatCompletion.create( # type: ignore[no-untyped-call]
openai_functions = [schema.dict() for schema in function_schemas]
response = openai_chatcompletion_create(
model=self._model,
messages=[message_to_openai_message(m) for m in messages],
temperature=self._temperature,
**function_args,
stream=True,
functions=openai_functions,
function_call=(
{"name": openai_functions[0]["name"]}
if len(openai_functions) == 1 and not allow_string_output
else None
),
)

first_chunk = next(response)
first_chunk_delta = first_chunk["choices"][0]["delta"]
first_chunk_delta = first_chunk.choices[0].delta

if first_chunk_delta.get("function_call"):
if first_chunk_delta.function_call:
function_schema_by_name = {
function_schema.name: function_schema
for function_schema in function_schemas
}
function_name = first_chunk_delta["function_call"]["name"]
function_name = first_chunk_delta.function_call.get_name_or_raise()
function_schema = function_schema_by_name[function_name]
try:
message = function_schema.parse_args_to_message(
chunk["choices"][0]["delta"]["function_call"]["arguments"]
chunk.choices[0].delta.function_call.arguments
for chunk in response
if chunk["choices"][0]["delta"]
if chunk.choices[0].delta.function_call
)
except ValidationError as e:
raise StructuredOutputError(
Expand All @@ -413,9 +503,9 @@ def complete(
" your prompt to encourage the model to return a specific type."
)
streamed_str = StreamedStr(
chunk["choices"][0]["delta"]["content"]
chunk.choices[0].delta.content
for chunk in response
if chunk["choices"][0]["delta"]
if chunk.choices[0].delta.content is not None
)
if streamed_str_in_output_types:
return cast(AssistantMessage[R], AssistantMessage(streamed_str))
Expand Down Expand Up @@ -444,38 +534,34 @@ async def acomplete(
)
allow_string_output = str_in_output_types or async_streamed_str_in_output_types

# `openai.ChatCompletion.acreate` doesn't accept `None`
# so only pass function args if there are functions
function_args: dict[str, Any] = {}
if function_schemas:
function_args["functions"] = [schema.dict() for schema in function_schemas]
if len(function_schemas) == 1 and not allow_string_output:
# Force the model to call the function
function_args["function_call"] = {"name": function_schemas[0].name}

response: AsyncIterator[dict[str, Any]] = await openai.ChatCompletion.acreate( # type: ignore[no-untyped-call]
openai_functions = [schema.dict() for schema in function_schemas]
response = await openai_chatcompletion_acreate(
model=self._model,
messages=[message_to_openai_message(m) for m in messages],
temperature=self._temperature,
**function_args,
stream=True,
functions=openai_functions,
function_call=(
{"name": openai_functions[0]["name"]}
if len(openai_functions) == 1 and not allow_string_output
else None
),
)

first_chunk = await anext(response)
first_chunk_delta = first_chunk["choices"][0]["delta"]
first_chunk_delta = first_chunk.choices[0].delta

if first_chunk_delta.get("function_call"):
if first_chunk_delta.function_call:
function_schema_by_name = {
function_schema.name: function_schema
for function_schema in function_schemas
}
function_name = first_chunk_delta["function_call"]["name"]
function_name = first_chunk_delta.function_call.get_name_or_raise()
function_schema = function_schema_by_name[function_name]
try:
message = await function_schema.aparse_args_to_message(
chunk["choices"][0]["delta"]["function_call"]["arguments"]
chunk.choices[0].delta.function_call.arguments
async for chunk in response
if chunk["choices"][0]["delta"]
if chunk.choices[0].delta.function_call
)
except ValidationError as e:
raise StructuredOutputError(
Expand All @@ -490,9 +576,9 @@ async def acomplete(
" your prompt to encourage the model to return a specific type."
)
async_streamed_str = AsyncStreamedStr(
chunk["choices"][0]["delta"]["content"]
chunk.choices[0].delta.content
async for chunk in response
if chunk["choices"][0]["delta"]
if chunk.choices[0].delta.content is not None
)
if async_streamed_str_in_output_types:
return cast(AssistantMessage[R], AssistantMessage(async_streamed_str))
Expand Down

0 comments on commit c056d17

Please sign in to comment.