Skip to content

Commit

Permalink
chore: include types to instructor.dsl (#419)
Browse files Browse the repository at this point in the history
  • Loading branch information
savarin authored Feb 9, 2024
1 parent 805161b commit 25f8214
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 69 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/mypy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ env:
instructor/cli/usage.py
instructor/exceptions.py
instructor/distil.py
instructor/dsl/citation.py
instructor/dsl/iterable.py
instructor/dsl/maybe.py
instructor/dsl/parallel.py
instructor/dsl/partial.py
instructor/dsl/partialjson.py
instructor/dsl/validators.py
instructor/function_calls.py
tests/test_function_calls.py
tests/test_distil.py
Expand Down
14 changes: 8 additions & 6 deletions instructor/dsl/citation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from pydantic import BaseModel, Field, FieldValidationInfo, model_validator
from typing import List
from typing import Generator, List, Tuple


class CitationMixin(BaseModel):
class CitationMixin(BaseModel): # type: ignore[misc]
"""
Helpful mixing that can use `validation_context={"context": context}` in `from_response` to find the span of the substring_phrase in the context.
Expand Down Expand Up @@ -57,7 +57,7 @@ class User(BaseModel):
description="List of unique and specific substrings of the quote that was used to answer the question.",
)

@model_validator(mode="after")
@model_validator(mode="after") # type: ignore[misc]
def validate_sources(self, info: FieldValidationInfo) -> "CitationMixin":
"""
For each substring_phrase, find the span of the substring_phrase in the context.
Expand All @@ -75,8 +75,10 @@ def validate_sources(self, info: FieldValidationInfo) -> "CitationMixin":
self.substring_quotes = [text_chunks[span[0] : span[1]] for span in spans]
return self

def _get_span(self, quote, context, errs=5):
import regex
def _get_span(
self, quote: str, context: str, errs: int = 5
) -> Generator[Tuple[int, int], None, None]:
import regex # type: ignore[import-untyped]

minor = quote
major = context
Expand All @@ -90,6 +92,6 @@ def _get_span(self, quote, context, errs=5):
if s is not None:
yield from s.spans()

def get_spans(self, context):
def get_spans(self, context: str) -> Generator[Tuple[int, int], None, None]:
for quote in self.substring_quotes:
yield from self._get_span(quote, context)
48 changes: 31 additions & 17 deletions instructor/dsl/iterable.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,31 @@
from typing import List, Optional, Type, Any
from typing import Any, AsyncGenerator, Generator, Iterable, List, Optional, Tuple, Type

from pydantic import BaseModel, Field, create_model

from instructor.function_calls import OpenAISchema, Mode


class IterableBase:
task_type = None # type: ignore
task_type = None # type: ignore[var-annotated]

@classmethod
def from_streaming_response(cls, completion, mode: Mode, **kwargs: Any): # noqa: ARG003
def from_streaming_response(
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
) -> Generator[BaseModel, None, None]: # noqa: ARG003
json_chunks = cls.extract_json(completion, mode)
yield from cls.tasks_from_chunks(json_chunks)
yield from cls.tasks_from_chunks(json_chunks, **kwargs)

@classmethod
async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs):
async def from_streaming_response_async(
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
) -> AsyncGenerator[BaseModel, None]:
json_chunks = cls.extract_json_async(completion, mode)
return cls.tasks_from_chunks_async(json_chunks, **kwargs)

@classmethod
def tasks_from_chunks(cls, json_chunks, **kwargs):
def tasks_from_chunks(
cls, json_chunks: Iterable[str], **kwargs: Any
) -> Generator[BaseModel, None, None]:
started = False
potential_object = ""
for chunk in json_chunks:
Expand All @@ -32,11 +38,14 @@ def tasks_from_chunks(cls, json_chunks, **kwargs):

task_json, potential_object = cls.get_object(potential_object, 0)
if task_json:
obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore
assert cls.task_type is not None
obj = cls.task_type.model_validate_json(task_json, **kwargs)
yield obj

@classmethod
async def tasks_from_chunks_async(cls, json_chunks, **kwargs):
async def tasks_from_chunks_async(
cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
) -> AsyncGenerator[BaseModel, None]:
started = False
potential_object = ""
async for chunk in json_chunks:
Expand All @@ -49,11 +58,14 @@ async def tasks_from_chunks_async(cls, json_chunks, **kwargs):

