diff --git a/.github/workflows/mypy.yml b/.github/workflows/mypy.yml index 4986b7bfe..26897e439 100644 --- a/.github/workflows/mypy.yml +++ b/.github/workflows/mypy.yml @@ -16,8 +16,10 @@ env: instructor/cli/jobs.py instructor/cli/usage.py instructor/exceptions.py + instructor/distil.py instructor/function_calls.py tests/test_function_calls.py + tests/test_distil.py jobs: MyPy: diff --git a/instructor/distil.py b/instructor/distil.py index d6cb11843..d7fb22aa2 100644 --- a/instructor/distil.py +++ b/instructor/distil.py @@ -5,19 +5,22 @@ import inspect import functools -from typing import Any, Callable, List, Optional +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar from pydantic import BaseModel, validate_call from openai import OpenAI from instructor.function_calls import openai_schema +T_Retval = TypeVar("T_Retval") + + class FinetuneFormat(enum.Enum): MESSAGES: str = "messages" RAW: str = "raw" -def get_signature_from_fn(fn: Callable) -> str: +def get_signature_from_fn(fn: Callable[..., Any]) -> str: """ Get the function signature as a string. @@ -43,7 +46,7 @@ def get_signature_from_fn(fn: Callable) -> str: @functools.lru_cache() -def format_function(func: Callable) -> str: +def format_function(func: Callable[..., Any]) -> str: """ Format a function as a string with docstring and body. """ @@ -79,14 +82,14 @@ def is_return_type_base_model_or_instance(func: Callable[..., Any]) -> bool: class Instructions: def __init__( self, - name: str = None, - id: str = None, - log_handlers: List[logging.Handler] = None, + name: Optional[str] = None, + id: Optional[str] = None, + log_handlers: Optional[List[logging.Handler]] = None, finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES, indent: int = 2, include_code_body: bool = False, - openai_client: OpenAI = None, - ): + openai_client: Optional[OpenAI] = None, + ) -> None: """ Instructions for distillation and dispatch. @@ -111,12 +114,15 @@ def __init__( def distil( self, - *args, - name: str = None, + *args: Any, + name: Optional[str] = None, mode: str = "distil", model: str = "gpt-3.5-turbo", - fine_tune_format: FinetuneFormat = None, - ): + fine_tune_format: Optional[FinetuneFormat] = None, + ) -> Callable[ + [Callable[..., Any]], + Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]], + ]: """ Decorator to track the function call and response, supports distillation and dispatch modes. @@ -142,13 +148,16 @@ def distil( if fine_tune_format is None: fine_tune_format = self.finetune_format - def _wrap_distil(fn): + def _wrap_distil( + fn: Callable[..., Any], + ) -> Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]]: msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'" assert is_return_type_base_model_or_instance(fn), msg return_base_model = inspect.signature(fn).return_annotation @functools.wraps(fn) - def _dispatch(*args, **kwargs): + def _dispatch(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]: + name = name if name else fn.__name__ openai_kwargs = self.openai_kwargs( name=name, fn=fn, @@ -161,7 +170,7 @@ def _dispatch(*args, **kwargs): ) @functools.wraps(fn) - def _distil(*args, **kwargs): + def _distil(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]: resp = fn(*args, **kwargs) self.track( fn, args, kwargs, resp, name=name, finetune_format=fine_tune_format @@ -169,27 +178,23 @@ def _distil(*args, **kwargs): return resp - if mode == "dispatch": - return _dispatch - - if mode == "distil": - return _distil + return _dispatch if mode == "dispatch" else _distil if len(args) == 1 and callable(args[0]): return _wrap_distil(args[0]) return _wrap_distil - @validate_call + @validate_call # type: ignore[misc] def track( self, fn: Callable[..., Any], - args: tuple, - kwargs: dict, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], resp: BaseModel, name: Optional[str] = None, finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES, - ): + ) -> None: """ Track the function call and response in a log file, later used for finetuning. @@ -229,7 +234,14 @@ def track( ) self.logger.info(json.dumps(function_body)) - def openai_kwargs(self, name, fn, args, kwargs, base_model): + def openai_kwargs( + self, + name: str, + fn: Callable[..., Any], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + base_model: Type[BaseModel], + ) -> Dict[str, Any]: if self.include_code_body: func_def = format_function(fn) else: diff --git a/tests/test_distil.py b/tests/test_distil.py index e2156fc8f..03eda3fec 100644 --- a/tests/test_distil.py +++ b/tests/test_distil.py @@ -1,3 +1,4 @@ +from typing import Any, Dict, Callable, Tuple, cast import pytest import instructor @@ -18,27 +19,27 @@ ) -class SimpleModel(BaseModel): +class SimpleModel(BaseModel): # type: ignore[misc] data: int -def test_must_have_hint(): +def test_must_have_hint() -> None: with pytest.raises(AssertionError): @instructions.distil - def test_func(x: int): + def test_func(x: int): # type: ignore[no-untyped-def] return SimpleModel(data=x) -def test_must_be_base_model(): +def test_must_be_base_model() -> None: with pytest.raises(AssertionError): @instructions.distil - def test_func(x) -> int: + def test_func(x: int) -> int: return SimpleModel(data=x) -def test_is_return_type_base_model_or_instance(): +def test_is_return_type_base_model_or_instance() -> None: def valid_function() -> SimpleModel: return SimpleModel(data=1) @@ -49,8 +50,8 @@ def invalid_function() -> int: assert not is_return_type_base_model_or_instance(invalid_function) -def test_get_signature_from_fn(): - def test_function(a: int, b: str) -> float: +def test_get_signature_from_fn() -> None: + def test_function(a: int, b: str) -> float: # type: ignore[empty-body] """Sample docstring""" pass @@ -60,7 +61,7 @@ def test_function(a: int, b: str) -> float: assert "Sample docstring" in result -def test_format_function(): +def test_format_function() -> None: def sample_function(x: int) -> SimpleModel: """This is a docstring.""" return SimpleModel(data=x) @@ -71,26 +72,28 @@ def sample_function(x: int) -> SimpleModel: assert "return SimpleModel(data=x)" in formatted -def test_distil_decorator_without_arguments(): +def test_distil_decorator_without_arguments() -> None: @instructions.distil def test_func(x: int) -> SimpleModel: return SimpleModel(data=x) - result = test_func(42) + casted_test_func = cast(Callable[[int], SimpleModel], test_func) + result: SimpleModel = casted_test_func(42) assert result.data == 42 -def test_distil_decorator_with_name_argument(): +def test_distil_decorator_with_name_argument() -> None: @instructions.distil(name="custom_name") def another_test_func(x: int) -> SimpleModel: return SimpleModel(data=x) - result = another_test_func(55) + casted_another_test_func = cast(Callable[[int], SimpleModel], another_test_func) + result: SimpleModel = casted_another_test_func(55) assert result.data == 55 # Mock track function for decorator tests -def mock_track(*args, **kwargs): +def mock_track(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> None: pass