Skip to content

Commit

Permalink
Merge pull request #150 from BrianPugh/name-transform
Browse files Browse the repository at this point in the history
User customizable name_transform
  • Loading branch information
BrianPugh authored Apr 10, 2024
2 parents f87ea85 + b19224c commit 09320d1
Show file tree
Hide file tree
Showing 14 changed files with 398 additions and 57 deletions.
2 changes: 2 additions & 0 deletions cyclopts/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"UnusedCliTokensError",
"ValidationError",
"convert",
"default_name_transform",
"types",
"validators",
]
Expand All @@ -36,5 +37,6 @@
from cyclopts.group import Group
from cyclopts.parameter import Parameter
from cyclopts.protocols import Dispatcher
from cyclopts.utils import default_name_transform

from . import types, validators
90 changes: 66 additions & 24 deletions cyclopts/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@


from cyclopts.exceptions import CoercionError
from cyclopts.utils import is_union
from cyclopts.utils import default_name_transform, is_union

if TYPE_CHECKING:
from cyclopts.parameter import Parameter
Expand Down Expand Up @@ -83,7 +83,13 @@ def _bytearray(s: str) -> bytearray:
}


def _convert_tuple(type_: Type[Any], *args: str, converter: Optional[Callable] = None) -> Tuple:
def _convert_tuple(
type_: Type[Any],
*args: str,
converter: Optional[Callable[[Type, str], Any]],
name_transform: Callable[[str], str],
) -> Tuple:
convert = partial(_convert, converter=converter, name_transform=name_transform)
inner_types = tuple(x for x in get_args(type_) if x is not ...)
inner_token_count, consume_all = token_count(type_)
if consume_all:
Expand All @@ -101,11 +107,10 @@ def _convert_tuple(type_: Type[Any], *args: str, converter: Optional[Callable] =
raise ValueError("A tuple must have 0 or 1 inner-types.")

if inner_token_count == 1:
out = tuple(_convert(inner_type, x, converter=converter) for x in args)
out = tuple(convert(inner_type, x) for x in args)
else:
out = tuple(
_convert(inner_type, args[i : i + inner_token_count], converter=converter)
for i in range(0, len(args), inner_token_count)
convert(inner_type, args[i : i + inner_token_count]) for i in range(0, len(args), inner_token_count)
)
return out
else:
Expand All @@ -116,35 +121,49 @@ def _convert_tuple(type_: Type[Any], *args: str, converter: Optional[Callable] =
it = iter(args)
batched = [[next(it) for _ in range(size)] for size in args_per_convert]
batched = [elem[0] if len(elem) == 1 else elem for elem in batched]
out = tuple(_convert(inner_type, arg, converter=converter) for inner_type, arg in zip(inner_types, batched))
out = tuple(convert(inner_type, arg) for inner_type, arg in zip(inner_types, batched))
return out


def _convert(type_, element, converter=None):
pconvert = partial(_convert, converter=converter)
def _convert(
type_,
element,
*,
converter: Optional[Callable[[Type, str], Any]],
name_transform: Callable[[str], str],
):
"""Inner recursive conversion function for public ``convert``.
Parameters
----------
converter: Callable
name_transform: Callable
"""
convert = partial(_convert, converter=converter, name_transform=name_transform)
convert_tuple = partial(_convert_tuple, converter=converter, name_transform=name_transform)
origin_type = get_origin(type_)
inner_types = [resolve(x) for x in get_args(type_)]

if type_ in _implicit_iterable_type_mapping:
return pconvert(_implicit_iterable_type_mapping[type_], element)
return convert(_implicit_iterable_type_mapping[type_], element)

if origin_type is collections.abc.Iterable:
assert len(inner_types) == 1
return pconvert(List[inner_types[0]], element) # pyright: ignore[reportGeneralTypeIssues]
return convert(List[inner_types[0]], element) # pyright: ignore[reportGeneralTypeIssues]
elif is_union(origin_type):
for t in inner_types:
if t is NoneType:
continue
try:
return pconvert(t, element)
return convert(t, element)
except Exception:
pass
else:
raise CoercionError(input_value=element, target_type=type_)
elif origin_type is Literal:
for choice in get_args(type_):
try:
res = pconvert(type(choice), (element))
res = convert(type(choice), (element))
except Exception:
continue
if res == choice:
Expand All @@ -157,18 +176,18 @@ def _convert(type_, element, converter=None):
gen = zip(*[iter(element)] * count)
else:
gen = element
return origin_type(pconvert(inner_types[0], e) for e in gen) # pyright: ignore[reportOptionalCall]
return origin_type(convert(inner_types[0], e) for e in gen) # pyright: ignore[reportOptionalCall]
elif origin_type is tuple:
if isinstance(element, str):
# E.g. Tuple[str] (Annotation: tuple containing a single string)
return _convert_tuple(type_, element, converter=converter)
return convert_tuple(type_, element, converter=converter)
else:
return _convert_tuple(type_, *element, converter=converter)
return convert_tuple(type_, *element, converter=converter)
elif isclass(type_) and issubclass(type_, Enum):
if converter is None:
element_lower = element.lower().replace("-", "_")
element_transformed = name_transform(element)
for member in type_:
if member.name.lower().strip("_") == element_lower:
if name_transform(member.name) == element_transformed:
return member
raise CoercionError(input_value=element, target_type=type_)
else:
Expand Down Expand Up @@ -240,7 +259,12 @@ def resolve_annotated(type_: Type) -> Type:
return type_


def convert(type_: Type, *args: str, converter: Optional[Callable] = None):
def convert(
type_: Type,
*args: str,
converter: Optional[Callable[[Type, str], Any]] = None,
name_transform: Optional[Callable[[str], str]] = None,
):
"""Coerce variables into a specified type.
Internally used to coercing string CLI tokens into python builtin types.
Expand All @@ -259,8 +283,7 @@ def convert(type_: Type, *args: str, converter: Optional[Callable] = None):
A type hint/annotation to coerce ``*args`` into.
`*args`: str
String tokens to coerce.
converter: Optional[Callable]
converter: Optional[Callable[[Type, str], Any]]
An optional function to convert tokens to the inner-most types.
The converter should have signature:
Expand All @@ -272,12 +295,31 @@ def converter(type_: type, value: str) -> Any:
This allows to use the :func:`convert` function to handle the the difficult task
of traversing lists/tuples/unions/etc, while leaving the final conversion logic to
the caller.
name_transform: Optional[Callable[[str], str]]
Currently only used for ``Enum`` type hints.
A function that transforms enum names and CLI values into a normalized format.
The function should have signature:
.. code-block:: python
def name_transform(s: str) -> str:
...
where the returned value is the name to be used on the CLI.
If ``None``, defaults to ``cyclopts.default_name_transform``.
Returns
-------
Any
Coerced version of input ``*args``.
"""
if name_transform is None:
name_transform = default_name_transform

convert = partial(_convert, converter=converter, name_transform=name_transform)
convert_tuple = partial(_convert_tuple, converter=converter, name_transform=name_transform)
type_ = resolve(type_)

if type_ is Any:
Expand All @@ -288,13 +330,13 @@ def converter(type_: type, value: str) -> Any:
origin_type = get_origin_and_validate(type_)

if origin_type is tuple:
return _convert_tuple(type_, *args, converter=converter)
return convert_tuple(type_, *args)
elif (origin_type or type_) in _iterable_types or origin_type is collections.abc.Iterable:
return _convert(type_, args, converter=converter)
return convert(type_, args)
elif len(args) == 1:
return _convert(type_, args[0], converter=converter)
return convert(type_, args[0])
else:
return [_convert(type_, item, converter=converter) for item in args]
return [convert(type_, item) for item in args]


def token_count(type_: Union[Type[Any], inspect.Parameter]) -> Tuple[int, bool]:
Expand Down
25 changes: 19 additions & 6 deletions cyclopts/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,13 @@
from cyclopts.parameter import Parameter, validate_command
from cyclopts.protocols import Dispatcher
from cyclopts.resolve import ResolvedCommand
from cyclopts.utils import optional_to_tuple_converter, to_list_converter, to_tuple_converter
from cyclopts.utils import default_name_transform, optional_to_tuple_converter, to_list_converter, to_tuple_converter

with suppress(ImportError):
# By importing, makes things like the arrow-keys work.
import readline # Not available on windows


def _format_name(name: str):
return name.lower().replace("_", "-").strip("-")


class _CannotDeriveCallingModuleNameError(Exception):
pass

Expand Down Expand Up @@ -224,6 +220,12 @@ class App:
converter: Optional[Callable] = field(default=None, kw_only=True)
validator: List[Callable] = field(default=None, converter=to_list_converter, kw_only=True)

_name_transform: Optional[Callable[[str], str]] = field(
default=None,
alias="name_transform",
kw_only=True,
)

######################
# Private Attributes #
######################
Expand Down Expand Up @@ -307,7 +309,7 @@ def name(self) -> Tuple[str, ...]:
name = _get_root_module_name()
return (name,)
else:
return (_format_name(self.default_command.__name__),)
return (self.name_transform(self.default_command.__name__),)

@property
def help(self) -> str:
Expand All @@ -328,6 +330,14 @@ def help(self) -> str:
def help(self, value):
self._help = value

@property
def name_transform(self):
return self._name_transform if self._name_transform else default_name_transform

@name_transform.setter
def name_transform(self, value):
self._name_transform = value

def version_print(self) -> None:
"""Print the application version."""
print(self.version() if callable(self.version) else self.version)
Expand Down Expand Up @@ -458,6 +468,9 @@ def command(
app = App(default_command=obj, **kwargs)
# app.name is handled below

if app._name_transform is None:
app.name_transform = self.name_transform

if name is None:
name = app.name
else:
Expand Down
18 changes: 10 additions & 8 deletions cyclopts/help.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import inspect
from enum import Enum
from functools import lru_cache
from functools import lru_cache, partial
from inspect import isclass
from typing import TYPE_CHECKING, List, Literal, Tuple, Type, Union, get_args, get_origin
from typing import TYPE_CHECKING, Callable, List, Literal, Tuple, Type, Union, get_args, get_origin

import docstring_parser
from attrs import define, field, frozen
Expand Down Expand Up @@ -190,20 +190,21 @@ def format_doc(root_app, app: "App", format: str = "restructuredtext"):
raise ValueError(f'Unknown help_format "{format}"')


def _get_choices(type_: Type) -> str:
def _get_choices(type_: Type, name_transform: Callable[[str], str]) -> str:
get_choices = partial(_get_choices, name_transform=name_transform)
choices: str = ""
_origin = get_origin(type_)
if isclass(type_) and issubclass(type_, Enum):
choices = ",".join(x.name.lower().replace("_", "-") for x in type_)
choices = ",".join(name_transform(x.name) for x in type_)
elif _origin is Union:
inner_choices = [_get_choices(inner) for inner in get_args(type_)]
inner_choices = [get_choices(inner) for inner in get_args(type_)]
choices = ",".join(x for x in inner_choices if x)
elif _origin is Literal:
choices = ",".join(str(x) for x in get_args(type_))
elif _origin in (list, set, tuple):
args = get_args(type_)
if len(args) == 1 or (_origin is tuple and len(args) == 2 and args[1] is Ellipsis):
choices = _get_choices(args[0])
choices = get_choices(args[0])
return choices


Expand All @@ -218,6 +219,7 @@ def create_parameter_help_panel(group: "Group", iparams, cparams: List[Parameter

for iparam, cparam in icparams:
assert cparam.name is not None
assert cparam.name_transform is not None
type_ = get_hint_parameter(iparam)[0]
options = list(cparam.name)
options.extend(cparam.get_negatives(type_, *options))
Expand All @@ -241,7 +243,7 @@ def create_parameter_help_panel(group: "Group", iparams, cparams: List[Parameter
help_components.append(cparam.help)

if cparam.show_choices:
choices = _get_choices(type_)
choices = _get_choices(type_, cparam.name_transform)
if choices:
help_components.append(rf"[dim]\[choices: {choices}][/dim]")

Expand All @@ -254,7 +256,7 @@ def create_parameter_help_panel(group: "Group", iparams, cparams: List[Parameter
):
default = ""
if isclass(type_) and issubclass(type_, Enum):
default = iparam.default.name.lower().replace("_", "-")
default = cparam.name_transform(iparam.default.name)
else:
default = iparam.default

Expand Down
15 changes: 13 additions & 2 deletions cyclopts/parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
from functools import partial
from typing import Any, Callable, Iterable, Optional, Tuple, Type, Union, cast, get_args, get_origin

import attrs
Expand All @@ -13,7 +14,7 @@
resolve_optional,
)
from cyclopts.group import Group
from cyclopts.utils import optional_to_tuple_converter, record_init, to_tuple_converter
from cyclopts.utils import default_name_transform, optional_to_tuple_converter, record_init, to_tuple_converter


def _double_hyphen_validator(instance, attribute, values):
Expand Down Expand Up @@ -48,7 +49,7 @@ class Parameter:
converter=lambda x: cast(Tuple[str, ...], to_tuple_converter(x)),
)

converter: Callable = field(default=None, converter=attrs.converters.default_if_none(convert))
_converter: Callable = field(default=None, alias="converter")

# This can ONLY ever be a Tuple[Callable, ...]
validator: Union[None, Callable, Iterable[Callable]] = field(
Expand Down Expand Up @@ -98,13 +99,23 @@ class Parameter:

allow_leading_hyphen: bool = field(default=False)

name_transform: Optional[Callable[[str], str]] = field(
default=None,
converter=attrs.converters.default_if_none(default_name_transform),
kw_only=True,
)

# Populated by the record_attrs_init_args decorator.
_provided_args: Tuple[str] = field(default=(), init=False, eq=False)

@property
def show(self):
return self._show if self._show is not None else self.parse

@property
def converter(self):
return self._converter if self._converter else partial(convert, name_transform=self.name_transform)

def get_negatives(self, type_, *names: str) -> Tuple[str, ...]:
type_ = get_origin(type_) or type_

Expand Down
Loading

0 comments on commit 09320d1

Please sign in to comment.