Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use encoding.TextMarshaler/encoding.TextUnmarshaler if present #413

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package flags

import (
"encoding"
"fmt"
"reflect"
"strconv"
Expand Down Expand Up @@ -52,11 +53,15 @@ func getBase(options multiTag, base int) (int, error) {
}

func convertMarshal(val reflect.Value) (bool, string, error) {
// Check first for the Marshaler interface
// Check for a MarshalFlag or MarshalText interface:
if val.IsValid() && val.Type().NumMethod() > 0 && val.CanInterface() {
if marshaler, ok := val.Interface().(Marshaler); ok {
switch marshaler := val.Interface().(type) {
case Marshaler:
ret, err := marshaler.MarshalFlag()
return true, ret, err
case encoding.TextMarshaler:
ret, err := marshaler.MarshalText()
return true, string(ret), err
}
}

Expand Down Expand Up @@ -166,15 +171,22 @@ func convertToString(val reflect.Value, options multiTag) (string, error) {

func convertUnmarshal(val string, retval reflect.Value) (bool, error) {
if retval.Type().NumMethod() > 0 && retval.CanInterface() {
if unmarshaler, ok := retval.Interface().(Unmarshaler); ok {
iFace := retval.Interface()
switch unmarshaler := iFace.(type) {
case Unmarshaler:
if retval.IsNil() {
retval.Set(reflect.New(retval.Type().Elem()))

// Re-assign from the new value
unmarshaler = retval.Interface().(Unmarshaler)
}

return true, unmarshaler.UnmarshalFlag(val)
case encoding.TextUnmarshaler:
if retval.IsNil() {
retval.Set(reflect.New(retval.Type().Elem()))
// Re-assign from the new value
unmarshaler = retval.Interface().(encoding.TextUnmarshaler)
}
return true, unmarshaler.UnmarshalText([]byte(val))
}
}

Expand Down
65 changes: 65 additions & 0 deletions convert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,3 +176,68 @@ func TestConvertToMapWithDelimiter(t *testing.T) {

assertString(t, opts.StringStringMap["key"], "value")
}

type testEnum int

const (
one testEnum = iota
two
three
)

func (t *testEnum) UnmarshalText(text []byte) error {
switch string(text) {
case "one":
*t = one
case "two":
*t = two
case "three":
*t = three
default:
return newErrorf(ErrMarshal, "invalid value %q", text)
}
return nil
}

func (t testEnum) MarshalText() ([]byte, error) {
switch t {
case one:
return []byte("one"), nil
case two:
return []byte("two"), nil
case three:
return []byte("three"), nil
default:
return nil, newErrorf(ErrMarshal, "invalid value %q", t)
}
}

func TestConvertUsesUnmarshalText(t *testing.T) {
var opt = struct {
Enum testEnum `long:"enum" required:"true"`
}{0}

p := NewNamedParser("test", Default)
_, err := p.AddCommand("mycmd", "test", "test", &opt)
if err != nil {
t.Fatalf("error not expected %+v", err)
}
_, err = p.ParseArgs([]string{"mycmd", "--enum=three"})
if err != nil {
t.Fatalf("error not expected %+v", err)
}
if opt.Enum != three {
t.Fatalf("expected three, got %v", opt.Enum)
}

grp, _ := p.AddGroup("test group", "", &opt)
o := grp.Options()[0]

marshalled, err := convertToString(o.value, o.tag)
if err != nil {
t.Fatalf("error not expected %+v", err)
}
if marshalled != "three" {
t.Fatalf("expected three, got %v", marshalled)
}
}