diff --git a/CHANGELOG.md b/CHANGELOG.md index 64aa224..f70f169 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,7 +2,10 @@ ## v1.1.0 (not released yet) -- No changes yet. +- Expand functions transform a special sequence $$ to literal $. +- The underlying objects encapsulated by config.Value types will now + have the types determined by the YAML unmarshaller regardless of + whether expansion was performed or not. ## v1.0.2 (2017-08-17) diff --git a/expand.go b/expand.go new file mode 100644 index 0000000..05f9478 --- /dev/null +++ b/expand.go @@ -0,0 +1,197 @@ +// Copyright (c) 2017 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package config + +import ( + "bytes" + + "golang.org/x/text/transform" +) + +// expandTransformer implements transform.Transformer +type expandTransformer struct { + transform.NopResetter + expand func(string) (string, error) +} + +// First char of shell variable may be [a-zA-Z_] +func isShellNameFirstChar(c byte) bool { + return c == '_' || + (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') +} + +// Char's after the first of shell variable may be [a-zA-Z0-9_] +func isShellNameChar(c byte) bool { + return c == '_' || + (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') +} + +// bytesIndexCFunc returns the index of the byte for which the +// complement of the supplied function is true +func bytesIndexCFunc(buf []byte, f func(b byte) bool) int { + for i, b := range buf { + if !f(b) { + return i + } + } + return -1 +} + +// Transform expands shell-like sequences like $foo and ${foo} using +// the configured expand function. The sequence '$$' is replaced with +// a literal '$'. +func (e *expandTransformer) Transform(dst, src []byte, atEOF bool) (int, int, error) { + var srcPos int + var dstPos int + + for srcPos < len(src) { + + if dstPos == len(dst) { + return dstPos, srcPos, transform.ErrShortDst + } + + end := bytes.IndexByte(src[srcPos:], '$') + + if end == -1 { + // src does not contain '$', copy into dst + cnt := copy(dst[dstPos:], src[srcPos:]) + srcPos += cnt + dstPos += cnt + continue + } else if end > 0 { + // copy chars preceding '$' from src to dst + cnt := copy(dst[dstPos:], src[srcPos:srcPos+end]) + srcPos += cnt + dstPos += cnt + + if dstPos == len(dst) { + return dstPos, srcPos, transform.ErrShortDst + } + } + + // src[srcPos] now points to '$', dstPos < len(dst) + + // If we're at the end of src, but we found a starting + // token, return ErrShortSrc, unless we're also at EOF, + // in which case just copy it dst. + if srcPos+1 == len(src) { + if atEOF { + dst[dstPos] = src[srcPos] + srcPos++ + dstPos++ + continue + } + return dstPos, srcPos, transform.ErrShortSrc + } + + // At this point we know that src[srcPos+1] is populated. + + // If this token sequence represents the special '$$' + // sequence, emit a '$' into dst. + if src[srcPos+1] == '$' { + dst[dstPos] = src[srcPos] + srcPos += 2 + dstPos++ + continue + } + + var token []byte + var tokenEnd int + + // Start of bracketed token ${foo} + if src[srcPos+1] == '{' { + end := bytes.IndexByte(src[srcPos+2:], '}') + if end == -1 { + if atEOF { + // No closing bracket and we're at + // EOF, so it's not a valid bracket + // expression. + if len(dst[dstPos:]) < + len(src[srcPos:]) { + return dstPos, srcPos, + transform.ErrShortDst + } + + cnt := copy(dst[dstPos:], src[srcPos:]) + srcPos += cnt + dstPos += cnt + continue + } + + // Otherwise, we need more bytes in src + return dstPos, srcPos, transform.ErrShortSrc + } + + // Set tokenEnd so it points to the byte + // immediately after the closing '}' + tokenEnd = end + srcPos + 3 + + token = src[srcPos+2 : tokenEnd-1] + } else { // Else start of non-bracketed token $foo + if !isShellNameFirstChar(src[srcPos+1]) { + // If it doesn't conform to the naming + // rules for shell variables, do not + // try to expand, just copy to dst. + dst[dstPos] = src[srcPos] + srcPos++ + dstPos++ + continue + } + + end := bytesIndexCFunc(src[srcPos+2:], isShellNameChar) + + if end == -1 { + // Reached the end of src without finding + // end of shell variable + if !atEOF { + // We need more bytes in src + return dstPos, srcPos, + transform.ErrShortSrc + } + tokenEnd = len(src) + } else { + // Set tokenEnd so it points to the byte + // immediately after the token + tokenEnd = end + srcPos + 2 + } + + token = src[srcPos+1 : tokenEnd] + } + + replacement, err := e.expand(string(token)) + if err != nil { + return dstPos, srcPos, err + } + + if len(dst[dstPos:]) < len(replacement) { + return dstPos, srcPos, transform.ErrShortDst + } + + cnt := copy(dst[dstPos:], replacement) + srcPos = tokenEnd + dstPos += cnt + } + + return dstPos, srcPos, nil +} diff --git a/expand_test.go b/expand_test.go new file mode 100644 index 0000000..3a25fd4 --- /dev/null +++ b/expand_test.go @@ -0,0 +1,295 @@ +// Copyright (c) 2017 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package config + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/text/transform" +) + +// Size of buffer used by the transform package. +const transformBufSize = 4096 + +const orig = `This is a $t3sT$. $$ This is a $$test. + This is not a valid $0ne. But this one $i5_@_valid-one. + $$$$$$$ +${parti` + +const expected = `This is a test$. $ This is a $test. + This is not a valid $0ne. But this one is @_valid-one. + $$$$ +${parti` + +const ends_in_dollar = `There is a dollar at the end$` +const ends_in_ddollar = `There is a dollar at the end$$` +const many_dollars_orig = `$$$$$$$` +const many_dollars_expect = `$$$$` +const ends_in_var = `There is a test at the end: $t3sT` +const ends_in_var_expect = `There is a test at the end: test` + +type oneByteReader struct { + r io.Reader +} + +func (e *oneByteReader) Read(buf []byte) (n int, err error) { + var b [1]byte + + if len(buf) > 0 { + n, err = e.r.Read(b[:]) + buf[0] = b[0] + } + + return +} + +type bufReader struct { + buf []byte + offset int +} + +// Similar to a bytes.Reader except this Reader returns EOF in the same +// Read that reads the end of the buffer. +func (e *bufReader) Read(buf []byte) (int, error) { + var err error + n := copy(buf, e.buf[e.offset:]) + e.offset += n + if e.offset == len(e.buf) { + err = io.EOF + } + return n, err +} + +func TestExpander(t *testing.T) { + t.Parallel() + + r := bytes.NewReader([]byte(orig)) + + expand_func := func(s string) (string, error) { + switch s { + case "t3sT": + return "test", nil + case "i5_": + return "is ", nil + } + + return "NOMATCH", errors.New("No Match") + } + + // Parse whole string + tr := transform.NewReader(r, &expandTransformer{expand: expand_func}) + actual, err := ioutil.ReadAll(tr) + require.NoError(t, err) + assert.Exactly(t, expected, string(actual)) + + _, err = r.Seek(0, io.SeekStart) + require.NoError(t, err) + + // Partial parse + var partial [11]byte + tr = transform.NewReader(r, &expandTransformer{expand: expand_func}) + n, err := tr.Read(partial[:]) + require.NoError(t, err) + assert.Exactly(t, n, len(partial)) + assert.Exactly(t, expected[:n], string(partial[:])) + + // Empty string + r = bytes.NewReader([]byte{}) + tr = transform.NewReader(r, &expandTransformer{expand: expand_func}) + actual, err = ioutil.ReadAll(tr) + require.NoError(t, err) + assert.Exactly(t, "", string(actual)) +} + +func TestExpanderOneByteAtATime(t *testing.T) { + t.Parallel() + + r := bytes.NewReader([]byte(orig)) + rr := &oneByteReader{r: r} + + expand_func := func(s string) (string, error) { + switch s { + case "t3sT": + return "test", nil + case "i5_": + return "is ", nil + } + + return "NOMATCH", errors.New("No Match") + } + + tr := transform.NewReader(rr, &expandTransformer{expand: expand_func}) + actual, err := ioutil.ReadAll(tr) + require.NoError(t, err) + assert.Exactly(t, expected, string(actual)) +} + +func TestExpanderFailingTransform(t *testing.T) { + t.Parallel() + + r := bytes.NewReader([]byte(orig)) + + expand_func := func(s string) (string, error) { + switch s { + case "t3sT": + return "test", nil + // missing "i5_" case + } + + return "NOMATCH", errors.New("No Match") + } + + tr := transform.NewReader(r, &expandTransformer{expand: expand_func}) + _, err := ioutil.ReadAll(tr) + require.Error(t, err) +} + +func TestExpanderMisc(t *testing.T) { + t.Parallel() + + tests := [...]struct { + orig string + expect string + }{ + {ends_in_dollar, ends_in_dollar}, + {ends_in_ddollar, ends_in_dollar}, + {ends_in_var, ends_in_var_expect}, + {many_dollars_orig, many_dollars_expect}, + } + + expand_func := func(s string) (string, error) { + switch s { + case "t3sT": + return "test", nil + // missing "i5_" case + } + + return "NOMATCH", errors.New("No Match") + } + + for i, tst := range tests { + tst := tst + t.Run(fmt.Sprintf("sub=%d", i), + func(t *testing.T) { + t.Parallel() + tr := transform.NewReader( + bytes.NewReader([]byte(tst.orig)), + &expandTransformer{expand: expand_func}, + ) + actual, err := ioutil.ReadAll(tr) + require.NoError(t, err) + assert.Exactly(t, tst.expect, string(actual)) + }, + ) + } +} + +func TestExpanderLongSrc(t *testing.T) { + t.Parallel() + + a := strings.Repeat("a", transformBufSize-1) + + tests := [...]struct { + orig string + expect string + }{ + {"foo${a}" + a, "foo" + a + a}, + {"${a}foo$a", a + "foo" + a}, + {"$a${", a + "${"}, + } + + expand_func := func(s string) (string, error) { + switch s { + case "a": + return a, nil + } + + return "NOMATCH", errors.New("No Match") + } + + for i, tst := range tests { + tst := tst + t.Run(fmt.Sprintf("sub=%d", i), + func(t *testing.T) { + t.Parallel() + tr := transform.NewReader( + &bufReader{buf: []byte(tst.orig)}, + &expandTransformer{expand: expand_func}, + ) + actual, err := ioutil.ReadAll(tr) + require.NoError(t, err) + assert.Exactly(t, tst.expect, string(actual)) + }, + ) + } +} + +func TestTransformLimit(t *testing.T) { + t.Parallel() + + a := strings.Repeat("a", transformBufSize-1) + + // The transform package uses an internal fixed-size buffer. + // These tests are expected to fail with specific errors when + // that buffer is exceeded. If they stop failing, then other + // tests (above) have likely stopped working correctly too. + tests := [...]struct { + orig string + err error + }{ + {"$a", transform.ErrShortDst}, + {"$" + a, transform.ErrShortSrc}, + } + + expand_func := func(s string) (string, error) { + switch s { + case "a": + return a + "aa", nil + case a: + return "a", nil + } + + return "NOMATCH", errors.New("No Match") + } + + for i, tst := range tests { + tst := tst + t.Run(fmt.Sprintf("sub=%d", i), + func(t *testing.T) { + t.Parallel() + tr := transform.NewReader( + bytes.NewReader([]byte(tst.orig)), + &expandTransformer{expand: expand_func}, + ) + _, err := ioutil.ReadAll(tr) + require.EqualError(t, err, tst.err.Error()) + }, + ) + } +} diff --git a/glide.lock b/glide.lock index b743924..8fc7259 100644 --- a/glide.lock +++ b/glide.lock @@ -1,21 +1,25 @@ -hash: bc86f2e1ca95a651be64c94f23db29be3b606272d03bc66bb8c8efca2b0a3545 -updated: 2017-07-31T12:01:52.034767994-07:00 +hash: fdcf479340c55a75cf42a047276f952cd2847537e12028db6c017479e45ea626 +updated: 2017-08-09T17:47:23.192205998-07:00 imports: - name: github.com/pkg/errors version: 645ef00459ed84a119197bfb8d8205042c6df63d +- name: golang.org/x/text + version: 836efe42bb4aa16aaa17b9c155d8813d336ed720 + subpackages: + - transform - name: gopkg.in/validator.v2 version: 07ffaad256c8e957050ad83d6472eb97d785013d - name: gopkg.in/yaml.v2 version: 25c4ec802a7d637f88d584ab26798e94ad14c13b testImports: - name: github.com/davecgh/go-spew - version: adab96458c51a58dc1783b3335dcce5461522e75 + version: 6d212800a42e8ab5c146b8ace3490ee17e5225f9 subpackages: - spew - name: github.com/google/gofuzz version: 24818f796faf91cd76ec7bddd72458fbced7a6c1 - name: github.com/pmezard/go-difflib - version: 792786c7400a136282c1664665ae0a8db921c6c2 + version: d8ed2627bdf02c080bf22230dbb337003b7aba2d subpackages: - difflib - name: github.com/spf13/cast diff --git a/glide.yaml b/glide.yaml index f0cbc12..3556b56 100644 --- a/glide.yaml +++ b/glide.yaml @@ -4,6 +4,9 @@ import: - package: gopkg.in/yaml.v2 - package: github.com/pkg/errors version: ~0.8.0 +- package: golang.org/x/text + subpackages: + - transform testImport: - package: github.com/google/gofuzz - package: github.com/stretchr/testify diff --git a/static_provider.go b/static_provider.go index 96a4af2..6a1a136 100644 --- a/static_provider.go +++ b/static_provider.go @@ -23,7 +23,6 @@ package config import ( "bytes" "io" - "io/ioutil" "gopkg.in/yaml.v2" ) @@ -36,7 +35,7 @@ type staticProvider struct { // accessed via Get method. It is using the yaml marshaler to encode data first, // and is subject to panic if data contains a fixed sized array. func NewStaticProvider(data interface{}) (Provider, error) { - c, err := toReadCloser(data) + c, err := toReader(data) if err != nil { return nil, err } @@ -54,7 +53,7 @@ func NewStaticProvider(data interface{}) (Provider, error) { func NewStaticProviderWithExpand( data interface{}, mapping func(string) (string, bool)) (Provider, error) { - reader, err := toReadCloser(data) + reader, err := toReader(data) if err != nil { return nil, err } @@ -72,11 +71,11 @@ func (staticProvider) Name() string { return "static" } -func toReadCloser(data interface{}) (io.ReadCloser, error) { +func toReader(data interface{}) (io.Reader, error) { b, err := yaml.Marshal(data) if err != nil { return nil, err } - return ioutil.NopCloser(bytes.NewBuffer(b)), nil + return bytes.NewBuffer(b), nil } diff --git a/static_provider_test.go b/static_provider_test.go index 79c472d..48af1b5 100644 --- a/static_provider_test.go +++ b/static_provider_test.go @@ -258,12 +258,16 @@ func TestStaticProviderWithExpand(t *testing.T) { p, err := NewStaticProviderWithExpand(map[string]interface{}{ "slice": []interface{}{"one", "${iTwo:2}"}, "value": `${iValue:""}`, + "empty": `${iEmpty:""}`, + "two": `${iTwo:2}`, "map": map[string]interface{}{ "drink?": "${iMap:tea?}", "tea?": "with cream", }, }, func(key string) (string, bool) { switch key { + case "iEmpty": + return "\"\"", true case "iValue": return "null", true case "iTwo": @@ -279,12 +283,37 @@ func TestStaticProviderWithExpand(t *testing.T) { assert.Equal(t, "one", p.Get("slice.0").String()) assert.Equal(t, "3", p.Get("slice.1").String()) - assert.Equal(t, "null", p.Get("value").Value()) + assert.Equal(t, nil, p.Get("value").Value()) + assert.Equal(t, "", p.Get("empty").Value()) + assert.Equal(t, int(3), p.Get("two").Value()) assert.Equal(t, "rum please!", p.Get("map.drink?").String()) assert.Equal(t, "with cream", p.Get("map.tea?").String()) } +func TestStaticProviderWithExpandEscapeHandling(t *testing.T) { + t.Parallel() + + p, err := NewStaticProviderWithExpand(map[string]interface{}{ + "nil": "a", + "one": "$a", + "two": "$$a", + "three": "$$$a", + "four": "$$$$a", + "five": "$$$$$a", + }, func(key string) (string, bool) { + return fmt.Sprint(len(key)), true + }) + + require.NoError(t, err, "can't create a static provider") + assert.Equal(t, "a", p.Get("nil").String()) + assert.Equal(t, "1", p.Get("one").String()) + assert.Equal(t, "$a", p.Get("two").String()) + assert.Equal(t, "$1", p.Get("three").String()) + assert.Equal(t, "$$a", p.Get("four").String()) + assert.Equal(t, "$$1", p.Get("five").String()) +} + func TestPopulateForMapOfDifferentKeyTypes(t *testing.T) { t.Parallel() diff --git a/testdata/base.yaml b/testdata/base.yaml index 68c1795..4ae39d9 100644 --- a/testdata/base.yaml +++ b/testdata/base.yaml @@ -1,2 +1,9 @@ value: base_only value_override: base_setting +a-bool: true +a-empty: "" +a-float: 3.14 +a-int: 3 +a-nil: +a-null: null +a-string: "3.14" diff --git a/yaml.go b/yaml.go index 8b5e3ea..da07cf9 100644 --- a/yaml.go +++ b/yaml.go @@ -27,11 +27,11 @@ import ( "io/ioutil" "os" "reflect" - "runtime" "strconv" "strings" "github.com/pkg/errors" + "golang.org/x/text/transform" "gopkg.in/yaml.v2" ) @@ -44,7 +44,7 @@ var ( _emptyDefault = `""` ) -func newYAMLProviderCore(files ...io.ReadCloser) (*yamlConfigProvider, error) { +func newYAMLProviderCore(files ...io.Reader) (*yamlConfigProvider, error) { var root interface{} for _, v := range files { var curr interface{} @@ -134,33 +134,69 @@ func mergeMaps(dst interface{}, src interface{}) (interface{}, error) { // file names. All the objects are going to be merged and arrays/values // overridden in the order of the files. func NewYAMLProviderFromFiles(files ...string) (Provider, error) { - readers, err := filesToReaders(files...) + readClosers, err := filesToReaders(files...) if err != nil { return nil, err } - p, err := newYAMLProviderCore(readers...) - if err != nil { - return nil, err + readers := make([]io.Reader, len(readClosers)) + for i, r := range readClosers { + readers[i] = r } - return newCachedProvider(p) + provider, err := newYAMLProviderFromReader(readers...) + + for _, r := range readClosers { + nerr := r.Close() + if err == nil { + err = nerr + } + } + + return provider, err } // NewYAMLProviderWithExpand creates a configuration provider from a set of YAML // file names with ${var} or $var values replaced based on the mapping function. +// Variable names not wrapped in curly braces will be parsed according +// to the shell variable naming rules: +// +// ...a word consisting solely of underscores, digits, and +// alphabetics from the portable character set. The first +// character of a name is not a digit. +// +// For variables wrapped in braces, all characters between '{' and '}' +// will be passed to the expand function. e.g. "${foo:13}" will cause +// "foo:13" to be passed to the expand function. The sequence '$$' will +// be replaced by a literal '$'. All other sequences will be ignored +// for expansion purposes. func NewYAMLProviderWithExpand(mapping func(string) (string, bool), files ...string) (Provider, error) { - readers, err := filesToReaders(files...) + readClosers, err := filesToReaders(files...) if err != nil { return nil, err } - return newYAMLProviderFromReaderWithExpand(mapping, readers...) + readers := make([]io.Reader, len(readClosers)) + for i, r := range readClosers { + readers[i] = r + } + + provider, err := newYAMLProviderFromReaderWithExpand(mapping, + readers...) + + for _, r := range readClosers { + nerr := r.Close() + if err == nil { + err = nerr + } + } + + return provider, err } -// NewYAMLProviderFromReader creates a configuration provider from a list of io.ReadClosers. +// NewYAMLProviderFromReader creates a configuration provider from a list of io.Readers. // As above, all the objects are going to be merged and arrays/values overridden in the order of the files. -func newYAMLProviderFromReader(readers ...io.ReadCloser) (Provider, error) { +func newYAMLProviderFromReader(readers ...io.Reader) (Provider, error) { p, err := newYAMLProviderCore(readers...) if err != nil { return nil, err @@ -170,38 +206,34 @@ func newYAMLProviderFromReader(readers ...io.ReadCloser) (Provider, error) { } // NewYAMLProviderFromReaderWithExpand creates a configuration provider from -// a list of `io.ReadClosers and uses the mapping function to expand values +// a list of `io.Readers and uses the mapping function to expand values // in the underlying provider. func newYAMLProviderFromReaderWithExpand( mapping func(string) (string, bool), - readers ...io.ReadCloser) (Provider, error) { - p, err := newYAMLProviderCore(readers...) - if err != nil { - return nil, err - } + readers ...io.Reader) (Provider, error) { - if err := p.root.applyOnAllNodes(replace(mapping)); err != nil { - return nil, err + expandFunc := replace(mapping) + + ereaders := make([]io.Reader, len(readers)) + for i, reader := range readers { + ereaders[i] = transform.NewReader( + reader, + &expandTransformer{expand: expandFunc}) } - return newCachedProvider(p) + return newYAMLProviderFromReader(ereaders...) } // NewYAMLProviderFromBytes creates a config provider from a byte-backed YAML // blobs. As above, all the objects are going to be merged and arrays/values // overridden in the order of the yamls. func NewYAMLProviderFromBytes(yamls ...[]byte) (Provider, error) { - closers := make([]io.ReadCloser, len(yamls)) + readers := make([]io.Reader, len(yamls)) for i, yml := range yamls { - closers[i] = ioutil.NopCloser(bytes.NewReader(yml)) - } - - p, err := newYAMLProviderCore(closers...) - if err != nil { - return nil, err + readers[i] = bytes.NewReader(yml) } - return newCachedProvider(p) + return newYAMLProviderFromReader(readers...) } func filesToReaders(files ...string) ([]io.ReadCloser, error) { @@ -210,6 +242,9 @@ func filesToReaders(files ...string) ([]io.ReadCloser, error) { for _, v := range files { if reader, err := os.Open(v); err != nil { + for _, r := range readers { + r.Close() + } return nil, err } else if reader != nil { readers = append(readers, reader) @@ -326,64 +361,13 @@ func (n *yamlNode) Children() []*yamlNode { return nodes } -// Apply expand to all nested elements of a node. -// There is no need to use reflection, because YAML unmarshaler is using -// map[interface{}]interface{} to store objects and []interface{} -// to store collections. -func recursiveApply(node interface{}, expand func(string) string) interface{} { - if node == nil { - return nil - } - switch t := node.(type) { - case map[interface{}]interface{}: - for k := range t { - t[k] = recursiveApply(t[k], expand) - } - return t - case []interface{}: - for i := range t { - t[i] = recursiveApply(t[i], expand) - } - return t - } - - return os.Expand(fmt.Sprint(node), expand) -} - -func (n *yamlNode) applyOnAllNodes(expand func(string) string) (err error) { - - defer func() { - if r := recover(); r != nil { - if _, ok := r.(runtime.Error); ok { - panic(r) - } - - err = r.(error) - } - }() - - n.value = recursiveApply(n.value, expand) - - for _, c := range n.Children() { - if err := c.applyOnAllNodes(expand); err != nil { - return err - } - } - - return -} - -func unmarshalYAMLValue(reader io.ReadCloser, value interface{}) error { +func unmarshalYAMLValue(reader io.Reader, value interface{}) error { raw, err := ioutil.ReadAll(reader) if err != nil { return errors.Wrap(err, "failed to read the yaml config") } - if err = yaml.Unmarshal(raw, value); err != nil { - return err - } - - return reader.Close() + return yaml.Unmarshal(raw, value) } // Function to expand environment variables in returned values that have form: ${ENV_VAR:DEFAULT_VALUE}. @@ -396,10 +380,8 @@ func unmarshalYAMLValue(reader io.ReadCloser, value interface{}) error { // // In the case that HTTP_PORT is not provided, default value (in this case 8080) // will be used. -// -// TODO: what if someone wanted a literal ${FOO} in config? need a small escape hatch -func replace(lookUp func(string) (string, bool)) func(in string) string { - return func(in string) string { +func replace(lookUp func(string) (string, bool)) func(in string) (string, error) { + return func(in string) (string, error) { sep := strings.Index(in, _envSeparator) var key string var def string @@ -414,16 +396,16 @@ func replace(lookUp func(string) (string, bool)) func(in string) string { } if envVal, ok := lookUp(key); ok { - return envVal + return envVal, nil } if def == "" { - panic(fmt.Errorf(`default is empty for %q (use "" for empty string)`, key)) + return "", fmt.Errorf(`default is empty for %q (use "" for empty string)`, key) } else if def == _emptyDefault { - return "" + return "", nil } - return def + return def, nil } } diff --git a/yaml_test.go b/yaml_test.go index cbf4490..2c5326a 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -70,18 +70,21 @@ func TestYAMLEnvInterpolation(t *testing.T) { } cfg := strings.NewReader(` -name: some name here +name: some $$name here owner: ${OWNER_EMAIL} module: fake: number: ${FAKE_NUMBER:321}`) - p, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + p, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.NoError(t, err, "Can't create a YAML provider") require.Equal(t, "321", p.Get("module.fake.number").String()) owner := p.Get("owner").String() require.Equal(t, "hello@there.yasss", owner) + + name := p.Get("name").String() + require.Equal(t, "some $name here", name) } func TestYAMLEnvInterpolationMissing(t *testing.T) { @@ -92,7 +95,7 @@ name: some name here email: ${EMAIL_ADDRESS}`) f := func(string) (string, bool) { return "", false } - _, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + _, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.Error(t, err) assert.Contains(t, err.Error(), `default is empty for "EMAIL_ADDRESS"`) } @@ -105,7 +108,7 @@ name: some name here telephone: ${SUPPORT_TEL:}`) f := func(string) (string, bool) { return "", false } - _, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + _, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.Error(t, err) assert.Contains(t, err.Error(), `default is empty for "SUPPORT_TEL" (use "" for empty string)`) @@ -119,7 +122,7 @@ func TestYAMLEnvInterpolationWithColon(t *testing.T) { return "", false } - p, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + p, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.NoError(t, err, "Can't create a YAML provider") require.Equal(t, "this:is:my:value", p.Get("fullValue").String()) @@ -133,7 +136,7 @@ name: ${APP_NAME:my shiny app} fullTel: 1-800-LOLZ${TELEPHONE_EXTENSION:""}`) f := func(string) (string, bool) { return "", false } - p, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + p, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.NoError(t, err, "Can't create a YAML provider") require.Equal(t, "my shiny app", p.Get("name").String()) @@ -218,7 +221,7 @@ func TestNewYAMLProviderFromReader(t *testing.T) { t.Parallel() buff := bytes.NewBuffer([]byte(_yamlConfig1)) - provider, err := newYAMLProviderFromReader(ioutil.NopCloser(buff)) + provider, err := newYAMLProviderFromReader(buff) require.NoError(t, err, "Can't create a YAML provider") cs := &configStruct{} @@ -232,7 +235,7 @@ func TestYAMLNode(t *testing.T) { buff := bytes.NewBuffer([]byte("a: b")) node := &yamlNode{value: make(map[interface{}]interface{})} - require.NoError(t, unmarshalYAMLValue(ioutil.NopCloser(buff), &node.value)) + require.NoError(t, unmarshalYAMLValue(buff, &node.value)) assert.Equal(t, "map[a:b]", node.String()) assert.Equal(t, "map[interface {}]interface {}", node.Type().String()) @@ -1196,7 +1199,7 @@ func TestYAMLEnvInterpolationValueMissing(t *testing.T) { cfg := strings.NewReader(`name:`) f := func(string) (string, bool) { return "", false } - p, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + p, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.NoError(t, err, "Can't create a YAML provider") assert.Equal(t, nil, p.Get("name").Value()) } @@ -1211,7 +1214,7 @@ func TestYAMLEnvInterpolationValueConversion(t *testing.T) { return "3", true } - p, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + p, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.NoError(t, err, "Can't create a YAML provider") assert.Equal(t, "3", p.Get("number").String()) @@ -1745,14 +1748,39 @@ func TestNewYamlProviderWithExpand(t *testing.T) { assert.Equal(t, "base_only", baseValue) } +func TestYamlProvidersProduceSameResults(t *testing.T) { + t.Parallel() + + p, err := NewYAMLProviderFromFiles("./testdata/base.yaml") + require.NoError(t, err, "Can't create a YAML provider") + + pp, err := NewYAMLProviderWithExpand(nil, "./testdata/base.yaml") + require.NoError(t, err, "Can't create a YAML provider with expand") + + assert.IsType(t, true, p.Get("a-bool").Value()) + assert.Exactly(t, p.Get("a-bool").Value(), pp.Get("a-bool").Value()) + assert.IsType(t, "empty", p.Get("a-empty").Value()) + assert.Exactly(t, p.Get("a-empty").Value(), pp.Get("a-empty").Value()) + assert.IsType(t, float64(1.2), p.Get("a-float").Value()) + assert.Exactly(t, p.Get("a-float").Value(), pp.Get("a-float").Value()) + assert.IsType(t, int(12), p.Get("a-int").Value()) + assert.Exactly(t, p.Get("a-int").Value(), pp.Get("a-int").Value()) + assert.IsType(t, nil, p.Get("a-nil").Value()) + assert.Exactly(t, p.Get("a-nil").Value(), pp.Get("a-nil").Value()) + assert.IsType(t, nil, p.Get("a-null").Value()) + assert.Exactly(t, p.Get("a-null").Value(), pp.Get("a-null").Value()) + assert.IsType(t, "string", p.Get("a-string").Value()) + assert.Exactly(t, p.Get("a-string").Value(), pp.Get("a-string").Value()) +} + func TestMergeErrorsFromReaders(t *testing.T) { t.Parallel() t.Run("regular", func(t *testing.T) { - base := ioutil.NopCloser(strings.NewReader(`a: - - b`)) - dev := ioutil.NopCloser(strings.NewReader(`a: - b: c`)) + base := strings.NewReader(`a: + - b`) + dev := strings.NewReader(`a: + b: c`) _, err := newYAMLProviderFromReader(base, dev) require.Error(t, err) @@ -1762,10 +1790,10 @@ func TestMergeErrorsFromReaders(t *testing.T) { t.Run("expand", func(t *testing.T) { expand := func(string) (string, bool) { return "", false } - base := ioutil.NopCloser(strings.NewReader(`a: - - b`)) - dev := ioutil.NopCloser(strings.NewReader(`a: - b: c`)) + base := strings.NewReader(`a: + - b`) + dev := strings.NewReader(`a: + b: c`) _, err := newYAMLProviderFromReaderWithExpand(expand, base, dev) require.Error(t, err) @@ -1805,8 +1833,8 @@ func TestMergeErrorsFromFiles(t *testing.T) { require.NoError(t, err, "Can't read dev file") _, err = newYAMLProviderFromReader( - ioutil.NopCloser(bytes.NewBuffer(b)), - ioutil.NopCloser(bytes.NewBuffer(d))) + bytes.NewBuffer(b), + bytes.NewBuffer(d)) require.Error(t, err) assert.Contains(t, err.Error(), "can't merge map") @@ -1831,8 +1859,8 @@ func TestMergeErrorsFromFiles(t *testing.T) { _, err = newYAMLProviderFromReaderWithExpand( expand, - ioutil.NopCloser(bytes.NewBuffer(b)), - ioutil.NopCloser(bytes.NewBuffer(d))) + bytes.NewBuffer(b), + bytes.NewBuffer(d)) require.Error(t, err) assert.Contains(t, err.Error(), "can't merge map") @@ -1843,13 +1871,15 @@ func TestYAMLProviderWithGarbledPath(t *testing.T) { t.Parallel() t.Run("regular", func(t *testing.T) { - _, err := NewYAMLProviderFromFiles("/some/nonexisting/config") + _, err := NewYAMLProviderFromFiles("./testdata/base.yaml", + "/some/nonexisting/config") require.Error(t, err) assert.Contains(t, err.Error(), "no such file or directory") }) t.Run("expand", func(t *testing.T) { - _, err := NewYAMLProviderWithExpand(nil, "/some/nonexisting/config") + _, err := NewYAMLProviderWithExpand(nil, "./testdata/base.yaml", + "/some/nonexisting/config") require.Error(t, err) assert.Contains(t, err.Error(), "no such file or directory") })