diff --git a/cyclopts/__init__.py b/cyclopts/__init__.py index ef7126dd..70ed42c7 100644 --- a/cyclopts/__init__.py +++ b/cyclopts/__init__.py @@ -16,6 +16,7 @@ "UnusedCliTokensError", "ValidationError", "convert", + "default_name_transform", "types", "validators", ] @@ -36,5 +37,6 @@ from cyclopts.group import Group from cyclopts.parameter import Parameter from cyclopts.protocols import Dispatcher +from cyclopts.utils import default_name_transform from . import types, validators diff --git a/cyclopts/_convert.py b/cyclopts/_convert.py index 5e9506d8..be6efcfe 100644 --- a/cyclopts/_convert.py +++ b/cyclopts/_convert.py @@ -27,7 +27,7 @@ from cyclopts.exceptions import CoercionError -from cyclopts.utils import is_union +from cyclopts.utils import default_name_transform, is_union if TYPE_CHECKING: from cyclopts.parameter import Parameter @@ -83,7 +83,13 @@ def _bytearray(s: str) -> bytearray: } -def _convert_tuple(type_: Type[Any], *args: str, converter: Optional[Callable] = None) -> Tuple: +def _convert_tuple( + type_: Type[Any], + *args: str, + converter: Optional[Callable[[Type, str], Any]], + name_transform: Callable[[str], str], +) -> Tuple: + convert = partial(_convert, converter=converter, name_transform=name_transform) inner_types = tuple(x for x in get_args(type_) if x is not ...) inner_token_count, consume_all = token_count(type_) if consume_all: @@ -101,11 +107,10 @@ def _convert_tuple(type_: Type[Any], *args: str, converter: Optional[Callable] = raise ValueError("A tuple must have 0 or 1 inner-types.") if inner_token_count == 1: - out = tuple(_convert(inner_type, x, converter=converter) for x in args) + out = tuple(convert(inner_type, x) for x in args) else: out = tuple( - _convert(inner_type, args[i : i + inner_token_count], converter=converter) - for i in range(0, len(args), inner_token_count) + convert(inner_type, args[i : i + inner_token_count]) for i in range(0, len(args), inner_token_count) ) return out else: @@ -116,27 +121,41 @@ def _convert_tuple(type_: Type[Any], *args: str, converter: Optional[Callable] = it = iter(args) batched = [[next(it) for _ in range(size)] for size in args_per_convert] batched = [elem[0] if len(elem) == 1 else elem for elem in batched] - out = tuple(_convert(inner_type, arg, converter=converter) for inner_type, arg in zip(inner_types, batched)) + out = tuple(convert(inner_type, arg) for inner_type, arg in zip(inner_types, batched)) return out -def _convert(type_, element, converter=None): - pconvert = partial(_convert, converter=converter) +def _convert( + type_, + element, + *, + converter: Optional[Callable[[Type, str], Any]], + name_transform: Callable[[str], str], +): + """Inner recursive conversion function for public ``convert``. + + Parameters + ---------- + converter: Callable + name_transform: Callable + """ + convert = partial(_convert, converter=converter, name_transform=name_transform) + convert_tuple = partial(_convert_tuple, converter=converter, name_transform=name_transform) origin_type = get_origin(type_) inner_types = [resolve(x) for x in get_args(type_)] if type_ in _implicit_iterable_type_mapping: - return pconvert(_implicit_iterable_type_mapping[type_], element) + return convert(_implicit_iterable_type_mapping[type_], element) if origin_type is collections.abc.Iterable: assert len(inner_types) == 1 - return pconvert(List[inner_types[0]], element) # pyright: ignore[reportGeneralTypeIssues] + return convert(List[inner_types[0]], element) # pyright: ignore[reportGeneralTypeIssues] elif is_union(origin_type): for t in inner_types: if t is NoneType: continue try: - return pconvert(t, element) + return convert(t, element) except Exception: pass else: @@ -144,7 +163,7 @@ def _convert(type_, element, converter=None): elif origin_type is Literal: for choice in get_args(type_): try: - res = pconvert(type(choice), (element)) + res = convert(type(choice), (element)) except Exception: continue if res == choice: @@ -157,18 +176,18 @@ def _convert(type_, element, converter=None): gen = zip(*[iter(element)] * count) else: gen = element - return origin_type(pconvert(inner_types[0], e) for e in gen) # pyright: ignore[reportOptionalCall] + return origin_type(convert(inner_types[0], e) for e in gen) # pyright: ignore[reportOptionalCall] elif origin_type is tuple: if isinstance(element, str): # E.g. Tuple[str] (Annotation: tuple containing a single string) - return _convert_tuple(type_, element, converter=converter) + return convert_tuple(type_, element, converter=converter) else: - return _convert_tuple(type_, *element, converter=converter) + return convert_tuple(type_, *element, converter=converter) elif isclass(type_) and issubclass(type_, Enum): if converter is None: - element_lower = element.lower().replace("-", "_") + element_transformed = name_transform(element) for member in type_: - if member.name.lower().strip("_") == element_lower: + if name_transform(member.name) == element_transformed: return member raise CoercionError(input_value=element, target_type=type_) else: @@ -240,7 +259,12 @@ def resolve_annotated(type_: Type) -> Type: return type_ -def convert(type_: Type, *args: str, converter: Optional[Callable] = None): +def convert( + type_: Type, + *args: str, + converter: Optional[Callable[[Type, str], Any]] = None, + name_transform: Optional[Callable[[str], str]] = None, +): """Coerce variables into a specified type. Internally used to coercing string CLI tokens into python builtin types. @@ -259,8 +283,7 @@ def convert(type_: Type, *args: str, converter: Optional[Callable] = None): A type hint/annotation to coerce ``*args`` into. `*args`: str String tokens to coerce. - converter: Optional[Callable] - + converter: Optional[Callable[[Type, str], Any]] An optional function to convert tokens to the inner-most types. The converter should have signature: @@ -272,12 +295,31 @@ def converter(type_: type, value: str) -> Any: This allows to use the :func:`convert` function to handle the the difficult task of traversing lists/tuples/unions/etc, while leaving the final conversion logic to the caller. + name_transform: Optional[Callable[[str], str]] + Currently only used for ``Enum`` type hints. + A function that transforms enum names and CLI values into a normalized format. + + The function should have signature: + + .. code-block:: python + + def name_transform(s: str) -> str: + ... + + where the returned value is the name to be used on the CLI. + + If ``None``, defaults to ``cyclopts.default_name_transform``. Returns ------- Any Coerced version of input ``*args``. """ + if name_transform is None: + name_transform = default_name_transform + + convert = partial(_convert, converter=converter, name_transform=name_transform) + convert_tuple = partial(_convert_tuple, converter=converter, name_transform=name_transform) type_ = resolve(type_) if type_ is Any: @@ -288,13 +330,13 @@ def converter(type_: type, value: str) -> Any: origin_type = get_origin_and_validate(type_) if origin_type is tuple: - return _convert_tuple(type_, *args, converter=converter) + return convert_tuple(type_, *args) elif (origin_type or type_) in _iterable_types or origin_type is collections.abc.Iterable: - return _convert(type_, args, converter=converter) + return convert(type_, args) elif len(args) == 1: - return _convert(type_, args[0], converter=converter) + return convert(type_, args[0]) else: - return [_convert(type_, item, converter=converter) for item in args] + return [convert(type_, item) for item in args] def token_count(type_: Union[Type[Any], inspect.Parameter]) -> Tuple[int, bool]: diff --git a/cyclopts/core.py b/cyclopts/core.py index 349321d2..01982155 100644 --- a/cyclopts/core.py +++ b/cyclopts/core.py @@ -55,17 +55,13 @@ from cyclopts.parameter import Parameter, validate_command from cyclopts.protocols import Dispatcher from cyclopts.resolve import ResolvedCommand -from cyclopts.utils import optional_to_tuple_converter, to_list_converter, to_tuple_converter +from cyclopts.utils import default_name_transform, optional_to_tuple_converter, to_list_converter, to_tuple_converter with suppress(ImportError): # By importing, makes things like the arrow-keys work. import readline # Not available on windows -def _format_name(name: str): - return name.lower().replace("_", "-").strip("-") - - class _CannotDeriveCallingModuleNameError(Exception): pass @@ -224,6 +220,12 @@ class App: converter: Optional[Callable] = field(default=None, kw_only=True) validator: List[Callable] = field(default=None, converter=to_list_converter, kw_only=True) + _name_transform: Optional[Callable[[str], str]] = field( + default=None, + alias="name_transform", + kw_only=True, + ) + ###################### # Private Attributes # ###################### @@ -307,7 +309,7 @@ def name(self) -> Tuple[str, ...]: name = _get_root_module_name() return (name,) else: - return (_format_name(self.default_command.__name__),) + return (self.name_transform(self.default_command.__name__),) @property def help(self) -> str: @@ -328,6 +330,14 @@ def help(self) -> str: def help(self, value): self._help = value + @property + def name_transform(self): + return self._name_transform if self._name_transform else default_name_transform + + @name_transform.setter + def name_transform(self, value): + self._name_transform = value + def version_print(self) -> None: """Print the application version.""" print(self.version() if callable(self.version) else self.version) @@ -458,6 +468,9 @@ def command( app = App(default_command=obj, **kwargs) # app.name is handled below + if app._name_transform is None: + app.name_transform = self.name_transform + if name is None: name = app.name else: diff --git a/cyclopts/help.py b/cyclopts/help.py index 9042a04c..8e531084 100644 --- a/cyclopts/help.py +++ b/cyclopts/help.py @@ -1,8 +1,8 @@ import inspect from enum import Enum -from functools import lru_cache +from functools import lru_cache, partial from inspect import isclass -from typing import TYPE_CHECKING, List, Literal, Tuple, Type, Union, get_args, get_origin +from typing import TYPE_CHECKING, Callable, List, Literal, Tuple, Type, Union, get_args, get_origin import docstring_parser from attrs import define, field, frozen @@ -190,20 +190,21 @@ def format_doc(root_app, app: "App", format: str = "restructuredtext"): raise ValueError(f'Unknown help_format "{format}"') -def _get_choices(type_: Type) -> str: +def _get_choices(type_: Type, name_transform: Callable[[str], str]) -> str: + get_choices = partial(_get_choices, name_transform=name_transform) choices: str = "" _origin = get_origin(type_) if isclass(type_) and issubclass(type_, Enum): - choices = ",".join(x.name.lower().replace("_", "-") for x in type_) + choices = ",".join(name_transform(x.name) for x in type_) elif _origin is Union: - inner_choices = [_get_choices(inner) for inner in get_args(type_)] + inner_choices = [get_choices(inner) for inner in get_args(type_)] choices = ",".join(x for x in inner_choices if x) elif _origin is Literal: choices = ",".join(str(x) for x in get_args(type_)) elif _origin in (list, set, tuple): args = get_args(type_) if len(args) == 1 or (_origin is tuple and len(args) == 2 and args[1] is Ellipsis): - choices = _get_choices(args[0]) + choices = get_choices(args[0]) return choices @@ -218,6 +219,7 @@ def create_parameter_help_panel(group: "Group", iparams, cparams: List[Parameter for iparam, cparam in icparams: assert cparam.name is not None + assert cparam.name_transform is not None type_ = get_hint_parameter(iparam)[0] options = list(cparam.name) options.extend(cparam.get_negatives(type_, *options)) @@ -241,7 +243,7 @@ def create_parameter_help_panel(group: "Group", iparams, cparams: List[Parameter help_components.append(cparam.help) if cparam.show_choices: - choices = _get_choices(type_) + choices = _get_choices(type_, cparam.name_transform) if choices: help_components.append(rf"[dim]\[choices: {choices}][/dim]") @@ -254,7 +256,7 @@ def create_parameter_help_panel(group: "Group", iparams, cparams: List[Parameter ): default = "" if isclass(type_) and issubclass(type_, Enum): - default = iparam.default.name.lower().replace("_", "-") + default = cparam.name_transform(iparam.default.name) else: default = iparam.default diff --git a/cyclopts/parameter.py b/cyclopts/parameter.py index f8bf597e..f6bc032b 100644 --- a/cyclopts/parameter.py +++ b/cyclopts/parameter.py @@ -1,4 +1,5 @@ import inspect +from functools import partial from typing import Any, Callable, Iterable, Optional, Tuple, Type, Union, cast, get_args, get_origin import attrs @@ -13,7 +14,7 @@ resolve_optional, ) from cyclopts.group import Group -from cyclopts.utils import optional_to_tuple_converter, record_init, to_tuple_converter +from cyclopts.utils import default_name_transform, optional_to_tuple_converter, record_init, to_tuple_converter def _double_hyphen_validator(instance, attribute, values): @@ -48,7 +49,7 @@ class Parameter: converter=lambda x: cast(Tuple[str, ...], to_tuple_converter(x)), ) - converter: Callable = field(default=None, converter=attrs.converters.default_if_none(convert)) + _converter: Callable = field(default=None, alias="converter") # This can ONLY ever be a Tuple[Callable, ...] validator: Union[None, Callable, Iterable[Callable]] = field( @@ -98,6 +99,12 @@ class Parameter: allow_leading_hyphen: bool = field(default=False) + name_transform: Optional[Callable[[str], str]] = field( + default=None, + converter=attrs.converters.default_if_none(default_name_transform), + kw_only=True, + ) + # Populated by the record_attrs_init_args decorator. _provided_args: Tuple[str] = field(default=(), init=False, eq=False) @@ -105,6 +112,10 @@ class Parameter: def show(self): return self._show if self._show is not None else self.parse + @property + def converter(self): + return self._converter if self._converter else partial(convert, name_transform=self.name_transform) + def get_negatives(self, type_, *names: str) -> Tuple[str, ...]: type_ = get_origin(type_) or type_ diff --git a/cyclopts/resolve.py b/cyclopts/resolve.py index 9205b1a7..6a7ed21b 100644 --- a/cyclopts/resolve.py +++ b/cyclopts/resolve.py @@ -169,23 +169,26 @@ def __init__( iparam_to_docstring_cparam = _resolve_docstring(f, signature) if parse_docstring else ParameterDict() empty_help_string_parameter = Parameter(help="") for iparam, groups in self.iparam_to_groups.items(): - if iparam.kind in (iparam.POSITIONAL_ONLY, iparam.VAR_POSITIONAL): - # Name is only used for help-string - names = [iparam.name.upper()] - else: - names = ["--" + iparam.name.replace("_", "-")] - - default_name_parameter = Parameter(name=names) - cparam = get_hint_parameter( iparam, empty_help_string_parameter, app_parameter, *(x.default_parameter for x in groups), iparam_to_docstring_cparam.get(iparam), - default_name_parameter, Parameter(required=iparam.default is iparam.empty), )[1] + + # Resolve name now that ``name_transform`` has been resolved. + if iparam.kind in (iparam.POSITIONAL_ONLY, iparam.VAR_POSITIONAL): + # Name is only used for help-string + names = [iparam.name.upper()] + else: + # cparam.name_transform cannot be None due to: + # attrs.converters.default_if_none(default_name_transform) + assert cparam.name_transform is not None + names = ["--" + cparam.name_transform(iparam.name)] + + cparam = Parameter.combine(Parameter(name=names), cparam) self.iparam_to_cparam[iparam] = cparam self.bind = signature.bind_partial if _has_unparsed_parameters(signature, app_parameter) else signature.bind diff --git a/cyclopts/utils.py b/cyclopts/utils.py index 8a5efa5f..c6886b62 100644 --- a/cyclopts/utils.py +++ b/cyclopts/utils.py @@ -171,3 +171,27 @@ def optional_to_tuple_converter(value: Union[None, Any, Iterable[Any]]) -> Optio return () return to_tuple_converter(value) + + +def default_name_transform(s: str) -> str: + """Converts a python identifier into a CLI token. + + Performs the following operations (in order): + + 1. Convert the string to all lowercase. + 2. Replace ``_`` with ``-``. + 3. Strip any leading/trailing ``-`` (also stripping ``_``, due to point 2). + + Intended to be used with :attr:`App.name_transform` and :attr:`Parameter.name_transform`. + + Parameters + ---------- + s: str + Input python identifier string. + + Returns + ------- + str + Transformed name. + """ + return s.lower().replace("_", "-").strip("-") diff --git a/docs/source/api.rst b/docs/source/api.rst index f162bbc2..0a0e7082 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -161,6 +161,22 @@ API The raised error message will be presented to the user with python-variables prepended with "--" remapped to their CLI counterparts. + .. attribute:: name_transform + :type: Optional[Callable[[str], str]] + :value: None + + A function that converts function names to their CLI command counterparts. + + The function must have signature: + + .. code-block:: python + + def name_transform(s: str) -> str: + ... + + If :obj:`None` (default value), uses :func:`cyclopts.default_name_transform`. + If a subapp, inherits from first non-:obj:`None` parent. + .. autoclass:: cyclopts.Parameter Cyclopts configuration for individual function parameters. @@ -304,6 +320,21 @@ API If multiple environment variables are given, the left-most environment variable with a set value will be used. If no environment variable is set, Cyclopts will fallback to the function-signature default. + .. attribute:: name_transform + :type: Optional[Callable[[str], str]] + :value: None + + A function that converts python parameter names to their CLI command counterparts. + + The function must have signature: + + .. code-block:: python + + def name_transform(s: str) -> str: + ... + + If :obj:`None` (default value), uses :func:`cyclopts.default_name_transform`. + .. automethod:: combine .. automethod:: default @@ -480,6 +511,7 @@ API .. autofunction:: cyclopts.convert +.. autofunction:: cyclopts.default_name_transform .. _API Validators: diff --git a/docs/source/args_and_kwargs.rst b/docs/source/args_and_kwargs.rst index def2ee40..2972bb10 100644 --- a/docs/source/args_and_kwargs.rst +++ b/docs/source/args_and_kwargs.rst @@ -57,7 +57,7 @@ A variable number of keyword arguments consume all remaining CLI tokens starting Individual values are converted to the annotated type. As with normal python ``**kwargs``, the keywords are limited to python identifiers. Most prominently, no spaces allowed. -Keyword name-conversion is the :ref:`same as commands `. +Keyword name-conversion is the :ref:`same as commands `. .. code-block:: python diff --git a/docs/source/commands.rst b/docs/source/commands.rst index 8680e322..84236d03 100644 --- a/docs/source/commands.rst +++ b/docs/source/commands.rst @@ -11,7 +11,6 @@ There are 2 function-registering decorators: This was previously demonstrated in :ref:`Getting Started`. A sub-app **cannot** be registered with :meth:`@app.default `. - The default :meth:`app.default ` handler runs :meth:`app.help_print `. 2. :meth:`@app.command ` - Registers a function or :class:`.App` as a command. @@ -92,17 +91,23 @@ The :meth:`@app.command ` method can also register another The subcommand may have it's own registered ``default`` action. Cyclopts's command structure is fully recursive. -.. _Changing Name: +.. _Command Changing Name: ------------- Changing Name ------------- By default, a command is registered to the function name with underscores replaced with hyphens. Any leading or trailing underscore/hyphens will also be stripped. -For example, ``def _foo_bar()`` will become the command ``foo-bar``. +For example, the function ``_foo_bar()`` will become the command ``foo-bar``. +This automatic command name transform can be configured by :attr:`App.name_transform `. +For example, to make CLI command names be identical to their python function name counterparts, we can configure :class:`~cyclopts.App` as follows: + +.. code-block:: python + + app = App(name_transform=lambda s: s) -The name can be manually changed in the :meth:`@app.command ` decorator. -Manually set names are not subject to this name conversion. +Alternatively, the name can be manually changed in the :meth:`@app.command ` decorator. +Manually set names are not subject to :attr:`App.name_transform `. .. code-block:: python diff --git a/docs/source/parameters.rst b/docs/source/parameters.rst index 91f639bd..8ff6ef7f 100644 --- a/docs/source/parameters.rst +++ b/docs/source/parameters.rst @@ -63,6 +63,22 @@ Prior to Python 3.9, :obj:`~typing.Annotated` has to be imported from ``typing_e :class:`.Parameter` gives complete control on how Cyclopts processes the annotated parameter. See the API page for all configurable options. +------ +Naming +------ +Like :ref:`command names `, commandline parameters are derived from their python function argument counterparts. +This automatic command name transform can be configured by :attr:`Parameter.name_transform `. Note that the resulting string is **before** the standard ``--`` is prepended. + +To change the :attr:`~cyclopts.Parameter.name_transform` across your entire app, add the following to your :class:`~cyclopts.App` configuration: + +.. code-block:: python + + app = App( + default_parameter=Parameter(name_transform=my_custom_name_transform), + ) + +Manually set names via :attr:`Parameter.name ` are not subject to :attr:`Parameter.name_transform `. + ---- Help ---- diff --git a/docs/source/rules.rst b/docs/source/rules.rst index 4088a47b..26c8f742 100644 --- a/docs/source/rules.rst +++ b/docs/source/rules.rst @@ -273,11 +273,12 @@ Enum **** While `Literal`_ is the recommended way of providing the user options, another method is using :class:`~enum.Enum`. -For a user provided token, a **case-insensitive name** lookup is performed. +:attr:`Parameter.name_transform ` gets applied to all :class:`~enum.Enum` names, as well as the CLI provided token. +By default,this means that a **case-insensitive name** lookup is performed. If an enum name contains an underscore, the CLI parameter **may** instead contain a hyphen, ``-``. Leading/Trailing underscores will be stripped. -If coming from Typer_, **Cyclopts Enum handling is reversed compared to Typer**. +If coming from Typer_, **Cyclopts Enum handling is the reverse of Typer**. Typer attempts to match the token to an Enum **value**; Cyclopts attempts to match the token to an Enum **name**. diff --git a/tests/test_coercion.py b/tests/test_coercion.py index 3718a3ac..f44ed534 100644 --- a/tests/test_coercion.py +++ b/tests/test_coercion.py @@ -132,7 +132,10 @@ class SoftwareEnvironment(Enum): PROD = auto() _PROD_OLD = auto() + # tests case-insensitivity assert SoftwareEnvironment.STAGING == convert(SoftwareEnvironment, "staging") + + # tests underscore/hyphen support assert SoftwareEnvironment._PROD_OLD == convert(SoftwareEnvironment, "prod_old") assert SoftwareEnvironment._PROD_OLD == convert(SoftwareEnvironment, "prod-old") diff --git a/tests/test_name_transform.py b/tests/test_name_transform.py new file mode 100644 index 00000000..727f54f7 --- /dev/null +++ b/tests/test_name_transform.py @@ -0,0 +1,187 @@ +import sys +from enum import Enum, auto +from textwrap import dedent + +import pytest + +from cyclopts import App, Parameter, default_name_transform + +if sys.version_info < (3, 9): + from typing_extensions import Annotated # pragma: no cover +else: + from typing import Annotated # pragma: no cover + + +@pytest.mark.parametrize( + "before,after", + [ + ("FOO", "foo"), + ("_FOO", "foo"), + ("_FOO_", "foo"), + ("_F_O_O_", "f-o-o"), + ], +) +def test_default_name_transform(before, after): + assert default_name_transform(before) == after + + +def test_app_name_transform_default(app): + @app.command + def _F_O_O_(): # noqa: N802 + pass + + assert "f-o-o" in app + + +def test_app_name_transform_custom(app): + def name_transform(s: str) -> str: + return "my-custom-name-transform" + + app.name_transform = name_transform + + @app.command + def foo(): + pass + + assert "my-custom-name-transform" in app + + +def test_subapp_name_transform_custom(app): + """A subapp with an explicitly set ``name_transform`` should NOT inherit from parent.""" + + def name_transform_1(s: str) -> str: + return "my-custom-name-transform-1" + + def name_transform_2(s: str) -> str: + return "my-custom-name-transform-2" + + app.name_transform = name_transform_1 + + app.command(subapp := App(name="bar", name_transform=name_transform_2)) + + @subapp.command + def foo(): + pass + + assert "my-custom-name-transform-2" in subapp + + +def test_subapp_name_transform_custom_inherited(app): + """A subapp without an explicitly set ``name_transform`` should inherit it from the first parent.""" + + def name_transform(s: str) -> str: + return "my-custom-name-transform" + + app.name_transform = name_transform + + app.command(subapp := App(name="bar")) + + @subapp.command + def foo(): + pass + + assert "my-custom-name-transform" in subapp + + +def test_parameter_name_transform_default(app, assert_parse_args): + @app.default + def foo(*, b_a_r: int): + pass + + assert_parse_args(foo, "--b-a-r 5", b_a_r=5) + + +def test_parameter_name_transform_custom(app, assert_parse_args): + app.default_parameter = Parameter(name_transform=lambda s: s) + + @app.default + def foo(*, b_a_r: int): + pass + + assert_parse_args(foo, "--b_a_r 5", b_a_r=5) + + +def test_parameter_name_transform_custom_name_override(app, assert_parse_args): + app.default_parameter = Parameter(name_transform=lambda s: s) + + @app.default + def foo(*, b_a_r: Annotated[int, Parameter(name="--buzz")]): + pass + + assert_parse_args(foo, "--buzz 5", b_a_r=5) + + +def test_parameter_name_transform_custom_enum(app, assert_parse_args): + """name_transform should also be applied to enum options.""" + app.default_parameter = Parameter(name_transform=lambda s: s) + + class SoftwareEnvironment(Enum): + DEV = auto() + STAGING = auto() + PROD = auto() + _PROD_OLD = auto() + + @app.default + def foo(*, b_a_r: SoftwareEnvironment = SoftwareEnvironment.STAGING): + pass + + assert_parse_args(foo, "--b_a_r PROD", b_a_r=SoftwareEnvironment.PROD) + + +def test_parameter_name_transform_help(app, console): + app.default_parameter = Parameter(name_transform=lambda s: s) + + @app.default + def foo(*, b_a_r: int): + pass + + with console.capture() as capture: + app.help_print([], console=console) + + actual = capture.get() + expected = dedent( + """\ + Usage: foo COMMAND [OPTIONS] + + ╭─ Commands ─────────────────────────────────────────────────────────╮ + │ --help,-h Display this message and exit. │ + │ --version Display application version. │ + ╰────────────────────────────────────────────────────────────────────╯ + ╭─ Parameters ───────────────────────────────────────────────────────╮ + │ * --b_a_r [required] │ + ╰────────────────────────────────────────────────────────────────────╯ + """ + ) + assert actual == expected + + +def test_parameter_name_transform_help_enum(app, console): + """name_transform should also be applied to enum options on help page.""" + app.default_parameter = Parameter(name_transform=lambda s: s) + + class CompSciProblem(Enum): + FIZZ = "bleep bloop blop" + BUZZ = "blop bleep bloop" + + @app.command + def cmd( + foo: Annotated[CompSciProblem, Parameter(help="Docstring for foo.")] = CompSciProblem.FIZZ, + bar: Annotated[CompSciProblem, Parameter(help="Docstring for bar.")] = CompSciProblem.BUZZ, + ): + pass + + with console.capture() as capture: + app.help_print(["cmd"], console=console) + + actual = capture.get() + expected = dedent( + """\ + Usage: test_name_transform cmd [ARGS] [OPTIONS] + + ╭─ Parameters ───────────────────────────────────────────────────────╮ + │ FOO,--foo Docstring for foo. [choices: FIZZ,BUZZ] [default: FIZZ] │ + │ BAR,--bar Docstring for bar. [choices: FIZZ,BUZZ] [default: BUZZ] │ + ╰────────────────────────────────────────────────────────────────────╯ + """ + ) + assert actual == expected