Skip to content

Commit

Permalink
Add tests and doc.
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab committed Nov 9, 2024
1 parent ccb0d6f commit cff7488
Show file tree
Hide file tree
Showing 3 changed files with 198 additions and 26 deletions.
38 changes: 38 additions & 0 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -969,6 +969,44 @@ For `BaseModel` and `pydantic.dataclasses.dataclass` types, `CliApp.run` will in
The alias generator for kebab case does not propagate to subcommands or submodels and will have to be manually set
in these cases.

### Mutually Exclusive Groups

CLI mutually exclusive groups can be created by inheriting from the `CliMutuallyExclusiveGroup` class.

!!! note
A `CliMutuallyExclusiveGroup` cannot be used in a union or contain nested models.

```py
from typing import Optional

from pydantic import BaseModel

from pydantic_settings import CliApp, CliMutuallyExclusiveGroup, SettingsError


class Circle(CliMutuallyExclusiveGroup):
radius: Optional[float] = None
diameter: Optional[float] = None
perimeter: Optional[float] = None


class Settings(BaseModel):
circle: Circle


try:
CliApp.run(
Settings,
cli_args=['--circle.radius=1', '--circle.diameter=2'],
cli_exit_on_error=False,
)
except SettingsError as e:
print(e)
"""
error parsing CLI: argument --circle.diameter: not allowed with argument --circle.radius
"""
```

### Customizing the CLI Experience