task_json, potential_object = cls.get_object(potential_object, 0)
if task_json:
obj = cls.task_type.model_validate_json(task_json, **kwargs) # type: ignore
assert cls.task_type is not None
obj = cls.task_type.model_validate_json(task_json, **kwargs)
yield obj

@staticmethod
def extract_json(completion, mode: Mode):
def extract_json(
completion: Iterable[Any], mode: Mode
) -> Generator[str, None, None]:
for chunk in completion:
try:
if chunk.choices:
Expand All @@ -74,7 +86,9 @@ def extract_json(completion, mode: Mode):
pass

@staticmethod
async def extract_json_async(completion, mode: Mode):
async def extract_json_async(
completion: AsyncGenerator[Any, None], mode: Mode
) -> AsyncGenerator[str, None]:
async for chunk in completion:
try:
if chunk.choices:
Expand All @@ -95,15 +109,15 @@ async def extract_json_async(completion, mode: Mode):
pass

@staticmethod
def get_object(str, stack):
for i, c in enumerate(str):
def get_object(s: str, stack: int) -> Tuple[Optional[str], str]:
for i, c in enumerate(s):
if c == "{":
stack += 1
if c == "}":
stack -= 1
if stack == 0:
return str[: i + 1], str[i + 2 :]
return None, str
return s[: i + 1], s[i + 2 :]
return None, s


