diff --git a/options.go b/options.go index efd634fd..df05ebcd 100644 --- a/options.go +++ b/options.go @@ -256,6 +256,53 @@ func ParseURL(sURL string) (*Options, error) { return options, nil } +func (opts *Options) ToURL() string { + dsn := "postgres://" + + if len(opts.User) > 0 { + dsn += opts.User + + if len(opts.Password) > 0 { + dsn += ":" + opts.Password + } + + dsn += "@" + } + + if len(opts.Addr) > 0 { + dsn += opts.Addr + } else { + dsn += "localhost:5432" + } + + dsn += "/" + opts.Database + + values := url.Values{} + + if opts.DialTimeout > 0 { + values.Add("connect_timeout", strconv.Itoa(int(opts.DialTimeout)/int(time.Second))) + } + + if len(opts.ApplicationName) > 0 { + values.Add("application_name", opts.ApplicationName) + } + + if opts.TLSConfig == nil { + values.Add("sslmode", "disable") + } else if opts.TLSConfig.InsecureSkipVerify { + values.Add("sslmode", "allow") + } else if !opts.TLSConfig.InsecureSkipVerify { + values.Add("sslmode", "verify-ca") + } + + encoded := values.Encode() + if len(encoded) > 0 { + dsn += "?" + encoded + } + + return dsn +} + func (opt *Options) getDialer() func(context.Context) (net.Conn, error) { return func(ctx context.Context) (net.Conn, error) { return opt.Dialer(ctx, opt.Network, opt.Addr) diff --git a/options_test.go b/options_test.go index 12178729..cd7772ab 100644 --- a/options_test.go +++ b/options_test.go @@ -7,6 +7,8 @@ import ( "errors" "testing" "time" + + "github.com/stretchr/testify/assert" ) func TestParseURL(t *testing.T) { @@ -261,3 +263,54 @@ func TestParseURL(t *testing.T) { }) } } + +func TestOptions_ToURL(t *testing.T) { + tests := []struct { + name string + opts *Options + expected string + }{ + {"Empty", &Options{Database: "postgres"}, "postgres://localhost:5432/postgres?sslmode=disable"}, + {"User", &Options{Database: "postgres", User: "postgres"}, "postgres://postgres@localhost:5432/postgres?sslmode=disable"}, + {"UserPass", &Options{Database: "postgres", User: "postgres", Password: "password"}, "postgres://postgres:password@localhost:5432/postgres?sslmode=disable"}, + {"UserPassAddr", &Options{Database: "postgres", User: "postgres", Password: "password", Addr: "somewhere:1234"}, "postgres://postgres:password@somewhere:1234/postgres?sslmode=disable"}, + {"UserPassAddrAppl", &Options{Database: "postgres", User: "postgres", Password: "password", Addr: "somewhere:1234", ApplicationName: "test"}, "postgres://postgres:password@somewhere:1234/postgres?application_name=test&sslmode=disable"}, + {"UserPassAddrApplTimeout", &Options{Database: "postgres", User: "postgres", Password: "password", Addr: "somewhere:1234", ApplicationName: "test", DialTimeout: time.Second}, "postgres://postgres:password@somewhere:1234/postgres?application_name=test&connect_timeout=1&sslmode=disable"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert := assert.New(t) + + actual := tt.opts.ToURL() + assert.Equal(tt.expected, actual) + + _, err := ParseURL(actual) + assert.NoError(err) + }) + } +} + +func TestOptions_ToURL_reparsable(t *testing.T) { + assert := assert.New(t) + + opts := &Options{ + Database: "postgres", + User: "postgres", + Password: "password", + Addr: "somewhere:1234", + ApplicationName: "test", + DialTimeout: time.Second, + } + + url := opts.ToURL() + + actual, err := ParseURL(url) + assert.NoError(err) + + assert.Equal(opts.Database, actual.Database) + assert.Equal(opts.User, actual.User) + assert.Equal(opts.Password, actual.Password) + assert.Equal(opts.Addr, actual.Addr) + assert.Equal(opts.ApplicationName, actual.ApplicationName) + assert.Equal(opts.DialTimeout, actual.DialTimeout) +}