diff --git a/collection.go b/collection.go index f5e89ed..db27c3b 100644 --- a/collection.go +++ b/collection.go @@ -71,7 +71,7 @@ func (c *Collection) MarshalJSON() ([]byte, error) { func (c *Collection) UnmarshalJSON(data []byte) error { var strSlice []string if err := json.Unmarshal(data, &strSlice); err != nil { - return err + return fmt.Errorf("failed to decode JSON input: %w", err) } return c.unmarshal(strSlice) } @@ -85,7 +85,7 @@ func (c *Collection) MarshalYAML() (interface{}, error) { func (c *Collection) UnmarshalYAML(unmarshal func(interface{}) error) error { var strSlice []string if err := unmarshal(&strSlice); err != nil { - return err + return fmt.Errorf("failed to decode YAML input: %w", err) } return c.unmarshal(strSlice) } diff --git a/collection_test.go b/collection_test.go index a5f8104..7129176 100644 --- a/collection_test.go +++ b/collection_test.go @@ -73,3 +73,22 @@ func TestCollectionUnmarshalling(t *testing.T) { assert.Equal(t, "v1.0.1+k0s.1", c[1].String()) }) } + +func TestFailingCollectionUnmarshalling(t *testing.T) { + t.Run("JSON", func(t *testing.T) { + var c Collection + err := json.Unmarshal([]byte(`invalid_json`), &c) + assert.Error(t, err) + err = json.Unmarshal([]byte(`["invalid_version"]`), &c) + assert.Error(t, err) + }) + + t.Run("YAML", func(t *testing.T) { + var c Collection + err := c.UnmarshalYAML(func(i interface{}) error { + *(i.(*[]string)) = []string{"invalid\n"} + return nil + }) + assert.Error(t, err) + }) +} diff --git a/version.go b/version.go index 7815c83..bf9ef37 100644 --- a/version.go +++ b/version.go @@ -123,11 +123,11 @@ func (v *Version) MarshalYAML() (interface{}, error) { func (v *Version) unmarshal(f func(interface{}) error) error { var s string if err := f(&s); err != nil { - return fmt.Errorf("unmarshal failed to decode input: %w", err) + return fmt.Errorf("failed to decode input: %w", err) } newV, err := NewVersion(s) if err != nil { - return fmt.Errorf("failed to unmarshal '%s': %w", s, err) + return fmt.Errorf("failed to unmarshal version: %w", err) } *v = *newV return nil diff --git a/version_test.go b/version_test.go index 4ca853a..562075b 100644 --- a/version_test.go +++ b/version_test.go @@ -2,6 +2,7 @@ package version import ( "encoding/json" + "errors" "testing" "github.com/stretchr/testify/assert" @@ -81,3 +82,26 @@ func TestUnmarshalling(t *testing.T) { assert.Equal(t, "v1.0.0+k0s.1", v.String()) }) } + +func TestFailingUnmarshalling(t *testing.T) { + t.Run("JSON", func(t *testing.T) { + var v Version + err := json.Unmarshal([]byte(`invalid_json`), &v) + assert.Error(t, err) + err = json.Unmarshal([]byte(`"invalid_version"`), &v) + assert.Error(t, err) + }) + + t.Run("YAML", func(t *testing.T) { + var v = &Version{} + err := v.UnmarshalYAML(func(i interface{}) error { + return errors.New("forced error") + }) + assert.Error(t, err) + err = v.UnmarshalYAML(func(i interface{}) error { + *(i.(*string)) = "invalid_version" + return nil + }) + assert.Error(t, err) + }) +}