From 9cfd5eb3d541e948328784c256db0b47956de620 Mon Sep 17 00:00:00 2001 From: Ezzeri Esa Date: Fri, 9 Feb 2024 08:39:42 -0800 Subject: [PATCH] narrow type to OpenAISchema --- instructor/patch.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/instructor/patch.py b/instructor/patch.py index 8c4056945..e4db2532a 100644 --- a/instructor/patch.py +++ b/instructor/patch.py @@ -8,7 +8,6 @@ from typing import ( Any, Callable, - Coroutine, Dict, Generator, get_args, @@ -277,7 +276,7 @@ async def process_response_async( async def retry_async( # type: ignore[return] - func: Callable[..., Coroutine[Any, Any, ChatCompletion]], + func: Callable[..., ChatCompletion], response_model: Union[Type[OpenAISchema], ParallelBase], validation_context: Optional[Dict[str, Any]], args: Tuple[Any, ...], @@ -541,7 +540,7 @@ async def new_create_async( max_retries: int = 1, *args: Any, **kwargs: Any, - ) -> Union[BaseModel, List[BaseModel]]: + ) -> Union[OpenAISchema, List[OpenAISchema], Dict[str, Any], Generator[Any, None, None]]: new_response_model, new_kwargs = handle_response_model( response_model=response_model, mode=mode, **kwargs ) @@ -563,7 +562,7 @@ def new_create_sync( max_retries: int = 1, *args: Any, **kwargs: Any, - ) -> Union[BaseModel, List[BaseModel]]: + ) -> Union[OpenAISchema, List[OpenAISchema]]: new_response_model, new_kwargs = handle_response_model( response_model=response_model, mode=mode, **kwargs )