Skip to content

Commit

Permalink
Improve error messages and track full command name
Browse files Browse the repository at this point in the history
  • Loading branch information
mfridman committed Dec 27, 2024
1 parent 6ae6f27 commit f27444c
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 23 deletions.
28 changes: 16 additions & 12 deletions parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ func Parse(root *Command, args []string) error {

// Initialize root state
if root.state == nil {
root.state = &State{cmd: root}
root.state = &State{
cmd: root,
fullName: root.Name,
}
}

// First split args at the -- delimiter if present
Expand All @@ -49,11 +52,10 @@ func Parse(root *Command, args []string) error {

// Create combined flags with all parent flags
combinedFlags := flag.NewFlagSet(root.Name, flag.ContinueOnError)
// TODO(mf): revisit this
// TODO(mf): revisit this output location
combinedFlags.SetOutput(io.Discard)

// First pass: process commands and build the flag set. This lets us capture help requests
// before any flag parsing errors
// First pass: process commands and build the flag set
for _, arg := range argsToParse {
if arg == "-h" || arg == "--h" || arg == "-help" || arg == "--help" {
combinedFlags.Usage = func() { _ = current.showHelp() }
Expand All @@ -67,7 +69,10 @@ func Parse(root *Command, args []string) error {
if len(current.SubCommands) > 0 {
if sub := current.findSubCommand(arg); sub != nil {
if sub.state == nil {
sub.state = &State{cmd: sub}
sub.state = &State{
cmd: sub,
fullName: current.state.fullName + " " + sub.Name,
}
}
if sub.Flags == nil {
sub.Flags = flag.NewFlagSet(sub.Name, flag.ContinueOnError)
Expand Down Expand Up @@ -102,7 +107,7 @@ func Parse(root *Command, args []string) error {
return fmt.Errorf("command %q: %w", current.Name, err)
}

// Check required flags by checking if they were actually set to non-default values
// Check required flags
var missingFlags []string
for _, cmd := range commandChain {
if len(cmd.FlagsMetadata) > 0 {
Expand All @@ -112,20 +117,19 @@ func Parse(root *Command, args []string) error {
}
flag := combinedFlags.Lookup(flagMetadata.Name)
if flag == nil {
return fmt.Errorf("command %q: internal error: required flag %q not found in flag set", current.Name, flagMetadata.Name)
return fmt.Errorf("command %q: internal error: required flag %q not found in flag set", current.state.fullName, flagMetadata.Name)
}
// Check if the flag was set by checking its actual value against default
if flag.Value.String() == flag.DefValue {
missingFlags = append(missingFlags, flagMetadata.Name)
missingFlags = append(missingFlags, formatFlagName(flagMetadata.Name))
}
}
}
}
if len(missingFlags) > 0 {
return fmt.Errorf("command %q: required flag(s) %q not set", current.Name, strings.Join(missingFlags, ", "))
return fmt.Errorf("command %q: required flag(s) %q not set", current.state.fullName, strings.Join(missingFlags, ", "))
}

// Skip past command names in remaining args from flag parsing
// Skip past command names in remaining args
parsed := combinedFlags.Args()
startIdx := 0
for _, arg := range parsed {
Expand Down Expand Up @@ -164,7 +168,7 @@ func validateCommands(root *Command, path []string) error {
}
// Ensure name has no spaces
if strings.Contains(root.Name, " ") {
return fmt.Errorf("command name %q contains spaces", root.Name)
return fmt.Errorf("command name %q contains spaces, must be a single word", root.Name)
}

// Add current command to path for nested validation
Expand Down
6 changes: 2 additions & 4 deletions parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,7 @@ func TestParse(t *testing.T) {

err := Parse(s.root, []string{"nested", "hello"})
require.Error(t, err)
// TODO(mf): this error message should have the full path to the command, e.g., "todo nested hello"
require.ErrorContains(t, err, `command "hello": required flag(s) "mandatory-flag" not set`)
require.ErrorContains(t, err, `command "todo nested hello": required flag(s) "-mandatory-flag" not set`)

// Correct type
err = Parse(s.root, []string{"nested", "hello", "--mandatory-flag", "true"})
Expand Down Expand Up @@ -346,7 +345,6 @@ func TestParse(t *testing.T) {
}
err := Parse(cmd, nil)
require.Error(t, err)
// TODO(mf): consider improving this error message so it's a bit more user-friendly
require.ErrorContains(t, err, `command name "sub command" contains spaces`)
require.ErrorContains(t, err, `command name "sub command" contains spaces, must be a single word`)
})
}
16 changes: 11 additions & 5 deletions state.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ type State struct {
Stdin io.Reader
Stdout, Stderr io.Writer

cmd *Command // Reference to the command this state belongs to
// The full name of the command, including parent commands. E.g., "cli todo list all"
fullName string
// Reference to the command this state belongs to
cmd *Command
parent *State
}

Expand All @@ -28,20 +31,19 @@ type State struct {
// count := GetFlag[int](state, "count")
// path := GetFlag[string](state, "path")
//
// If the flag isn't found, it panics with a detailed error message.
// If the flag isn't known, or is the wrong type, it panics with a detailed error message.
//
// Why panic? Because if a flag is missing, it's likely a programming error or a missing flag
// definition, and it's better to fail LOUD and EARLY than to silently ignore the issue and cause
// unexpected behavior.
func GetFlag[T any](s *State, name string) T {
// TODO(mf): we should have a way to get the selected command here to improve error messages
if f := s.cmd.Flags.Lookup(name); f != nil {
if getter, ok := f.Value.(flag.Getter); ok {
value := getter.Get()
if v, ok := value.(T); ok {
return v
}
msg := fmt.Sprintf("internal error: type mismatch for flag %q: registered %T, requested %T", name, value, *new(T))
msg := fmt.Sprintf("internal error: type mismatch for flag %q in command %q: registered %T, requested %T", formatFlagName(name), s.fullName, value, *new(T))
// Flag exists but type doesn't match - this is an internal error
panic(msg)
}
Expand All @@ -51,6 +53,10 @@ func GetFlag[T any](s *State, name string) T {
return GetFlag[T](s.parent, name)
}
// If flag not found anywhere in hierarchy, panic with helpful message
msg := fmt.Sprintf("internal error: flag not found: %q in %s flag set", name, s.cmd.Name)
msg := fmt.Sprintf("internal error: flag %q not found in %q flag set", formatFlagName(name), s.fullName)
panic(msg)
}

func formatFlagName(name string) string {
return "-" + name
}
4 changes: 2 additions & 2 deletions state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ func TestGetFlag(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
assert.Equal(t, `internal error: flag not found: "version" in root flag set`, r)
assert.Equal(t, `internal error: flag "-version" not found in "" flag set`, r)
}()
// Panic because author forgot to define the flag and tried to access it. This is a
// programming error and should be caught early
Expand All @@ -38,7 +38,7 @@ func TestGetFlag(t *testing.T) {
defer func() {
r := recover()
require.NotNil(t, r)
assert.Equal(t, `internal error: type mismatch for flag "version": registered string, requested int`, r)
assert.Equal(t, `internal error: type mismatch for flag "-version" in command "": registered string, requested int`, r)
}()
_ = GetFlag[int](st, "version")
})
Expand Down

0 comments on commit f27444c

Please sign in to comment.