Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(commands): handle interactions in union types correctly #1121

Merged
merged 7 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/1121.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Make :class:`Interaction` and subtypes accept the bot type as a generic parameter to denote the type returned by the :attr:`~Interaction.bot` and :attr:`~Interaction.client` properties.
38 changes: 20 additions & 18 deletions disnake/ext/commands/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
Type,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
)
Expand Down Expand Up @@ -110,17 +109,26 @@


def issubclass_(obj: Any, tp: Union[TypeT, Tuple[TypeT, ...]]) -> TypeGuard[TypeT]:
"""Similar to the builtin `issubclass`, but more lenient.
Can also handle unions (`issubclass(Union[int, str], int)`) and
generic types (`issubclass(X[T], X)`) in the first argument.
"""
if not isinstance(tp, (type, tuple)):
return False
elif not isinstance(obj, type):
# Assume we have a type hint
if get_origin(obj) in (Union, UnionType, Optional):
obj = get_args(obj)
return any(isinstance(o, type) and issubclass(o, tp) for o in obj)
else:
# Other type hint specializations are not supported
return False
return issubclass(obj, tp)
elif isinstance(obj, type):
# common case
return issubclass(obj, tp)

# At this point, `obj` is likely a generic type hint
if (origin := get_origin(obj)) is None:
return False

if origin in (Union, UnionType):
# If we have a Union, try matching any of its args
# (recursively, to handle possibly generic types inside this union)
return any(issubclass_(o, tp) for o in obj.__args__)
else:
return isinstance(origin, type) and issubclass(origin, tp)


def remove_optionals(annotation: Any) -> Any:
Expand Down Expand Up @@ -912,7 +920,6 @@ def isolate_self(
parametersl.pop(0)
if parametersl:
annot = parametersl[0].annotation
annot = get_origin(annot) or annot
if issubclass_(annot, ApplicationCommandInteraction) or annot is inspect.Parameter.empty:
inter_param = parameters.pop(parametersl[0].name)

Expand Down Expand Up @@ -984,9 +991,7 @@ def collect_params(
injections[parameter.name] = default
elif parameter.annotation in Injection._registered:
injections[parameter.name] = Injection._registered[parameter.annotation]
elif issubclass_(
get_origin(parameter.annotation) or parameter.annotation, ApplicationCommandInteraction
):
elif issubclass_(parameter.annotation, ApplicationCommandInteraction):
if inter_param is None:
inter_param = parameter
else:
Expand Down Expand Up @@ -1120,10 +1125,7 @@ def expand_params(command: AnySlashCommand) -> List[Option]:
if param.autocomplete:
command.autocompleters[param.name] = param.autocomplete

if issubclass_(
get_origin(annot := sig.parameters[inter_param].annotation) or annot,
disnake.GuildCommandInteraction,
):
if issubclass_(sig.parameters[inter_param].annotation, disnake.GuildCommandInteraction):
command._guild_only = True

return [param.to_option() for param in params]
Expand Down
49 changes: 48 additions & 1 deletion tests/ext/commands/test_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import disnake
from disnake import Member, Role, User
from disnake.ext import commands
from disnake.ext.commands import params

OptionType = disnake.OptionType

Expand Down Expand Up @@ -66,6 +67,53 @@ async def test_verify_type__invalid_member(self, annotation, arg_types) -> None:
with pytest.raises(commands.errors.MemberNotFound):
await info.verify_type(mock.Mock(), arg_mock)

def test_isolate_self(self) -> None:
def func(a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
assert cog is None
assert inter is None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_inter(self) -> None:
def func(i: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_cog_inter(self) -> None:
def func(self, i: disnake.ApplicationCommandInteraction, a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
assert cog is not None
assert inter is not None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_generic(self) -> None:
def func(i: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})

def test_isolate_self_union(self) -> None:
def func(
i: Union[commands.Context, disnake.ApplicationCommandInteraction[commands.Bot]], a: int
) -> None:
...

(cog, inter), parameters = params.isolate_self(params.signature(func))
assert cog is None
assert inter is not None
assert parameters == ({"a": mock.ANY})


# this uses `Range` for testing `_BaseRange`, `String` should work equally
class TestBaseRange:
Expand Down Expand Up @@ -189,7 +237,6 @@ def test_string(self) -> None:
assert info.max_value is None
assert info.type == annotation.underlying_type

# uses lambdas since new union syntax isn't supported on all versions
@pytest.mark.parametrize(
"annotation_str",
[
Expand Down