Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow choices restriction #384

Open
Borda opened this issue Feb 8, 2022 · 7 comments
Open

Allow choices restriction #384

Borda opened this issue Feb 8, 2022 · 7 comments

Comments

@Borda
Copy link
Contributor

Borda commented Feb 8, 2022

Hello, and thank you for this great CLI!
Recently I get to a situation when I would like to restrict the options for a given argument similar to build-in argparse does with its option choices (see docs: https://docs.python.org/3/library/argparse.html#choices). Then I was checking Fire docs but could not find anything similar to it...
Checking alternative CLI packages I found a way that is quite simple but still elegant and would well fit the Fire style. It is leveraging python Enum class:

from enum import Enum
import fire

class Direction(str, Enum):
    up = "up"
    down = "down"
    left = "left"
    right = "right"


def main(move: Direction = Direction.left):
    print(f"Moving in given direction: {move.value}")


if __name__ == "__main__":
    fire.Fire(main)

For clarification, the example above is borrowed and adjusted from Typer/enum

@chris-clem
Copy link

That would be great! I use choices quite a lot.

@narothsolo
Copy link

Please help Star

@dbieber
Copy link
Member

dbieber commented Feb 22, 2022

Great idea. We don't currently use type annotations in fire to impose restrictions (but we could in a future version, though no one is actively working toward it atm).

Side note: One alternative that works today is to use a decorator, roughly like this:

def restrict_choices(choices):
  def decorator(f):
    def new_f(x):
      if x not in choices:
        raise FireError("Invalid choice")
      return f(x)
    return new_f
  return decorator

@restrict_choices(['left', 'right'])
def main(move):
    print(f"Moving in given direction: {move}")

See also SetParseFns in https://github.com/google/python-fire/blob/master/fire/decorators.py

@keyboardAnt
Copy link

@Borda
Copy link
Contributor Author

Borda commented Dec 8, 2023

You might also find the HfArgumentParser relevant: https://github.com/huggingface/transformers/blob/514de24abfd4416aeba6a6455ad5920f57f3567d/src/transformers/hf_argparser.py#L109

Not really if you have to install full HF package for it...

@keyboardAnt
Copy link

You might also find the HfArgumentParser relevant: https://github.com/huggingface/transformers/blob/514de24abfd4416aeba6a6455ad5920f57f3567d/src/transformers/hf_argparser.py#L109

Not really if you have to install full HF package for it...

The alternative below doesn't need the HF package. It is simple and readable but creates the Config object twice.

from pydantic import BaseModel

class Config(BaseModel):
    ...

def main(**kwargs):
    config = Config().model_copy(update=kwargs)

if __name__ == "__main__":
    fire.Fire(main)

@hesic73
Copy link

hesic73 commented Oct 22, 2024

Thanks! Here is a more generalized version with the help of GPT. The problem is that if sig.bind raises an TypeError Exception, fire won't work.

from typing import Union, List, Any
import inspect
from fire.core import FireError


def restrict_choices(arg_name_or_position: Union[int, str], choices: List[Any]):
    def decorator(f):
        sig = inspect.signature(f)  # Get the function signature

        def new_f(*args, **kwargs):
            # Map arguments by position and name
            bound_args = sig.bind(*args, **kwargs)
            bound_args.apply_defaults()  # Handle any default arguments

            # Determine if we're restricting by name or position
            if isinstance(arg_name_or_position, str):
                # Restrict by argument name
                if arg_name_or_position in bound_args.arguments:
                    restricted_arg = bound_args.arguments[arg_name_or_position]
                    arg_identifier = f"argument '{arg_name_or_position}'"
                else:
                    raise FireError(
                        f"Argument '{arg_name_or_position}' not found")
            elif isinstance(arg_name_or_position, int):
                # Restrict by argument position
                if arg_name_or_position < len(bound_args.args):
                    restricted_arg = bound_args.args[arg_name_or_position]
                    arg_identifier = f"position {arg_name_or_position}"
                else:
                    raise FireError(
                        f"Argument position {arg_name_or_position} is out of range")
            else:
                raise FireError(
                    "Invalid argument specifier, must be a name (str) or position (int)")

            # Check if the restricted argument is in the allowed choices
            if restricted_arg not in choices:
                raise FireError(
                    f"Invalid choice '{restricted_arg}' for {arg_identifier}. "
                    f"Valid choices are: {choices}")

            # Call the original function if the check passes
            return f(*args, **kwargs)

        return new_f
    return decorator

Examples:

@restrict_choices('direction', ['left', 'right'])
def move(direction, speed):
    print(f"Moving {direction} at speed {speed}")


@restrict_choices(1, ['left', 'right'])
def move(speed, direction):
    print(f"Moving {direction} at speed {speed}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

6 participants