Skip to content

Commit

Permalink
Merge pull request #15 from BrianPugh/more-tests
Browse files Browse the repository at this point in the history
Add unsupported Tuple[...] check.
  • Loading branch information
BrianPugh authored Dec 7, 2023
2 parents b5a95c7 + ac13ab8 commit 1f6640b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 6 deletions.
9 changes: 6 additions & 3 deletions cyclopts/parameter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
import typing
from functools import lru_cache
from typing import Iterable, List, Optional, Tuple, Type, Union, get_origin
from typing import Iterable, List, Optional, Tuple, Type, Union, get_args, get_origin

from attrs import field, frozen
from typing_extensions import Annotated
Expand Down Expand Up @@ -43,9 +43,12 @@ def validate_command(f):
signature = inspect.signature(f)
for iparam in signature.parameters.values():
_ = get_origin_and_validate(iparam.annotation)
_, cparam = get_hint_parameter(iparam.annotation)
type_, cparam = get_hint_parameter(iparam.annotation)
if not cparam.parse and iparam.kind is not iparam.KEYWORD_ONLY:
raise ValueError("Parameter.parse=False must be used with a KEYWORD_ONLY function parameter.")
if get_origin(type_) is tuple:
if ... in get_args(type_):
raise ValueError("Cannot use a variable-length tuple.")


@frozen
Expand Down Expand Up @@ -207,7 +210,7 @@ def get_hint_parameter(type_: Type) -> Tuple[Type, Parameter]:

if type(type_) is AnnotatedType:
annotations = type_.__metadata__ # pyright: ignore[reportGeneralTypeIssues]
type_ = typing.get_args(type_)[0]
type_ = get_args(type_)[0]
cyclopts_parameters = [x for x in annotations if isinstance(x, Parameter)]
if len(cyclopts_parameters) > 2:
raise MultipleParameterAnnotationError
Expand Down
20 changes: 20 additions & 0 deletions tests/test_bind_kwargs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,23 @@ def foo(a: int, **kwargs: int):
actual_command, actual_bind = app.parse_args("foo 1")
assert actual_command == foo
assert actual_bind == expected_bind


def test_args_and_kwargs_int(app):
@app.command
def foo(a: int, *args: int, **kwargs: int):
pass

signature = inspect.signature(foo)
expected_bind = signature.bind(1, 2, 3, 4, 5, bar=2, baz=3)

actual_command, actual_bind = app.parse_args("foo 1 2 3 4 5 --bar=2 --baz 3")
assert actual_command == foo
assert actual_bind == expected_bind

signature = inspect.signature(foo)
expected_bind = signature.bind(1)

actual_command, actual_bind = app.parse_args("foo 1")
assert actual_command == foo
assert actual_bind == expected_bind
15 changes: 12 additions & 3 deletions tests/test_validate_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,24 @@ def f4(a: Tuple[int, int], b: str):

validate_command(f4)

# Python automatically deduplicates the double None.
def f5(a: Union[None, None]):
pass

validate_command(f5)


def test_validate_command_exception():
def test_validate_command_exception_bare_tuple():
def f1(a: tuple):
pass

with pytest.raises(TypeError):
validate_command(f1)

def f2(a: Union[None, None]):

def test_validate_command_exception_elipsis_tuple():
def f1(a: Tuple[int, ...]):
pass

validate_command(f2)
with pytest.raises(ValueError):
validate_command(f1)

0 comments on commit 1f6640b

Please sign in to comment.