The below flags can be used to customise the CLI experience to your needs.
Expand Down
31 changes: 18 additions & 13 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,7 +1487,7 @@ def _connect_parser_method(
if (
parser_method is not None
and self.case_sensitive is False
and method_name == 'parsed_args_method'
and method_name == 'parse_args_method'
and isinstance(self._root_parser, _CliInternalArgParser)
):

Expand Down Expand Up @@ -1522,17 +1522,18 @@ def none_parser_method(*args: Any, **kwargs: Any) -> Any:
def _connect_group_method(self, add_argument_group_method: Callable[..., Any] | None) -> Callable[..., Any]:
add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')

def add_group_method(parser: Any, model: type[BaseModel], **kwargs: Any) -> Any:
if not issubclass(model, CliMutuallyExclusiveGroup):
def add_group_method(parser: Any, **kwargs: Any) -> Any:
if not kwargs.pop('_is_cli_mutually_exclusive_group'):
kwargs.pop('required')
return add_argument_group(parser, **kwargs)
else:
group = add_argument_group(
parser, **{arg: kwargs.pop(arg) for arg in ['title', 'description'] if arg in kwargs}
)
main_group_kwargs = {arg: kwargs.pop(arg) for arg in ['title', 'description'] if arg in kwargs}
main_group_kwargs['title'] += ' (mutually exclusive)'
group = add_argument_group(parser, **main_group_kwargs)
if not hasattr(group, 'add_mutually_exclusive_group'):
raise SettingsError(
'cannot connect CLI settings source root parser: add_mutually_exclusive_group is set to `None` but is needed for connecting'
'cannot connect CLI settings source root parser: '
'group object is missing add_mutually_exclusive_group but is needed for connecting'
)
return group.add_mutually_exclusive_group(**kwargs)

Expand All @@ -1554,7 +1555,7 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace:
self._root_parser = root_parser
if parse_args_method is None:
parse_args_method = _parse_known_args if self.cli_ignore_unknown_args else ArgumentParser.parse_args
self._parse_args = self._connect_parser_method(parse_args_method, 'parsed_args_method')
self._parse_args = self._connect_parser_method(parse_args_method, 'parse_args_method')
self._add_argument = self._connect_parser_method(add_argument_method, 'add_argument_method')
self._add_group = self._connect_group_method(add_argument_group_method)
self._add_parser = self._connect_parser_method(add_parser_method, 'add_parser_method')
Expand Down Expand Up @@ -1695,7 +1696,7 @@ def _add_parser_args(
elif not is_alias_path_only:
if group is not None:
if isinstance(group, dict):
group = self._add_group(parser, model, **group)
group = self._add_group(parser, **group)
added_args += list(arg_names)
self._add_argument(group, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs)
else:
Expand All @@ -1704,7 +1705,7 @@ def _add_parser_args(
parser, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs
)

self._add_parser_alias_paths(parser, model, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
return parser

def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, model_default: Any) -> None:
Expand Down Expand Up @@ -1764,6 +1765,11 @@ def _add_parser_submodels(
model_group_kwargs['title'] = f'{arg_names[0]} options'
model_group_kwargs['description'] = field_info.description
model_group_kwargs['required'] = kwargs['required']
model_group_kwargs['_is_cli_mutually_exclusive_group'] = any(
issubclass(model, CliMutuallyExclusiveGroup) for model in sub_models
)
if model_group_kwargs['_is_cli_mutually_exclusive_group'] and len(sub_models) > 1:
raise SettingsError('cannot use union with CliMutuallyExclusiveGroup')
if self.cli_use_class_docs_for_groups and len(sub_models) == 1:
model_group_kwargs['description'] = None if sub_models[0].__doc__ is None else dedent(sub_models[0].__doc__)

Expand All @@ -1786,7 +1792,7 @@ def _add_parser_submodels(
if not self.cli_avoid_json:
added_args.append(arg_names[0])
kwargs['help'] = f'set {arg_names[0]} from JSON string'
model_group = self._add_group(parser, model, **model_group_kwargs)
model_group = self._add_group(parser, **model_group_kwargs)
self._add_argument(model_group, *(f'{flag_prefix}{name}' for name in arg_names), **kwargs)
for model in sub_models:
self._add_parser_args(
Expand All @@ -1803,7 +1809,6 @@ def _add_parser_submodels(
def _add_parser_alias_paths(
self,
parser: Any,
model: type[BaseModel],
alias_path_args: dict[str, str],
added_args: list[str],
arg_prefix: str,
Expand All @@ -1813,7 +1818,7 @@ def _add_parser_alias_paths(
if alias_path_args:
context = parser
if group is not None:
context = self._add_group(parser, model, **group) if isinstance(group, dict) else group
context = self._add_group(parser, **group) if isinstance(group, dict) else group
is_nested_alias_path = arg_prefix.endswith('.')
arg_prefix = arg_prefix[:-1] if is_nested_alias_path else arg_prefix
for name, metavar in alias_path_args.items():
Expand Down
155 changes: 142 additions & 13 deletions tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
CLI_SUPPRESS,
CliExplicitFlag,
CliImplicitFlag,
CliMutuallyExclusiveGroup,
CliPositionalArg,
CliSettingsSource,
CliSubCommand,
Expand Down Expand Up @@ -79,30 +80,30 @@ class SettingWithIgnoreEmpty(BaseSettings):
class CliDummyArgGroup(BaseModel, arbitrary_types_allowed=True):
group: argparse._ArgumentGroup

def add_argument(self, *args, **kwargs) -> None:
def add_argument(self, *args: Any, **kwargs: Any) -> None:
self.group.add_argument(*args, **kwargs)


class CliDummySubParsers(BaseModel, arbitrary_types_allowed=True):
sub_parser: argparse._SubParsersAction

def add_parser(self, *args, **kwargs) -> 'CliDummyParser':
def add_parser(self, *args: Any, **kwargs: Any) -> 'CliDummyParser':
return CliDummyParser(parser=self.sub_parser.add_parser(*args, **kwargs))


class CliDummyParser(BaseModel, arbitrary_types_allowed=True):
parser: argparse.ArgumentParser = Field(default_factory=lambda: argparse.ArgumentParser())

def add_argument(self, *args, **kwargs) -> None:
def add_argument(self, *args: Any, **kwargs: Any) -> None:
self.parser.add_argument(*args, **kwargs)

def add_argument_group(self, *args, **kwargs) -> CliDummyArgGroup:
def add_argument_group(self, *args: Any, **kwargs: Any) -> CliDummyArgGroup:
return CliDummyArgGroup(group=self.parser.add_argument_group(*args, **kwargs))

def add_subparsers(self, *args, **kwargs) -> CliDummySubParsers:
def add_subparsers(self, *args: Any, **kwargs: Any) -> CliDummySubParsers:
return CliDummySubParsers(sub_parser=self.parser.add_subparsers(*args, **kwargs))

def parse_args(self, *args, **kwargs) -> argparse.Namespace:
def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace:
return self.parser.parse_args(*args, **kwargs)


Expand Down Expand Up @@ -1786,40 +1787,40 @@ class Cfg(BaseSettings):

args = ['--fruit', 'pear']
parsed_args = parser.parse_args(args)
assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == {
assert CliApp.run(Cfg, cli_args=parsed_args, cli_settings_source=cli_cfg_settings).model_dump() == {
'pet': 'bird',
'command': None,
}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == {
assert CliApp.run(Cfg, cli_args=args, cli_settings_source=cli_cfg_settings).model_dump() == {
'pet': 'bird',
'command': None,
}

arg_prefix = f'{prefix}.' if prefix else ''
args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog']
parsed_args = parser.parse_args(args)
assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == {
assert CliApp.run(Cfg, cli_args=parsed_args, cli_settings_source=cli_cfg_settings).model_dump() == {
'pet': 'dog',
'command': None,
}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == {
assert CliApp.run(Cfg, cli_args=args, cli_settings_source=cli_cfg_settings).model_dump() == {
'pet': 'dog',
'command': None,
}

parsed_args = parser.parse_args(['--fruit', 'kiwi', f'--{arg_prefix}pet', 'cat'])
assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == {
assert CliApp.run(Cfg, cli_args=vars(parsed_args), cli_settings_source=cli_cfg_settings).model_dump() == {
'pet': 'cat',
'command': None,
}

args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog', 'command', '--name', 'ralph', '--command', 'roll']
parsed_args = parser.parse_args(args)
assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == {
assert CliApp.run(Cfg, cli_args=vars(parsed_args), cli_settings_source=cli_cfg_settings).model_dump() == {
'pet': 'dog',
'command': {'name': 'ralph', 'command': 'roll'},
}
assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == {
assert CliApp.run(Cfg, cli_args=args, cli_settings_source=cli_cfg_settings).model_dump() == {
'pet': 'dog',
'command': {'name': 'ralph', 'command': 'roll'},
}
Expand Down Expand Up @@ -2045,3 +2046,131 @@ class Settings(BaseSettings, cli_parse_args=True):
-h, --help show this help message and exit
"""
)


def test_cli_mutually_exclusive_group(capsys):
class Circle(CliMutuallyExclusiveGroup):
radius: Optional[float] = 21
diameter: Optional[float] = 22
perimeter: Optional[float] = 23

class Settings(BaseModel):
circle_optional: Circle = Circle(radius=None, diameter=None, perimeter=24)
circle_required: Circle

CliApp.run(Settings, cli_args=['--circle-required.radius=1', '--circle-optional.radius=1']).model_dump() == {
'circle_optional': {'radius': 1, 'diameter': 22, 'perimeter': 24},
'circle_required': {'radius': 1, 'diameter': 22, 'perimeter': 23},
}

with pytest.raises(SystemExit):
CliApp.run(Settings, cli_args=['--circle-required.radius=1', '--circle-required.diameter=2'])
assert (
'error: argument --circle-required.diameter: not allowed with argument --circle-required.radius'
in capsys.readouterr().err
)

with pytest.raises(SystemExit):
CliApp.run(
Settings,
cli_args=['--circle-required.radius=1', '--circle-optional.radius=1', '--circle-optional.diameter=2'],
)
assert (
'error: argument --circle-optional.diameter: not allowed with argument --circle-optional.radius'
in capsys.readouterr().err
)

with pytest.raises(SystemExit):
CliApp.run(Settings, cli_args=['--help'])
assert (
capsys.readouterr().out
== f"""usage: example.py [-h] [--circle-optional.radius float |
--circle-optional.diameter float |
--circle-optional.perimeter float]
(--circle-required.radius float |
--circle-required.diameter float |
--circle-required.perimeter float)
{ARGPARSE_OPTIONS_TEXT}:
-h, --help show this help message and exit
circle-optional options (mutually exclusive):
--circle-optional.radius float
(default: None)
--circle-optional.diameter float
(default: None)
--circle-optional.perimeter float
(default: 24.0)
circle-required options (mutually exclusive):
--circle-required.radius float
(default: 21)
--circle-required.diameter float
(default: 22)
--circle-required.perimeter float
(default: 23)
"""
)


def test_cli_mutually_exclusive_group_exceptions():
class Circle(CliMutuallyExclusiveGroup):
radius: Optional[float] = 21
diameter: Optional[float] = 22
perimeter: Optional[float] = 23

class Settings(BaseSettings):
circle: Circle

parser = CliDummyParser()
with pytest.raises(
SettingsError,
match='cannot connect CLI settings source root parser: group object is missing add_mutually_exclusive_group but is needed for connecting',
):
CliSettingsSource(
Settings,
root_parser=parser,
parse_args_method=CliDummyParser.parse_args,
add_argument_method=CliDummyParser.add_argument,
add_argument_group_method=CliDummyParser.add_argument_group,
add_parser_method=CliDummySubParsers.add_parser,
add_subparsers_method=CliDummyParser.add_subparsers,
)

class SubModel(BaseModel):
pass

class SettingsInvalidUnion(BaseSettings):
union: Union[Circle, SubModel]

with pytest.raises(SettingsError, match='cannot use union with CliMutuallyExclusiveGroup'):
CliApp.run(SettingsInvalidUnion)

class CircleInvalidSubModel(Circle):
square: Optional[SubModel] = None

class SettingsInvalidOptSubModel(BaseModel):
circle: CircleInvalidSubModel = CircleInvalidSubModel()

class SettingsInvalidReqSubModel(BaseModel):
circle: CircleInvalidSubModel

for settings in [SettingsInvalidOptSubModel, SettingsInvalidReqSubModel]:
with pytest.raises(SettingsError, match='cannot have nested models in a CliMutuallyExclusiveGroup'):
CliApp.run(settings)

class CircleRequiredField(Circle):
length: float

class SettingsOptCircleReqField(BaseModel):
circle: CircleRequiredField = CircleRequiredField(length=2)

assert CliApp.run(SettingsOptCircleReqField, cli_args=[]).model_dump() == {
'circle': {'diameter': 22.0, 'length': 2.0, 'perimeter': 23.0, 'radius': 21.0}
}

class SettingsInvalidReqCircleReqField(BaseModel):
circle: CircleRequiredField

with pytest.raises(ValueError, match='mutually exclusive arguments must be optional'):
CliApp.run(SettingsInvalidReqCircleReqField)

0 comments on commit cff7488

Please sign in to comment.