diff --git a/convert.go b/convert.go index b27f698..31795d7 100644 --- a/convert.go +++ b/convert.go @@ -5,6 +5,7 @@ package flags import ( + "encoding" "fmt" "reflect" "strconv" @@ -52,11 +53,15 @@ func getBase(options multiTag, base int) (int, error) { } func convertMarshal(val reflect.Value) (bool, string, error) { - // Check first for the Marshaler interface + // Check for a MarshalFlag or MarshalText interface: if val.IsValid() && val.Type().NumMethod() > 0 && val.CanInterface() { - if marshaler, ok := val.Interface().(Marshaler); ok { + switch marshaler := val.Interface().(type) { + case Marshaler: ret, err := marshaler.MarshalFlag() return true, ret, err + case encoding.TextMarshaler: + ret, err := marshaler.MarshalText() + return true, string(ret), err } } @@ -166,15 +171,22 @@ func convertToString(val reflect.Value, options multiTag) (string, error) { func convertUnmarshal(val string, retval reflect.Value) (bool, error) { if retval.Type().NumMethod() > 0 && retval.CanInterface() { - if unmarshaler, ok := retval.Interface().(Unmarshaler); ok { + iFace := retval.Interface() + switch unmarshaler := iFace.(type) { + case Unmarshaler: if retval.IsNil() { retval.Set(reflect.New(retval.Type().Elem())) - // Re-assign from the new value unmarshaler = retval.Interface().(Unmarshaler) } - return true, unmarshaler.UnmarshalFlag(val) + case encoding.TextUnmarshaler: + if retval.IsNil() { + retval.Set(reflect.New(retval.Type().Elem())) + // Re-assign from the new value + unmarshaler = retval.Interface().(encoding.TextUnmarshaler) + } + return true, unmarshaler.UnmarshalText([]byte(val)) } } diff --git a/convert_test.go b/convert_test.go index 21982ae..fda6e87 100644 --- a/convert_test.go +++ b/convert_test.go @@ -176,3 +176,68 @@ func TestConvertToMapWithDelimiter(t *testing.T) { assertString(t, opts.StringStringMap["key"], "value") } + +type testEnum int + +const ( + one testEnum = iota + two + three +) + +func (t *testEnum) UnmarshalText(text []byte) error { + switch string(text) { + case "one": + *t = one + case "two": + *t = two + case "three": + *t = three + default: + return newErrorf(ErrMarshal, "invalid value %q", text) + } + return nil +} + +func (t testEnum) MarshalText() ([]byte, error) { + switch t { + case one: + return []byte("one"), nil + case two: + return []byte("two"), nil + case three: + return []byte("three"), nil + default: + return nil, newErrorf(ErrMarshal, "invalid value %q", t) + } +} + +func TestConvertUsesUnmarshalText(t *testing.T) { + var opt = struct { + Enum testEnum `long:"enum" required:"true"` + }{0} + + p := NewNamedParser("test", Default) + _, err := p.AddCommand("mycmd", "test", "test", &opt) + if err != nil { + t.Fatalf("error not expected %+v", err) + } + _, err = p.ParseArgs([]string{"mycmd", "--enum=three"}) + if err != nil { + t.Fatalf("error not expected %+v", err) + } + if opt.Enum != three { + t.Fatalf("expected three, got %v", opt.Enum) + } + + grp, _ := p.AddGroup("test group", "", &opt) + o := grp.Options()[0] + + marshalled, err := convertToString(o.value, o.tag) + if err != nil { + t.Fatalf("error not expected %+v", err) + } + if marshalled != "three" { + t.Fatalf("expected three, got %v", marshalled) + } +}