diff --git a/changelog/1121.feature.rst b/changelog/1121.feature.rst new file mode 100644 index 0000000000..1294ba4044 --- /dev/null +++ b/changelog/1121.feature.rst @@ -0,0 +1 @@ +Make :class:`Interaction` and subtypes accept the bot type as a generic parameter to denote the type returned by the :attr:`~Interaction.bot` and :attr:`~Interaction.client` properties. diff --git a/disnake/ext/commands/params.py b/disnake/ext/commands/params.py index 2ab93359d2..0e702385ad 100644 --- a/disnake/ext/commands/params.py +++ b/disnake/ext/commands/params.py @@ -31,7 +31,6 @@ Type, TypeVar, Union, - get_args, get_origin, get_type_hints, ) @@ -110,17 +109,26 @@ def issubclass_(obj: Any, tp: Union[TypeT, Tuple[TypeT, ...]]) -> TypeGuard[TypeT]: + """Similar to the builtin `issubclass`, but more lenient. + Can also handle unions (`issubclass(Union[int, str], int)`) and + generic types (`issubclass(X[T], X)`) in the first argument. + """ if not isinstance(tp, (type, tuple)): return False - elif not isinstance(obj, type): - # Assume we have a type hint - if get_origin(obj) in (Union, UnionType, Optional): - obj = get_args(obj) - return any(isinstance(o, type) and issubclass(o, tp) for o in obj) - else: - # Other type hint specializations are not supported - return False - return issubclass(obj, tp) + elif isinstance(obj, type): + # common case + return issubclass(obj, tp) + + # At this point, `obj` is likely a generic type hint + if (origin := get_origin(obj)) is None: + return False + + if origin in (Union, UnionType): + # If we have a Union, try matching any of its args + # (recursively, to handle possibly generic types inside this union) + return any(issubclass_(o, tp) for o in obj.__args__) + else: + return isinstance(origin, type) and issubclass(origin, tp) def remove_optionals(annotation: Any) -> Any: @@ -912,7 +920,6 @@ def isolate_self( parametersl.pop(0) if parametersl: annot = parametersl[0].annotation - annot = get_origin(annot) or annot if issubclass_(annot, ApplicationCommandInteraction) or annot is inspect.Parameter.empty: inter_param = parameters.pop(parametersl[0].name) @@ -984,9 +991,7 @@ def collect_params( injections[parameter.name] = default elif parameter.annotation in Injection._registered: injections[parameter.name] = Injection._registered[parameter.annotation] - elif issubclass_( - get_origin(parameter.annotation) or parameter.annotation, ApplicationCommandInteraction - ): + elif issubclass_(parameter.annotation, ApplicationCommandInteraction): if inter_param is None: inter_param = parameter else: @@ -1120,10 +1125,7 @@ def expand_params(command: AnySlashCommand) -> List[Option]: if param.autocomplete: command.autocompleters[param.name] = param.autocomplete - if issubclass_( - get_origin(annot := sig.parameters[inter_param].annotation) or annot, - disnake.GuildCommandInteraction, - ): + if issubclass_(sig.parameters[inter_param].annotation, disnake.GuildCommandInteraction): command._guild_only = True return [param.to_option() for param in params] diff --git a/tests/ext/commands/test_params.py b/tests/ext/commands/test_params.py index a3b4ea4289..8e8ca91304 100644 --- a/tests/ext/commands/test_params.py +++ b/tests/ext/commands/test_params.py @@ -10,6 +10,7 @@ import disnake from disnake import Member, Role, User from disnake.ext import commands +from disnake.ext.commands import params OptionType = disnake.OptionType @@ -66,6 +67,53 @@ async def test_verify_type__invalid_member(self, annotation, arg_types) -> None: with pytest.raises(commands.errors.MemberNotFound): await info.verify_type(mock.Mock(), arg_mock) + def test_isolate_self(self) -> None: + def func(a: int) -> None: + ... + + (cog, inter), parameters = params.isolate_self(params.signature(func)) + assert cog is None + assert inter is None + assert parameters == ({"a": mock.ANY}) + + def test_isolate_self_inter(self) -> None: + def func(i: disnake.ApplicationCommandInteraction, a: int) -> None: + ... + + (cog, inter), parameters = params.isolate_self(params.signature(func)) + assert cog is None + assert inter is not None + assert parameters == ({"a": mock.ANY}) + + def test_isolate_self_cog_inter(self) -> None: + def func(self, i: disnake.ApplicationCommandInteraction, a: int) -> None: + ... + + (cog, inter), parameters = params.isolate_self(params.signature(func)) + assert cog is not None + assert inter is not None + assert parameters == ({"a": mock.ANY}) + + def test_isolate_self_generic(self) -> None: + def func(i: disnake.ApplicationCommandInteraction[commands.Bot], a: int) -> None: + ... + + (cog, inter), parameters = params.isolate_self(params.signature(func)) + assert cog is None + assert inter is not None + assert parameters == ({"a": mock.ANY}) + + def test_isolate_self_union(self) -> None: + def func( + i: Union[commands.Context, disnake.ApplicationCommandInteraction[commands.Bot]], a: int + ) -> None: + ... + + (cog, inter), parameters = params.isolate_self(params.signature(func)) + assert cog is None + assert inter is not None + assert parameters == ({"a": mock.ANY}) + # this uses `Range` for testing `_BaseRange`, `String` should work equally class TestBaseRange: @@ -189,7 +237,6 @@ def test_string(self) -> None: assert info.max_value is None assert info.type == annotation.underlying_type - # uses lambdas since new union syntax isn't supported on all versions @pytest.mark.parametrize( "annotation_str", [