Skip to content

Commit

Permalink
Enable more linting rules with Ruff. Update pre-commit hooks (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackmpcollins authored Sep 13, 2023
1 parent a74bee9 commit f18c1e6
Show file tree
Hide file tree
Showing 8 changed files with 56 additions and 46 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.278
rev: v0.0.289
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
Expand All @@ -19,6 +19,6 @@ repos:
hooks:
- id: black
- repo: https://github.com/adamchainz/blacken-docs
rev: "1.15.0"
rev: "1.16.0"
hooks:
- id: blacken-docs
18 changes: 11 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,27 @@ ignore = [
"C90", # mccabe
"D", # pydocstyle
"ANN", # flake8-annotations
"S", # flake8-bandit
"A", # flake8-builtins
"COM", # flake8-commas
"EM", # flake8-errmsg
"FA", # flake8-future-annotations
"INP", # flake8-no-pep420
"T20", # flake8-print
"PT", # flake8-pytest-style
"SLF", # flake8-self
"ARG", # flake8-unused-arguments
"TD", # flake8-todos
"FIX", # flake8-fixme
"PGH", # pygrep-hooks
"PL", # Pylint
"TRY", # tryceratops
"RUF", # Ruff-specific rules
]

[tool.ruff.flake8-pytest-style]
mark-parentheses = false

[tool.ruff.isort]
known-first-party = ["magentic"]

[tool.ruff.per-file-ignores]
"examples/*" = [
"T20", # flake8-print
]
"tests/*" = [
"S", # flake8-bandit
]
22 changes: 14 additions & 8 deletions src/magentic/chat_model/openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(self, output_type: type[AsyncIterableT]):
self._item_type_adapter = TypeAdapter(get_args(output_type)[0])
# Convert to list so pydantic can handle for schema generation
# But keep the type hint using AsyncIterableT for type checking
self._model: type[Output[AsyncIterableT]] = Output[list[get_args(output_type)[0]]] # type: ignore
self._model: type[Output[AsyncIterableT]] = Output[list[get_args(output_type)[0]]] # type: ignore[index,misc]

@property
def name(self) -> str:
Expand Down Expand Up @@ -306,7 +306,9 @@ class OpenaiChatCompletionFunctionCall(BaseModel):

def get_name_or_raise(self) -> str:
"""Return the name, raising an error if it doesn't exist."""
assert self.name is not None
if self.name is None:
msg = "OpenAI function call name is None"
raise ValueError(msg)
return self.name


Expand Down Expand Up @@ -504,17 +506,19 @@ def complete(
if chunk.choices[0].delta.function_call
)
except ValidationError as e:
raise StructuredOutputError(
msg = (
"Failed to parse model output. You may need to update your prompt"
" to encourage the model to return a specific type."
) from e
)
raise StructuredOutputError(msg) from e
return message

if not allow_string_output:
raise ValueError(
msg = (
"String was returned by model but not expected. You may need to update"
" your prompt to encourage the model to return a specific type."
)
raise ValueError(msg)
streamed_str = StreamedStr(
chunk.choices[0].delta.content
for chunk in response
Expand Down Expand Up @@ -577,17 +581,19 @@ async def acomplete(
if chunk.choices[0].delta.function_call
)
except ValidationError as e:
raise StructuredOutputError(
msg = (
"Failed to parse model output. You may need to update your prompt"
" to encourage the model to return a specific type."
) from e
)
raise StructuredOutputError(msg) from e
return message

if not allow_string_output:
raise ValueError(
msg = (
"String was returned by model but not expected. You may need to update"
" your prompt to encourage the model to return a specific type."
)
raise ValueError(msg)
async_streamed_str = AsyncStreamedStr(
chunk.choices[0].delta.content
async for chunk in response
Expand Down
10 changes: 5 additions & 5 deletions src/magentic/typing.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import inspect
import types
from collections.abc import Mapping, Sequence
from typing import (
Any,
Mapping,
Sequence,
TypeGuard,
TypeVar,
Union,
Expand All @@ -15,7 +14,7 @@
def is_union_type(type_: type) -> bool:
"""Return True if the type is a union type."""
type_ = get_origin(type_) or type_
return type_ is Union or type_ is types.UnionType # noqa: E721
return type_ is Union or type_ is types.UnionType


TypeT = TypeVar("TypeT", bound=type)
Expand Down Expand Up @@ -58,11 +57,12 @@ def name_type(type_: type) -> str:
return f"dict_of_{name_type(key_type)}_to_{name_type(value_type)}"

if name := getattr(type_, "__name__", None):
assert isinstance(name, str)
assert isinstance(name, str) # noqa: S101

if len(args) == 1:
return f"{name.lower()}_of_{name_type(args[0])}"

return name.lower()

raise ValueError(f"Unable to name type {type_}")
msg = f"Unable to name type {type_}"
raise ValueError(msg)
4 changes: 2 additions & 2 deletions tests/test_function_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def plus_default_value(a: int, b: int = 3) -> int:


@pytest.mark.parametrize(
["left", "right", "equal"],
("left", "right", "equal"),
[
(FunctionCall(plus, a=1, b=2), FunctionCall(plus, a=1, b=2), True),
(FunctionCall(plus, a=1, b=2), FunctionCall(plus, a=1, b=33), False),
Expand All @@ -40,7 +40,7 @@ def test_function_call_eq(left, right, equal):


@pytest.mark.parametrize(
["function_call", "arguments"],
("function_call", "arguments"),
[
(FunctionCall(plus, a=1, b=2), {"a": 1, "b": 2}),
(FunctionCall(plus, 1, 2), {"a": 1, "b": 2}),
Expand Down
32 changes: 16 additions & 16 deletions tests/test_openai_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@pytest.mark.parametrize(
["type_", "json_schema"],
("type_", "json_schema"),
[
(
str,
Expand Down Expand Up @@ -151,15 +151,15 @@ def test_any_function_schema(type_, json_schema):


@pytest.mark.parametrize(
["type_", "args_str", "expected_args"], any_function_schema_args_test_cases
("type_", "args_str", "expected_args"), any_function_schema_args_test_cases
)
def test_any_function_schema_parse_args(type_, args_str, expected_args):
parsed_args = AnyFunctionSchema(type_).parse_args(args_str)
assert parsed_args == expected_args


@pytest.mark.parametrize(
["type_", "args_str", "expected_args"], any_function_schema_args_test_cases
("type_", "args_str", "expected_args"), any_function_schema_args_test_cases
)
@pytest.mark.asyncio
async def test_any_function_schema_aparse_args(type_, args_str, expected_args):
Expand All @@ -168,15 +168,15 @@ async def test_any_function_schema_aparse_args(type_, args_str, expected_args):


@pytest.mark.parametrize(
["type_", "expected_args_str", "args"], any_function_schema_args_test_cases
("type_", "expected_args_str", "args"), any_function_schema_args_test_cases
)
def test_any_function_schema_serialize_args(type_, expected_args_str, args):
serialized_args = AnyFunctionSchema(type_).serialize_args(args)
assert json.loads(serialized_args) == json.loads(expected_args_str)


@pytest.mark.parametrize(
["type_", "json_schema"],
("type_", "json_schema"),
[
(
list[str],
Expand Down Expand Up @@ -244,7 +244,7 @@ def test_iterable_function_schema(type_, json_schema):


@pytest.mark.parametrize(
["type_", "args_str", "expected_args"], iterable_function_schema_args_test_cases
("type_", "args_str", "expected_args"), iterable_function_schema_args_test_cases
)
def test_iterable_function_schema_parse_args(type_, args_str, expected_args):
parsed_args = IterableFunctionSchema(type_).parse_args(args_str)
Expand All @@ -253,15 +253,15 @@ def test_iterable_function_schema_parse_args(type_, args_str, expected_args):


@pytest.mark.parametrize(
["type_", "expected_args_str", "args"], iterable_function_schema_args_test_cases
("type_", "expected_args_str", "args"), iterable_function_schema_args_test_cases
)
def test_iterable_function_schema_serialize_args(type_, expected_args_str, args):
serialized_args = IterableFunctionSchema(type_).serialize_args(args)
assert json.loads(serialized_args) == json.loads(expected_args_str)


@pytest.mark.parametrize(
["type_", "json_schema"],
("type_", "json_schema"),
[
(
typing.AsyncIterable[str],
Expand Down Expand Up @@ -311,7 +311,7 @@ def test_async_iterable_function_schema(type_, json_schema):


@pytest.mark.parametrize(
["type_", "args_str", "expected_args"],
("type_", "args_str", "expected_args"),
async_iterable_function_schema_args_test_cases,
)
@pytest.mark.asyncio
Expand All @@ -326,7 +326,7 @@ async def test_async_iterable_function_schema_aparse_args(


@pytest.mark.parametrize(
["type_", "expected_args_str", "args"],
("type_", "expected_args_str", "args"),
async_iterable_function_schema_args_test_cases,
)
def test_async_iterable_function_schema_serialize_args(type_, expected_args_str, args):
Expand All @@ -340,7 +340,7 @@ class User(BaseModel):


@pytest.mark.parametrize(
["type_", "json_schema"],
("type_", "json_schema"),
[
(
dict,
Expand Down Expand Up @@ -417,7 +417,7 @@ def test_dict_function_schema(type_, json_schema):


@pytest.mark.parametrize(
["type_", "args_str", "expected_args"], dict_function_schema_args_test_cases
("type_", "args_str", "expected_args"), dict_function_schema_args_test_cases
)
def test_dict_function_schema_parse_args(type_, args_str, expected_args):
parsed_args = DictFunctionSchema(type_).parse_args(args_str)
Expand All @@ -426,7 +426,7 @@ def test_dict_function_schema_parse_args(type_, args_str, expected_args):


@pytest.mark.parametrize(
["type_", "expected_args_str", "args"], dict_function_schema_args_test_cases
("type_", "expected_args_str", "args"), dict_function_schema_args_test_cases
)
def test_dict_function_schema_serialize_args(type_, expected_args_str, args):
serialized_args = DictFunctionSchema(type_).serialize_args(args)
Expand Down Expand Up @@ -478,7 +478,7 @@ def plus_with_annotated(


@pytest.mark.parametrize(
["function", "json_schema"],
("function", "json_schema"),
[
(
plus,
Expand Down Expand Up @@ -589,7 +589,7 @@ def test_function_call_function_schema_with_default_value():


@pytest.mark.parametrize(
["function", "args_str", "expected_args"],
("function", "args_str", "expected_args"),
function_call_function_schema_args_test_cases,
)
def test_function_call_function_schema_parse_args(function, args_str, expected_args):
Expand All @@ -598,7 +598,7 @@ def test_function_call_function_schema_parse_args(function, args_str, expected_a


@pytest.mark.parametrize(
["function", "expected_args_str", "args"],
("function", "expected_args_str", "args"),
function_call_function_schema_args_test_cases,
)
def test_function_call_function_schema_serialize_args(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ async def test_async_iter():
]


@pytest.mark.parametrize(["input", "expected"], iter_streamed_json_array_test_cases)
@pytest.mark.parametrize(("input", "expected"), iter_streamed_json_array_test_cases)
def test_iter_streamed_json_array(input, expected):
assert list(iter_streamed_json_array(iter(input))) == expected


@pytest.mark.parametrize(["input", "expected"], iter_streamed_json_array_test_cases)
@pytest.mark.parametrize(("input", "expected"), iter_streamed_json_array_test_cases)
@pytest.mark.asyncio
async def test_aiter_streamed_json_array(input, expected):
assert [x async for x in aiter_streamed_json_array(async_iter(input))] == expected
Expand Down
8 changes: 4 additions & 4 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


@pytest.mark.parametrize(
["type_", "expected_types"],
("type_", "expected_types"),
[
(str, ["str"]),
(int, ["int"]),
Expand All @@ -25,7 +25,7 @@ def test_split_union_type(type_, expected_types):


@pytest.mark.parametrize(
["type_", "expected_result"],
("type_", "expected_result"),
[
(str, False),
(list[str], False),
Expand All @@ -37,7 +37,7 @@ def test_is_origin_abstract(type_, expected_result):


@pytest.mark.parametrize(
["type_", "cls_or_tuple", "expected_result"],
("type_", "cls_or_tuple", "expected_result"),
[
(str, str, True),
(str, int, False),
Expand All @@ -52,7 +52,7 @@ def test_is_origin_subclass(type_, cls_or_tuple, expected_result):


@pytest.mark.parametrize(
["type_", "expected_name"],
("type_", "expected_name"),
[
(str, "str"),
(int, "int"),
Expand Down

0 comments on commit f18c1e6

Please sign in to comment.