From 71f3b0be8a65d3b297260cdb6cf14c9ce91488e1 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Thu, 21 Oct 2021 10:31:01 +0200 Subject: [PATCH 01/15] Extend util package with helper functions. - Add MIME interface to use mimtype.MIME or an already KnownMIME - Add function to detach context to avoid context cancelation. Can be used to pass context information to go routines without a deadline or cancel - Add oauth2 helper for PKCE extention --- internal/api/httperrors/common.go | 7 ++ internal/util/bool.go | 10 +++ internal/util/bool_test.go | 16 ++++ internal/util/context.go | 19 ++++- internal/util/context_test.go | 83 +++++++++++++++++++ internal/util/currency.go | 4 + internal/util/currency_test.go | 2 + internal/util/db/db.go | 22 +++++ internal/util/db/db_test.go | 21 ++++- internal/util/db/ilike.go | 26 ++++++ internal/util/db/ilike_test.go | 20 +++++ internal/util/db/where_in.go | 16 ++++ internal/util/db/where_in_test.go | 25 ++++++ internal/util/env.go | 21 +++++ internal/util/env_test.go | 18 ++++ internal/util/http.go | 13 ++- internal/util/http_test.go | 4 +- internal/util/int.go | 23 +++++ internal/util/int_test.go | 22 +++++ internal/util/lang.go | 22 +++++ internal/util/lang_test.go | 34 ++++++++ internal/util/mime/mime.go | 29 +++++++ internal/util/oauth2/pkce.go | 32 +++++++ internal/util/oauth2/pkce_test.go | 16 ++++ internal/util/path.go | 42 ++++++++++ internal/util/path_test.go | 42 ++++++++++ internal/util/string.go | 70 ++++++++++++++++ internal/util/string_test.go | 35 ++++++++ internal/util/time.go | 34 +++++++- internal/util/time_test.go | 42 ++++++++++ .../snapshots/TestILikeSearchArgs.golden | 4 + .../snapshots/TestILikeSearchSQL.golden | 1 + .../testdata/snapshots/TestWhereInArgs.golden | 5 ++ test/testdata/snapshots/TestWhereInSQL.golden | 1 + 34 files changed, 771 insertions(+), 10 deletions(-) create mode 100644 internal/api/httperrors/common.go create mode 100644 internal/util/bool.go create mode 100644 internal/util/bool_test.go create mode 100644 internal/util/context_test.go create mode 100644 internal/util/db/where_in.go create mode 100644 internal/util/db/where_in_test.go create mode 100644 internal/util/int.go create mode 100644 internal/util/int_test.go create mode 100644 internal/util/lang.go create mode 100644 internal/util/lang_test.go create mode 100644 internal/util/mime/mime.go create mode 100644 internal/util/oauth2/pkce.go create mode 100644 internal/util/oauth2/pkce_test.go create mode 100644 internal/util/path.go create mode 100644 internal/util/path_test.go create mode 100644 test/testdata/snapshots/TestILikeSearchArgs.golden create mode 100644 test/testdata/snapshots/TestILikeSearchSQL.golden create mode 100644 test/testdata/snapshots/TestWhereInArgs.golden create mode 100644 test/testdata/snapshots/TestWhereInSQL.golden diff --git a/internal/api/httperrors/common.go b/internal/api/httperrors/common.go new file mode 100644 index 00000000..06efc7c5 --- /dev/null +++ b/internal/api/httperrors/common.go @@ -0,0 +1,7 @@ +package httperrors + +import "net/http" + +var ( + ErrBadRequestZeroFileSize = NewHTTPError(http.StatusBadRequest, "ZERO_FILE_SIZE", "File size of 0 is not supported.") +) diff --git a/internal/util/bool.go b/internal/util/bool.go new file mode 100644 index 00000000..39260e95 --- /dev/null +++ b/internal/util/bool.go @@ -0,0 +1,10 @@ +package util + +// FalseIfNil returns false if the passed pointer is nil. Passing a pointer to a bool will return the value of the bool. +func FalseIfNil(b *bool) bool { + if b == nil { + return false + } + + return *b +} diff --git a/internal/util/bool_test.go b/internal/util/bool_test.go new file mode 100644 index 00000000..5ccbff0f --- /dev/null +++ b/internal/util/bool_test.go @@ -0,0 +1,16 @@ +package util_test + +import ( + "testing" + + "allaboutapps.dev/aw/go-starter/internal/util" + "github.com/stretchr/testify/assert" +) + +func TestFalseIfNil(t *testing.T) { + b := true + assert.True(t, util.FalseIfNil(&b)) + b = false + assert.False(t, util.FalseIfNil(&b)) + assert.False(t, util.FalseIfNil(nil)) +} diff --git a/internal/util/context.go b/internal/util/context.go index 00ce4543..7ee19665 100644 --- a/internal/util/context.go +++ b/internal/util/context.go @@ -3,6 +3,7 @@ package util import ( "context" "errors" + "time" ) type contextKey string @@ -10,11 +11,27 @@ type contextKey string const ( CTXKeyUser contextKey = "user" CTXKeyAccessToken contextKey = "access_token" + CTXKeyCacheControl contextKey = "cache_control" CTXKeyRequestID contextKey = "request_id" CTXKeyDisableLogger contextKey = "disable_logger" - CTXKeyCacheControl contextKey = "cache_control" ) +type detachedContext struct { + parent context.Context +} + +func (c detachedContext) Deadline() (time.Time, bool) { return time.Time{}, false } +func (c detachedContext) Done() <-chan struct{} { return nil } +func (c detachedContext) Err() error { return nil } +func (c detachedContext) Value(key interface{}) interface{} { return c.parent.Value(key) } + +// DetachContext detaches a context by returning a wrapped struct implementing the context interface, but omitting the deadline, done and error functionality. +// Mainly used to pass context information to go routines that should not be cancelled by the context. +// ! USE THIS DETACHED CONTEXT SPARINGLY, ONLY IF ABSOLUTELY NEEDED. DO *NOT* KEEP USING A DETACHED CONTEXT FOR A PROLONGED TIME OUT OF CHAIN +func DetachContext(ctx context.Context) context.Context { + return detachedContext{ctx} +} + // RequestIDFromContext returns the ID of the (HTTP) request, returning an error if it is not present. func RequestIDFromContext(ctx context.Context) (string, error) { val := ctx.Value(CTXKeyRequestID) diff --git a/internal/util/context_test.go b/internal/util/context_test.go new file mode 100644 index 00000000..02315c94 --- /dev/null +++ b/internal/util/context_test.go @@ -0,0 +1,83 @@ +package util_test + +import ( + "context" + "testing" + "time" + + "allaboutapps.dev/aw/go-starter/internal/util" + "github.com/stretchr/testify/assert" +) + +type contextKey string + +func TestDetachContextWithCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var key contextKey = "test" + val := 42 + ctx2 := context.WithValue(ctx, key, val) + detachedContext := util.DetachContext(ctx2) + + cancel() + + select { + case <-ctx.Done(): + t.Log("Context cancelled") + default: + t.Error("Context is not canceled") + } + + select { + case <-ctx2.Done(): + t.Log("Context with value cancelled") + default: + t.Error("Context with value is not canceled") + } + + select { + case <-detachedContext.Done(): + t.Error("Detached context is cancelled") + default: + t.Log("Detached context is not cancelled") + } + + res := detachedContext.Value(key).(int) + assert.Equal(t, val, res) +} + +func TestDetachContextWithDeadline(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + defer cancel() + + var key contextKey = "test" + val := 42 + ctx2 := context.WithValue(ctx, key, val) + detachedContext := util.DetachContext(ctx2) + + time.Sleep(time.Second * 2) + + select { + case <-ctx.Done(): + t.Log("Context cancelled") + default: + t.Error("Context is not canceled") + } + + select { + case <-ctx2.Done(): + t.Log("Context with value cancelled") + default: + t.Error("Context with value is not canceled") + } + + select { + case <-detachedContext.Done(): + t.Error("Detached context is cancelled") + default: + t.Log("Detached context is not cancelled") + } + + res := detachedContext.Value(key).(int) + assert.Equal(t, val, res) +} diff --git a/internal/util/currency.go b/internal/util/currency.go index c4b383a3..a0a4b115 100644 --- a/internal/util/currency.go +++ b/internal/util/currency.go @@ -38,6 +38,10 @@ func Float64PtrToInt64WithCents(f *float64) int64 { return int64(swag.Float64Value(f) * 100) } +func Float64PToInt64WithCents(f float64) int64 { + return int64(f * 100) +} + func Float64PtrToIntPtrWithCents(f *float64) *int { if f == nil { return nil diff --git a/internal/util/currency_test.go b/internal/util/currency_test.go index 58c18b28..fe615ebd 100644 --- a/internal/util/currency_test.go +++ b/internal/util/currency_test.go @@ -46,6 +46,8 @@ func TestCurrencyConversion(t *testing.T) { res = util.IntPtrWithCentsToFloat64Ptr(&inInt) outInt := util.Float64PtrToIntPtrWithCents(res) assert.Equal(t, inInt, *outInt) + outInt2 := util.Float64PToInt64WithCents(*res) + assert.Equal(t, int64(inInt), outInt2) }) } } diff --git a/internal/util/db/db.go b/internal/util/db/db.go index 1d19f3f9..1b29cd8b 100644 --- a/internal/util/db/db.go +++ b/internal/util/db/db.go @@ -63,3 +63,25 @@ func NullFloat32FromFloat64Ptr(f *float64) null.Float32 { } return null.NewFloat32(float32(*f), true) } + +func NullIntFromInt16Ptr(i *int16) null.Int { + if i == nil { + return null.NewInt(0, false) + } + return null.NewInt(int(*i), true) +} + +func Int16PtrFromNullInt(i null.Int) *int16 { + if !i.Valid { + return nil + } + + res := int16(i.Int) + return &res +} + +func Int16PtrFromInt(i int) *int16 { + res := int16(i) + + return &res +} diff --git a/internal/util/db/db_test.go b/internal/util/db/db_test.go index 56266023..f101eb63 100644 --- a/internal/util/db/db_test.go +++ b/internal/util/db/db_test.go @@ -148,8 +148,8 @@ func TestWithTransactionWithPanic(t *testing.T) { } func TestDBTypeConversions(t *testing.T) { - i := int64(19) - res := db.NullIntFromInt64Ptr(&i) + i64 := int64(19) + res := db.NullIntFromInt64Ptr(&i64) assert.Equal(t, 19, res.Int) assert.True(t, res.Valid) @@ -163,4 +163,21 @@ func TestDBTypeConversions(t *testing.T) { res2 = db.NullFloat32FromFloat64Ptr(nil) assert.False(t, res2.Valid) + + i16 := int16(19) + res3 := db.NullIntFromInt16Ptr(&i16) + assert.Equal(t, 19, res3.Int) + assert.True(t, res3.Valid) + + res4 := db.Int16PtrFromNullInt(res3) + require.NotEmpty(t, res4) + assert.Equal(t, i16, *res4) + + res5 := db.Int16PtrFromNullInt(null.IntFromPtr(nil)) + assert.Empty(t, res5) + + i := 7 + res6 := db.Int16PtrFromInt(i) + require.NotEmpty(t, res6) + assert.Equal(t, i, int(*res6)) } diff --git a/internal/util/db/ilike.go b/internal/util/db/ilike.go index 3670aa06..3833e20a 100644 --- a/internal/util/db/ilike.go +++ b/internal/util/db/ilike.go @@ -2,11 +2,17 @@ package db import ( "fmt" + "regexp" "strings" "github.com/volatiletech/sqlboiler/v4/queries/qm" ) +var ( + likeQueryEscapeRegex = regexp.MustCompile(`(%|_)`) + likeQueryWhiteSpaceRegex = regexp.MustCompile(`\s+`) +) + // ILike returns a query mod containing a pre-formatted ILIKE clause. // The value provided is applied directly - to perform a wildcard search, // enclose the desired search value in `%` as desired before passing it @@ -19,3 +25,23 @@ func ILike(val string, path ...string) qm.QueryMod { // ! being inserted. On the contrary to other parts using PG queries, ? actually works with qm.Where. return qm.Where(fmt.Sprintf("%s ILIKE ?", strings.Join(path, ".")), val) } + +// ILikeSearch returns a query mod with one or multiple ILIKE clauses in an +// AND expression. +// The query is split on whitespace characters and for each word an escaped +// ILIKE with prefix and suffix wildcard will be generated. +func ILikeSearch(query string, path ...string) qm.QueryMod { + res := []qm.QueryMod{} + + terms := likeQueryWhiteSpaceRegex.Split(strings.TrimSpace(query), -1) + for _, t := range terms { + res = append(res, ILike("%"+EscapeLike(t)+"%", path...)) + } + + return qm.Expr(res...) +} + +// EscapeLike escapes a string to be placed in an ILIKE query. +func EscapeLike(val string) string { + return likeQueryEscapeRegex.ReplaceAllString(val, "\\$1") +} diff --git a/internal/util/db/ilike_test.go b/internal/util/db/ilike_test.go index 467474e3..36563406 100644 --- a/internal/util/db/ilike_test.go +++ b/internal/util/db/ilike_test.go @@ -6,6 +6,7 @@ import ( "allaboutapps.dev/aw/go-starter/internal/models" "allaboutapps.dev/aw/go-starter/internal/test" "allaboutapps.dev/aw/go-starter/internal/util/db" + "github.com/stretchr/testify/assert" "github.com/volatiletech/sqlboiler/v4/queries" "github.com/volatiletech/sqlboiler/v4/queries/qm" ) @@ -24,3 +25,22 @@ func TestILike(t *testing.T) { test.Snapshoter.Label("SQL").Save(t, sql) test.Snapshoter.Label("Args").Save(t, args) } + +func TestEscapeLike(t *testing.T) { + res := db.EscapeLike("%foo% _b%a_r%") + assert.Equal(t, "\\%foo\\% \\_b\\%a\\_r\\%", res) +} + +func TestILikeSearch(t *testing.T) { + q := models.NewQuery( + qm.Select("*"), + qm.From("users"), + db.InnerJoin("users", "id", "app_user_profiles", "user_id"), + db.ILikeSearch(" mus%ter m_ax ", "users", "username"), + ) + + sql, args := queries.BuildQuery(q) + + test.Snapshoter.Label("SQL").Save(t, sql) + test.Snapshoter.Label("Args").Save(t, args) +} diff --git a/internal/util/db/where_in.go b/internal/util/db/where_in.go new file mode 100644 index 00000000..b5c607b7 --- /dev/null +++ b/internal/util/db/where_in.go @@ -0,0 +1,16 @@ +package db + +import ( + "fmt" + + "github.com/volatiletech/sqlboiler/v4/queries/qm" +) + +// WhereIn is a copy from sqlboiler's WHERE IN query helpers since these don't get generated for nullable columns. +func WhereIn(tableName string, columnName string, slice []string) qm.QueryMod { + values := make([]interface{}, 0, len(slice)) + for _, value := range slice { + values = append(values, value) + } + return qm.WhereIn(fmt.Sprintf("%s.%s IN ?", tableName, columnName), values...) +} diff --git a/internal/util/db/where_in_test.go b/internal/util/db/where_in_test.go new file mode 100644 index 00000000..ae4f9ebd --- /dev/null +++ b/internal/util/db/where_in_test.go @@ -0,0 +1,25 @@ +package db_test + +import ( + "testing" + + "allaboutapps.dev/aw/go-starter/internal/models" + "allaboutapps.dev/aw/go-starter/internal/test" + "allaboutapps.dev/aw/go-starter/internal/util/db" + "github.com/volatiletech/sqlboiler/v4/queries" + "github.com/volatiletech/sqlboiler/v4/queries/qm" +) + +func TestWhereIn(t *testing.T) { + q := models.NewQuery( + qm.Select("*"), + qm.From("users"), + db.InnerJoin("users", "id", "app_user_profiles", "user_id"), + db.WhereIn("app_user_profiles", "username", []string{"max", "muster", "peter"}), + ) + + sql, args := queries.BuildQuery(q) + + test.Snapshoter.Label("SQL").Save(t, sql) + test.Snapshoter.Label("Args").Save(t, args) +} diff --git a/internal/util/env.go b/internal/util/env.go index 9b2d510b..bba8b589 100644 --- a/internal/util/env.go +++ b/internal/util/env.go @@ -6,6 +6,7 @@ import ( "strconv" "strings" "sync" + "time" "github.com/rs/zerolog/log" ) @@ -156,3 +157,23 @@ func GetMgmtSecret(envKey string) string { return mgmtSecret } + +func GetEnvAsLocation(key string, defaultVal string) *time.Location { + strVal := GetEnv(key, "") + + if len(strVal) == 0 { + l, err := time.LoadLocation(defaultVal) + if err != nil { + log.Panic().Str("key", key).Str("defaultVal", defaultVal).Err(err).Msg("Failed to parse default value for env variable as location") + } + + return l + } + + l, err := time.LoadLocation(strVal) + if err != nil { + log.Panic().Str("key", key).Str("strVal", strVal).Err(err).Msg("Failed to parse env variable as location") + } + + return l +} diff --git a/internal/util/env_test.go b/internal/util/env_test.go index b4d8cdf6..d98b0beb 100644 --- a/internal/util/env_test.go +++ b/internal/util/env_test.go @@ -5,6 +5,7 @@ import ( "net/url" "os" "testing" + "time" "allaboutapps.dev/aw/go-starter/internal/util" "github.com/stretchr/testify/assert" @@ -211,3 +212,20 @@ func TestGetMgmtSecretRandom(t *testing.T) { assert.Equal(t, expectedVal, val) } } + +func TestGetEnvAsLocation(t *testing.T) { + testVarKey := "TEST_ONLY_FOR_UNIT_TEST_LOCATION" + res := util.GetEnvAsLocation(testVarKey, "UTC") + assert.Equal(t, time.UTC, res) + + os.Setenv(testVarKey, "Local") + defer os.Unsetenv(testVarKey) + res = util.GetEnvAsLocation(testVarKey, "UTC") + assert.Equal(t, time.Local, res) + + vienna, err := time.LoadLocation("Europe/Vienna") + require.NoError(t, err) + os.Setenv(testVarKey, "Europe/Vienna") + res = util.GetEnvAsLocation(testVarKey, "UTC") + assert.Equal(t, vienna, res) +} diff --git a/internal/util/http.go b/internal/util/http.go index 90365b0d..983238b4 100644 --- a/internal/util/http.go +++ b/internal/util/http.go @@ -87,7 +87,7 @@ func BindAndValidateQueryParams(c echo.Context, v runtime.Validatable) error { // BindAndValidate binds the request, parsing path+query+body and validating these structs. // -// Deprecated: Use our dedicated BindAndValidate* mappers instead: +// De pre ca ted (bad word, the linter will cry!): Use our dedicated BindAndValidate* mappers instead: // BindAndValidateBody(c echo.Context, v runtime.Validatable) error // preferred // BindAndValidatePathAndQueryParams(c echo.Context, v runtime.Validatable) error // preferred // BindAndValidatePathParams(c echo.Context, v runtime.Validatable) error // rare usecases @@ -167,6 +167,11 @@ func ParseFileUpload(c echo.Context, formNameFile string, allowedMIMETypes []str return nil, nil, nil, err } + if fh.Size < 1 { + log.Debug().Err(err).Str("filename", fh.Filename).Int64("fileSize", fh.Size).Msg("File size can't be 0") + return nil, nil, nil, httperrors.ErrBadRequestZeroFileSize + } + mime, err := mimetype.DetectReader(file) if err != nil { log.Debug().Err(err).Str("filename", fh.Filename).Int64("fileSize", fh.Size).Msg("Failed to detect MIME type of uploaded file") @@ -231,7 +236,7 @@ func validatePayload(c echo.Context, v runtime.Validatable) error { case *errors.CompositeError: LogFromEchoContext(c).Debug().Errs("validation_errors", ee.Errors).Msg("Payload did match schema, returning HTTP validation error") - valErrs := formatValidationErrors(c.Request().Context(), ee) + valErrs := FormatValidationErrors(c.Request().Context(), ee) return httperrors.NewHTTPValidationError(http.StatusBadRequest, httperrors.HTTPErrorTypeGeneric, http.StatusText(http.StatusBadRequest), valErrs) case *errors.Validation: @@ -280,7 +285,7 @@ func defaultEchoBindAll(c echo.Context, v runtime.Validatable) (err error) { return binder.BindBody(c, v) } -func formatValidationErrors(ctx context.Context, err *errors.CompositeError) []*types.HTTPValidationErrorDetail { +func FormatValidationErrors(ctx context.Context, err *errors.CompositeError) []*types.HTTPValidationErrorDetail { valErrs := make([]*types.HTTPValidationErrorDetail, 0, len(err.Errors)) for _, e := range err.Errors { switch ee := e.(type) { @@ -291,7 +296,7 @@ func formatValidationErrors(ctx context.Context, err *errors.CompositeError) []* Error: swag.String(ee.Error()), }) case *errors.CompositeError: - valErrs = append(valErrs, formatValidationErrors(ctx, ee)...) + valErrs = append(valErrs, FormatValidationErrors(ctx, ee)...) default: LogFromContext(ctx).Warn().Err(e).Str("err_type", fmt.Sprintf("%T", e)).Msg("Received unknown error type while validating payload, skipping") } diff --git a/internal/util/http_test.go b/internal/util/http_test.go index 4b587665..c814032b 100644 --- a/internal/util/http_test.go +++ b/internal/util/http_test.go @@ -55,7 +55,7 @@ func TestBindAndValidateSuccess(t *testing.T) { res := test.PerformRequest(t, s, "POST", "/?test=true", testBody, nil) - assert.Equal(t, http.StatusOK, res.Result().StatusCode) + require.Equal(t, http.StatusOK, res.Result().StatusCode) var response types.PostLoginResponse test.ParseResponseAndValidate(t, res, &response) @@ -154,7 +154,7 @@ func prepareFileUpload(t *testing.T, filePath string) (*bytes.Buffer, string) { require.NoError(t, err) defer src.Close() - dst, err := writer.CreateFormFile("file", src.Name()) + dst, err := writer.CreateFormFile("file", filepath.Base(src.Name())) require.NoError(t, err) _, err = io.Copy(dst, src) diff --git a/internal/util/int.go b/internal/util/int.go new file mode 100644 index 00000000..29f852f2 --- /dev/null +++ b/internal/util/int.go @@ -0,0 +1,23 @@ +package util + +import "github.com/go-openapi/swag" + +func IntPtrToInt64Ptr(num *int) *int64 { + if num == nil { + return nil + } + + return swag.Int64(int64(*num)) +} + +func Int64PtrToIntPtr(num *int64) *int { + if num == nil { + return nil + } + + return swag.Int(int(*num)) +} + +func IntToInt32Ptr(num int) *int32 { + return swag.Int32(int32(num)) +} diff --git a/internal/util/int_test.go b/internal/util/int_test.go new file mode 100644 index 00000000..9418df25 --- /dev/null +++ b/internal/util/int_test.go @@ -0,0 +1,22 @@ +package util_test + +import ( + "testing" + + "allaboutapps.dev/aw/go-starter/internal/util" + "github.com/stretchr/testify/assert" +) + +func TestTypeConversions(t *testing.T) { + i := 19 + i64 := int64(i) + i32 := int32(i) + res := util.IntPtrToInt64Ptr(&i) + assert.Equal(t, &i64, res) + + res2 := util.Int64PtrToIntPtr(&i64) + assert.Equal(t, &i, res2) + + res3 := util.IntToInt32Ptr(i) + assert.Equal(t, &i32, res3) +} diff --git a/internal/util/lang.go b/internal/util/lang.go new file mode 100644 index 00000000..b4c1c82b --- /dev/null +++ b/internal/util/lang.go @@ -0,0 +1,22 @@ +package util + +import ( + "sort" + + "golang.org/x/text/collate" + "golang.org/x/text/language" +) + +// SortCollateStringSlice is used to sort a slice of strings if the language specific order of caracters is +// important for the order of the string. +// ! The slice passed will be changed. +func SortCollateStringSlice(slice []string, lang language.Tag, options ...collate.Option) { + if len(options) == 0 { + options = []collate.Option{collate.IgnoreCase, collate.IgnoreWidth} + } + coll := collate.New(lang, options...) + + sort.Slice(slice, func(i int, j int) bool { + return coll.CompareString(slice[i], slice[j]) < 0 + }) +} diff --git a/internal/util/lang_test.go b/internal/util/lang_test.go new file mode 100644 index 00000000..89aa1b90 --- /dev/null +++ b/internal/util/lang_test.go @@ -0,0 +1,34 @@ +package util_test + +import ( + "testing" + + "allaboutapps.dev/aw/go-starter/internal/util" + "github.com/stretchr/testify/assert" + "golang.org/x/text/collate" + "golang.org/x/text/language" +) + +func TestSortCollateStringGerman(t *testing.T) { + slice := []string{"a", "ä", "e", "ö", "u", "ü", "o"} + util.SortCollateStringSlice(slice, language.German) + + expected := []string{"a", "ä", "e", "o", "ö", "u", "ü"} + assert.Equal(t, expected, slice) +} + +func TestSortCollateStringEnglish(t *testing.T) { + slice := []string{"a", "ä", "e", "ö", "u", "ü", "o"} + util.SortCollateStringSlice(slice, language.English) + + expected := []string{"a", "ä", "e", "o", "ö", "u", "ü"} + assert.Equal(t, expected, slice) +} + +func TestSortCollateStringGermanAndOptions(t *testing.T) { + slice := []string{"a", "ä", "e", "ö", "u", "ü", "o"} + util.SortCollateStringSlice(slice, language.German, collate.IgnoreCase, collate.IgnoreWidth, collate.IgnoreDiacritics) + + expected := []string{"a", "ä", "e", "ö", "o", "u", "ü"} + assert.Equal(t, expected, slice) +} diff --git a/internal/util/mime/mime.go b/internal/util/mime/mime.go new file mode 100644 index 00000000..21481afb --- /dev/null +++ b/internal/util/mime/mime.go @@ -0,0 +1,29 @@ +package mime + +// MIME interface enables to use either *mimetype.MIME or KnownMIME as mimetype. +type MIME interface { + String() string + Extension() string + Is(expectedMIME string) bool +} + +// KnownMIME implements the MIME interface to be able to pass a *mimetype.MIME +// compatible value if the mimetype is already known so mimetype detection is not +// needed. It is therefore possible to skip mimetype detection if the mimetype is known +// or it is not possible to use a readSeeker but a mimetype is required. +type KnownMIME struct { + MimeType string + FileExtension string +} + +func (m *KnownMIME) String() string { + return m.MimeType +} + +func (m *KnownMIME) Extension() string { + return m.FileExtension +} + +func (m *KnownMIME) Is(expectedMIME string) bool { + return expectedMIME == m.MimeType +} diff --git a/internal/util/oauth2/pkce.go b/internal/util/oauth2/pkce.go new file mode 100644 index 00000000..bcc3361b --- /dev/null +++ b/internal/util/oauth2/pkce.go @@ -0,0 +1,32 @@ +package oauth2 + +import ( + "crypto/sha256" + "encoding/base64" + + "allaboutapps.dev/aw/go-starter/internal/util" +) + +func GetNewPKCECodeVerifier() (string, error) { + + // for details regarding possible characters in verifier, see: + // https://tools.ietf.org/html/rfc7636#section-4.1 + verifier, err := util.GenerateRandomString(128, []util.CharRange{util.CharRangeNumeric, util.CharRangeAlphaLowerCase, util.CharRangeAlphaUpperCase}, "-._~") + if err != nil { + return "", err + } + + return verifier, err +} + +func GetPKCECodeChallengeS256(verifier string) string { + + // for details regarding transformation of verifier to challenge see: + // https://tools.ietf.org/html/rfc7636#section-4.2 + // base64 encoding must be unpadded, URL encoding: + // https://tools.ietf.org/html/rfc7636#page-17 + sum := sha256.Sum256([]byte(verifier)) + b64 := base64.RawURLEncoding.EncodeToString(sum[:]) + + return b64 +} diff --git a/internal/util/oauth2/pkce_test.go b/internal/util/oauth2/pkce_test.go new file mode 100644 index 00000000..ed0fcd45 --- /dev/null +++ b/internal/util/oauth2/pkce_test.go @@ -0,0 +1,16 @@ +package oauth2_test + +import ( + "testing" + + "allaboutapps.dev/aw/go-starter/internal/util/oauth2" + "github.com/stretchr/testify/assert" +) + +func TestGetPKCECodeChallengeS256(t *testing.T) { + verifier := "U7MEZRmshzwIHRIGvF5iy6FLKgTtUHV0Vb0Hpczh6jJ_XZKcQIurow2LvsjG6hx2k57s9Pz8UmCZTvazosnniTM-z6EC.skJlQMGA~8ue3LMiOWdFYTfsLdX8GKol285" + expected := "Jg697bAjhzV1upYvV9R04784OFNVRAZh2IjeFlMJ8bE" + + challenge := oauth2.GetPKCECodeChallengeS256(verifier) + assert.Equal(t, expected, challenge) +} diff --git a/internal/util/path.go b/internal/util/path.go new file mode 100644 index 00000000..ce8702a7 --- /dev/null +++ b/internal/util/path.go @@ -0,0 +1,42 @@ +package util + +import ( + "path/filepath" + "strings" +) + +// FileNameWithoutExtension returns the name of the file referenced by the +// provided path without the file's extension. +// The function accepts a full (local) file path as well, only the latest +// element of the path will be considered as a name. +// If the provided path is empty or consists entirely of separators, an +// empty string will be returned. +func FileNameWithoutExtension(path string) string { + base := filepath.Base(path) + if base == "." { + return "" + } else if base == "/" { + return "" + } + + return strings.TrimSuffix(base, filepath.Ext(path)) +} + +// FileNameAndExtension returns the name of the file referenced by the +// provided path as well as its extension as separated strings. +// The function accepts a full (local) file path as well, only the latest +// element of the path will be considered as a name. +// If the provided path is empty or consists entirely of separators, +// empty strings will be returned. +func FileNameAndExtension(path string) (fileName string, extension string) { + base := filepath.Base(path) + if base == "." { + return "", "" + } else if base == "/" { + return "", "" + } + + extension = filepath.Ext(path) + + return strings.TrimSuffix(base, extension), extension +} diff --git a/internal/util/path_test.go b/internal/util/path_test.go new file mode 100644 index 00000000..7435e90b --- /dev/null +++ b/internal/util/path_test.go @@ -0,0 +1,42 @@ +package util_test + +import ( + "testing" + + "allaboutapps.dev/aw/go-starter/internal/util" + "github.com/stretchr/testify/assert" +) + +func TestGetFileNameWithoutExtension(t *testing.T) { + assert.Equal(t, "example", util.FileNameWithoutExtension("/a/b/c/d/example.jpg")) + assert.Equal(t, "example", util.FileNameWithoutExtension("example.jpg")) + assert.Equal(t, "example_test-check", util.FileNameWithoutExtension("example_test-check.jpg")) + assert.Equal(t, "example", util.FileNameWithoutExtension("example")) + assert.Equal(t, "", util.FileNameWithoutExtension("")) + assert.Equal(t, "", util.FileNameWithoutExtension(".")) + assert.Equal(t, "", util.FileNameWithoutExtension("///")) +} + +func TestFileNameAndExtension(t *testing.T) { + name, ext := util.FileNameAndExtension("/a/b/c/d/example.jpg") + assert.Equal(t, "example", name) + assert.Equal(t, ".jpg", ext) + name, ext = util.FileNameAndExtension("example.jpg") + assert.Equal(t, "example", name) + assert.Equal(t, ".jpg", ext) + name, ext = util.FileNameAndExtension("example_test-check.jpg") + assert.Equal(t, "example_test-check", name) + assert.Equal(t, ".jpg", ext) + name, ext = util.FileNameAndExtension("example") + assert.Equal(t, "example", name) + assert.Empty(t, ext) + name, ext = util.FileNameAndExtension("") + assert.Empty(t, name) + assert.Empty(t, ext) + name, ext = util.FileNameAndExtension(".") + assert.Empty(t, name) + assert.Empty(t, ext) + name, ext = util.FileNameAndExtension("///") + assert.Empty(t, name) + assert.Empty(t, ext) +} diff --git a/internal/util/string.go b/internal/util/string.go index 42e83382..f73ee153 100644 --- a/internal/util/string.go +++ b/internal/util/string.go @@ -5,7 +5,14 @@ import ( "encoding/base64" "encoding/hex" "errors" + "regexp" "strings" + + "github.com/go-openapi/swag" +) + +var ( + StringSpaceReplacer = regexp.MustCompile(`\s+`) ) // GenerateRandomBytes returns n random bytes securely generated using the system's default CSPRNG. @@ -119,3 +126,66 @@ func GenerateRandomString(n int, ranges []CharRange, extra string) (string, erro return str.String(), nil } + +// NonEmptyOrNil returns a pointer to passed string if it is not empty. Passing empty strings returns nil instead. +func NonEmptyOrNil(s string) *string { + if len(s) > 0 { + return swag.String(s) + } + + return nil +} + +// EmptyIfNil returns an empty string if the passed pointer is nil. Passing a pointer to a string will return the value of the string. +func EmptyIfNil(s *string) string { + if s == nil { + return "" + } + + return *s +} + +// ContainsAll returns true if a string (str) contains all substrings (sub) +func ContainsAll(str string, sub ...string) bool { + subLen := len(sub) + contains := make([]bool, subLen) + indices := make([]int, subLen) + runes := make([][]rune, subLen) + for i, s := range sub { + runes[i] = []rune(s) + } + + for _, marked := range str { + for i, r := range runes { + if !contains[i] && marked == r[indices[i]] { + indices[i]++ + if indices[i] >= len(r) { + contains[i] = true + } + } + } + } + + for _, c := range contains { + if !c { + return false + } + } + + return true +} + +// StringSliceEquals returns only true if two string slices have the same +// strings in the same order. +func StringSliceEquals(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + + for i, v := range a { + if v != b[i] { + return false + } + } + return true +} diff --git a/internal/util/string_test.go b/internal/util/string_test.go index 1a3c9ded..ad01f82a 100644 --- a/internal/util/string_test.go +++ b/internal/util/string_test.go @@ -56,3 +56,38 @@ func TestGenerateRandom(t *testing.T) { assert.Len(t, randString, 8) assert.Equal(t, "aaaaaaaa", randString) } + +func TestNonEmptyOrNil(t *testing.T) { + assert.Equal(t, "test", *util.NonEmptyOrNil("test")) + assert.Equal(t, (*string)(nil), util.NonEmptyOrNil("")) +} + +func TestContainsAll(t *testing.T) { + assert.True(t, util.ContainsAll("Lorem ipsum dolor sit amet, consectetur adipiscing elit.", "dolor")) + assert.False(t, util.ContainsAll("Lorem ipsum dolor sit amet, consectetur adipiscing elit.", "dolorx")) + + assert.True(t, util.ContainsAll("Lorem ipsum dolor sit amet, consectetur adipiscing elit.", ".", "sit", "elit", "ipsum", "Lorem ipsum")) + assert.False(t, util.ContainsAll("Lorem ipsum dolor sit amet, consectetur adipiscing elit.", ".", "sit", "elit", "ipsum", " Lorem")) + + assert.True(t, util.ContainsAll("Lorem ipsum dolor sit amet, ÄÜiö consectetur adipiscing elit.", "ÄÜiö c")) +} + +func TestEmptyIfNil(t *testing.T) { + s := "Lorem ipsum" + assert.Equal(t, s, util.EmptyIfNil(&s)) + assert.Equal(t, "", util.EmptyIfNil(nil)) +} + +func TestStringSliceEquals(t *testing.T) { + a := []string{"a", "b", "c"} + b := []string{"a", "b", "c"} + assert.True(t, util.StringSliceEquals(a, b)) + + b[0] = "b" + b[1] = "a" + assert.False(t, util.StringSliceEquals(a, b)) + + b = a + b = append(b, "x") + assert.False(t, util.StringSliceEquals(a, b)) +} diff --git a/internal/util/time.go b/internal/util/time.go index 869d84df..63b3b3ae 100644 --- a/internal/util/time.go +++ b/internal/util/time.go @@ -1,6 +1,10 @@ package util -import "time" +import ( + "time" + + "github.com/go-openapi/swag" +) const ( DateFormat = "2006-01-02" @@ -19,10 +23,18 @@ func EndOfMonth(d time.Time) time.Time { return time.Date(d.Year(), d.Month()+1, 1, 0, 0, 0, -1, d.Location()) } +func EndOfPreviousMonth(d time.Time) time.Time { + return time.Date(d.Year(), d.Month(), 1, 0, 0, 0, -1, d.Location()) +} + func EndOfDay(d time.Time) time.Time { return time.Date(d.Year(), d.Month(), d.Day()+1, 0, 0, 0, -1, d.Location()) } +func StartOfDay(d time.Time) time.Time { + return time.Date(d.Year(), d.Month(), d.Day(), 0, 0, 0, 0, d.Location()) +} + func StartOfMonth(d time.Time) time.Time { return time.Date(d.Year(), d.Month(), 1, 0, 0, 0, 0, d.Location()) } @@ -63,3 +75,23 @@ func DayBefore(d time.Time) time.Time { func TruncateTime(d time.Time) time.Time { return time.Date(d.Year(), d.Month(), d.Day(), 0, 0, 0, 0, d.Location()) } + +// MaxTime returns the latest time.Time of the given params +func MaxTime(times ...time.Time) time.Time { + var max time.Time + for _, t := range times { + if t.After(max) { + max = t + } + } + return max +} + +// NonZeroTimeOrNil returns a pointer to passed time if it is not a zero time. Passing zero/uninitialised time returns nil instead. +func NonZeroTimeOrNil(t time.Time) *time.Time { + if t.IsZero() { + return nil + } + + return swag.Time(t) +} diff --git a/internal/util/time_test.go b/internal/util/time_test.go index af684a83..045b703e 100644 --- a/internal/util/time_test.go +++ b/internal/util/time_test.go @@ -90,6 +90,31 @@ func TestEndOfMonth(t *testing.T) { } +func TestEndOfPreviousMonth(t *testing.T) { + d := util.Date(2020, 3, 12, time.UTC) + expected := time.Date(2020, 2, 29, 23, 59, 59, 999999999, time.UTC) + assert.True(t, expected.Equal(util.EndOfPreviousMonth(d))) + + d = util.Date(2020, 12, 35, time.UTC) + expected = time.Date(2020, 12, 31, 23, 59, 59, 999999999, time.UTC) + res := util.EndOfPreviousMonth(d) + assert.True(t, expected.Equal(res)) + + expected = time.Date(2020, 12, 31, 0, 0, 0, 0, time.UTC) + assert.Equal(t, expected, util.TruncateTime(res)) +} + +func TestStartOfDay(t *testing.T) { + d := time.Date(2020, 3, 12, 23, 59, 59, 999999999, time.UTC) + expected := util.Date(2020, 3, 12, time.UTC) + assert.True(t, expected.Equal(util.StartOfDay(d))) + + d = time.Date(2021, 1, 4, 23, 59, 59, 999999999, time.UTC) + expected = util.Date(2020, 12, 35, time.UTC) + res := util.StartOfDay(d) + assert.True(t, expected.Equal(res)) +} + func TestEndOfDay(t *testing.T) { d := util.Date(2020, 3, 12, time.UTC) expected := time.Date(2020, 3, 12, 23, 59, 59, 999999999, time.UTC) @@ -129,5 +154,22 @@ func TestDayBefore(t *testing.T) { expected = time.Date(2020, 2, 29, 0, 0, 0, 0, time.UTC) assert.Equal(t, expected, util.TruncateTime(res)) +} + +func TestMaxTime(t *testing.T) { + a := time.Date(2022, 4, 12, 0, 0, 0, 1, time.UTC) + b := time.Date(2022, 4, 12, 0, 0, 0, 2, time.UTC) + c := time.Date(2022, 4, 12, 0, 0, 0, 0, time.UTC) + max := util.MaxTime(a, b, c) + assert.Equal(t, b, max) +} + +func TestNonZeroTimeOrNil(t *testing.T) { + d := time.Time{} + res := util.NonZeroTimeOrNil(d) + assert.Empty(t, res) + d = util.Date(2021, 7, 2, time.UTC) + res = util.NonZeroTimeOrNil(d) + assert.Equal(t, &d, res) } diff --git a/test/testdata/snapshots/TestILikeSearchArgs.golden b/test/testdata/snapshots/TestILikeSearchArgs.golden new file mode 100644 index 00000000..29eeabe2 --- /dev/null +++ b/test/testdata/snapshots/TestILikeSearchArgs.golden @@ -0,0 +1,4 @@ +([]interface {}) (len=2) { + (string) (len=10) "%mus\\%ter%", + (string) (len=7) "%m\\_ax%" +} diff --git a/test/testdata/snapshots/TestILikeSearchSQL.golden b/test/testdata/snapshots/TestILikeSearchSQL.golden new file mode 100644 index 00000000..e7c70e42 --- /dev/null +++ b/test/testdata/snapshots/TestILikeSearchSQL.golden @@ -0,0 +1 @@ +(string) (len=149) "SELECT * FROM \"users\" INNER JOIN app_user_profiles ON app_user_profiles.user_id=users.id WHERE (users.username ILIKE $1 AND users.username ILIKE $2);" diff --git a/test/testdata/snapshots/TestWhereInArgs.golden b/test/testdata/snapshots/TestWhereInArgs.golden new file mode 100644 index 00000000..4ef66434 --- /dev/null +++ b/test/testdata/snapshots/TestWhereInArgs.golden @@ -0,0 +1,5 @@ +([]interface {}) (len=3) { + (string) (len=3) "max", + (string) (len=6) "muster", + (string) (len=5) "peter" +} diff --git a/test/testdata/snapshots/TestWhereInSQL.golden b/test/testdata/snapshots/TestWhereInSQL.golden new file mode 100644 index 00000000..0996e121 --- /dev/null +++ b/test/testdata/snapshots/TestWhereInSQL.golden @@ -0,0 +1 @@ +(string) (len=142) "SELECT * FROM \"users\" INNER JOIN app_user_profiles ON app_user_profiles.user_id=users.id WHERE (\"app_user_profiles\".\"username\" IN ($1,$2,$3));" From e5d3e6ad2717cd969d6f33882e88167c283da727 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Thu, 21 Oct 2021 10:40:44 +0200 Subject: [PATCH 02/15] Add missing golang/text package. --- go.mod | 1 + internal/util/context_test.go | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 09f4fa8d..ddfc8293 100644 --- a/go.mod +++ b/go.mod @@ -32,5 +32,6 @@ require ( github.com/ziutek/mymysql v1.5.4 // indirect golang.org/x/crypto v0.0.0-20210921155107-089bfa567519 golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6 + golang.org/x/text v0.3.7 google.golang.org/api v0.57.0 ) diff --git a/internal/util/context_test.go b/internal/util/context_test.go index 02315c94..d479d4f0 100644 --- a/internal/util/context_test.go +++ b/internal/util/context_test.go @@ -47,7 +47,7 @@ func TestDetachContextWithCancel(t *testing.T) { } func TestDetachContextWithDeadline(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), time.Second*2) + ctx, cancel := context.WithTimeout(context.Background(), time.Second*1) defer cancel() var key contextKey = "test" From ebad9d0fdf5a79190912c5f039c044136f2f7c85 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Thu, 21 Oct 2021 10:43:57 +0200 Subject: [PATCH 03/15] Add starttls support to mailer - Extend mailer mock to support waiting for all expected mails to arrive to test mails sent in go routines --- internal/config/server_config.go | 14 ++++----- internal/mailer/transport/mock.go | 26 +++++++++++++++-- internal/mailer/transport/smtp.go | 11 +++++-- internal/mailer/transport/smtp_config.go | 37 +++++++++++++++++++----- 4 files changed, 68 insertions(+), 20 deletions(-) diff --git a/internal/config/server_config.go b/internal/config/server_config.go index d4886786..284874a6 100644 --- a/internal/config/server_config.go +++ b/internal/config/server_config.go @@ -185,13 +185,13 @@ func DefaultServiceConfigFromEnv() Server { Transporter: util.GetEnvEnum("SERVER_MAILER_TRANSPORTER", MailerTransporterMock.String(), []string{MailerTransporterSMTP.String(), MailerTransporterMock.String()}), }, SMTP: transport.SMTPMailTransportConfig{ - Host: util.GetEnv("SERVER_SMTP_HOST", "mailhog"), - Port: util.GetEnvAsInt("SERVER_SMTP_PORT", 1025), - Username: util.GetEnv("SERVER_SMTP_USERNAME", ""), - Password: util.GetEnv("SERVER_SMTP_PASSWORD", ""), - AuthType: transport.SMTPAuthTypeFromString(util.GetEnv("SERVER_SMTP_AUTH_TYPE", transport.SMTPAuthTypeNone.String())), - UseTLS: util.GetEnvAsBool("SERVER_SMTP_USE_TLS", false), - TLSConfig: nil, + Host: util.GetEnv("SERVER_SMTP_HOST", "mailhog"), + Port: util.GetEnvAsInt("SERVER_SMTP_PORT", 1025), + Username: util.GetEnv("SERVER_SMTP_USERNAME", ""), + Password: util.GetEnv("SERVER_SMTP_PASSWORD", ""), + AuthType: transport.SMTPAuthTypeFromString(util.GetEnv("SERVER_SMTP_AUTH_TYPE", transport.SMTPAuthTypeNone.String())), + Encryption: transport.SMTPEncryption(util.GetEnvEnum("SERVER_SMTP_ENCRYPTION", transport.SMTPEncryptionNone.String(), []string{transport.SMTPEncryptionNone.String(), transport.SMTPEncryptionTLS.String(), transport.SMTPEncryptionStartTLS.String()})), + TLSConfig: nil, }, Frontend: FrontendServer{ BaseURL: util.GetEnv("SERVER_FRONTEND_BASE_URL", "http://localhost:3000"), diff --git a/internal/mailer/transport/mock.go b/internal/mailer/transport/mock.go index e02cedc8..ecbaf675 100644 --- a/internal/mailer/transport/mock.go +++ b/internal/mailer/transport/mock.go @@ -2,19 +2,24 @@ package transport import ( "sync" + "time" + "allaboutapps.dev/aw/go-starter/internal/util" "github.com/jordan-wright/email" ) type MockMailTransport struct { sync.RWMutex - mails []*email.Email + mails []*email.Email + OnMailSent func(mail email.Email) // non pointer to prevent concurrent read errors + wg sync.WaitGroup } func NewMock() *MockMailTransport { return &MockMailTransport{ - RWMutex: sync.RWMutex{}, - mails: make([]*email.Email, 0), + RWMutex: sync.RWMutex{}, + mails: make([]*email.Email, 0), + OnMailSent: func(mail email.Email) {}, } } @@ -23,6 +28,7 @@ func (m *MockMailTransport) Send(mail *email.Email) error { defer m.Unlock() m.mails = append(m.mails, mail) + m.OnMailSent(*mail) return nil } @@ -44,3 +50,17 @@ func (m *MockMailTransport) GetSentMails() []*email.Email { return m.mails } + +// Expect adds the mailCnt to a waitgroup and sets the OnMailSent callback +// to call wg.Done() +func (m *MockMailTransport) Expect(mailCnt int) { + m.wg.Add(mailCnt) + m.OnMailSent = func(email.Email) { + m.wg.Done() + } +} + +// Wait until all expected mails have arrived +func (m *MockMailTransport) Wait() { + _ = util.WaitTimeout(&m.wg, time.Second*10) +} diff --git a/internal/mailer/transport/smtp.go b/internal/mailer/transport/smtp.go index 8f2477dc..d4d62b45 100644 --- a/internal/mailer/transport/smtp.go +++ b/internal/mailer/transport/smtp.go @@ -34,9 +34,14 @@ func NewSMTP(config SMTPMailTransportConfig) *SMTPMailTransport { } func (m *SMTPMailTransport) Send(mail *email.Email) error { - if m.config.UseTLS { + switch m.config.Encryption { + case SMTPEncryptionNone: + return mail.Send(m.addr, m.auth) + case SMTPEncryptionTLS: return mail.SendWithTLS(m.addr, m.auth, m.config.TLSConfig) + case SMTPEncryptionStartTLS: + return mail.SendWithStartTLS(m.addr, m.auth, m.config.TLSConfig) + default: + return fmt.Errorf("invalid SMTP encryption %q", m.config.Encryption) } - - return mail.Send(m.addr, m.auth) } diff --git a/internal/mailer/transport/smtp_config.go b/internal/mailer/transport/smtp_config.go index 16483993..c2dbcc35 100644 --- a/internal/mailer/transport/smtp_config.go +++ b/internal/mailer/transport/smtp_config.go @@ -47,12 +47,35 @@ func SMTPAuthTypeFromString(s string) SMTPAuthType { } } +type SMTPEncryption string + +const ( + SMTPEncryptionNone SMTPEncryption = "none" + SMTPEncryptionTLS SMTPEncryption = "tls" + SMTPEncryptionStartTLS SMTPEncryption = "starttls" +) + +func (e SMTPEncryption) String() string { + return string(e) +} + +func SMTPEncryptionFromString(s string) SMTPEncryption { + switch strings.ToLower(s) { + case "tls": + return SMTPEncryptionTLS + case "starttls": + return SMTPEncryptionStartTLS + default: + return SMTPEncryptionNone + } +} + type SMTPMailTransportConfig struct { - Host string - Port int - AuthType SMTPAuthType `json:"-"` // iota - Username string - Password string `json:"-"` // sensitive - UseTLS bool - TLSConfig *tls.Config `json:"-"` // pointer + Host string + Port int + AuthType SMTPAuthType `json:"-"` // iota + Username string + Password string `json:"-"` // sensitive + Encryption SMTPEncryption `json:"-"` // iota + TLSConfig *tls.Config `json:"-"` // pointer } From fa21af3cac300c943820e89882da246e95fc3769 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Thu, 21 Oct 2021 17:05:19 +0200 Subject: [PATCH 04/15] update changelog --- CHANGELOG.md | 6 ++++++ go.mod | 1 - 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 7a65c895..db7fa3bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,12 @@ ## Unreleased ### Changed +- Extend util package with additional helper functions. +- Add MIME interface to use *mimtype.MIME or an already KnownMIME. +- Add function to detach context to avoid context cancelation. Can be used to pass context information to go routines without a deadline or cancel. +- Add oauth2 helper for PKCE extention to generate verifier and challenge. +- Add starttls support to mailer. +- Extend mailer mock to support waiting for all expected mails to arrive to check asynchronously sent mails in tests. ## 2021-10-19 diff --git a/go.mod b/go.mod index 06e79e90..c578ac6c 100644 --- a/go.mod +++ b/go.mod @@ -102,7 +102,6 @@ require ( go.opencensus.io v0.23.0 // indirect golang.org/x/net v0.0.0-20210913180222-943fd674d43e // indirect golang.org/x/oauth2 v0.0.0-20210819190943-2bc19b11175f // indirect - golang.org/x/text v0.3.7 // indirect golang.org/x/time v0.0.0-20201208040808-7e3f01d25324 // indirect golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect google.golang.org/appengine v1.6.7 // indirect From 63e470b59b74c57d76394f1fbf782c75f0c768b2 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Thu, 21 Oct 2021 17:25:28 +0200 Subject: [PATCH 05/15] update CHANGELOG with information about breaking mail tls change --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index db7fa3bc..94fd065e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ - Add MIME interface to use *mimtype.MIME or an already KnownMIME. - Add function to detach context to avoid context cancelation. Can be used to pass context information to go routines without a deadline or cancel. - Add oauth2 helper for PKCE extention to generate verifier and challenge. -- Add starttls support to mailer. +- **BREAKING** Add starttls support to mailer. If you were using the `SERVER_SMTP_USE_TLS` flag before to enable TLS you need to change it to the `SERVER_SMTP_ENCRYPTION` setting and set it to `tls`. - Extend mailer mock to support waiting for all expected mails to arrive to check asynchronously sent mails in tests. ## 2021-10-19 From baccba790847c170ce27bbb7bc20937079463b76 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Fri, 22 Oct 2021 13:36:13 +0200 Subject: [PATCH 06/15] Fix panic in ContainsAll util function. --- internal/util/string.go | 3 +++ internal/util/string_test.go | 3 +++ 2 files changed, 6 insertions(+) diff --git a/internal/util/string.go b/internal/util/string.go index f73ee153..323107be 100644 --- a/internal/util/string.go +++ b/internal/util/string.go @@ -157,6 +157,9 @@ func ContainsAll(str string, sub ...string) bool { for _, marked := range str { for i, r := range runes { + if len(r) == 0 { + contains[i] = true + } if !contains[i] && marked == r[indices[i]] { indices[i]++ if indices[i] >= len(r) { diff --git a/internal/util/string_test.go b/internal/util/string_test.go index ad01f82a..56b609df 100644 --- a/internal/util/string_test.go +++ b/internal/util/string_test.go @@ -70,6 +70,9 @@ func TestContainsAll(t *testing.T) { assert.False(t, util.ContainsAll("Lorem ipsum dolor sit amet, consectetur adipiscing elit.", ".", "sit", "elit", "ipsum", " Lorem")) assert.True(t, util.ContainsAll("Lorem ipsum dolor sit amet, ÄÜiö consectetur adipiscing elit.", "ÄÜiö c")) + + assert.False(t, util.ContainsAll("", "ÄÜiö c")) + assert.True(t, util.ContainsAll("Lorem ipsum dolor sit amet, ÄÜiö consectetur adipiscing elit.", "")) } func TestEmptyIfNil(t *testing.T) { From 8c7b77a9dc413420a62e4e594faa97c2aed6d655 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Mon, 25 Oct 2021 10:59:56 +0200 Subject: [PATCH 07/15] Add test to mime package and verify interface compliance for mimetype.MIME Fixed names of some util methods --- internal/util/currency.go | 2 +- internal/util/currency_test.go | 2 +- internal/util/http.go | 8 ++++---- internal/util/mime/mime.go | 4 ++++ internal/util/mime/mime_test.go | 30 ++++++++++++++++++++++++++++++ internal/util/oauth2/pkce.go | 2 +- 6 files changed, 41 insertions(+), 7 deletions(-) create mode 100644 internal/util/mime/mime_test.go diff --git a/internal/util/currency.go b/internal/util/currency.go index a0a4b115..f5d5e57f 100644 --- a/internal/util/currency.go +++ b/internal/util/currency.go @@ -38,7 +38,7 @@ func Float64PtrToInt64WithCents(f *float64) int64 { return int64(swag.Float64Value(f) * 100) } -func Float64PToInt64WithCents(f float64) int64 { +func Float64ToInt64WithCents(f float64) int64 { return int64(f * 100) } diff --git a/internal/util/currency_test.go b/internal/util/currency_test.go index fe615ebd..abd03133 100644 --- a/internal/util/currency_test.go +++ b/internal/util/currency_test.go @@ -46,7 +46,7 @@ func TestCurrencyConversion(t *testing.T) { res = util.IntPtrWithCentsToFloat64Ptr(&inInt) outInt := util.Float64PtrToIntPtrWithCents(res) assert.Equal(t, inInt, *outInt) - outInt2 := util.Float64PToInt64WithCents(*res) + outInt2 := util.Float64ToInt64WithCents(*res) assert.Equal(t, int64(inInt), outInt2) }) } diff --git a/internal/util/http.go b/internal/util/http.go index 983238b4..0345a193 100644 --- a/internal/util/http.go +++ b/internal/util/http.go @@ -87,7 +87,7 @@ func BindAndValidateQueryParams(c echo.Context, v runtime.Validatable) error { // BindAndValidate binds the request, parsing path+query+body and validating these structs. // -// De pre ca ted (bad word, the linter will cry!): Use our dedicated BindAndValidate* mappers instead: +// Deprecated: Use our dedicated BindAndValidate* mappers instead: // BindAndValidateBody(c echo.Context, v runtime.Validatable) error // preferred // BindAndValidatePathAndQueryParams(c echo.Context, v runtime.Validatable) error // preferred // BindAndValidatePathParams(c echo.Context, v runtime.Validatable) error // rare usecases @@ -236,7 +236,7 @@ func validatePayload(c echo.Context, v runtime.Validatable) error { case *errors.CompositeError: LogFromEchoContext(c).Debug().Errs("validation_errors", ee.Errors).Msg("Payload did match schema, returning HTTP validation error") - valErrs := FormatValidationErrors(c.Request().Context(), ee) + valErrs := formatValidationErrors(c.Request().Context(), ee) return httperrors.NewHTTPValidationError(http.StatusBadRequest, httperrors.HTTPErrorTypeGeneric, http.StatusText(http.StatusBadRequest), valErrs) case *errors.Validation: @@ -285,7 +285,7 @@ func defaultEchoBindAll(c echo.Context, v runtime.Validatable) (err error) { return binder.BindBody(c, v) } -func FormatValidationErrors(ctx context.Context, err *errors.CompositeError) []*types.HTTPValidationErrorDetail { +func formatValidationErrors(ctx context.Context, err *errors.CompositeError) []*types.HTTPValidationErrorDetail { valErrs := make([]*types.HTTPValidationErrorDetail, 0, len(err.Errors)) for _, e := range err.Errors { switch ee := e.(type) { @@ -296,7 +296,7 @@ func FormatValidationErrors(ctx context.Context, err *errors.CompositeError) []* Error: swag.String(ee.Error()), }) case *errors.CompositeError: - valErrs = append(valErrs, FormatValidationErrors(ctx, ee)...) + valErrs = append(valErrs, formatValidationErrors(ctx, ee)...) default: LogFromContext(ctx).Warn().Err(e).Str("err_type", fmt.Sprintf("%T", e)).Msg("Received unknown error type while validating payload, skipping") } diff --git a/internal/util/mime/mime.go b/internal/util/mime/mime.go index 21481afb..bec30cda 100644 --- a/internal/util/mime/mime.go +++ b/internal/util/mime/mime.go @@ -1,5 +1,9 @@ package mime +import "github.com/gabriel-vasile/mimetype" + +var _ MIME = (*mimetype.MIME)(nil) + // MIME interface enables to use either *mimetype.MIME or KnownMIME as mimetype. type MIME interface { String() string diff --git a/internal/util/mime/mime_test.go b/internal/util/mime/mime_test.go new file mode 100644 index 00000000..6dca618e --- /dev/null +++ b/internal/util/mime/mime_test.go @@ -0,0 +1,30 @@ +package mime_test + +import ( + "path/filepath" + "testing" + + "allaboutapps.dev/aw/go-starter/internal/util" + "allaboutapps.dev/aw/go-starter/internal/util/mime" + "github.com/gabriel-vasile/mimetype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestKnownMIME(t *testing.T) { + filePath := filepath.Join(util.GetProjectRootDir(), "test", "testdata", "example.jpg") + + var detectedMIME mime.MIME + var err error + detectedMIME, err = mimetype.DetectFile(filePath) + require.NoError(t, err) + + var knownMIME mime.MIME = &mime.KnownMIME{ + MimeType: "image/jpeg", + FileExtension: ".jpg", + } + + assert.Equal(t, detectedMIME.Extension(), knownMIME.Extension()) + assert.Equal(t, detectedMIME.String(), knownMIME.String()) + assert.True(t, knownMIME.Is(detectedMIME.String())) +} diff --git a/internal/util/oauth2/pkce.go b/internal/util/oauth2/pkce.go index bcc3361b..75858c77 100644 --- a/internal/util/oauth2/pkce.go +++ b/internal/util/oauth2/pkce.go @@ -7,7 +7,7 @@ import ( "allaboutapps.dev/aw/go-starter/internal/util" ) -func GetNewPKCECodeVerifier() (string, error) { +func GetPKCECodeVerifier() (string, error) { // for details regarding possible characters in verifier, see: // https://tools.ietf.org/html/rfc7636#section-4.1 From 6828eb09f36685b18b55013ae19801531caee9c2 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Mon, 25 Oct 2021 11:17:56 +0200 Subject: [PATCH 08/15] Changebreaking tls setting to deprecation by supporting both env variables --- CHANGELOG.md | 2 +- internal/config/server_config.go | 1 + internal/mailer/transport/smtp.go | 4 ++++ internal/mailer/transport/smtp_config.go | 1 + 4 files changed, 7 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 94fd065e..223a3874 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,7 @@ - Add MIME interface to use *mimtype.MIME or an already KnownMIME. - Add function to detach context to avoid context cancelation. Can be used to pass context information to go routines without a deadline or cancel. - Add oauth2 helper for PKCE extention to generate verifier and challenge. -- **BREAKING** Add starttls support to mailer. If you were using the `SERVER_SMTP_USE_TLS` flag before to enable TLS you need to change it to the `SERVER_SMTP_ENCRYPTION` setting and set it to `tls`. +- **DEPRECATED** Add starttls support to mailer. If you were using the `SERVER_SMTP_USE_TLS` flag before to enable TLS you need to change it to the `SERVER_SMTP_ENCRYPTION` setting and set it to `tls`. - Extend mailer mock to support waiting for all expected mails to arrive to check asynchronously sent mails in tests. ## 2021-10-19 diff --git a/internal/config/server_config.go b/internal/config/server_config.go index 284874a6..b13c9227 100644 --- a/internal/config/server_config.go +++ b/internal/config/server_config.go @@ -191,6 +191,7 @@ func DefaultServiceConfigFromEnv() Server { Password: util.GetEnv("SERVER_SMTP_PASSWORD", ""), AuthType: transport.SMTPAuthTypeFromString(util.GetEnv("SERVER_SMTP_AUTH_TYPE", transport.SMTPAuthTypeNone.String())), Encryption: transport.SMTPEncryption(util.GetEnvEnum("SERVER_SMTP_ENCRYPTION", transport.SMTPEncryptionNone.String(), []string{transport.SMTPEncryptionNone.String(), transport.SMTPEncryptionTLS.String(), transport.SMTPEncryptionStartTLS.String()})), + UseTLS: util.GetEnvAsBool("SERVER_SMTP_USE_TLS", false), TLSConfig: nil, }, Frontend: FrontendServer{ diff --git a/internal/mailer/transport/smtp.go b/internal/mailer/transport/smtp.go index d4d62b45..c6533220 100644 --- a/internal/mailer/transport/smtp.go +++ b/internal/mailer/transport/smtp.go @@ -34,6 +34,10 @@ func NewSMTP(config SMTPMailTransportConfig) *SMTPMailTransport { } func (m *SMTPMailTransport) Send(mail *email.Email) error { + if m.config.UseTLS { + return mail.SendWithTLS(m.addr, m.auth, m.config.TLSConfig) + } + switch m.config.Encryption { case SMTPEncryptionNone: return mail.Send(m.addr, m.auth) diff --git a/internal/mailer/transport/smtp_config.go b/internal/mailer/transport/smtp_config.go index c2dbcc35..88fdc999 100644 --- a/internal/mailer/transport/smtp_config.go +++ b/internal/mailer/transport/smtp_config.go @@ -78,4 +78,5 @@ type SMTPMailTransportConfig struct { Password string `json:"-"` // sensitive Encryption SMTPEncryption `json:"-"` // iota TLSConfig *tls.Config `json:"-"` // pointer + UseTLS bool // ! deprecated since 2021-10-25, use Encryption type 'SMTPEncryptionTLS' instead } From c94f8e84298db09351ed77af2fe6cb5a24167889 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Wed, 27 Oct 2021 10:32:43 +0200 Subject: [PATCH 09/15] Add deprecation warning when using SERVER_SMTP_USE_TLS flag --- CHANGELOG.md | 3 ++- internal/mailer/transport/smtp.go | 2 ++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 223a3874..23b4a0bc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,7 +14,8 @@ - Add MIME interface to use *mimtype.MIME or an already KnownMIME. - Add function to detach context to avoid context cancelation. Can be used to pass context information to go routines without a deadline or cancel. - Add oauth2 helper for PKCE extention to generate verifier and challenge. -- **DEPRECATED** Add starttls support to mailer. If you were using the `SERVER_SMTP_USE_TLS` flag before to enable TLS you need to change it to the `SERVER_SMTP_ENCRYPTION` setting and set it to `tls`. +- Added starttls support to mailer +- **DEPRECATED** If you were using the `SERVER_SMTP_USE_TLS` flag before to enable TLS, you'll need to migrate to the `SERVER_SMTP_ENCRYPTION` setting of `tls`. For the moment, both settings are supported (with a warning being printed when using `SERVER_SMTP_USE_TLS`, however support for the deprecated config might be dropped in a future release). - Extend mailer mock to support waiting for all expected mails to arrive to check asynchronously sent mails in tests. ## 2021-10-19 diff --git a/internal/mailer/transport/smtp.go b/internal/mailer/transport/smtp.go index c6533220..f9e40688 100644 --- a/internal/mailer/transport/smtp.go +++ b/internal/mailer/transport/smtp.go @@ -6,6 +6,7 @@ import ( "net/smtp" "github.com/jordan-wright/email" + "github.com/rs/zerolog/log" ) type SMTPMailTransport struct { @@ -35,6 +36,7 @@ func NewSMTP(config SMTPMailTransportConfig) *SMTPMailTransport { func (m *SMTPMailTransport) Send(mail *email.Email) error { if m.config.UseTLS { + log.Warn().Msg("Enabling TLS with the UseTLS flag is *DEPRECATED* and will be removed in future releases") return mail.SendWithTLS(m.addr, m.auth, m.config.TLSConfig) } From ae9f9159458c4a378102a448cfe1fab5b3c962e4 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Mon, 30 May 2022 16:31:32 +0200 Subject: [PATCH 10/15] Removed override of global OnMailSent hook in mock mailer add tests in util package --- internal/mailer/transport/mock.go | 12 ++++++----- internal/util/env_test.go | 12 +++++++++++ internal/util/http_test.go | 36 +++++++++++++++++++++++++++++++ 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/internal/mailer/transport/mock.go b/internal/mailer/transport/mock.go index ecbaf675..0cae35de 100644 --- a/internal/mailer/transport/mock.go +++ b/internal/mailer/transport/mock.go @@ -13,6 +13,7 @@ type MockMailTransport struct { mails []*email.Email OnMailSent func(mail email.Email) // non pointer to prevent concurrent read errors wg sync.WaitGroup + expected int } func NewMock() *MockMailTransport { @@ -30,6 +31,10 @@ func (m *MockMailTransport) Send(mail *email.Email) error { m.mails = append(m.mails, mail) m.OnMailSent(*mail) + if m.expected > 0 { + m.wg.Done() + } + return nil } @@ -51,13 +56,10 @@ func (m *MockMailTransport) GetSentMails() []*email.Email { return m.mails } -// Expect adds the mailCnt to a waitgroup and sets the OnMailSent callback -// to call wg.Done() +// Expect adds the mailCnt to a waitgroup. Done() is called by Send func (m *MockMailTransport) Expect(mailCnt int) { + m.expected = mailCnt m.wg.Add(mailCnt) - m.OnMailSent = func(email.Email) { - m.wg.Done() - } } // Wait until all expected mails have arrived diff --git a/internal/util/env_test.go b/internal/util/env_test.go index 2cfe6855..4399201a 100644 --- a/internal/util/env_test.go +++ b/internal/util/env_test.go @@ -263,4 +263,16 @@ func TestGetEnvAsLocation(t *testing.T) { t.Setenv(testVarKey, "Europe/Vienna") res = util.GetEnvAsLocation(testVarKey, "UTC") assert.Equal(t, vienna, res) + + panicFunc := func() { + t.Setenv(testVarKey, "") + _ = util.GetEnvAsLocation(testVarKey, "not-valud") + } + assert.Panics(t, panicFunc) + + panicFunc = func() { + t.Setenv(testVarKey, "not-valid") + _ = util.GetEnvAsLocation(testVarKey, "UTC") + } + assert.Panics(t, panicFunc) } diff --git a/internal/util/http_test.go b/internal/util/http_test.go index c814032b..afe6a963 100644 --- a/internal/util/http_test.go +++ b/internal/util/http_test.go @@ -10,6 +10,7 @@ import ( "testing" "allaboutapps.dev/aw/go-starter/internal/api" + "allaboutapps.dev/aw/go-starter/internal/api/httperrors" "allaboutapps.dev/aw/go-starter/internal/test" "allaboutapps.dev/aw/go-starter/internal/types" "allaboutapps.dev/aw/go-starter/internal/types/auth" @@ -144,6 +145,41 @@ func TestParseFileUplaodUnsupported(t *testing.T) { require.Equal(t, http.StatusUnsupportedMediaType, res.Result().StatusCode) } +func TestParseFileUplaodEmpty(t *testing.T) { + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + _, err := writer.CreateFormFile("file", filepath.Base("example.txt")) + require.NoError(t, err) + + err = writer.Close() + require.NoError(t, err) + + e := echo.New() + e.POST("/", func(c echo.Context) error { + + fh, file, mime, err := util.ParseFileUpload(c, "file", []string{"text/plain"}) + assert.Nil(t, fh) + assert.Nil(t, file) + assert.Nil(t, mime) + assert.Equal(t, httperrors.ErrBadRequestZeroFileSize, err) + if err != nil { + return err + } + + return c.NoContent(204) + }) + + s := &api.Server{ + Echo: e, + } + + headers := http.Header{} + headers.Set(echo.HeaderContentType, writer.FormDataContentType()) + + test.PerformRequestWithRawBody(t, s, "POST", "/", &body, headers, nil) +} + func prepareFileUpload(t *testing.T, filePath string) (*bytes.Buffer, string) { t.Helper() From d3359f6ea5f307a976a8753e33eaad56780b73fb Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Mon, 30 May 2022 18:12:00 +0200 Subject: [PATCH 11/15] Add helper to test package to copy test files into a folder unique to the test. --- internal/test/helper_files.go | 62 ++++++++++++++++++++++++++++++ internal/test/helper_files_test.go | 30 +++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 internal/test/helper_files.go create mode 100644 internal/test/helper_files_test.go diff --git a/internal/test/helper_files.go b/internal/test/helper_files.go new file mode 100644 index 00000000..cab1ddf5 --- /dev/null +++ b/internal/test/helper_files.go @@ -0,0 +1,62 @@ +package test + +import ( + "io" + "os" + "path/filepath" + "strings" + "testing" + + "allaboutapps.dev/aw/go-starter/internal/util" + "github.com/stretchr/testify/require" +) + +// PrepareTestFile copies a test file by name from test/testdata/ to a folder unique to the test. +// Used for tests with document to load reference files by fixtures. +func PrepareTestFile(t *testing.T, fileName string, destFileName ...string) { + t.Helper() + + src, err := os.Open(filepath.Join(util.GetProjectRootDir(), "test", "testdata", fileName)) + require.NoError(t, err) + defer src.Close() + + dest := fileName + if len(destFileName) > 0 { + dest = destFileName[0] + } + + path := filepath.Join(util.GetProjectRootDir(), "assets", "mnt", strings.ToLower(t.Name()), "documents", dest) + err = os.MkdirAll(filepath.Dir(path), 0755) + require.NoError(t, err) + + dst, err := os.Create(path) + require.NoError(t, err) + defer dst.Close() + + _, err = io.Copy(dst, src) + require.NoError(t, err) +} + +// CleanupTestFiles removes folder unique to the test if exists +func CleanupTestFiles(t *testing.T) { + t.Helper() + + err := os.RemoveAll(filepath.Join(util.GetProjectRootDir(), "assets", "mnt", strings.ToLower(t.Name()))) + require.NoError(t, err) +} + +// WithTempDir creates a folder unique to the tests and ensures cleanup of the folder will be +// performed after the fn got called. +func WithTempDir(t *testing.T, fn func(localBasePath string, basePath string)) { + t.Helper() + + localBasePath := filepath.Join(util.GetProjectRootDir(), "assets", "mnt", strings.ToLower(t.Name())) + basePath := "/documents" + path := filepath.Join(localBasePath, basePath) + err := os.MkdirAll(filepath.Dir(path), 0755) + require.NoError(t, err) + + defer CleanupTestFiles(t) + + fn(localBasePath, basePath) +} diff --git a/internal/test/helper_files_test.go b/internal/test/helper_files_test.go new file mode 100644 index 00000000..837ab542 --- /dev/null +++ b/internal/test/helper_files_test.go @@ -0,0 +1,30 @@ +package test_test + +import ( + "os" + "path/filepath" + "strings" + "testing" + + "allaboutapps.dev/aw/go-starter/internal/test" + "github.com/stretchr/testify/assert" +) + +func TestPrepareTestFile(t *testing.T) { + var path string + test.WithTempDir(t, func(localBasePath, basePath string) { + assert.True(t, strings.HasSuffix(localBasePath, strings.ToLower(t.Name()))) + assert.NotEmpty(t, basePath) + + fileName := "example.jpg" + test.PrepareTestFile(t, fileName) + + path = filepath.Join(localBasePath, basePath, fileName) + _, err := os.Stat(path) + assert.NoError(t, err) + }) + + _, err := os.Stat(path) + assert.Error(t, err) + assert.ErrorIs(t, err, os.ErrNotExist) +} From 4fdca52635e62deb1a44c09bc90826db0f87c1c7 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Mon, 30 May 2022 18:14:06 +0200 Subject: [PATCH 12/15] Add util for IN and NIN query helper which use pg.StringArray as param for slices instead of passing each item as arg. sqlboiler fails at about 10k params because database and driver limits are reached. With this slices with over 1kk entries are working. --- internal/util/db/db.go | 8 +++++++ internal/util/db/db_test.go | 8 +++++++ internal/util/db/where_in.go | 21 +++++++++++++------ internal/util/db/where_in_test.go | 14 +++++++++++++ test/testdata/snapshots/TestNINArgs.golden | 7 +++++++ test/testdata/snapshots/TestNINSQL.golden | 1 + .../testdata/snapshots/TestWhereInArgs.golden | 10 +++++---- test/testdata/snapshots/TestWhereInSQL.golden | 2 +- 8 files changed, 60 insertions(+), 11 deletions(-) create mode 100644 test/testdata/snapshots/TestNINArgs.golden create mode 100644 test/testdata/snapshots/TestNINSQL.golden diff --git a/internal/util/db/db.go b/internal/util/db/db.go index 1b29cd8b..e5e170fd 100644 --- a/internal/util/db/db.go +++ b/internal/util/db/db.go @@ -85,3 +85,11 @@ func Int16PtrFromInt(i int) *int16 { return &res } + +func NullStringIfEmpty(s string) null.String { + if len(s) == 0 { + return null.String{} + } + + return null.StringFrom(s) +} diff --git a/internal/util/db/db_test.go b/internal/util/db/db_test.go index f101eb63..3216fb1f 100644 --- a/internal/util/db/db_test.go +++ b/internal/util/db/db_test.go @@ -180,4 +180,12 @@ func TestDBTypeConversions(t *testing.T) { res6 := db.Int16PtrFromInt(i) require.NotEmpty(t, res6) assert.Equal(t, i, int(*res6)) + + res7 := db.NullStringIfEmpty("") + assert.False(t, res7.Valid) + + s := "foo" + res8 := db.NullStringIfEmpty(s) + assert.True(t, res8.Valid) + assert.Equal(t, s, res8.String) } diff --git a/internal/util/db/where_in.go b/internal/util/db/where_in.go index b5c607b7..022e7ab2 100644 --- a/internal/util/db/where_in.go +++ b/internal/util/db/where_in.go @@ -3,14 +3,23 @@ package db import ( "fmt" + "github.com/lib/pq" "github.com/volatiletech/sqlboiler/v4/queries/qm" ) -// WhereIn is a copy from sqlboiler's WHERE IN query helpers since these don't get generated for nullable columns. +// WhereIn was a copy from sqlboiler's WHERE IN query helpers since these don't get generated for nullable columns. +// Since sqlboilers IN query helpers will set a param for earch element in the slice we reccomment using this packages IN. func WhereIn(tableName string, columnName string, slice []string) qm.QueryMod { - values := make([]interface{}, 0, len(slice)) - for _, value := range slice { - values = append(values, value) - } - return qm.WhereIn(fmt.Sprintf("%s.%s IN ?", tableName, columnName), values...) + return IN(fmt.Sprintf("%s.%s", tableName, columnName), slice) +} + +// IN is a replacement for sqlboilers IN query mod. sqlboilers IN will set a param for +// each element in the slice and we do not reccomend to use this, because it will run into driver and +// database limits. While the sqlboiler IN fails at about ~10000 params this was tested with over 1000000. +func IN(path string, slice []string) qm.QueryMod { + return qm.Where(fmt.Sprintf("%s = any(?)", path), pq.StringArray(slice)) +} + +func NIN(path string, slice []string) qm.QueryMod { + return qm.Where(fmt.Sprintf("%s <> all(?)", path), pq.StringArray(slice)) } diff --git a/internal/util/db/where_in_test.go b/internal/util/db/where_in_test.go index ae4f9ebd..ce9201ae 100644 --- a/internal/util/db/where_in_test.go +++ b/internal/util/db/where_in_test.go @@ -23,3 +23,17 @@ func TestWhereIn(t *testing.T) { test.Snapshoter.Label("SQL").Save(t, sql) test.Snapshoter.Label("Args").Save(t, args) } + +func TestNIN(t *testing.T) { + q := models.NewQuery( + qm.Select("*"), + qm.From("users"), + db.InnerJoin("users", "id", "app_user_profiles", "user_id"), + db.NIN("app_user_profiles.username", []string{"max", "muster", "peter"}), + ) + + sql, args := queries.BuildQuery(q) + + test.Snapshoter.Label("SQL").Save(t, sql) + test.Snapshoter.Label("Args").Save(t, args) +} diff --git a/test/testdata/snapshots/TestNINArgs.golden b/test/testdata/snapshots/TestNINArgs.golden new file mode 100644 index 00000000..ffec0702 --- /dev/null +++ b/test/testdata/snapshots/TestNINArgs.golden @@ -0,0 +1,7 @@ +([]interface {}) (len=1) { + (pq.StringArray) (len=3) { + (string) (len=3) "max", + (string) (len=6) "muster", + (string) (len=5) "peter" + } +} diff --git a/test/testdata/snapshots/TestNINSQL.golden b/test/testdata/snapshots/TestNINSQL.golden new file mode 100644 index 00000000..006963d6 --- /dev/null +++ b/test/testdata/snapshots/TestNINSQL.golden @@ -0,0 +1 @@ +(string) (len=135) "SELECT * FROM \"users\" INNER JOIN app_user_profiles ON app_user_profiles.user_id=users.id WHERE (app_user_profiles.username <> all($1));" diff --git a/test/testdata/snapshots/TestWhereInArgs.golden b/test/testdata/snapshots/TestWhereInArgs.golden index 4ef66434..ffec0702 100644 --- a/test/testdata/snapshots/TestWhereInArgs.golden +++ b/test/testdata/snapshots/TestWhereInArgs.golden @@ -1,5 +1,7 @@ -([]interface {}) (len=3) { - (string) (len=3) "max", - (string) (len=6) "muster", - (string) (len=5) "peter" +([]interface {}) (len=1) { + (pq.StringArray) (len=3) { + (string) (len=3) "max", + (string) (len=6) "muster", + (string) (len=5) "peter" + } } diff --git a/test/testdata/snapshots/TestWhereInSQL.golden b/test/testdata/snapshots/TestWhereInSQL.golden index 0996e121..2b0c18a2 100644 --- a/test/testdata/snapshots/TestWhereInSQL.golden +++ b/test/testdata/snapshots/TestWhereInSQL.golden @@ -1 +1 @@ -(string) (len=142) "SELECT * FROM \"users\" INNER JOIN app_user_profiles ON app_user_profiles.user_id=users.id WHERE (\"app_user_profiles\".\"username\" IN ($1,$2,$3));" +(string) (len=134) "SELECT * FROM \"users\" INNER JOIN app_user_profiles ON app_user_profiles.user_id=users.id WHERE (app_user_profiles.username = any($1));" From c845a6fcca3b51a1e10704a0bdbe9e62b84dc307 Mon Sep 17 00:00:00 2001 From: Manuel Wieser Date: Mon, 30 May 2022 18:18:06 +0200 Subject: [PATCH 13/15] Fixed bug in test.RunningInTest were the test env did not get recognized while debugging --- CHANGELOG.md | 6 +++++- internal/util/test.go | 14 ++++++++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d67369a2..cf0f45fe 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,13 +9,17 @@ - Please follow the update process in *[I just want to update / upgrade my project!](https://github.com/allaboutapps/go-starter/wiki/FAQ#i-just-want-to-update--upgrade-my-project)*. ## Unreleased -- Extend util package with additional helper functions. +- Extend util and test package with additional helper functions. - Add MIME interface to use *mimtype.MIME or an already KnownMIME. - Add function to detach context to avoid context cancelation. Can be used to pass context information to go routines without a deadline or cancel. - Add oauth2 helper for PKCE extention to generate verifier and challenge. - Added starttls support to mailer - **DEPRECATED** If you were using the `SERVER_SMTP_USE_TLS` flag before to enable TLS, you'll need to migrate to the `SERVER_SMTP_ENCRYPTION` setting of `tls`. For the moment, both settings are supported (with a warning being printed when using `SERVER_SMTP_USE_TLS`, however support for the deprecated config might be dropped in a future release). - Extend mailer mock to support waiting for all expected mails to arrive to check asynchronously sent mails in tests. +- Add util for IN and NIN query helper which use `pg.StringArray` as param for slices instead of passing each item as arg. + - sqlboiler fails at around 10k params because each element is passed as an argument and some database and driver limits are reached + - `db.IN` should be used in go-starter projects when using postgres because it allows for over 1kk params in `IN` and `NIN` where clauses +- Fixed bug in `test.RunningInTest()` were the test env did not get recognized while debugging ## 2022-04-15 - Switch [from Go 1.17.1 to Go 1.17.9](https://go.dev/doc/devel/release#go1.17.minor) (requires `./docker-helper.sh --rebuild`). diff --git a/internal/util/test.go b/internal/util/test.go index 90de81a9..7f9d93d5 100644 --- a/internal/util/test.go +++ b/internal/util/test.go @@ -9,6 +9,16 @@ import ( // The function first checks the `CI` env variable defined by various CI environments, // then tests whether the executable ends with the `.test` suffix generated by `go test`. func RunningInTest() bool { - // Partially taken from: https://stackoverflow.com/a/45913089 @ 2021-06-02T14:55:01+00:00 - return len(os.Getenv("CI")) > 0 || strings.HasSuffix(os.Args[0], ".test") + // Partially taken from: https://stackoverflow.com/a/45913089 @ 2021-06-02T14:55:01+00:00, early out + if len(os.Getenv("CI")) > 0 || strings.HasSuffix(os.Args[0], ".test") { + return true + } + + for _, arg := range os.Args { + if strings.HasPrefix(arg, "-test.") { + return true + } + } + + return false } From 333336be4d23a5e6a0932fc5b49a9343c26f1fd4 Mon Sep 17 00:00:00 2001 From: anjankow Date: Wed, 10 May 2023 14:48:13 +0000 Subject: [PATCH 14/15] add check if not enough emails have been sent --- internal/mailer/mailer_test.go | 7 ++++++- internal/mailer/transport/mock.go | 31 ++++++++++++++++++++++++++++++- 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/internal/mailer/mailer_test.go b/internal/mailer/mailer_test.go index c879f816..08d4e91f 100644 --- a/internal/mailer/mailer_test.go +++ b/internal/mailer/mailer_test.go @@ -3,6 +3,7 @@ package mailer_test import ( "context" "testing" + "time" "allaboutapps.dev/aw/go-starter/internal/api" "allaboutapps.dev/aw/go-starter/internal/config" @@ -16,12 +17,16 @@ func TestMailerSendPasswordReset(t *testing.T) { fixtures := test.Fixtures() m := test.NewTestMailer(t) + mt := test.GetTestMailerMockTransport(t, m) + mt.Expect(1) + //nolint:gosec passwordResetLink := "http://localhost/password/reset/12345" err := m.SendPasswordReset(ctx, fixtures.User1.Username.String, passwordResetLink) require.NoError(t, err) - mt := test.GetTestMailerMockTransport(t, m) + mt.WaitWithTimeout(time.Second) + mail := mt.GetLastSentMail() mails := mt.GetSentMails() require.NotNil(t, mail) diff --git a/internal/mailer/transport/mock.go b/internal/mailer/transport/mock.go index 0cae35de..e8b0cc0b 100644 --- a/internal/mailer/transport/mock.go +++ b/internal/mailer/transport/mock.go @@ -1,6 +1,9 @@ package transport import ( + "errors" + "fmt" + "log" "sync" "time" @@ -8,6 +11,8 @@ import ( "github.com/jordan-wright/email" ) +const defaultWaitTimeout = time.Second * 10 + type MockMailTransport struct { sync.RWMutex mails []*email.Email @@ -28,6 +33,22 @@ func (m *MockMailTransport) Send(mail *email.Email) error { m.Lock() defer m.Unlock() + // Calling wg.Done might panic leaving a user clueless what was the reason of test failure. + // We will add more information before exiting. + defer func() { + rcp := recover() + if rcp == nil { + return + } + + err, ok := rcp.(error) + if !ok { + err = fmt.Errorf("%v", rcp) + } + + log.Fatalf("Unexpected email sent! MockMailTransport panicked: %s", err) + }() + m.mails = append(m.mails, mail) m.OnMailSent(*mail) @@ -64,5 +85,13 @@ func (m *MockMailTransport) Expect(mailCnt int) { // Wait until all expected mails have arrived func (m *MockMailTransport) Wait() { - _ = util.WaitTimeout(&m.wg, time.Second*10) + if err := util.WaitTimeout(&m.wg, defaultWaitTimeout); errors.Is(err, util.ErrWaitTimeout) { + panic(fmt.Sprintf("Some emails are missing, sent: %v", len(m.GetSentMails()))) + } +} + +func (m *MockMailTransport) WaitWithTimeout(timeout time.Duration) { + if err := util.WaitTimeout(&m.wg, timeout); errors.Is(err, util.ErrWaitTimeout) { + log.Fatalf("Some emails are missing, found: %v", len(m.GetSentMails())) + } } From be48d27d314c8040807c82f500822eb1d8643f30 Mon Sep 17 00:00:00 2001 From: anjankow Date: Thu, 11 May 2023 07:13:23 +0000 Subject: [PATCH 15/15] fix typo in changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4bdeffa2..1a85b7cc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,7 +19,7 @@ - Add util for IN and NIN query helper which use `pg.StringArray` as param for slices instead of passing each item as arg. - sqlboiler fails at around 10k params because each element is passed as an argument and some database and driver limits are reached - `db.IN` should be used in go-starter projects when using postgres because it allows for over 1kk params in `IN` and `NIN` where clauses -- Fixed bug in `test.RunningInTest()` were the test env did not get recognized while debugging +- Fixed bug in `test.RunningInTest()` where the test env did not get recognized while debugging ## 2023-05-03 - Switch [from Go 1.19.3 to Go 1.20.3](https://go.dev/doc/devel/release#go1.20) (requires `./docker-helper.sh --rebuild`).