def IterableModel(
Expand Down Expand Up @@ -166,7 +180,7 @@ def from_streaming_response(cls, completion) -> Generator[User]:
name = f"Iterable{task_name}"

list_tasks = (
List[subtask_class],
List[subtask_class], # type: ignore[valid-type]
Field(
default_factory=list,
repr=False,
Expand All @@ -177,7 +191,7 @@ def from_streaming_response(cls, completion) -> Generator[User]:
new_cls = create_model(
name,
tasks=list_tasks,
__base__=(OpenAISchema, IterableBase), # type: ignore
__base__=(OpenAISchema, IterableBase),
)
# set the class constructor BaseModel
new_cls.task_type = subtask_class
Expand Down
8 changes: 4 additions & 4 deletions instructor/dsl/maybe.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pydantic import BaseModel, Field, create_model
from typing import Type, Optional, TypeVar, Generic
from typing import Generic, Optional, Type, TypeVar

T = TypeVar("T", bound=BaseModel)


class MaybeBase(BaseModel, Generic[T]):
class MaybeBase(BaseModel, Generic[T]): # type: ignore[misc]
"""
Extract a result from a model, if any, otherwise set the error and message fields.
"""
Expand All @@ -13,8 +13,8 @@ class MaybeBase(BaseModel, Generic[T]):
error: bool = Field(default=False)
message: Optional[str]

def __bool__(self):
return self.result is not None # type: ignore
def __bool__(self) -> bool:
return self.result is not None


def Maybe(model: Type[T]) -> Type[MaybeBase[T]]:
Expand Down
32 changes: 22 additions & 10 deletions instructor/dsl/parallel.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from typing import Type, TypeVar, Union, get_origin, get_args
from types import UnionType
from typing import (
Any,
Dict,
Generator,
List,
Optional,
Tuple,
Type,
TypeVar,
Union,
get_args,
get_origin,
)
from types import UnionType # type: ignore[attr-defined]

from instructor.function_calls import OpenAISchema, Mode, openai_schema
from collections.abc import Iterable

T = TypeVar("T")
T = TypeVar("T", bound=OpenAISchema)


class ParallelBase:
Expand All @@ -16,11 +28,11 @@ def __init__(self, *models: Type[OpenAISchema]):

def from_response(
self,
response,
response: Any,
mode: Mode,
validation_context=None,
strict: bool = None,
) -> Iterable[Union[T]]:
validation_context: Optional[Any] = None,
strict: Optional[bool] = None,
) -> Generator[T, None, None]:
#! We expect this from the OpenAISchema class, We should address
#! this with a protocol or an abstract class... @jxnlco
assert mode == Mode.PARALLEL_TOOLS, "Mode must be PARALLEL_TOOLS"
Expand All @@ -32,7 +44,7 @@ def from_response(
)


def get_types_array(typehint: Type[Iterable[Union[T]]]):
def get_types_array(typehint: Type[Iterable[Union[T]]]) -> Tuple[Type[T], ...]:
should_be_iterable = get_origin(typehint)
assert should_be_iterable is Iterable

Expand All @@ -50,14 +62,14 @@ def get_types_array(typehint: Type[Iterable[Union[T]]]):
return get_args(typehint)


def handle_parallel_model(typehint: Type[Iterable[Union[T]]]):
def handle_parallel_model(typehint: Type[Iterable[Union[T]]]) -> List[Dict[str, Any]]:
the_types = get_types_array(typehint)
return [
{"type": "function", "function": openai_schema(model).openai_schema}
for model in the_types
]


def ParallelModel(typehint):
def ParallelModel(typehint: Type[Iterable[Union[T]]]) -> ParallelBase:
the_types = get_types_array(typehint)
return ParallelBase(*[model for model in the_types])
43 changes: 33 additions & 10 deletions instructor/dsl/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,18 @@

from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
from typing import TypeVar, NoReturn, get_args, get_origin, Optional, Generic
from typing import (
Any,
AsyncGenerator,
Generator,
Generic,
get_args,
get_origin,
Iterable,
NoReturn,
Optional,
TypeVar,
)
from copy import deepcopy

from instructor.function_calls import Mode
Expand All @@ -21,17 +32,23 @@

class PartialBase:
@classmethod
def from_streaming_response(cls, completion, mode: Mode, **kwargs):
def from_streaming_response(
cls, completion: Iterable[Any], mode: Mode, **kwargs: Any
) -> Generator[Model, None, None]:
json_chunks = cls.extract_json(completion, mode)
yield from cls.model_from_chunks(json_chunks, **kwargs)

@classmethod
async def from_streaming_response_async(cls, completion, mode: Mode, **kwargs):
async def from_streaming_response_async(
cls, completion: AsyncGenerator[Any, None], mode: Mode, **kwargs: Any
) -> AsyncGenerator[Model, None]:
json_chunks = cls.extract_json_async(completion, mode)
return cls.model_from_chunks_async(json_chunks, **kwargs)

@classmethod
def model_from_chunks(cls, json_chunks, **kwargs):
def model_from_chunks(
cls, json_chunks: Iterable[Any], **kwargs: Any
) -> Generator[Model, None, None]:
prev_obj = None
potential_object = ""
for chunk in json_chunks:
Expand All @@ -42,7 +59,7 @@ def model_from_chunks(cls, json_chunks, **kwargs):
parser.parse(potential_object) if potential_object.strip() else None
)
if task_json:
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
if obj != prev_obj:
obj.__dict__[
"chunk"
Expand All @@ -51,7 +68,9 @@ def model_from_chunks(cls, json_chunks, **kwargs):
yield obj

@classmethod
async def model_from_chunks_async(cls, json_chunks, **kwargs):
async def model_from_chunks_async(
cls, json_chunks: AsyncGenerator[str, None], **kwargs: Any
) -> AsyncGenerator[Model, None]:
potential_object = ""
prev_obj = None
async for chunk in json_chunks:
Expand All @@ -62,7 +81,7 @@ async def model_from_chunks_async(cls, json_chunks, **kwargs):
parser.parse(potential_object) if potential_object.strip() else None
)
if task_json:
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore
obj = cls.model_validate(task_json, strict=None, **kwargs) # type: ignore[attr-defined]
if obj != prev_obj:
obj.__dict__[
"chunk"
Expand All @@ -71,7 +90,9 @@ async def model_from_chunks_async(cls, json_chunks, **kwargs):
yield obj

@staticmethod
def extract_json(completion, mode: Mode):
def extract_json(
completion: Iterable[Any], mode: Mode
) -> Generator[str, None, None]:
for chunk in completion:
try:
if chunk.choices:
Expand All @@ -92,7 +113,9 @@ def extract_json(completion, mode: Mode):
pass

@staticmethod
async def extract_json_async(completion, mode: Mode):
async def extract_json_async(
completion: AsyncGenerator[Any, None], mode: Mode
) -> AsyncGenerator[str, None]:
async for chunk in completion:
try:
if chunk.choices:
Expand Down Expand Up @@ -169,7 +192,7 @@ def _make_field_optional(

# Recursively apply Partial to each of the generic arguments
modified_args = tuple(
Partial[arg]
Partial[arg] # type: ignore[valid-type]
if isinstance(arg, type) and issubclass(arg, BaseModel)
else arg
for arg in generic_args
Expand Down
Loading

0 comments on commit 25f8214

Please sign in to comment.