Skip to content

Commit

Permalink
include types to instructor.distil tests
Browse files Browse the repository at this point in the history
  • Loading branch information
savarin committed Feb 6, 2024
1 parent 62f9938 commit 33d2109
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 14 deletions.
1 change: 1 addition & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ env:
instructor/distil.py
instructor/function_calls.py
tests/test_function_calls.py
tests/test_distil.py
jobs:
MyPy:
Expand Down
31 changes: 17 additions & 14 deletions tests/test_distil.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Any, Dict, Callable, Tuple, cast
import pytest
import instructor

Expand All @@ -18,27 +19,27 @@
)


class SimpleModel(BaseModel):
class SimpleModel(BaseModel): # type: ignore[misc]
data: int


def test_must_have_hint():
def test_must_have_hint() -> None:
with pytest.raises(AssertionError):

@instructions.distil
def test_func(x: int):
def test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)


def test_must_be_base_model():
def test_must_be_base_model() -> None:
with pytest.raises(AssertionError):

@instructions.distil
def test_func(x) -> int:
def test_func(x: int) -> int:
return SimpleModel(data=x)


def test_is_return_type_base_model_or_instance():
def test_is_return_type_base_model_or_instance() -> None:
def valid_function() -> SimpleModel:
return SimpleModel(data=1)

Expand All @@ -49,8 +50,8 @@ def invalid_function() -> int:
assert not is_return_type_base_model_or_instance(invalid_function)


def test_get_signature_from_fn():
def test_function(a: int, b: str) -> float:
def test_get_signature_from_fn() -> None:
def test_function(a: int, b: str) -> float: # type: ignore[empty-body]
"""Sample docstring"""
pass

Expand All @@ -60,7 +61,7 @@ def test_function(a: int, b: str) -> float:
assert "Sample docstring" in result


def test_format_function():
def test_format_function() -> None:
def sample_function(x: int) -> SimpleModel:
"""This is a docstring."""
return SimpleModel(data=x)
Expand All @@ -71,26 +72,28 @@ def sample_function(x: int) -> SimpleModel:
assert "return SimpleModel(data=x)" in formatted


def test_distil_decorator_without_arguments():
def test_distil_decorator_without_arguments() -> None:
@instructions.distil
def test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)

result = test_func(42)
casted_test_func = cast(Callable[[int], SimpleModel], test_func)
result: SimpleModel = casted_test_func(42)
assert result.data == 42


def test_distil_decorator_with_name_argument():
def test_distil_decorator_with_name_argument() -> None:
@instructions.distil(name="custom_name")
def another_test_func(x: int) -> SimpleModel:
return SimpleModel(data=x)

result = another_test_func(55)
casted_another_test_func = cast(Callable[[int], SimpleModel], another_test_func)
result: SimpleModel = casted_another_test_func(55)
assert result.data == 55


# Mock track function for decorator tests
def mock_track(*args, **kwargs):
def mock_track(*args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> None:
pass


Expand Down

0 comments on commit 33d2109

Please sign in to comment.