Skip to content

Commit

Permalink
Revert "chore: Include types to instructor.patch (#422)"
Browse files Browse the repository at this point in the history
This reverts commit f0d7889.
  • Loading branch information
jxnl authored Feb 15, 2024
1 parent 4f8f2c0 commit c5db786
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 88 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ env:
instructor/dsl/partialjson.py
instructor/dsl/validators.py
instructor/function_calls.py
instructor/patch.py
tests/test_function_calls.py
tests/test_distil.py
tests/test_patch.py
jobs:
MyPy:
Expand Down
154 changes: 78 additions & 76 deletions instructor/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,19 @@
import logging
from collections.abc import Iterable
from functools import wraps
from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError # type: ignore[import-not-found]
from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError
from json import JSONDecodeError
from typing import (
Any,
Callable,
Dict,
Generator,
get_args,
get_origin,
List,
Optional,
overload,
ParamSpec,
Protocol,
Tuple,
Type,
TypeVar,
Union,
get_args,
get_origin,
overload,
)

from openai import AsyncOpenAI, OpenAI
Expand All @@ -37,6 +34,13 @@
from .function_calls import Mode, OpenAISchema, openai_schema

logger = logging.getLogger("instructor")
T = TypeVar("T")


T_Model = TypeVar("T_Model", bound=BaseModel)
T_Retval = TypeVar("T_Retval")
T_ParamSpec = ParamSpec("T_ParamSpec")
T = TypeVar("T")


def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:
Expand All @@ -56,8 +60,8 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam:


