From bfbf5e8fa4301934238c3a8334e2947bb40ec548 Mon Sep 17 00:00:00 2001 From: Neko Date: Tue, 17 Dec 2024 22:00:03 +0800 Subject: [PATCH] feat: error handling and proper status codes (#5) --- cspell.config.yaml | 3 + go.mod | 5 + go.sum | 12 +- pkg/apierrors/apierrors.go | 14 ++ pkg/apierrors/errors.go | 56 +++-- pkg/backend/backend.go | 15 +- pkg/backend/error.go | 71 +++++++ pkg/backend/openai.go | 32 ++- pkg/jsonapi/error.go | 2 - pkg/utils/json.go | 93 ++++++++ pkg/utils/json_test.go | 106 ++++++++++ pkg/utils/string.go | 217 +++++++++++++++++++ pkg/utils/string_test.go | 420 +++++++++++++++++++++++++++++++++++++ 13 files changed, 1004 insertions(+), 42 deletions(-) create mode 100644 pkg/backend/error.go create mode 100644 pkg/utils/json.go create mode 100644 pkg/utils/json_test.go create mode 100644 pkg/utils/string.go create mode 100644 pkg/utils/string_test.go diff --git a/cspell.config.yaml b/cspell.config.yaml index 7aa66cd..22ed3e5 100644 --- a/cspell.config.yaml +++ b/cspell.config.yaml @@ -9,10 +9,12 @@ words: - cyclop - depguard - Describedby + - Detailf - dupl - durationcheck - elevenlabs - errcheck + - errchkjson - errname - execinquery - exhaustive @@ -59,6 +61,7 @@ words: - predeclared - reassign - revive + - samber - staticcheck - strconv - tagalign diff --git a/go.mod b/go.mod index 27fb4f0..0657757 100644 --- a/go.mod +++ b/go.mod @@ -9,9 +9,12 @@ require ( github.com/samber/lo v1.47.0 github.com/samber/mo v1.13.0 github.com/spf13/cobra v1.8.1 + github.com/stretchr/testify v1.10.0 + k8s.io/client-go v0.32.0 ) require ( + github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/gobuffalo/envy v1.7.0 // indirect github.com/gobuffalo/packd v0.3.0 // indirect github.com/gobuffalo/packr v1.30.1 // indirect @@ -20,6 +23,7 @@ require ( github.com/labstack/gommon v0.4.2 // indirect github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect + github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.3.0 // indirect github.com/spf13/pflag v1.0.5 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect @@ -31,4 +35,5 @@ require ( golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect golang.org/x/time v0.8.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index bfe09a0..96515db 100644 --- a/go.sum +++ b/go.sum @@ -6,8 +6,9 @@ github.com/coreos/go-semver v0.2.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3Ee github.com/cpuguy83/go-md2man v1.0.10/go.mod h1:SmD6nW6nTyfqj6ABTjUi3V3JVMnlJmwcJI5acqYI6dE= github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/gobuffalo/envy v1.7.0 h1:GlXgaiBkmrYMHco6t4j7SacKO4XUjvh5pwXh0f4uxXU= github.com/gobuffalo/envy v1.7.0/go.mod h1:n7DRkBerg/aorDM8kbduw5dN3oXGswK5liaSCx4T5NI= @@ -28,9 +29,12 @@ github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqx github.com/karrick/godirwalk v1.10.12/go.mod h1:RoGL9dQei4vP9ilrpETWE8CLOZ1kiN0LhBygSwrAsHA= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/labstack/echo/v4 v4.13.2 h1:9aAt4hstpH54qIcqkuUXRLTf+v7yOTfMPWzDtuqLmtA= github.com/labstack/echo/v4 v4.13.2/go.mod h1:uc9gDtHB8UWt3FfbYx0HyxcCuvR4YuPYOxF/1QjoV/c= github.com/labstack/gommon v0.4.2 h1:F8qTUNXgG1+6WQmqoUWnz8WiEU60mXVVw0P4ht1WRA0= @@ -46,8 +50,9 @@ github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh github.com/nekomeowww/fo v1.4.0 h1:ULX5KsnDzWHoDwHgtjd2wibpdpyh+5/5DITmvhJZyWY= github.com/nekomeowww/fo v1.4.0/go.mod h1:ctwQ+BZ0UYUb2s+yM7h9SFHjqGCXeUIXFLK2ujAneWw= github.com/pelletier/go-toml v1.2.0/go.mod h1:5z9KED0ma1S8pY6P1sdut58dfprrGBbd/94hg7ilaic= -github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.1.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/rogpeppe/go-internal v1.3.0 h1:RR9dF3JtopPvtkroDZuVD7qquD0bnHlKSqaQhgwt8yk= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= @@ -117,9 +122,12 @@ golang.org/x/time v0.8.0 h1:9i3RxcPv3PZnitoVGMPDKZSq1xW1gK1Xy3ArNOGZfEg= golang.org/x/time v0.8.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= golang.org/x/tools v0.0.0-20190624180213-70d37148ca0c/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +k8s.io/client-go v0.32.0 h1:DimtMcnN/JIKZcrSrstiwvvZvLjG0aSxy8PxN8IChp8= +k8s.io/client-go v0.32.0/go.mod h1:boDWvdM1Drk4NJj/VddSLnx59X3OPgwrOo0vGbtq9+8= diff --git a/pkg/apierrors/apierrors.go b/pkg/apierrors/apierrors.go index 1a9bb57..575bc7c 100644 --- a/pkg/apierrors/apierrors.go +++ b/pkg/apierrors/apierrors.go @@ -104,6 +104,20 @@ func (e *Error) WithSourceHeader(header string) *Error { return e } +func (e *Error) WithReason(reason string) *Error { + return e.WithMeta("reason", reason) +} + +func (e *Error) WithMeta(key string, val any) *Error { + if e.Meta.IsAbsent() { + e.Meta = mo.Some(map[string]any{}) + } + + e.Meta.MustGet()[key] = val + + return e +} + type ErrResponse struct { jsonapi.Response } diff --git a/pkg/apierrors/errors.go b/pkg/apierrors/errors.go index e05113e..4eeb1b1 100644 --- a/pkg/apierrors/errors.go +++ b/pkg/apierrors/errors.go @@ -10,24 +10,6 @@ func NewErrBadRequest() *Error { WithDetail("The request was invalid or cannot be served") } -func NewErrInternal() *Error { - return NewError(http.StatusInternalServerError, "INTERNAL_SERVER_ERROR"). - WithTitle("Internal Server Error"). - WithDetail("An internal server error occurred") -} - -func NewErrPermissionDenied() *Error { - return NewError(http.StatusForbidden, "PERMISSION_DENIED"). - WithTitle("Permission Denied"). - WithDetail("You do not have permission to access the requested resources") -} - -func NewErrUnavailable() *Error { - return NewError(http.StatusServiceUnavailable, "UNAVAILABLE"). - WithTitle("Service Unavailable"). - WithDetail("The requested service is unavailable") -} - func NewErrInvalidArgument() *Error { return NewError(http.StatusBadRequest, "INVALID_ARGUMENT"). WithTitle("Invalid Argument"). @@ -46,6 +28,18 @@ func NewErrUnauthorized() *Error { WithDetail("The requested resources require authentication") } +func NewErrPermissionDenied() *Error { + return NewError(http.StatusForbidden, "PERMISSION_DENIED"). + WithTitle("Permission Denied"). + WithDetail("You do not have permission to access the requested resources") +} + +func NewErrForbidden() *Error { + return NewError(http.StatusForbidden, "FORBIDDEN"). + WithTitle("Forbidden"). + WithDetail("You do not have permission to access the requested resources") +} + func NewErrNotFound() *Error { return NewError(http.StatusNotFound, "NOT_FOUND"). WithTitle("Not Found"). @@ -64,8 +58,26 @@ func NewErrQuotaExceeded() *Error { WithDetail("The request quota has been exceeded") } -func NewErrForbidden() *Error { - return NewError(http.StatusForbidden, "FORBIDDEN"). - WithTitle("Forbidden"). - WithDetail("You do not have permission to access the requested resources") +func NewErrInternal() *Error { + return NewError(http.StatusInternalServerError, "INTERNAL_SERVER_ERROR"). + WithTitle("Internal Server Error"). + WithDetail("An internal server error occurred") +} + +func NewErrBadGateway() *Error { + return NewError(http.StatusBadGateway, "BAD_GATEWAY"). + WithTitle("Bad gateway"). + WithDetail("The server received an invalid response from an upstream server") +} + +func NewErrUnavailable() *Error { + return NewError(http.StatusServiceUnavailable, "UNAVAILABLE"). + WithTitle("Service Unavailable"). + WithDetail("The requested service is unavailable") +} + +func NewUpstreamError(statusCode int) *Error { + return NewError(statusCode, "UPSTREAM_ERROR"). + WithTitle("Upstream Error"). + WithReason("An error occurred while processing the request from the upstream service") } diff --git a/pkg/backend/backend.go b/pkg/backend/backend.go index e6dc7ab..a8c76ea 100644 --- a/pkg/backend/backend.go +++ b/pkg/backend/backend.go @@ -36,29 +36,26 @@ type FullOptions struct { } func Speech(c echo.Context) mo.Result[any] { - options := new(Options) + var options Options - if err := c.Bind(options); err != nil { - return mo.Err[any](apierrors.NewErrBadRequest().WithCaller()) + if err := c.Bind(&options); err != nil { + return mo.Err[any](apierrors.NewErrBadRequest()) } - if options.Model == "" || options.Input == "" || options.Voice == "" { - return mo.Err[any](apierrors.NewErrBadRequest().WithCaller()) + return mo.Err[any](apierrors.NewErrInvalidArgument().WithDetail("either one of model, input, and voice parameter is required")) } backendAndModel := lo.Ternary( strings.Contains(options.Model, ":"), - //nolint:mnd - strings.SplitN(options.Model, ":", 2), + strings.SplitN(options.Model, ":", 2), //nolint:mnd []string{options.Model, ""}, ) fullOptions := FullOptions{ - Options: *options, + Options: options, Backend: backendAndModel[0], Model: backendAndModel[1], } return openai(c, fullOptions) - // return mo.Ok[any](c.JSON(http.StatusOK, fullOptions)) } diff --git a/pkg/backend/error.go b/pkg/backend/error.go new file mode 100644 index 0000000..28dba6b --- /dev/null +++ b/pkg/backend/error.go @@ -0,0 +1,71 @@ +package backend + +import ( + "encoding/json" + "io" + + "github.com/moeru-ai/unspeech/pkg/utils" + "github.com/samber/lo" + "github.com/samber/mo" +) + +var _ error = (*JSONResponseError)(nil) + +type JSONResponseError struct { + StatusCode int `json:"status_code"` + Message string `json:"message"` + + bodyParsed map[string]any +} + +func NewJSONResponseError(statusCode int, responseBody io.Reader) mo.Result[*JSONResponseError] { + jsonData, err := io.ReadAll(responseBody) + if err != nil { + return mo.Err[*JSONResponseError](err) + } + + resp := &JSONResponseError{ + StatusCode: statusCode, + } + + err = json.Unmarshal(jsonData, &resp.bodyParsed) + if err != nil { + return mo.Err[*JSONResponseError](err) + } + + errorMessage := utils.GetByJSONPath[string](resp.bodyParsed, "{ .message }") + errorStr := utils.GetByJSONPath[string](resp.bodyParsed, "{ .error }") + errorMap := utils.GetByJSONPath[map[string]any](resp.bodyParsed, "{ .error }") + errorStrFromErrorMap := utils.GetByJSONPath[string](errorMap, "{ .message }") + + resp.Message = lo.Must(lo.Coalesce(errorMessage, errorStr, errorStrFromErrorMap, "Unknown error")) + + return mo.Ok(resp) +} + +func (r *JSONResponseError) Error() string { + return r.Message +} + +var _ error = (*TextResponseError)(nil) + +type TextResponseError struct { + StatusCode int `json:"status_code"` + Body string `json:"body"` +} + +func (r *TextResponseError) Error() string { + return r.Body +} + +func NewTextResponseError(statusCode int, responseBody io.Reader) mo.Result[*TextResponseError] { + data, err := io.ReadAll(responseBody) + if err != nil { + return mo.Err[*TextResponseError](err) + } + + return mo.Ok(&TextResponseError{ + StatusCode: statusCode, + Body: string(data), + }) +} diff --git a/pkg/backend/openai.go b/pkg/backend/openai.go index f74d0da..e685ed8 100644 --- a/pkg/backend/openai.go +++ b/pkg/backend/openai.go @@ -3,11 +3,12 @@ package backend import ( "bytes" "encoding/json" + "log/slog" "net/http" + "strings" "github.com/labstack/echo/v4" "github.com/moeru-ai/unspeech/pkg/apierrors" - "github.com/nekomeowww/fo" "github.com/samber/mo" ) @@ -20,7 +21,7 @@ func openai(c echo.Context, options FullOptions) mo.Result[any] { Speed: options.Speed, } - payload := fo.May(json.Marshal(values)) + payload, _ := json.Marshal(values) //nolint:errchkjson req, err := http.NewRequestWithContext( c.Request().Context(), @@ -29,21 +30,38 @@ func openai(c echo.Context, options FullOptions) mo.Result[any] { bytes.NewBuffer(payload), ) if err != nil { - return mo.Err[any](apierrors.NewErrBadRequest().WithCaller()) + return mo.Err[any](apierrors.NewErrInternal().WithCaller()) } - // TODO: Bearer Auth + // Proxy the Authorization header + req.Header.Set("Authorization", c.Request().Header.Get("Authorization")) req.Header.Set("Content-Type", "application/json") res, err := http.DefaultClient.Do(req) - if err != nil { - return mo.Err[any](apierrors.NewErrBadRequest().WithCaller()) + return mo.Err[any](apierrors.NewErrBadGateway().WithDetail(err.Error()).WithError(err).WithCaller()) } defer res.Body.Close() - // body, _ := io.ReadAll(res.Body) + if res.StatusCode >= 400 && res.StatusCode < 600 { + switch { + case strings.HasPrefix(res.Header.Get("Content-Type"), "application/json"): + return mo.Err[any](apierrors. + NewUpstreamError(res.StatusCode). + WithDetail(NewJSONResponseError(res.StatusCode, res.Body).OrEmpty().Error())) + case strings.HasPrefix(res.Header.Get("Content-Type"), "text/"): + return mo.Err[any](apierrors. + NewUpstreamError(res.StatusCode). + WithDetail(NewTextResponseError(res.StatusCode, res.Body).OrEmpty().Error())) + default: + slog.Warn("unknown upstream error with unknown Content-Type", + slog.Int("status", res.StatusCode), + slog.String("content-type", res.Header.Get("Content-Type")), + slog.String("content-length", res.Header.Get("Content-Length")), + ) + } + } return mo.Ok[any](c.Stream(http.StatusOK, "audio/mp3", res.Body)) } diff --git a/pkg/jsonapi/error.go b/pkg/jsonapi/error.go index 6436820..b0095ba 100644 --- a/pkg/jsonapi/error.go +++ b/pkg/jsonapi/error.go @@ -43,8 +43,6 @@ type ErrorObject struct { Links mo.Option[*Links] `json:"links,omitempty"` // the HTTP status code applicable to this problem, expressed as a string value. Status int `json:"status,omitempty"` - // the HTTP status code applicable to this problem, expressed as a string value. - GrpcStatus uint64 `json:"grpc_status,omitempty"` // an application-specific error code, expressed as a string value. Code string `json:"code,omitempty"` // a short, human-readable summary of the problem diff --git a/pkg/utils/json.go b/pkg/utils/json.go new file mode 100644 index 0000000..30625ab --- /dev/null +++ b/pkg/utils/json.go @@ -0,0 +1,93 @@ +package utils + +import ( + "bytes" + "encoding/json" + "io" + + "k8s.io/client-go/util/jsonpath" +) + +func GetByJSONPathWithoutConvert(input any, template string) (string, error) { + j := jsonpath.New("document") + j.AllowMissingKeys(true) + + err := j.Parse(template) + if err != nil { + return "", err + } + + buffer := new(bytes.Buffer) + + err = j.Execute(buffer, input) + if err != nil { + return "", err + } + + return buffer.String(), nil +} + +func GetByJSONPath[T any](input any, template string) T { + var empty T + + result, err := GetByJSONPathWithoutConvert(input, template) + if err != nil { + return empty + } + + return FromStringOrEmpty[T](result) +} + +func ReadAsJSONWithClose(readCloser io.ReadCloser) (*bytes.Buffer, map[string]any, error) { + defer func() { + _ = readCloser.Close() + }() + + buffer, jsonMap, err := ReadAsJSON(readCloser) + if err != nil { + return buffer, jsonMap, err + } + + return buffer, jsonMap, nil +} + +func ReadAsJSON(reader io.Reader) (*bytes.Buffer, map[string]any, error) { + buffer := new(bytes.Buffer) + jsonMap := make(map[string]any) + + _, err := io.Copy(buffer, reader) + if err != nil { + return buffer, jsonMap, err + } + + err = json.Unmarshal(buffer.Bytes(), &jsonMap) + if err != nil { + return buffer, jsonMap, err + } + + return buffer, jsonMap, nil +} + +func FromMap[T any, MK comparable, MV any](m map[MK]MV) (*T, error) { + if m == nil { + return nil, nil + } + + if len(m) == 0 { + return nil, nil + } + + var initial T + + bs, err := json.Marshal(m) + if err != nil { + return nil, err + } + + err = json.Unmarshal(bs, &initial) + if err != nil { + return nil, err + } + + return &initial, nil +} diff --git a/pkg/utils/json_test.go b/pkg/utils/json_test.go new file mode 100644 index 0000000..f3de94b --- /dev/null +++ b/pkg/utils/json_test.go @@ -0,0 +1,106 @@ +package utils + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestJSONPathExecute(t *testing.T) { + t.Parallel() + + t.Run("string", func(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + payload map[string]any + template string + expected any + } + + testCases := []testCase{ + { + name: "model", + payload: map[string]any{ + "model": "gpt-4o", + }, + template: "{ .model }", + expected: "gpt-4o", + }, + { + name: "message role", + payload: map[string]any{ + "model": "gpt-4o", + "messages": []any{ + map[string]any{ + "role": "user", + "content": "Hello", + }, + }, + }, + template: "{ .messages[0].role }", + expected: "user", + }, + { + name: "message content", + payload: map[string]any{ + "model": "gpt-4o", + "messages": []any{ + map[string]any{ + "role": "user", + "content": "Hello", + }, + }, + }, + template: "{ .messages[0].content }", + expected: "Hello", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + assert.Equal(t, tc.expected, GetByJSONPath[string](tc.payload, tc.template)) + }) + } + }) + + t.Run("number", func(t *testing.T) { + t.Parallel() + + payload := map[string]any{ + "code": 401, + } + + assert.Equal(t, 401, GetByJSONPath[int](payload, "{ .code }")) + }) + + t.Run("null", func(t *testing.T) { + t.Parallel() + + t.Run("unknown nil", func(t *testing.T) { + t.Parallel() + + payload := map[string]any{ + "code": nil, + } + + assert.Equal(t, "", GetByJSONPath[string](payload, "{ .code }")) + }) + + t.Run("nil string", func(t *testing.T) { + t.Parallel() + + type payload struct { + Code *string `json:"code"` + } + + p := payload{ + Code: nil, + } + + assert.Equal(t, "", GetByJSONPath[string](p, "{ .code }")) + }) + }) +} diff --git a/pkg/utils/string.go b/pkg/utils/string.go new file mode 100644 index 0000000..1bed9fc --- /dev/null +++ b/pkg/utils/string.go @@ -0,0 +1,217 @@ +package utils + +import ( + "encoding/json" + "fmt" + "strconv" + "strings" +) + +var ( + errFailedToConvertStringToType = func(t any, err error) error { return fmt.Errorf("failed to convert string to type %T: %w", t, err) } +) + +func FromString[T any](str string) (T, error) { //nolint:gocyclo + var empty T + if str == "" { + switch any(empty).(type) { + case []byte: + val, _ := any(make([]byte, 0)).(T) + return val, nil + case []rune: + val, _ := any(make([]rune, 0)).(T) + return val, nil + case *strings.Builder: + val, _ := any(&strings.Builder{}).(T) + return val, nil + } + + return empty, nil + } + if str == "null" { + return empty, nil + } + if str == "" { + return empty, nil + } + + switch any(empty).(type) { + case string: + val, _ := any(str).(T) + return val, nil + case int: + val, err := strconv.ParseInt(str, 10, 0) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(int(val)).(T) + + return typeVal, nil + case int8: + val, err := strconv.ParseInt(str, 10, 8) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(int8(val)).(T) + + return typeVal, nil + case int16: + val, err := strconv.ParseInt(str, 10, 16) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(int16(val)).(T) + + return typeVal, nil + case int32: + val, err := strconv.ParseInt(str, 10, 32) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(int32(val)).(T) + + return typeVal, nil + case int64: + val, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(val).(T) + + return typeVal, nil + case uint: + val, err := strconv.ParseUint(str, 10, 0) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(uint(val)).(T) + + return typeVal, nil + case uint8: + val, err := strconv.ParseUint(str, 10, 8) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(uint8(val)).(T) + + return typeVal, nil + case uint16: + val, err := strconv.ParseUint(str, 10, 16) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(uint16(val)).(T) + + return typeVal, nil + case uint32: + val, err := strconv.ParseUint(str, 10, 32) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(uint32(val)).(T) + + return typeVal, nil + case uint64: + val, err := strconv.ParseUint(str, 10, 64) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(val).(T) + + return typeVal, nil + case float32: + val, err := strconv.ParseFloat(str, 32) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(float32(val)).(T) + + return typeVal, nil + case float64: + val, err := strconv.ParseFloat(str, 64) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(val).(T) + + return typeVal, nil + case complex64: + val, err := strconv.ParseComplex(str, 64) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(complex64(val)).(T) + + return typeVal, nil + case complex128: + val, err := strconv.ParseComplex(str, 128) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(val).(T) + + return typeVal, nil + case bool: + val, err := strconv.ParseBool(str) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + typeVal, _ := any(val).(T) + + return typeVal, nil + case []byte: + val, _ := any([]byte(str)).(T) + return val, nil + case []rune: + val, _ := any([]rune(str)).(T) + return val, nil + case *strings.Builder: + var sb strings.Builder + + sb.WriteString(str) + val, _ := any(&sb).(T) + + return val, nil + default: + var initial T + + err := json.Unmarshal([]byte(str), &initial) + if err != nil { + return empty, errFailedToConvertStringToType(empty, err) + } + + return initial, nil + } +} + +func FromStringOrEmpty[T any](str string) T { + var empty T + + val, err := FromString[T](str) + if err != nil { + return empty + } + + return val +} + +func IsNumber(str string) bool { + _, err := strconv.ParseFloat(str, 64) + + return err == nil +} diff --git a/pkg/utils/string_test.go b/pkg/utils/string_test.go new file mode 100644 index 0000000..ace9e3e --- /dev/null +++ b/pkg/utils/string_test.go @@ -0,0 +1,420 @@ +package utils + +import ( + "fmt" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFromString(t *testing.T) { + t.Run("Unsupported", func(t *testing.T) { + funcVal, err := FromString[func()]("") + require.NoError(t, err) + assert.Nil(t, funcVal) + + mapVal, err := FromString[map[string]any]("") + require.NoError(t, err) + assert.Nil(t, mapVal) + + mapVal, err = FromString[map[string]any]("") + require.NoError(t, err) + assert.Empty(t, mapVal) + + sliceVal, err := FromString[[]string]("") + require.NoError(t, err) + assert.Nil(t, sliceVal) + + sliceVal, err = FromString[[]string]("") + require.NoError(t, err) + assert.Empty(t, sliceVal) + + structVal, err := FromString[struct{}]("") + require.NoError(t, err) + assert.Empty(t, structVal) + + funcVal, err = FromString[func()]("{}") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type func(): json: cannot unmarshal object into Go value of type func()") + assert.Nil(t, funcVal) + }) + + t.Run("Empty", func(t *testing.T) { + stringVal, err := FromString[string]("") + require.NoError(t, err) + assert.Equal(t, "", stringVal) + + intVal, err := FromString[int]("") + require.NoError(t, err) + assert.Zero(t, intVal) + + int8Val, err := FromString[int8]("") + require.NoError(t, err) + assert.Zero(t, int8Val) + + int16Val, err := FromString[int16]("") + require.NoError(t, err) + assert.Zero(t, int16Val) + + int32Val, err := FromString[int32]("") + require.NoError(t, err) + assert.Zero(t, int32Val) + + int64Val, err := FromString[int64]("") + require.NoError(t, err) + assert.Zero(t, int64Val) + + uintVal, err := FromString[uint]("") + require.NoError(t, err) + assert.Zero(t, uintVal) + + uint8Val, err := FromString[uint8]("") + require.NoError(t, err) + assert.Zero(t, uint8Val) + + uint16Val, err := FromString[uint16]("") + require.NoError(t, err) + assert.Zero(t, uint16Val) + + uint32Val, err := FromString[uint32]("") + require.NoError(t, err) + assert.Zero(t, uint32Val) + + uint64Val, err := FromString[uint64]("") + require.NoError(t, err) + assert.Zero(t, uint64Val) + + float32Val, err := FromString[float32]("") + require.NoError(t, err) + assert.Zero(t, float32Val) + + float64Val, err := FromString[float64]("") + require.NoError(t, err) + assert.Zero(t, float64Val) + + complex64Val, err := FromString[complex64]("") + require.NoError(t, err) + assert.Zero(t, complex64Val) + + complex128Val, err := FromString[complex128]("") + require.NoError(t, err) + assert.Zero(t, complex128Val) + + boolVal, err := FromString[bool]("") + require.NoError(t, err) + assert.False(t, boolVal) + + bytesVal, err := FromString[[]byte]("") + require.NoError(t, err) + assert.Empty(t, bytesVal) + + runesVal, err := FromString[[]rune]("") + require.NoError(t, err) + assert.Empty(t, runesVal) + + mapVal, err := FromString[map[string]any]("{}") + require.NoError(t, err) + assert.Empty(t, mapVal) + + sliceVal, err := FromString[[]string]("[]") + require.NoError(t, err) + assert.Empty(t, sliceVal) + + structVal, err := FromString[struct{}]("{}") + require.NoError(t, err) + assert.Empty(t, structVal) + + builderVal, err := FromString[*strings.Builder]("") + require.NoError(t, err) + assert.NotNil(t, builderVal) + + anyVal, err := FromString[any]("") + require.NoError(t, err) + assert.Equal(t, "", fmt.Sprintf("%T", anyVal)) + assert.Nil(t, anyVal) + }) + + t.Run("Invalid", func(t *testing.T) { + intVal, err := FromString[int]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type int: strconv.ParseInt: parsing \"invalid\": invalid syntax") + assert.Zero(t, intVal) + + int8Val, err := FromString[int8]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type int8: strconv.ParseInt: parsing \"invalid\": invalid syntax") + assert.Zero(t, int8Val) + + int16Val, err := FromString[int16]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type int16: strconv.ParseInt: parsing \"invalid\": invalid syntax") + assert.Zero(t, int16Val) + + int32Val, err := FromString[int32]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type int32: strconv.ParseInt: parsing \"invalid\": invalid syntax") + assert.Zero(t, int32Val) + + int64Val, err := FromString[int64]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type int64: strconv.ParseInt: parsing \"invalid\": invalid syntax") + assert.Zero(t, int64Val) + + uintVal, err := FromString[uint]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type uint: strconv.ParseUint: parsing \"invalid\": invalid syntax") + assert.Zero(t, uintVal) + + uint8Val, err := FromString[uint8]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type uint8: strconv.ParseUint: parsing \"invalid\": invalid syntax") + assert.Zero(t, uint8Val) + + uint16Val, err := FromString[uint16]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type uint16: strconv.ParseUint: parsing \"invalid\": invalid syntax") + assert.Zero(t, uint16Val) + + uint32Val, err := FromString[uint32]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type uint32: strconv.ParseUint: parsing \"invalid\": invalid syntax") + assert.Zero(t, uint32Val) + + uint64Val, err := FromString[uint64]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type uint64: strconv.ParseUint: parsing \"invalid\": invalid syntax") + assert.Zero(t, uint64Val) + + float32Val, err := FromString[float32]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type float32: strconv.ParseFloat: parsing \"invalid\": invalid syntax") + assert.Zero(t, float32Val) + + float64Val, err := FromString[float64]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type float64: strconv.ParseFloat: parsing \"invalid\": invalid syntax") + assert.Zero(t, float64Val) + + complex64Val, err := FromString[complex64]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type complex64: strconv.ParseComplex: parsing \"invalid\": invalid syntax") + assert.Zero(t, complex64Val) + + complex128Val, err := FromString[complex128]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type complex128: strconv.ParseComplex: parsing \"invalid\": invalid syntax") + assert.Zero(t, complex128Val) + + boolVal, err := FromString[bool]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type bool: strconv.ParseBool: parsing \"invalid\": invalid syntax") + assert.False(t, boolVal) + + mapVal, err := FromString[map[string]any]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type map[string]interface {}: invalid character 'i' looking for beginning of value") + assert.Nil(t, mapVal) + + mapVal, err = FromString[map[string]any]("[]") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type map[string]interface {}: json: cannot unmarshal array into Go value of type map[string]interface {}") + assert.Empty(t, mapVal) + + sliceVal, err := FromString[[]string]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type []string: invalid character 'i' looking for beginning of value") + assert.Nil(t, sliceVal) + + sliceVal, err = FromString[[]string]("{}") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type []string: json: cannot unmarshal object into Go value of type []string") + assert.Nil(t, sliceVal) + + structVal, err := FromString[struct{}]("invalid") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type struct {}: invalid character 'i' looking for beginning of value") + assert.Empty(t, structVal) + + structVal, err = FromString[struct{}]("[]") + require.Error(t, err) + require.EqualError(t, err, "failed to convert string to type struct {}: json: cannot unmarshal array into Go value of type struct {}") + assert.Empty(t, structVal) + }) + + t.Run("Valid", func(t *testing.T) { + stringVal, err := FromString[string]("abcd") + require.NoError(t, err) + assert.Equal(t, "abcd", stringVal) + + intVal, err := FromString[int]("1234") + require.NoError(t, err) + assert.Equal(t, 1234, intVal) + + int8Val, err := FromString[int8]("123") + require.NoError(t, err) + assert.Equal(t, int8(123), int8Val) + + int16Val, err := FromString[int16]("1234") + require.NoError(t, err) + assert.Equal(t, int16(1234), int16Val) + + int32Val, err := FromString[int32]("1234") + require.NoError(t, err) + assert.Equal(t, int32(1234), int32Val) + + int64Val, err := FromString[int64]("1234") + require.NoError(t, err) + assert.Equal(t, int64(1234), int64Val) + + uintVal, err := FromString[uint]("1234") + require.NoError(t, err) + assert.Equal(t, uint(1234), uintVal) + + uint8Val, err := FromString[uint8]("123") + require.NoError(t, err) + assert.Equal(t, uint8(123), uint8Val) + + uint16Val, err := FromString[uint16]("1234") + require.NoError(t, err) + assert.Equal(t, uint16(1234), uint16Val) + + uint32Val, err := FromString[uint32]("1234") + require.NoError(t, err) + assert.Equal(t, uint32(1234), uint32Val) + + uint64Val, err := FromString[uint64]("1234") + require.NoError(t, err) + assert.Equal(t, uint64(1234), uint64Val) + + float32Val, err := FromString[float32]("1234.56") + require.NoError(t, err) + assert.InDelta(t, float32(1234.56), float32Val, 0.0001) + + float64Val, err := FromString[float64]("1234.56") + require.NoError(t, err) + assert.InDelta(t, float64(1234.56), float64Val, 0.0001) + + complex64Val, err := FromString[complex64]("1234.56") + require.NoError(t, err) + assert.Equal(t, complex64(1234.56), complex64Val) + + complex128Val, err := FromString[complex128]("1234.56") + require.NoError(t, err) + assert.Equal(t, complex128(1234.56), complex128Val) + + boolVal, err := FromString[bool]("true") + require.NoError(t, err) + assert.True(t, boolVal) + + bytesVal, err := FromString[[]byte]("abcd") + require.NoError(t, err) + assert.Equal(t, []byte("abcd"), bytesVal) + + runesVal, err := FromString[[]rune]("abcd") + require.NoError(t, err) + assert.Equal(t, []rune("abcd"), runesVal) + + builderVal, err := FromString[*strings.Builder]("abcd") + require.NoError(t, err) + assert.Equal(t, "abcd", builderVal.String()) + + arrayVal, err := FromString[[]int]("[1,2,3,4]") + require.NoError(t, err) + assert.Equal(t, []int{1, 2, 3, 4}, arrayVal) + + mapVal, err := FromString[map[string]int](`{"a":1,"b":2,"c":3,"d":4}`) + require.NoError(t, err) + assert.Equal(t, map[string]int{"a": 1, "b": 2, "c": 3, "d": 4}, mapVal) + + structVal, err := FromString[struct{ A int }](`{"A":1}`) + require.NoError(t, err) + assert.Equal(t, struct{ A int }{A: 1}, structVal) + }) +} + +func TestFromStringOrEmpty(t *testing.T) { + t.Run("Unsupported", func(t *testing.T) { + assert.Nil(t, FromStringOrEmpty[func()]("")) + assert.Nil(t, FromStringOrEmpty[map[string]any]("")) + assert.Empty(t, FromStringOrEmpty[map[string]any]("")) + assert.Nil(t, FromStringOrEmpty[[]string]("")) + assert.Empty(t, FromStringOrEmpty[[]string]("")) + assert.Empty(t, FromStringOrEmpty[struct{}]("")) + }) + + t.Run("Empty", func(t *testing.T) { + assert.Nil(t, FromStringOrEmpty[func()]("abcd")) + assert.Nil(t, FromStringOrEmpty[map[string]any]("abcd")) + assert.Empty(t, FromStringOrEmpty[map[string]any]("abcd")) + assert.Nil(t, FromStringOrEmpty[[]string]("abcd")) + assert.Empty(t, FromStringOrEmpty[[]string]("abcd")) + assert.Empty(t, FromStringOrEmpty[struct{}]("abcd")) + assert.Equal(t, "", FromStringOrEmpty[string]("")) + assert.Zero(t, FromStringOrEmpty[int]("")) + assert.Zero(t, FromStringOrEmpty[int8]("")) + assert.Zero(t, FromStringOrEmpty[int16]("")) + assert.Zero(t, FromStringOrEmpty[int32]("")) + assert.Zero(t, FromStringOrEmpty[int64]("")) + assert.Zero(t, FromStringOrEmpty[uint]("")) + assert.Zero(t, FromStringOrEmpty[uint8]("")) + assert.Zero(t, FromStringOrEmpty[uint16]("")) + assert.Zero(t, FromStringOrEmpty[uint32]("")) + assert.Zero(t, FromStringOrEmpty[uint64]("")) + assert.Zero(t, FromStringOrEmpty[float32]("")) + assert.Zero(t, FromStringOrEmpty[float64]("")) + assert.Zero(t, FromStringOrEmpty[complex64]("")) + assert.Zero(t, FromStringOrEmpty[complex128]("")) + assert.False(t, FromStringOrEmpty[bool]("")) + assert.Empty(t, FromStringOrEmpty[[]byte]("")) + assert.Empty(t, FromStringOrEmpty[[]rune]("")) + assert.Equal(t, "", FromStringOrEmpty[*strings.Builder]("").String()) + }) + + t.Run("Invalid", func(t *testing.T) { + assert.Zero(t, FromStringOrEmpty[int]("invalid")) + assert.Zero(t, FromStringOrEmpty[int8]("invalid")) + assert.Zero(t, FromStringOrEmpty[int16]("invalid")) + assert.Zero(t, FromStringOrEmpty[int32]("invalid")) + assert.Zero(t, FromStringOrEmpty[int64]("invalid")) + assert.Zero(t, FromStringOrEmpty[uint]("invalid")) + assert.Zero(t, FromStringOrEmpty[uint8]("invalid")) + assert.Zero(t, FromStringOrEmpty[uint16]("invalid")) + assert.Zero(t, FromStringOrEmpty[uint32]("invalid")) + assert.Zero(t, FromStringOrEmpty[uint64]("invalid")) + assert.Zero(t, FromStringOrEmpty[float32]("invalid")) + assert.Zero(t, FromStringOrEmpty[float64]("invalid")) + assert.Zero(t, FromStringOrEmpty[complex64]("invalid")) + assert.Zero(t, FromStringOrEmpty[complex128]("invalid")) + assert.False(t, FromStringOrEmpty[bool]("invalid")) + assert.Empty(t, FromStringOrEmpty[map[string]any]("invalid")) + assert.Empty(t, FromStringOrEmpty[map[string]any]("[]")) + assert.Empty(t, FromStringOrEmpty[[]string]("invalid")) + assert.Empty(t, FromStringOrEmpty[[]string]("{}")) + assert.Empty(t, FromStringOrEmpty[struct{}]("invalid")) + assert.Empty(t, FromStringOrEmpty[struct{}]("[]")) + }) + + t.Run("Valid", func(t *testing.T) { + assert.Equal(t, "abcd", FromStringOrEmpty[string]("abcd")) + assert.Equal(t, 1234, FromStringOrEmpty[int]("1234")) + assert.Equal(t, int8(123), FromStringOrEmpty[int8]("123")) + assert.Equal(t, int16(1234), FromStringOrEmpty[int16]("1234")) + assert.Equal(t, int32(1234), FromStringOrEmpty[int32]("1234")) + assert.Equal(t, int64(1234), FromStringOrEmpty[int64]("1234")) + assert.Equal(t, uint(1234), FromStringOrEmpty[uint]("1234")) + assert.Equal(t, uint8(123), FromStringOrEmpty[uint8]("123")) + assert.Equal(t, uint16(1234), FromStringOrEmpty[uint16]("1234")) + assert.Equal(t, uint32(1234), FromStringOrEmpty[uint32]("1234")) + assert.Equal(t, uint64(1234), FromStringOrEmpty[uint64]("1234")) + assert.InDelta(t, float32(1234.56), FromStringOrEmpty[float32]("1234.56"), 0.0001) + assert.InDelta(t, float64(1234.56), FromStringOrEmpty[float64]("1234.56"), 0.0001) + assert.Equal(t, complex64(1234.56), FromStringOrEmpty[complex64]("1234.56")) + assert.Equal(t, complex128(1234.56), FromStringOrEmpty[complex128]("1234.56")) + assert.True(t, FromStringOrEmpty[bool]("true")) + assert.Equal(t, []byte("abcd"), FromStringOrEmpty[[]byte]("abcd")) + assert.Equal(t, []rune("abcd"), FromStringOrEmpty[[]rune]("abcd")) + assert.Equal(t, "abcd", FromStringOrEmpty[*strings.Builder]("abcd").String()) + }) +}