Skip to content

Commit

Permalink
chore: Include types to instructor.distil and tests (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
savarin authored Feb 8, 2024
1 parent f65ba6b commit d632682
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 39 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ env:
instructor/cli/jobs.py
instructor/cli/usage.py
instructor/exceptions.py
instructor/distil.py
instructor/function_calls.py
tests/test_function_calls.py
tests/test_distil.py
jobs:
MyPy:
Expand Down
62 changes: 37 additions & 25 deletions instructor/distil.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,22 @@
import inspect
import functools

from typing import Any, Callable, List, Optional
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar
from pydantic import BaseModel, validate_call

from openai import OpenAI
from instructor.function_calls import openai_schema


T_Retval = TypeVar("T_Retval")


class FinetuneFormat(enum.Enum):
MESSAGES: str = "messages"
RAW: str = "raw"


def get_signature_from_fn(fn: Callable) -> str:
def get_signature_from_fn(fn: Callable[..., Any]) -> str:
"""
Get the function signature as a string.
Expand All @@ -43,7 +46,7 @@ def get_signature_from_fn(fn: Callable) -> str:


@functools.lru_cache()
def format_function(func: Callable) -> str:
def format_function(func: Callable[..., Any]) -> str:
"""
Format a function as a string with docstring and body.
"""
Expand Down Expand Up @@ -79,14 +82,14 @@ def is_return_type_base_model_or_instance(func: Callable[..., Any]) -> bool:
class Instructions:
def __init__(
self,
name: str = None,
id: str = None,
log_handlers: List[logging.Handler] = None,
name: Optional[str] = None,
id: Optional[str] = None,
log_handlers: Optional[List[logging.Handler]] = None,
finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
indent: int = 2,
include_code_body: bool = False,
openai_client: OpenAI = None,
):
openai_client: Optional[OpenAI] = None,
) -> None:
"""
Instructions for distillation and dispatch.
Expand All @@ -111,12 +114,15 @@ def __init__(

def distil(
self,
*args,
name: str = None,
*args: Any,
name: Optional[str] = None,
mode: str = "distil",
model: str = "gpt-3.5-turbo",
fine_tune_format: FinetuneFormat = None,
):
fine_tune_format: Optional[FinetuneFormat] = None,
) -> Callable[
[Callable[..., Any]],
Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]],
]:
"""
Decorator to track the function call and response, supports distillation and dispatch modes.
Expand All @@ -142,13 +148,16 @@ def distil(
if fine_tune_format is None:
fine_tune_format = self.finetune_format

def _wrap_distil(fn):
def _wrap_distil(
fn: Callable[..., Any],
) -> Callable[[Callable[..., T_Retval]], Callable[..., T_Retval]]:
msg = f"Return type hint for {fn} must subclass `pydantic.BaseModel'"
assert is_return_type_base_model_or_instance(fn), msg
return_base_model = inspect.signature(fn).return_annotation

@functools.wraps(fn)
def _dispatch(*args, **kwargs):
def _dispatch(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]:
name = name if name else fn.__name__
openai_kwargs = self.openai_kwargs(
name=name,
fn=fn,
Expand All @@ -161,35 +170,31 @@ def _dispatch(*args, **kwargs):
)

@functools.wraps(fn)
def _distil(*args, **kwargs):
def _distil(*args: Any, **kwargs: Any) -> Callable[..., T_Retval]:
resp = fn(*args, **kwargs)
self.track(
fn, args, kwargs, resp, name=name, finetune_format=fine_tune_format
)

return resp

if mode == "dispatch":
return _dispatch

if mode == "distil":
return _distil
return _dispatch if mode == "dispatch" else _distil

if len(args) == 1 and callable(args[0]):
return _wrap_distil(args[0])

return _wrap_distil

@validate_call
@validate_call # type: ignore[misc]
def track(
self,
fn: Callable[..., Any],
args: tuple,
kwargs: dict,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
resp: BaseModel,
name: Optional[str] = None,
finetune_format: FinetuneFormat = FinetuneFormat.MESSAGES,
):
) -> None:
"""
Track the function call and response in a log file, later used for finetuning.
Expand Down Expand Up @@ -229,7 +234,14 @@ def track(
)
self.logger.info(json.dumps(function_body))

def openai_kwargs(self, name, fn, args, kwargs, base_model):
def openai_kwargs(
self,
name: str,
fn: Callable[..., Any],
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
base_model: Type[BaseModel],
) -> Dict[str, Any]:
if self.include_code_body:
func_def = format_function(fn)
else:
Expand Down
31 changes: 17 additions & 14 deletions tests/test_distil.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Dict, Callable, Tuple, cast
import pytest
import instructor

Expand All @@ -18,27 +19,27 @@
)


class SimpleModel(BaseModel):
class SimpleModel(BaseModel): # type: ignore[misc]
data: int


def test_must_have_hint():
def test_must_have_hint() -> None:
with pytest.raises(AssertionError):

@instructions.distil
def test_func(x: int):
def test_func(x: int): # type: ignore[no-untyped-def]
return SimpleModel(data=x)


def test_must_be_base_model():
def test_must_be_base_model() -> None:
with pytest.raises(AssertionError):

@instructions.distil
def test_func(x) -> int:
def test_func(x: int) -> int:
return SimpleModel(data=x)


def test_is_return_type_base_model_or_instance():
def test_is_return_type_base_model_or_instance() -> None:
def valid_function() -> SimpleModel:
return SimpleModel(data=1)

Expand All @@ -49,8 +50,8 @@ def invalid_function() -> int:
assert not is_return_type_base_model_or_instance(invalid_function)


def test_get_signature_from_fn():
def test_function(a: int, b: str) -> float:
def test_get_signature_from_fn() -> None:
def test_function(a: int, b: str) -> float: # type: ignore[empty-body]
"""Sample docstring"""
pass

Expand All @@ -60,7 +61,7 @@ def test_function(a: int, b: str) -> float:
assert "Sample docstring" in result


def test_format_function():
def test_format_function() -> None:
def sample_function(x: int) -> SimpleModel:
"""This is a docstring."""
return SimpleModel(data=x)
Expand All @@ -71,26 +72,28 @@ def sample_function(x: int) -> SimpleModel:
assert "return SimpleModel(data=x)" in formatted


def test_distil_decorator_without_arguments():
def test_distil_decorator_without_arguments() -> None:
@instructions.distil
def test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)

result = test_func(42)
casted_test_func = cast(Callable[[int], SimpleModel], test_func)
result: SimpleModel = casted_test_func(42)
assert result.data == 42


def test_distil_decorator_with_name_argument():
def test_distil_decorator_with_name_argument() -> None:
@instructions.distil(name="custom_name")
def another_test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)

result = another_test_func(55)
casted_another_test_func = cast(Callable[[int], SimpleModel], another_test_func)
result: SimpleModel = casted_another_test_func(55)
assert result.data == 55


# Mock track function for decorator tests
def mock_track(*args, **kwargs):
def mock_track(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> None:
pass


Expand Down

0 comments on commit d632682

Please sign in to comment.