def handle_response_model(
response_model: Type[BaseModel], mode: Mode = Mode.TOOLS, **kwargs: Any
) -> Tuple[Union[Type[OpenAISchema], ParallelBase], Dict[str, Any]]:
response_model: T, mode: Mode = Mode.TOOLS, **kwargs
) -> Union[Type[OpenAISchema], dict]:
"""Prepare the response model type hint, and returns the response_model
along with the new modified kwargs needed to be able to use the response_model
parameter with the patch function.
Expand Down Expand Up @@ -85,14 +89,15 @@ def handle_response_model(
new_kwargs["tool_choice"] = "auto"

# This is a special case for parallel models
return ParallelModel(typehint=response_model), new_kwargs
response_model = ParallelModel(typehint=response_model)
return response_model, new_kwargs

# This is for all other single model cases
if get_origin(response_model) is Iterable:
iterable_element_class = get_args(response_model)[0]
response_model = IterableModel(iterable_element_class)
if not issubclass(response_model, OpenAISchema):
response_model = openai_schema(response_model)
response_model = openai_schema(response_model) # type: ignore

if new_kwargs.get("stream", False) and not issubclass(
response_model, (IterableBase, PartialBase)
Expand All @@ -102,8 +107,8 @@ def handle_response_model(
)

if mode == Mode.FUNCTIONS:
new_kwargs["functions"] = [response_model.openai_schema]
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]}
new_kwargs["functions"] = [response_model.openai_schema] # type: ignore
new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore
elif mode == Mode.TOOLS:
new_kwargs["tools"] = [
{
Expand Down Expand Up @@ -163,14 +168,14 @@ def handle_response_model(


def process_response(
response: ChatCompletion,
response: T,
*,
response_model: Union[Type[OpenAISchema], ParallelBase, None],
response_model: Type[T_Model],
stream: bool,
validation_context: Optional[Dict[str, Any]] = None,
strict: Optional[bool] = None,
validation_context: dict = None,
strict=None,
mode: Mode = Mode.FUNCTIONS,
) -> Union[OpenAISchema, List[OpenAISchema]]:
) -> Union[T_Model, T]:
"""Processes a OpenAI response with the response model, if available.
Args:
Expand Down Expand Up @@ -208,25 +213,24 @@ def process_response(
# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
return [task for task in model.tasks] # type: ignore[attr-defined]
return [task for task in model.tasks]

if isinstance(response_model, ParallelBase):
return model

assert hasattr(model, "_raw_response")
model._raw_response = response
return model


async def process_response_async(
response: ChatCompletion,
*,
response_model: Union[Type[OpenAISchema], ParallelBase, None],
stream: bool,
validation_context: Optional[Dict[str, Any]] = None,
response_model: Type[T_Model],
stream: bool = False,
validation_context: dict = None,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]:
) -> T:
"""Processes a OpenAI response with the response model, if available.
It can use `validation_context` and `strict` to validate the response
via the pydantic model
Expand All @@ -246,44 +250,42 @@ async def process_response_async(
and issubclass(response_model, (IterableBase, PartialBase))
and stream
):
await_model = await response_model.from_streaming_response_async(
model = await response_model.from_streaming_response_async(
response,
mode=mode,
)
return await_model
return model

model = response_model.from_response(
response,
validation_context=validation_context,
strict=strict,
mode=mode,
) # type: ignore[var-annotated]
)

# ? This really hints at the fact that we need a better way of
# ? attaching usage data and the raw response to the model we return.
if isinstance(model, IterableBase):
#! If the response model is a multitask, return the tasks
assert hasattr(model, "tasks")
return [task for task in model.tasks]

if isinstance(response_model, ParallelBase):
return model

assert hasattr(model, "_raw_response")
model._raw_response = response
return model


async def retry_async( # type: ignore[return]
func: Callable[..., ChatCompletion],
response_model: Union[Type[OpenAISchema], ParallelBase],
validation_context: Optional[Dict[str, Any]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
max_retries: int,
async def retry_async(
func: Callable[T_ParamSpec, T_Retval],
response_model: Type[T],
validation_context,
args,
kwargs,
max_retries: int | AsyncRetrying = 1,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]:
) -> T:
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)

# If max_retries is int, then create a AsyncRetrying object
Expand Down Expand Up @@ -325,7 +327,7 @@ async def retry_async( # type: ignore[return]
)
except (ValidationError, JSONDecodeError) as e:
logger.debug(f"Error response: {response}")
kwargs["messages"].append(dump_message(response.choices[0].message))
kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore
if mode == Mode.TOOLS:
kwargs["messages"].append(
{
Expand Down Expand Up @@ -359,22 +361,22 @@ async def retry_async( # type: ignore[return]
raise e.last_attempt.exception from e


def retry_sync( # type: ignore[return]
func: Callable[..., ChatCompletion],
response_model: Union[Type[OpenAISchema], ParallelBase],
validation_context: Optional[Dict[str, Any]],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
max_retries: int,
def retry_sync(
func: Callable[T_ParamSpec, T_Retval],
response_model: Type[T],
validation_context: dict,
args,
kwargs,
max_retries: int | Retrying = 1,
strict: Optional[bool] = None,
mode: Mode = Mode.FUNCTIONS,
) -> Union[OpenAISchema, List[OpenAISchema]]:
):
total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0)

# If max_retries is int, then create a Retrying object
if isinstance(max_retries, int):
logger.debug(f"max_retries: {max_retries}")
max_retries = Retrying(
max_retries: Retrying = Retrying(
stop=stop_after_attempt(max_retries),
reraise=True,
)
Expand Down Expand Up @@ -442,7 +444,7 @@ def retry_sync( # type: ignore[return]
raise e.last_attempt.exception from e


def is_async(func: Callable[..., Any]) -> bool:
def is_async(func: Callable) -> bool:
"""Returns true if the callable is async, accounting for wrapped callables"""
return inspect.iscoroutinefunction(func) or (
hasattr(func, "__wrapped__") and inspect.iscoroutinefunction(func.__wrapped__)
Expand Down Expand Up @@ -472,12 +474,12 @@ def is_async(func: Callable[..., Any]) -> bool:
class InstructorChatCompletionCreate(Protocol):
def __call__(
self,
response_model: Union[Type[BaseModel], ParallelBase, None] = None,
validation_context: Optional[Dict[str, Any]] = None,
response_model: Type[T_Model] = None,
validation_context: dict = None,
max_retries: int = 1,
*args: Any,
**kwargs: Any,
) -> Type[BaseModel]:
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
...


Expand All @@ -490,7 +492,7 @@ def patch(


@overload
def patch( # type: ignore[misc]
def patch(
client: AsyncOpenAI,
mode: Mode = Mode.FUNCTIONS,
) -> AsyncOpenAI:
Expand All @@ -499,15 +501,15 @@ def patch( # type: ignore[misc]

@overload
def patch(
create: Callable[..., Any],
create: Callable[T_ParamSpec, T_Retval],
mode: Mode = Mode.FUNCTIONS,
) -> InstructorChatCompletionCreate:
...


def patch( # type: ignore[misc]
client: Union[OpenAI, AsyncOpenAI, None] = None,
create: Optional[Callable[..., Any]] = None,
def patch(
client: Union[OpenAI, AsyncOpenAI] = None,
create: Callable[T_ParamSpec, T_Retval] = None,
mode: Mode = Mode.FUNCTIONS,
) -> Union[OpenAI, AsyncOpenAI]:
"""
Expand All @@ -534,40 +536,40 @@ def patch( # type: ignore[misc]

@wraps(func)
async def new_create_async(
response_model: Type[BaseModel],
validation_context: Dict[str, Any],
response_model: Type[T_Model] = None,
validation_context: dict = None,
max_retries: int = 1,
*args: Any,
**kwargs: Any,
) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]:
new_response_model, new_kwargs = handle_response_model(
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
response_model, new_kwargs = handle_response_model(
response_model=response_model, mode=mode, **kwargs
)
response = await retry_async(
func=func,
response_model=new_response_model,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
kwargs=new_kwargs,
mode=mode,
)
) # type: ignore
return response

@wraps(func)
def new_create_sync(
response_model: Type[BaseModel],
validation_context: Dict[str, Any],
response_model: Type[T_Model] = None,
validation_context: dict = None,
max_retries: int = 1,
*args: Any,
**kwargs: Any,
) -> Union[OpenAISchema, List[OpenAISchema]]:
new_response_model, new_kwargs = handle_response_model(
*args: T_ParamSpec.args,
**kwargs: T_ParamSpec.kwargs,
) -> T_Model:
response_model, new_kwargs = handle_response_model(
response_model=response_model, mode=mode, **kwargs
)
response = retry_sync(
func=func,
response_model=new_response_model,
response_model=response_model,
validation_context=validation_context,
max_retries=max_retries,
args=args,
Expand All @@ -586,7 +588,7 @@ def new_create_sync(
return new_create


def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS) -> AsyncOpenAI:
def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS):
"""
No longer necessary, use `patch` instead.
Expand Down
Loading

0 comments on commit c5db786

Please sign in to comment.