Skip to content

Commit

Permalink
Add support for CliMutuallyExclusiveGroup.
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab committed Nov 8, 2024
1 parent 87ad4db commit ccb0d6f
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 5 deletions.
2 changes: 2 additions & 0 deletions pydantic_settings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
AzureKeyVaultSettingsSource,
CliExplicitFlag,
CliImplicitFlag,
CliMutuallyExclusiveGroup,
CliPositionalArg,
CliSettingsSource,
CliSubCommand,
Expand Down Expand Up @@ -34,6 +35,7 @@
'CliPositionalArg',
'CliExplicitFlag',
'CliImplicitFlag',
'CliMutuallyExclusiveGroup',
'InitSettingsSource',
'JsonConfigSettingsSource',
'PyprojectTomlConfigSettingsSource',
Expand Down
44 changes: 39 additions & 5 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,10 @@ def error(self, message: str) -> NoReturn:
super().error(message)


class CliMutuallyExclusiveGroup(BaseModel):
pass


T = TypeVar('T')
CliSubCommand = Annotated[Union[T, None], _CliSubCommand]
CliPositionalArg = Annotated[T, _CliPositionalArg]
Expand Down Expand Up @@ -1515,6 +1519,25 @@ def none_parser_method(*args: Any, **kwargs: Any) -> Any:
else:
return parser_method

def _connect_group_method(self, add_argument_group_method: Callable[..., Any] | None) -> Callable[..., Any]:
add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')

def add_group_method(parser: Any, model: type[BaseModel], **kwargs: Any) -> Any:
if not issubclass(model, CliMutuallyExclusiveGroup):
kwargs.pop('required')
return add_argument_group(parser, **kwargs)
else:
group = add_argument_group(
parser, **{arg: kwargs.pop(arg) for arg in ['title', 'description'] if arg in kwargs}
)
if not hasattr(group, 'add_mutually_exclusive_group'):
raise SettingsError(
'cannot connect CLI settings source root parser: add_mutually_exclusive_group is set to `None` but is needed for connecting'
)
return group.add_mutually_exclusive_group(**kwargs)

return add_group_method

def _connect_root_parser(
self,
root_parser: T,
Expand All @@ -1533,7 +1556,7 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace:
parse_args_method = _parse_known_args if self.cli_ignore_unknown_args else ArgumentParser.parse_args
self._parse_args = self._connect_parser_method(parse_args_method, 'parsed_args_method')
self._add_argument = self._connect_parser_method(add_argument_method, 'add_argument_method')
self._add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method')
self._add_group = self._connect_group_method(add_argument_group_method)
self._add_parser = self._connect_parser_method(add_parser_method, 'add_parser_method')
self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method')
self._formatter_class = formatter_class
Expand Down Expand Up @@ -1656,6 +1679,7 @@ def _add_parser_args(
if is_parser_submodel:
self._add_parser_submodels(
parser,
model,
sub_models,
added_args,
arg_prefix,
Expand All @@ -1671,7 +1695,7 @@ def _add_parser_args(
elif not is_alias_path_only:
if group is not None:
if isinstance(group, dict):
group = self._add_argument_group(parser, **group)
group = self._add_group(parser, model, **group)
added_args += list(arg_names)
self._add_argument(group, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs)
else:
Expand All @@ -1680,7 +1704,7 @@ def _add_parser_args(
parser, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs
)

self._add_parser_alias_paths(parser, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
self._add_parser_alias_paths(parser, model, alias_path_args, added_args, arg_prefix, subcommand_prefix, group)
return parser

def _convert_bool_flag(self, kwargs: dict[str, Any], field_info: FieldInfo, model_default: Any) -> None:
Expand Down Expand Up @@ -1715,6 +1739,7 @@ def _get_arg_names(
def _add_parser_submodels(
self,
parser: Any,
model: type[BaseModel],
sub_models: list[type[BaseModel]],
added_args: list[str],
arg_prefix: str,
Expand All @@ -1727,10 +1752,18 @@ def _add_parser_submodels(
alias_names: tuple[str, ...],
model_default: Any,
) -> None:
if issubclass(model, CliMutuallyExclusiveGroup):
# Argparse has deprecated "calling add_argument_group() or add_mutually_exclusive_group() on a
# mutually exclusive group" (https://docs.python.org/3/library/argparse.html#mutual-exclusion).
# Since nested models result in a group add, raise an exception for nested models in a mutually
# exclusive group.
raise SettingsError('cannot have nested models in a CliMutuallyExclusiveGroup')

model_group: Any = None
model_group_kwargs: dict[str, Any] = {}
model_group_kwargs['title'] = f'{arg_names[0]} options'
model_group_kwargs['description'] = field_info.description
model_group_kwargs['required'] = kwargs['required']
if self.cli_use_class_docs_for_groups and len(sub_models) == 1:
model_group_kwargs['description'] = None if sub_models[0].__doc__ is None else dedent(sub_models[0].__doc__)

Expand All @@ -1753,7 +1786,7 @@ def _add_parser_submodels(
if not self.cli_avoid_json:
added_args.append(arg_names[0])
kwargs['help'] = f'set {arg_names[0]} from JSON string'
model_group = self._add_argument_group(parser, **model_group_kwargs)
model_group = self._add_group(parser, model, **model_group_kwargs)
self._add_argument(model_group, *(f'{flag_prefix}{name}' for name in arg_names), **kwargs)
for model in sub_models:
self._add_parser_args(
Expand All @@ -1770,6 +1803,7 @@ def _add_parser_submodels(
def _add_parser_alias_paths(
self,
parser: Any,
model: type[BaseModel],
alias_path_args: dict[str, str],
added_args: list[str],
arg_prefix: str,
Expand All @@ -1779,7 +1813,7 @@ def _add_parser_alias_paths(
if alias_path_args:
context = parser
if group is not None:
context = self._add_argument_group(parser, **group) if isinstance(group, dict) else group
context = self._add_group(parser, model, **group) if isinstance(group, dict) else group
is_nested_alias_path = arg_prefix.endswith('.')
arg_prefix = arg_prefix[:-1] if is_nested_alias_path else arg_prefix
for name, metavar in alias_path_args.items():
Expand Down

0 comments on commit ccb0d6f

Please sign in to comment.