From 78fd1806b547a4b0766757ffa14b47546a7a3bdb Mon Sep 17 00:00:00 2001 From: Brandon Casey Date: Wed, 9 Aug 2017 16:24:01 -0700 Subject: [PATCH 1/4] yaml.go: use newYAMLProviderFromReader() in a couple more places We have a helper function newYAMLProviderFromReader() which calls newYAMLProviderCore() and returns newCachedProvider(). Let's use it instead of repeating the same code in NewYAMLProviderFromFiles() and NewYAMLProviderFromBytes(). --- yaml.go | 14 ++------------ 1 file changed, 2 insertions(+), 12 deletions(-) diff --git a/yaml.go b/yaml.go index 8b5e3ea..adcf4fa 100644 --- a/yaml.go +++ b/yaml.go @@ -139,12 +139,7 @@ func NewYAMLProviderFromFiles(files ...string) (Provider, error) { return nil, err } - p, err := newYAMLProviderCore(readers...) - if err != nil { - return nil, err - } - - return newCachedProvider(p) + return newYAMLProviderFromReader(readers...) } // NewYAMLProviderWithExpand creates a configuration provider from a set of YAML @@ -196,12 +191,7 @@ func NewYAMLProviderFromBytes(yamls ...[]byte) (Provider, error) { closers[i] = ioutil.NopCloser(bytes.NewReader(yml)) } - p, err := newYAMLProviderCore(closers...) - if err != nil { - return nil, err - } - - return newCachedProvider(p) + return newYAMLProviderFromReader(closers...) } func filesToReaders(files ...string) ([]io.ReadCloser, error) { From 158145329aaeda4b6c4f6ff9feb79d89ea8fb779 Mon Sep 17 00:00:00 2001 From: Brandon Casey Date: Tue, 29 Aug 2017 13:05:19 -0700 Subject: [PATCH 2/4] yaml.go: don't leak file handles on error Currently, yaml.go:unmarshalYAMLValue() is passed an io.ReadCloser and assumes responsibility for closing the io.ReadCloser object that is passed in. If this function encounters an error, it will not close the io.ReadCloser, and neither will any of its callers. This could result in a leak of file descriptors. Let's modify the interface to unmarshalYAMLValue() and most callers so that instead of being passed an io.ReadCloser(), or a slice of io.ReadCloser()'s, and assuming the responsibility for closing them all, we'll just pass them io.Reader()'s, and leave the responsibility for closing any readers that need closing in the high-level callers that opened the files in the first place, e.g. NewYAMLProviderFromFiles and NewYAMLProviderWithExpand. This has the additional benefit of simplifying some callers in static_provider.go and tests which no longer need to wrap their objects in ioutil.NopCloser(). --- static_provider.go | 9 +++---- yaml.go | 64 +++++++++++++++++++++++++++++++++------------- yaml_test.go | 48 +++++++++++++++++----------------- 3 files changed, 75 insertions(+), 46 deletions(-) diff --git a/static_provider.go b/static_provider.go index 96a4af2..6a1a136 100644 --- a/static_provider.go +++ b/static_provider.go @@ -23,7 +23,6 @@ package config import ( "bytes" "io" - "io/ioutil" "gopkg.in/yaml.v2" ) @@ -36,7 +35,7 @@ type staticProvider struct { // accessed via Get method. It is using the yaml marshaler to encode data first, // and is subject to panic if data contains a fixed sized array. func NewStaticProvider(data interface{}) (Provider, error) { - c, err := toReadCloser(data) + c, err := toReader(data) if err != nil { return nil, err } @@ -54,7 +53,7 @@ func NewStaticProvider(data interface{}) (Provider, error) { func NewStaticProviderWithExpand( data interface{}, mapping func(string) (string, bool)) (Provider, error) { - reader, err := toReadCloser(data) + reader, err := toReader(data) if err != nil { return nil, err } @@ -72,11 +71,11 @@ func (staticProvider) Name() string { return "static" } -func toReadCloser(data interface{}) (io.ReadCloser, error) { +func toReader(data interface{}) (io.Reader, error) { b, err := yaml.Marshal(data) if err != nil { return nil, err } - return ioutil.NopCloser(bytes.NewBuffer(b)), nil + return bytes.NewBuffer(b), nil } diff --git a/yaml.go b/yaml.go index adcf4fa..d9e3b98 100644 --- a/yaml.go +++ b/yaml.go @@ -44,7 +44,7 @@ var ( _emptyDefault = `""` ) -func newYAMLProviderCore(files ...io.ReadCloser) (*yamlConfigProvider, error) { +func newYAMLProviderCore(files ...io.Reader) (*yamlConfigProvider, error) { var root interface{} for _, v := range files { var curr interface{} @@ -134,28 +134,57 @@ func mergeMaps(dst interface{}, src interface{}) (interface{}, error) { // file names. All the objects are going to be merged and arrays/values // overridden in the order of the files. func NewYAMLProviderFromFiles(files ...string) (Provider, error) { - readers, err := filesToReaders(files...) + readClosers, err := filesToReaders(files...) if err != nil { return nil, err } - return newYAMLProviderFromReader(readers...) + readers := make([]io.Reader, len(readClosers)) + for i, r := range readClosers { + readers[i] = r + } + + provider, err := newYAMLProviderFromReader(readers...) + + for _, r := range readClosers { + nerr := r.Close() + if err == nil { + err = nerr + } + } + + return provider, err } // NewYAMLProviderWithExpand creates a configuration provider from a set of YAML // file names with ${var} or $var values replaced based on the mapping function. func NewYAMLProviderWithExpand(mapping func(string) (string, bool), files ...string) (Provider, error) { - readers, err := filesToReaders(files...) + readClosers, err := filesToReaders(files...) if err != nil { return nil, err } - return newYAMLProviderFromReaderWithExpand(mapping, readers...) + readers := make([]io.Reader, len(readClosers)) + for i, r := range readClosers { + readers[i] = r + } + + provider, err := newYAMLProviderFromReaderWithExpand(mapping, + readers...) + + for _, r := range readClosers { + nerr := r.Close() + if err == nil { + err = nerr + } + } + + return provider, err } -// NewYAMLProviderFromReader creates a configuration provider from a list of io.ReadClosers. +// NewYAMLProviderFromReader creates a configuration provider from a list of io.Readers. // As above, all the objects are going to be merged and arrays/values overridden in the order of the files. -func newYAMLProviderFromReader(readers ...io.ReadCloser) (Provider, error) { +func newYAMLProviderFromReader(readers ...io.Reader) (Provider, error) { p, err := newYAMLProviderCore(readers...) if err != nil { return nil, err @@ -165,11 +194,11 @@ func newYAMLProviderFromReader(readers ...io.ReadCloser) (Provider, error) { } // NewYAMLProviderFromReaderWithExpand creates a configuration provider from -// a list of `io.ReadClosers and uses the mapping function to expand values +// a list of `io.Readers and uses the mapping function to expand values // in the underlying provider. func newYAMLProviderFromReaderWithExpand( mapping func(string) (string, bool), - readers ...io.ReadCloser) (Provider, error) { + readers ...io.Reader) (Provider, error) { p, err := newYAMLProviderCore(readers...) if err != nil { return nil, err @@ -186,12 +215,12 @@ func newYAMLProviderFromReaderWithExpand( // blobs. As above, all the objects are going to be merged and arrays/values // overridden in the order of the yamls. func NewYAMLProviderFromBytes(yamls ...[]byte) (Provider, error) { - closers := make([]io.ReadCloser, len(yamls)) + readers := make([]io.Reader, len(yamls)) for i, yml := range yamls { - closers[i] = ioutil.NopCloser(bytes.NewReader(yml)) + readers[i] = bytes.NewReader(yml) } - return newYAMLProviderFromReader(closers...) + return newYAMLProviderFromReader(readers...) } func filesToReaders(files ...string) ([]io.ReadCloser, error) { @@ -200,6 +229,9 @@ func filesToReaders(files ...string) ([]io.ReadCloser, error) { for _, v := range files { if reader, err := os.Open(v); err != nil { + for _, r := range readers { + r.Close() + } return nil, err } else if reader != nil { readers = append(readers, reader) @@ -363,17 +395,13 @@ func (n *yamlNode) applyOnAllNodes(expand func(string) string) (err error) { return } -func unmarshalYAMLValue(reader io.ReadCloser, value interface{}) error { +func unmarshalYAMLValue(reader io.Reader, value interface{}) error { raw, err := ioutil.ReadAll(reader) if err != nil { return errors.Wrap(err, "failed to read the yaml config") } - if err = yaml.Unmarshal(raw, value); err != nil { - return err - } - - return reader.Close() + return yaml.Unmarshal(raw, value) } // Function to expand environment variables in returned values that have form: ${ENV_VAR:DEFAULT_VALUE}. diff --git a/yaml_test.go b/yaml_test.go index cbf4490..9cc5bcb 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -76,7 +76,7 @@ module: fake: number: ${FAKE_NUMBER:321}`) - p, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + p, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.NoError(t, err, "Can't create a YAML provider") require.Equal(t, "321", p.Get("module.fake.number").String()) @@ -92,7 +92,7 @@ name: some name here email: ${EMAIL_ADDRESS}`) f := func(string) (string, bool) { return "", false } - _, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + _, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.Error(t, err) assert.Contains(t, err.Error(), `default is empty for "EMAIL_ADDRESS"`) } @@ -105,7 +105,7 @@ name: some name here telephone: ${SUPPORT_TEL:}`) f := func(string) (string, bool) { return "", false } - _, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + _, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.Error(t, err) assert.Contains(t, err.Error(), `default is empty for "SUPPORT_TEL" (use "" for empty string)`) @@ -119,7 +119,7 @@ func TestYAMLEnvInterpolationWithColon(t *testing.T) { return "", false } - p, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + p, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.NoError(t, err, "Can't create a YAML provider") require.Equal(t, "this:is:my:value", p.Get("fullValue").String()) @@ -133,7 +133,7 @@ name: ${APP_NAME:my shiny app} fullTel: 1-800-LOLZ${TELEPHONE_EXTENSION:""}`) f := func(string) (string, bool) { return "", false } - p, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + p, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.NoError(t, err, "Can't create a YAML provider") require.Equal(t, "my shiny app", p.Get("name").String()) @@ -218,7 +218,7 @@ func TestNewYAMLProviderFromReader(t *testing.T) { t.Parallel() buff := bytes.NewBuffer([]byte(_yamlConfig1)) - provider, err := newYAMLProviderFromReader(ioutil.NopCloser(buff)) + provider, err := newYAMLProviderFromReader(buff) require.NoError(t, err, "Can't create a YAML provider") cs := &configStruct{} @@ -232,7 +232,7 @@ func TestYAMLNode(t *testing.T) { buff := bytes.NewBuffer([]byte("a: b")) node := &yamlNode{value: make(map[interface{}]interface{})} - require.NoError(t, unmarshalYAMLValue(ioutil.NopCloser(buff), &node.value)) + require.NoError(t, unmarshalYAMLValue(buff, &node.value)) assert.Equal(t, "map[a:b]", node.String()) assert.Equal(t, "map[interface {}]interface {}", node.Type().String()) @@ -1196,7 +1196,7 @@ func TestYAMLEnvInterpolationValueMissing(t *testing.T) { cfg := strings.NewReader(`name:`) f := func(string) (string, bool) { return "", false } - p, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + p, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.NoError(t, err, "Can't create a YAML provider") assert.Equal(t, nil, p.Get("name").Value()) } @@ -1211,7 +1211,7 @@ func TestYAMLEnvInterpolationValueConversion(t *testing.T) { return "3", true } - p, err := newYAMLProviderFromReaderWithExpand(f, ioutil.NopCloser(cfg)) + p, err := newYAMLProviderFromReaderWithExpand(f, cfg) require.NoError(t, err, "Can't create a YAML provider") assert.Equal(t, "3", p.Get("number").String()) @@ -1749,10 +1749,10 @@ func TestMergeErrorsFromReaders(t *testing.T) { t.Parallel() t.Run("regular", func(t *testing.T) { - base := ioutil.NopCloser(strings.NewReader(`a: - - b`)) - dev := ioutil.NopCloser(strings.NewReader(`a: - b: c`)) + base := strings.NewReader(`a: + - b`) + dev := strings.NewReader(`a: + b: c`) _, err := newYAMLProviderFromReader(base, dev) require.Error(t, err) @@ -1762,10 +1762,10 @@ func TestMergeErrorsFromReaders(t *testing.T) { t.Run("expand", func(t *testing.T) { expand := func(string) (string, bool) { return "", false } - base := ioutil.NopCloser(strings.NewReader(`a: - - b`)) - dev := ioutil.NopCloser(strings.NewReader(`a: - b: c`)) + base := strings.NewReader(`a: + - b`) + dev := strings.NewReader(`a: + b: c`) _, err := newYAMLProviderFromReaderWithExpand(expand, base, dev) require.Error(t, err) @@ -1805,8 +1805,8 @@ func TestMergeErrorsFromFiles(t *testing.T) { require.NoError(t, err, "Can't read dev file") _, err = newYAMLProviderFromReader( - ioutil.NopCloser(bytes.NewBuffer(b)), - ioutil.NopCloser(bytes.NewBuffer(d))) + bytes.NewBuffer(b), + bytes.NewBuffer(d)) require.Error(t, err) assert.Contains(t, err.Error(), "can't merge map") @@ -1831,8 +1831,8 @@ func TestMergeErrorsFromFiles(t *testing.T) { _, err = newYAMLProviderFromReaderWithExpand( expand, - ioutil.NopCloser(bytes.NewBuffer(b)), - ioutil.NopCloser(bytes.NewBuffer(d))) + bytes.NewBuffer(b), + bytes.NewBuffer(d)) require.Error(t, err) assert.Contains(t, err.Error(), "can't merge map") @@ -1843,13 +1843,15 @@ func TestYAMLProviderWithGarbledPath(t *testing.T) { t.Parallel() t.Run("regular", func(t *testing.T) { - _, err := NewYAMLProviderFromFiles("/some/nonexisting/config") + _, err := NewYAMLProviderFromFiles("./testdata/base.yaml", + "/some/nonexisting/config") require.Error(t, err) assert.Contains(t, err.Error(), "no such file or directory") }) t.Run("expand", func(t *testing.T) { - _, err := NewYAMLProviderWithExpand(nil, "/some/nonexisting/config") + _, err := NewYAMLProviderWithExpand(nil, "./testdata/base.yaml", + "/some/nonexisting/config") require.Error(t, err) assert.Contains(t, err.Error(), "no such file or directory") }) From 878c41277282fe44219abd5eb90707189cdae21d Mon Sep 17 00:00:00 2001 From: Brandon Casey Date: Wed, 9 Aug 2017 18:27:28 -0700 Subject: [PATCH 3/4] Demonstrate that YAML providers produce different data types Currently, the non-expanding YAML provider and the one that can expand the shell-like variable syntax produce different output from the same input even when no expansion is involved. This is because the expanding YAML provider converts all fields to string as a side-effect of its operation. Let's add some tests to demonstrate this incongruity. --- testdata/base.yaml | 7 +++++++ yaml_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/testdata/base.yaml b/testdata/base.yaml index 68c1795..4ae39d9 100644 --- a/testdata/base.yaml +++ b/testdata/base.yaml @@ -1,2 +1,9 @@ value: base_only value_override: base_setting +a-bool: true +a-empty: "" +a-float: 3.14 +a-int: 3 +a-nil: +a-null: null +a-string: "3.14" diff --git a/yaml_test.go b/yaml_test.go index 9cc5bcb..675c258 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -1745,6 +1745,31 @@ func TestNewYamlProviderWithExpand(t *testing.T) { assert.Equal(t, "base_only", baseValue) } +func TestYamlProvidersProduceSameResults(t *testing.T) { + t.Parallel() + + p, err := NewYAMLProviderFromFiles("./testdata/base.yaml") + require.NoError(t, err, "Can't create a YAML provider") + + pp, err := NewYAMLProviderWithExpand(nil, "./testdata/base.yaml") + require.NoError(t, err, "Can't create a YAML provider with expand") + + assert.IsType(t, true, p.Get("a-bool").Value()) + assert.Exactly(t, p.Get("a-bool").Value(), pp.Get("a-bool").Value()) + assert.IsType(t, "empty", p.Get("a-empty").Value()) + assert.Exactly(t, p.Get("a-empty").Value(), pp.Get("a-empty").Value()) + assert.IsType(t, float64(1.2), p.Get("a-float").Value()) + assert.Exactly(t, p.Get("a-float").Value(), pp.Get("a-float").Value()) + assert.IsType(t, int(12), p.Get("a-int").Value()) + assert.Exactly(t, p.Get("a-int").Value(), pp.Get("a-int").Value()) + assert.IsType(t, nil, p.Get("a-nil").Value()) + assert.Exactly(t, p.Get("a-nil").Value(), pp.Get("a-nil").Value()) + assert.IsType(t, nil, p.Get("a-null").Value()) + assert.Exactly(t, p.Get("a-null").Value(), pp.Get("a-null").Value()) + assert.IsType(t, "string", p.Get("a-string").Value()) + assert.Exactly(t, p.Get("a-string").Value(), pp.Get("a-string").Value()) +} + func TestMergeErrorsFromReaders(t *testing.T) { t.Parallel() From bc0991eb33ac4945034272fc4d484f63a6fe1279 Mon Sep 17 00:00:00 2001 From: Brandon Casey Date: Wed, 9 Aug 2017 16:16:00 -0700 Subject: [PATCH 4/4] Introduce a text.Transformer to perform variable expansion Currently expansion of variables using a shell-like syntax like ${foo} or $foo is supported, but it is performed on the unmarshalled data structures _after_ YAML parsing. A side-effect of this is that any data type resolution that the YAML parser may perform is lost and the resulting fields are all represented as strings. Since the non-expanding YAML provider does not perform this filtering step on the result of the YAML unmarshaller, it does not suffer from this problem, but it does mean that the data structures constructed by the non-expanding YAML provider differ from the expanding YAML provider. Let's introduce a text.Transformer to perform the parsing of the shell-like syntax and apply the expansion function on the variables. This allows us to simplify the code paths in yaml.go and also to define an escape sequence '$$' that can be used to avoid the variable expansion when a literal '$foo' is desired. Since we're no longer using os.Expand, we can modify the interface of the function produced by replace() so that it can return an error on failure instead of using panic/recover. Update tests and add a few new ones. This fixes the tests in TestYamlProvidersProduceSameResults in yaml_test.go --- CHANGELOG.md | 7 + expand.go | 197 +++++++++++++++++++++++++++ expand_test.go | 295 ++++++++++++++++++++++++++++++++++++++++ glide.lock | 12 +- glide.yaml | 3 + static_provider_test.go | 31 ++++- yaml.go | 90 ++++-------- yaml_test.go | 5 +- 8 files changed, 571 insertions(+), 69 deletions(-) create mode 100644 expand.go create mode 100644 expand_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index b08c923..348ff99 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## v1.1.0 (unreleased) + +- Expand functions transform a special sequence $$ to literal $. +- The underlying objects encapsulated by config.Value types will now + have the types determined by the YAML unmarshaller regardless of + whether expansion was performed or not. + ## v1.0.1 (2017-08-04) - Fixed unmarshal text on missing value. diff --git a/expand.go b/expand.go new file mode 100644 index 0000000..05f9478 --- /dev/null +++ b/expand.go @@ -0,0 +1,197 @@ +// Copyright (c) 2017 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package config + +import ( + "bytes" + + "golang.org/x/text/transform" +) + +// expandTransformer implements transform.Transformer +type expandTransformer struct { + transform.NopResetter + expand func(string) (string, error) +} + +// First char of shell variable may be [a-zA-Z_] +func isShellNameFirstChar(c byte) bool { + return c == '_' || + (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') +} + +// Char's after the first of shell variable may be [a-zA-Z0-9_] +func isShellNameChar(c byte) bool { + return c == '_' || + (c >= 'a' && c <= 'z') || + (c >= 'A' && c <= 'Z') || + (c >= '0' && c <= '9') +} + +// bytesIndexCFunc returns the index of the byte for which the +// complement of the supplied function is true +func bytesIndexCFunc(buf []byte, f func(b byte) bool) int { + for i, b := range buf { + if !f(b) { + return i + } + } + return -1 +} + +// Transform expands shell-like sequences like $foo and ${foo} using +// the configured expand function. The sequence '$$' is replaced with +// a literal '$'. +func (e *expandTransformer) Transform(dst, src []byte, atEOF bool) (int, int, error) { + var srcPos int + var dstPos int + + for srcPos < len(src) { + + if dstPos == len(dst) { + return dstPos, srcPos, transform.ErrShortDst + } + + end := bytes.IndexByte(src[srcPos:], '$') + + if end == -1 { + // src does not contain '$', copy into dst + cnt := copy(dst[dstPos:], src[srcPos:]) + srcPos += cnt + dstPos += cnt + continue + } else if end > 0 { + // copy chars preceding '$' from src to dst + cnt := copy(dst[dstPos:], src[srcPos:srcPos+end]) + srcPos += cnt + dstPos += cnt + + if dstPos == len(dst) { + return dstPos, srcPos, transform.ErrShortDst + } + } + + // src[srcPos] now points to '$', dstPos < len(dst) + + // If we're at the end of src, but we found a starting + // token, return ErrShortSrc, unless we're also at EOF, + // in which case just copy it dst. + if srcPos+1 == len(src) { + if atEOF { + dst[dstPos] = src[srcPos] + srcPos++ + dstPos++ + continue + } + return dstPos, srcPos, transform.ErrShortSrc + } + + // At this point we know that src[srcPos+1] is populated. + + // If this token sequence represents the special '$$' + // sequence, emit a '$' into dst. + if src[srcPos+1] == '$' { + dst[dstPos] = src[srcPos] + srcPos += 2 + dstPos++ + continue + } + + var token []byte + var tokenEnd int + + // Start of bracketed token ${foo} + if src[srcPos+1] == '{' { + end := bytes.IndexByte(src[srcPos+2:], '}') + if end == -1 { + if atEOF { + // No closing bracket and we're at + // EOF, so it's not a valid bracket + // expression. + if len(dst[dstPos:]) < + len(src[srcPos:]) { + return dstPos, srcPos, + transform.ErrShortDst + } + + cnt := copy(dst[dstPos:], src[srcPos:]) + srcPos += cnt + dstPos += cnt + continue + } + + // Otherwise, we need more bytes in src + return dstPos, srcPos, transform.ErrShortSrc + } + + // Set tokenEnd so it points to the byte + // immediately after the closing '}' + tokenEnd = end + srcPos + 3 + + token = src[srcPos+2 : tokenEnd-1] + } else { // Else start of non-bracketed token $foo + if !isShellNameFirstChar(src[srcPos+1]) { + // If it doesn't conform to the naming + // rules for shell variables, do not + // try to expand, just copy to dst. + dst[dstPos] = src[srcPos] + srcPos++ + dstPos++ + continue + } + + end := bytesIndexCFunc(src[srcPos+2:], isShellNameChar) + + if end == -1 { + // Reached the end of src without finding + // end of shell variable + if !atEOF { + // We need more bytes in src + return dstPos, srcPos, + transform.ErrShortSrc + } + tokenEnd = len(src) + } else { + // Set tokenEnd so it points to the byte + // immediately after the token + tokenEnd = end + srcPos + 2 + } + + token = src[srcPos+1 : tokenEnd] + } + + replacement, err := e.expand(string(token)) + if err != nil { + return dstPos, srcPos, err + } + + if len(dst[dstPos:]) < len(replacement) { + return dstPos, srcPos, transform.ErrShortDst + } + + cnt := copy(dst[dstPos:], replacement) + srcPos = tokenEnd + dstPos += cnt + } + + return dstPos, srcPos, nil +} diff --git a/expand_test.go b/expand_test.go new file mode 100644 index 0000000..3a25fd4 --- /dev/null +++ b/expand_test.go @@ -0,0 +1,295 @@ +// Copyright (c) 2017 Uber Technologies, Inc. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +package config + +import ( + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/text/transform" +) + +// Size of buffer used by the transform package. +const transformBufSize = 4096 + +const orig = `This is a $t3sT$. $$ This is a $$test. + This is not a valid $0ne. But this one $i5_@_valid-one. + $$$$$$$ +${parti` + +const expected = `This is a test$. $ This is a $test. + This is not a valid $0ne. But this one is @_valid-one. + $$$$ +${parti` + +const ends_in_dollar = `There is a dollar at the end$` +const ends_in_ddollar = `There is a dollar at the end$$` +const many_dollars_orig = `$$$$$$$` +const many_dollars_expect = `$$$$` +const ends_in_var = `There is a test at the end: $t3sT` +const ends_in_var_expect = `There is a test at the end: test` + +type oneByteReader struct { + r io.Reader +} + +func (e *oneByteReader) Read(buf []byte) (n int, err error) { + var b [1]byte + + if len(buf) > 0 { + n, err = e.r.Read(b[:]) + buf[0] = b[0] + } + + return +} + +type bufReader struct { + buf []byte + offset int +} + +// Similar to a bytes.Reader except this Reader returns EOF in the same +// Read that reads the end of the buffer. +func (e *bufReader) Read(buf []byte) (int, error) { + var err error + n := copy(buf, e.buf[e.offset:]) + e.offset += n + if e.offset == len(e.buf) { + err = io.EOF + } + return n, err +} + +func TestExpander(t *testing.T) { + t.Parallel() + + r := bytes.NewReader([]byte(orig)) + + expand_func := func(s string) (string, error) { + switch s { + case "t3sT": + return "test", nil + case "i5_": + return "is ", nil + } + + return "NOMATCH", errors.New("No Match") + } + + // Parse whole string + tr := transform.NewReader(r, &expandTransformer{expand: expand_func}) + actual, err := ioutil.ReadAll(tr) + require.NoError(t, err) + assert.Exactly(t, expected, string(actual)) + + _, err = r.Seek(0, io.SeekStart) + require.NoError(t, err) + + // Partial parse + var partial [11]byte + tr = transform.NewReader(r, &expandTransformer{expand: expand_func}) + n, err := tr.Read(partial[:]) + require.NoError(t, err) + assert.Exactly(t, n, len(partial)) + assert.Exactly(t, expected[:n], string(partial[:])) + + // Empty string + r = bytes.NewReader([]byte{}) + tr = transform.NewReader(r, &expandTransformer{expand: expand_func}) + actual, err = ioutil.ReadAll(tr) + require.NoError(t, err) + assert.Exactly(t, "", string(actual)) +} + +func TestExpanderOneByteAtATime(t *testing.T) { + t.Parallel() + + r := bytes.NewReader([]byte(orig)) + rr := &oneByteReader{r: r} + + expand_func := func(s string) (string, error) { + switch s { + case "t3sT": + return "test", nil + case "i5_": + return "is ", nil + } + + return "NOMATCH", errors.New("No Match") + } + + tr := transform.NewReader(rr, &expandTransformer{expand: expand_func}) + actual, err := ioutil.ReadAll(tr) + require.NoError(t, err) + assert.Exactly(t, expected, string(actual)) +} + +func TestExpanderFailingTransform(t *testing.T) { + t.Parallel() + + r := bytes.NewReader([]byte(orig)) + + expand_func := func(s string) (string, error) { + switch s { + case "t3sT": + return "test", nil + // missing "i5_" case + } + + return "NOMATCH", errors.New("No Match") + } + + tr := transform.NewReader(r, &expandTransformer{expand: expand_func}) + _, err := ioutil.ReadAll(tr) + require.Error(t, err) +} + +func TestExpanderMisc(t *testing.T) { + t.Parallel() + + tests := [...]struct { + orig string + expect string + }{ + {ends_in_dollar, ends_in_dollar}, + {ends_in_ddollar, ends_in_dollar}, + {ends_in_var, ends_in_var_expect}, + {many_dollars_orig, many_dollars_expect}, + } + + expand_func := func(s string) (string, error) { + switch s { + case "t3sT": + return "test", nil + // missing "i5_" case + } + + return "NOMATCH", errors.New("No Match") + } + + for i, tst := range tests { + tst := tst + t.Run(fmt.Sprintf("sub=%d", i), + func(t *testing.T) { + t.Parallel() + tr := transform.NewReader( + bytes.NewReader([]byte(tst.orig)), + &expandTransformer{expand: expand_func}, + ) + actual, err := ioutil.ReadAll(tr) + require.NoError(t, err) + assert.Exactly(t, tst.expect, string(actual)) + }, + ) + } +} + +func TestExpanderLongSrc(t *testing.T) { + t.Parallel() + + a := strings.Repeat("a", transformBufSize-1) + + tests := [...]struct { + orig string + expect string + }{ + {"foo${a}" + a, "foo" + a + a}, + {"${a}foo$a", a + "foo" + a}, + {"$a${", a + "${"}, + } + + expand_func := func(s string) (string, error) { + switch s { + case "a": + return a, nil + } + + return "NOMATCH", errors.New("No Match") + } + + for i, tst := range tests { + tst := tst + t.Run(fmt.Sprintf("sub=%d", i), + func(t *testing.T) { + t.Parallel() + tr := transform.NewReader( + &bufReader{buf: []byte(tst.orig)}, + &expandTransformer{expand: expand_func}, + ) + actual, err := ioutil.ReadAll(tr) + require.NoError(t, err) + assert.Exactly(t, tst.expect, string(actual)) + }, + ) + } +} + +func TestTransformLimit(t *testing.T) { + t.Parallel() + + a := strings.Repeat("a", transformBufSize-1) + + // The transform package uses an internal fixed-size buffer. + // These tests are expected to fail with specific errors when + // that buffer is exceeded. If they stop failing, then other + // tests (above) have likely stopped working correctly too. + tests := [...]struct { + orig string + err error + }{ + {"$a", transform.ErrShortDst}, + {"$" + a, transform.ErrShortSrc}, + } + + expand_func := func(s string) (string, error) { + switch s { + case "a": + return a + "aa", nil + case a: + return "a", nil + } + + return "NOMATCH", errors.New("No Match") + } + + for i, tst := range tests { + tst := tst + t.Run(fmt.Sprintf("sub=%d", i), + func(t *testing.T) { + t.Parallel() + tr := transform.NewReader( + bytes.NewReader([]byte(tst.orig)), + &expandTransformer{expand: expand_func}, + ) + _, err := ioutil.ReadAll(tr) + require.EqualError(t, err, tst.err.Error()) + }, + ) + } +} diff --git a/glide.lock b/glide.lock index b743924..8fc7259 100644 --- a/glide.lock +++ b/glide.lock @@ -1,21 +1,25 @@ -hash: bc86f2e1ca95a651be64c94f23db29be3b606272d03bc66bb8c8efca2b0a3545 -updated: 2017-07-31T12:01:52.034767994-07:00 +hash: fdcf479340c55a75cf42a047276f952cd2847537e12028db6c017479e45ea626 +updated: 2017-08-09T17:47:23.192205998-07:00 imports: - name: github.com/pkg/errors version: 645ef00459ed84a119197bfb8d8205042c6df63d +- name: golang.org/x/text + version: 836efe42bb4aa16aaa17b9c155d8813d336ed720 + subpackages: + - transform - name: gopkg.in/validator.v2 version: 07ffaad256c8e957050ad83d6472eb97d785013d - name: gopkg.in/yaml.v2 version: 25c4ec802a7d637f88d584ab26798e94ad14c13b testImports: - name: github.com/davecgh/go-spew - version: adab96458c51a58dc1783b3335dcce5461522e75 + version: 6d212800a42e8ab5c146b8ace3490ee17e5225f9 subpackages: - spew - name: github.com/google/gofuzz version: 24818f796faf91cd76ec7bddd72458fbced7a6c1 - name: github.com/pmezard/go-difflib - version: 792786c7400a136282c1664665ae0a8db921c6c2 + version: d8ed2627bdf02c080bf22230dbb337003b7aba2d subpackages: - difflib - name: github.com/spf13/cast diff --git a/glide.yaml b/glide.yaml index f0cbc12..3556b56 100644 --- a/glide.yaml +++ b/glide.yaml @@ -4,6 +4,9 @@ import: - package: gopkg.in/yaml.v2 - package: github.com/pkg/errors version: ~0.8.0 +- package: golang.org/x/text + subpackages: + - transform testImport: - package: github.com/google/gofuzz - package: github.com/stretchr/testify diff --git a/static_provider_test.go b/static_provider_test.go index 1ffdc79..c640baa 100644 --- a/static_provider_test.go +++ b/static_provider_test.go @@ -258,12 +258,16 @@ func TestStaticProviderWithExpand(t *testing.T) { p, err := NewStaticProviderWithExpand(map[string]interface{}{ "slice": []interface{}{"one", "${iTwo:2}"}, "value": `${iValue:""}`, + "empty": `${iEmpty:""}`, + "two": `${iTwo:2}`, "map": map[string]interface{}{ "drink?": "${iMap:tea?}", "tea?": "with cream", }, }, func(key string) (string, bool) { switch key { + case "iEmpty": + return "\"\"", true case "iValue": return "null", true case "iTwo": @@ -279,12 +283,37 @@ func TestStaticProviderWithExpand(t *testing.T) { assert.Equal(t, "one", p.Get("slice.0").String()) assert.Equal(t, "3", p.Get("slice.1").String()) - assert.Equal(t, "null", p.Get("value").Value()) + assert.Equal(t, nil, p.Get("value").Value()) + assert.Equal(t, "", p.Get("empty").Value()) + assert.Equal(t, int(3), p.Get("two").Value()) assert.Equal(t, "rum please!", p.Get("map.drink?").String()) assert.Equal(t, "with cream", p.Get("map.tea?").String()) } +func TestStaticProviderWithExpandEscapeHandling(t *testing.T) { + t.Parallel() + + p, err := NewStaticProviderWithExpand(map[string]interface{}{ + "nil": "a", + "one": "$a", + "two": "$$a", + "three": "$$$a", + "four": "$$$$a", + "five": "$$$$$a", + }, func(key string) (string, bool) { + return fmt.Sprint(len(key)), true + }) + + require.NoError(t, err, "can't create a static provider") + assert.Equal(t, "a", p.Get("nil").String()) + assert.Equal(t, "1", p.Get("one").String()) + assert.Equal(t, "$a", p.Get("two").String()) + assert.Equal(t, "$1", p.Get("three").String()) + assert.Equal(t, "$$a", p.Get("four").String()) + assert.Equal(t, "$$1", p.Get("five").String()) +} + func TestPopulateForMapOfDifferentKeyTypes(t *testing.T) { t.Parallel() diff --git a/yaml.go b/yaml.go index d9e3b98..da07cf9 100644 --- a/yaml.go +++ b/yaml.go @@ -27,11 +27,11 @@ import ( "io/ioutil" "os" "reflect" - "runtime" "strconv" "strings" "github.com/pkg/errors" + "golang.org/x/text/transform" "gopkg.in/yaml.v2" ) @@ -158,6 +158,18 @@ func NewYAMLProviderFromFiles(files ...string) (Provider, error) { // NewYAMLProviderWithExpand creates a configuration provider from a set of YAML // file names with ${var} or $var values replaced based on the mapping function. +// Variable names not wrapped in curly braces will be parsed according +// to the shell variable naming rules: +// +// ...a word consisting solely of underscores, digits, and +// alphabetics from the portable character set. The first +// character of a name is not a digit. +// +// For variables wrapped in braces, all characters between '{' and '}' +// will be passed to the expand function. e.g. "${foo:13}" will cause +// "foo:13" to be passed to the expand function. The sequence '$$' will +// be replaced by a literal '$'. All other sequences will be ignored +// for expansion purposes. func NewYAMLProviderWithExpand(mapping func(string) (string, bool), files ...string) (Provider, error) { readClosers, err := filesToReaders(files...) if err != nil { @@ -199,16 +211,17 @@ func newYAMLProviderFromReader(readers ...io.Reader) (Provider, error) { func newYAMLProviderFromReaderWithExpand( mapping func(string) (string, bool), readers ...io.Reader) (Provider, error) { - p, err := newYAMLProviderCore(readers...) - if err != nil { - return nil, err - } - if err := p.root.applyOnAllNodes(replace(mapping)); err != nil { - return nil, err + expandFunc := replace(mapping) + + ereaders := make([]io.Reader, len(readers)) + for i, reader := range readers { + ereaders[i] = transform.NewReader( + reader, + &expandTransformer{expand: expandFunc}) } - return newCachedProvider(p) + return newYAMLProviderFromReader(ereaders...) } // NewYAMLProviderFromBytes creates a config provider from a byte-backed YAML @@ -348,53 +361,6 @@ func (n *yamlNode) Children() []*yamlNode { return nodes } -// Apply expand to all nested elements of a node. -// There is no need to use reflection, because YAML unmarshaler is using -// map[interface{}]interface{} to store objects and []interface{} -// to store collections. -func recursiveApply(node interface{}, expand func(string) string) interface{} { - if node == nil { - return nil - } - switch t := node.(type) { - case map[interface{}]interface{}: - for k := range t { - t[k] = recursiveApply(t[k], expand) - } - return t - case []interface{}: - for i := range t { - t[i] = recursiveApply(t[i], expand) - } - return t - } - - return os.Expand(fmt.Sprint(node), expand) -} - -func (n *yamlNode) applyOnAllNodes(expand func(string) string) (err error) { - - defer func() { - if r := recover(); r != nil { - if _, ok := r.(runtime.Error); ok { - panic(r) - } - - err = r.(error) - } - }() - - n.value = recursiveApply(n.value, expand) - - for _, c := range n.Children() { - if err := c.applyOnAllNodes(expand); err != nil { - return err - } - } - - return -} - func unmarshalYAMLValue(reader io.Reader, value interface{}) error { raw, err := ioutil.ReadAll(reader) if err != nil { @@ -414,10 +380,8 @@ func unmarshalYAMLValue(reader io.Reader, value interface{}) error { // // In the case that HTTP_PORT is not provided, default value (in this case 8080) // will be used. -// -// TODO: what if someone wanted a literal ${FOO} in config? need a small escape hatch -func replace(lookUp func(string) (string, bool)) func(in string) string { - return func(in string) string { +func replace(lookUp func(string) (string, bool)) func(in string) (string, error) { + return func(in string) (string, error) { sep := strings.Index(in, _envSeparator) var key string var def string @@ -432,16 +396,16 @@ func replace(lookUp func(string) (string, bool)) func(in string) string { } if envVal, ok := lookUp(key); ok { - return envVal + return envVal, nil } if def == "" { - panic(fmt.Errorf(`default is empty for %q (use "" for empty string)`, key)) + return "", fmt.Errorf(`default is empty for %q (use "" for empty string)`, key) } else if def == _emptyDefault { - return "" + return "", nil } - return def + return def, nil } } diff --git a/yaml_test.go b/yaml_test.go index 675c258..2c5326a 100644 --- a/yaml_test.go +++ b/yaml_test.go @@ -70,7 +70,7 @@ func TestYAMLEnvInterpolation(t *testing.T) { } cfg := strings.NewReader(` -name: some name here +name: some $$name here owner: ${OWNER_EMAIL} module: fake: @@ -82,6 +82,9 @@ module: owner := p.Get("owner").String() require.Equal(t, "hello@there.yasss", owner) + + name := p.Get("name").String() + require.Equal(t, "some $name here", name) } func TestYAMLEnvInterpolationMissing(t *testing.T) {