diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index ef34ea303..d76a711d8 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -25,10 +25,8 @@ env: instructor/dsl/partialjson.py instructor/dsl/validators.py instructor/function_calls.py - instructor/patch.py tests/test_function_calls.py tests/test_distil.py - tests/test_patch.py jobs: MyPy: diff --git a/instructor/patch.py b/instructor/patch.py index 0d3f068b8..e5adde1e5 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -3,22 +3,19 @@ import logging from collections.abc import Iterable from functools import wraps -from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError # type: ignore[import-not-found] +from tenacity import Retrying, AsyncRetrying, stop_after_attempt, RetryError from json import JSONDecodeError from typing import ( - Any, Callable, - Dict, - Generator, - get_args, - get_origin, - List, Optional, - overload, + ParamSpec, Protocol, - Tuple, Type, + TypeVar, Union, + get_args, + get_origin, + overload, ) from openai import AsyncOpenAI, OpenAI @@ -37,6 +34,13 @@ from .function_calls import Mode, OpenAISchema, openai_schema logger = logging.getLogger("instructor") +T = TypeVar("T") + + +T_Model = TypeVar("T_Model", bound=BaseModel) +T_Retval = TypeVar("T_Retval") +T_ParamSpec = ParamSpec("T_ParamSpec") +T = TypeVar("T") def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: @@ -56,8 +60,8 @@ def dump_message(message: ChatCompletionMessage) -> ChatCompletionMessageParam: def handle_response_model( - response_model: Type[BaseModel], mode: Mode = Mode.TOOLS, **kwargs: Any -) -> Tuple[Union[Type[OpenAISchema], ParallelBase], Dict[str, Any]]: + response_model: T, mode: Mode = Mode.TOOLS, **kwargs +) -> Union[Type[OpenAISchema], dict]: """Prepare the response model type hint, and returns the response_model along with the new modified kwargs needed to be able to use the response_model parameter with the patch function. @@ -85,14 +89,15 @@ def handle_response_model( new_kwargs["tool_choice"] = "auto" # This is a special case for parallel models - return ParallelModel(typehint=response_model), new_kwargs + response_model = ParallelModel(typehint=response_model) + return response_model, new_kwargs # This is for all other single model cases if get_origin(response_model) is Iterable: iterable_element_class = get_args(response_model)[0] response_model = IterableModel(iterable_element_class) if not issubclass(response_model, OpenAISchema): - response_model = openai_schema(response_model) + response_model = openai_schema(response_model) # type: ignore if new_kwargs.get("stream", False) and not issubclass( response_model, (IterableBase, PartialBase) @@ -102,8 +107,8 @@ def handle_response_model( ) if mode == Mode.FUNCTIONS: - new_kwargs["functions"] = [response_model.openai_schema] - new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} + new_kwargs["functions"] = [response_model.openai_schema] # type: ignore + new_kwargs["function_call"] = {"name": response_model.openai_schema["name"]} # type: ignore elif mode == Mode.TOOLS: new_kwargs["tools"] = [ { @@ -163,14 +168,14 @@ def handle_response_model( def process_response( - response: ChatCompletion, + response: T, *, - response_model: Union[Type[OpenAISchema], ParallelBase, None], + response_model: Type[T_Model], stream: bool, - validation_context: Optional[Dict[str, Any]] = None, - strict: Optional[bool] = None, + validation_context: dict = None, + strict=None, mode: Mode = Mode.FUNCTIONS, -) -> Union[OpenAISchema, List[OpenAISchema]]: +) -> Union[T_Model, T]: """Processes a OpenAI response with the response model, if available. Args: @@ -208,12 +213,11 @@ def process_response( # ? This really hints at the fact that we need a better way of # ? attaching usage data and the raw response to the model we return. if isinstance(model, IterableBase): - return [task for task in model.tasks] # type: ignore[attr-defined] + return [task for task in model.tasks] if isinstance(response_model, ParallelBase): return model - assert hasattr(model, "_raw_response") model._raw_response = response return model @@ -221,12 +225,12 @@ def process_response( async def process_response_async( response: ChatCompletion, *, - response_model: Union[Type[OpenAISchema], ParallelBase, None], - stream: bool, - validation_context: Optional[Dict[str, Any]] = None, + response_model: Type[T_Model], + stream: bool = False, + validation_context: dict = None, strict: Optional[bool] = None, mode: Mode = Mode.FUNCTIONS, -) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]: +) -> T: """Processes a OpenAI response with the response model, if available. It can use `validation_context` and `strict` to validate the response via the pydantic model @@ -246,44 +250,42 @@ async def process_response_async( and issubclass(response_model, (IterableBase, PartialBase)) and stream ): - await_model = await response_model.from_streaming_response_async( + model = await response_model.from_streaming_response_async( response, mode=mode, ) - return await_model + return model model = response_model.from_response( response, validation_context=validation_context, strict=strict, mode=mode, - ) # type: ignore[var-annotated] + ) # ? This really hints at the fact that we need a better way of # ? attaching usage data and the raw response to the model we return. if isinstance(model, IterableBase): #! If the response model is a multitask, return the tasks - assert hasattr(model, "tasks") return [task for task in model.tasks] if isinstance(response_model, ParallelBase): return model - assert hasattr(model, "_raw_response") model._raw_response = response return model -async def retry_async( # type: ignore[return] - func: Callable[..., ChatCompletion], - response_model: Union[Type[OpenAISchema], ParallelBase], - validation_context: Optional[Dict[str, Any]], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - max_retries: int, +async def retry_async( + func: Callable[T_ParamSpec, T_Retval], + response_model: Type[T], + validation_context, + args, + kwargs, + max_retries: int | AsyncRetrying = 1, strict: Optional[bool] = None, mode: Mode = Mode.FUNCTIONS, -) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]: +) -> T: total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) # If max_retries is int, then create a AsyncRetrying object @@ -325,7 +327,7 @@ async def retry_async( # type: ignore[return] ) except (ValidationError, JSONDecodeError) as e: logger.debug(f"Error response: {response}") - kwargs["messages"].append(dump_message(response.choices[0].message)) + kwargs["messages"].append(dump_message(response.choices[0].message)) # type: ignore if mode == Mode.TOOLS: kwargs["messages"].append( { @@ -359,22 +361,22 @@ async def retry_async( # type: ignore[return] raise e.last_attempt.exception from e -def retry_sync( # type: ignore[return] - func: Callable[..., ChatCompletion], - response_model: Union[Type[OpenAISchema], ParallelBase], - validation_context: Optional[Dict[str, Any]], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - max_retries: int, +def retry_sync( + func: Callable[T_ParamSpec, T_Retval], + response_model: Type[T], + validation_context: dict, + args, + kwargs, + max_retries: int | Retrying = 1, strict: Optional[bool] = None, mode: Mode = Mode.FUNCTIONS, -) -> Union[OpenAISchema, List[OpenAISchema]]: +): total_usage = CompletionUsage(completion_tokens=0, prompt_tokens=0, total_tokens=0) # If max_retries is int, then create a Retrying object if isinstance(max_retries, int): logger.debug(f"max_retries: {max_retries}") - max_retries = Retrying( + max_retries: Retrying = Retrying( stop=stop_after_attempt(max_retries), reraise=True, ) @@ -442,7 +444,7 @@ def retry_sync( # type: ignore[return] raise e.last_attempt.exception from e -def is_async(func: Callable[..., Any]) -> bool: +def is_async(func: Callable) -> bool: """Returns true if the callable is async, accounting for wrapped callables""" return inspect.iscoroutinefunction(func) or ( hasattr(func, "__wrapped__") and inspect.iscoroutinefunction(func.__wrapped__) @@ -472,12 +474,12 @@ def is_async(func: Callable[..., Any]) -> bool: class InstructorChatCompletionCreate(Protocol): def __call__( self, - response_model: Union[Type[BaseModel], ParallelBase, None] = None, - validation_context: Optional[Dict[str, Any]] = None, + response_model: Type[T_Model] = None, + validation_context: dict = None, max_retries: int = 1, - *args: Any, - **kwargs: Any, - ) -> Type[BaseModel]: + *args: T_ParamSpec.args, + **kwargs: T_ParamSpec.kwargs, + ) -> T_Model: ... @@ -490,7 +492,7 @@ def patch( @overload -def patch( # type: ignore[misc] +def patch( client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS, ) -> AsyncOpenAI: @@ -499,15 +501,15 @@ def patch( # type: ignore[misc] @overload def patch( - create: Callable[..., Any], + create: Callable[T_ParamSpec, T_Retval], mode: Mode = Mode.FUNCTIONS, ) -> InstructorChatCompletionCreate: ... -def patch( # type: ignore[misc] - client: Union[OpenAI, AsyncOpenAI, None] = None, - create: Optional[Callable[..., Any]] = None, +def patch( + client: Union[OpenAI, AsyncOpenAI] = None, + create: Callable[T_ParamSpec, T_Retval] = None, mode: Mode = Mode.FUNCTIONS, ) -> Union[OpenAI, AsyncOpenAI]: """ @@ -534,40 +536,40 @@ def patch( # type: ignore[misc] @wraps(func) async def new_create_async( - response_model: Type[BaseModel], - validation_context: Dict[str, Any], + response_model: Type[T_Model] = None, + validation_context: dict = None, max_retries: int = 1, - *args: Any, - **kwargs: Any, - ) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]: - new_response_model, new_kwargs = handle_response_model( + *args: T_ParamSpec.args, + **kwargs: T_ParamSpec.kwargs, + ) -> T_Model: + response_model, new_kwargs = handle_response_model( response_model=response_model, mode=mode, **kwargs ) response = await retry_async( func=func, - response_model=new_response_model, + response_model=response_model, validation_context=validation_context, max_retries=max_retries, args=args, kwargs=new_kwargs, mode=mode, - ) + ) # type: ignore return response @wraps(func) def new_create_sync( - response_model: Type[BaseModel], - validation_context: Dict[str, Any], + response_model: Type[T_Model] = None, + validation_context: dict = None, max_retries: int = 1, - *args: Any, - **kwargs: Any, - ) -> Union[OpenAISchema, List[OpenAISchema]]: - new_response_model, new_kwargs = handle_response_model( + *args: T_ParamSpec.args, + **kwargs: T_ParamSpec.kwargs, + ) -> T_Model: + response_model, new_kwargs = handle_response_model( response_model=response_model, mode=mode, **kwargs ) response = retry_sync( func=func, - response_model=new_response_model, + response_model=response_model, validation_context=validation_context, max_retries=max_retries, args=args, @@ -586,7 +588,7 @@ def new_create_sync( return new_create -def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS) -> AsyncOpenAI: +def apatch(client: AsyncOpenAI, mode: Mode = Mode.FUNCTIONS): """ No longer necessary, use `patch` instead. diff --git a/tests/test_patch.py b/tests/test_patch.py index ac28e0f80..0418a1e14 100644 --- a/tests/test_patch.py +++ b/tests/test_patch.py @@ -6,40 +6,40 @@ from instructor.patch import OVERRIDE_DOCS, is_async -def test_patch_completes_successfully() -> None: +def test_patch_completes_successfully(): instructor.patch(OpenAI()) -def test_apatch_completes_successfully() -> None: +def test_apatch_completes_successfully(): instructor.apatch(AsyncOpenAI()) -def test_is_async_returns_true_if_function_is_async() -> None: - async def async_function() -> None: +def test_is_async_returns_true_if_function_is_async(): + async def async_function(): pass assert is_async(async_function) is True -def test_is_async_returns_false_if_function_is_not_async() -> None: - def sync_function() -> None: +def test_is_async_returns_false_if_function_is_not_async(): + def sync_function(): pass assert is_async(sync_function) is False -def test_is_async_returns_true_if_wrapped_function_is_async() -> None: - async def async_function() -> None: +def test_is_async_returns_true_if_wrapped_function_is_async(): + async def async_function(): pass @functools.wraps(async_function) - def wrapped_function() -> None: + def wrapped_function(): pass assert is_async(wrapped_function) is True -def test_override_docs() -> None: +def test_override_docs(): assert ( "response_model" in OVERRIDE_DOCS ), "response_model should be in OVERRIDE_DOCS"