From 40849b2ae71a60696dc4c0ca586ab4049d85015c Mon Sep 17 00:00:00 2001 From: Noam <69756316+noamsan@users.noreply.github.com> Date: Mon, 5 Dec 2022 17:34:24 +0200 Subject: [PATCH 1/2] add CheckJWTMulti --- middleware.go | 110 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 110 insertions(+) diff --git a/middleware.go b/middleware.go index 89dcd3f7..f8013fed 100644 --- a/middleware.go +++ b/middleware.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "net/http" + "os" ) // ContextKey is the key used in the request @@ -19,6 +20,8 @@ type JWTMiddleware struct { validateOnOptions bool } +type JWTMiddlewares []*JWTMiddleware + // ValidateToken takes in a string JWT and makes sure it is valid and // returns the valid token. If it is not valid it will return nil and // an error message describing why validation failed. @@ -26,6 +29,11 @@ type JWTMiddleware struct { // In the default implementation we can add safe defaults for those. type ValidateToken func(context.Context, string) (interface{}, error) +func IsDebug() bool { + _, exists := os.LookupEnv("DEBUG") + return exists +} + // New constructs a new JWTMiddleware instance with the supplied options. // It requires a ValidateToken function to be passed in, so it can // properly validate tokens. @@ -90,3 +98,105 @@ func (m *JWTMiddleware) CheckJWT(next http.Handler) http.Handler { next.ServeHTTP(w, r) }) } + +// CheckJWTMulti is the main JWTMiddleware function which performs the main logic. It +// is passed a http.Handler which will be called if the JWT passes validation for one +// of the JWTMiddleware configs in a slice. +func (mm JWTMiddlewares) CheckJWTMulti(next http.Handler) http.Handler { + if IsDebug() { + fmt.Println("CheckJWTMulti") + fmt.Println(mm) + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + broken := false + for i := 0; i < len(mm); i++ { + m := mm[i] + if IsDebug() { + fmt.Println("\ncurrent conf:") + fmt.Println(m.validateToken) + } + isLast := true + if (i + 1) == len(mm) { + isLast = true + } else { + isLast = false + } + // If we don't validate on OPTIONS and this is OPTIONS + // then continue onto next without validating. + if !m.validateOnOptions && r.Method == http.MethodOptions { + next.ServeHTTP(w, r) + broken = true + break + } + + token, err := m.tokenExtractor(r) + if err != nil { + // This is not ErrJWTMissing because an error here means that the + // tokenExtractor had an error and _not_ that the token was missing. + m.errorHandler(w, r, fmt.Errorf("error extracting token: %w", err)) + broken = true + break + } + + if token == "" { + // If credentials are optional continue + // onto next without validating. + if m.credentialsOptional { + next.ServeHTTP(w, r) + broken = true + break + } + + if !isLast { + if IsDebug() { + fmt.Println("token empty, but not last m") + } + continue + } else { + if IsDebug() { + fmt.Println("token empty, is last m") + } + } + // Credentials were not optional so we error. + m.errorHandler(w, r, ErrJWTMissing) + broken = true + break + } + + // Validate the token using the token validator. + validToken, err := m.validateToken(r.Context(), token) + if err != nil { + if !isLast { + if IsDebug() { + fmt.Println("\ntoken not valid, but not last m") + } + continue + } else { + if IsDebug() { + fmt.Println("\ntoken not valid, is last m") + } + } + m.errorHandler(w, r, &invalidError{details: err}) + broken = true + break + } + + // No err means we have a valid token, so set + // it into the context and continue onto next. + r = r.Clone(context.WithValue(r.Context(), ContextKey{}, validToken)) + next.ServeHTTP(w, r) + broken = true + break + } + if broken { + if IsDebug() { + fmt.Println("break") + } + return + } else { + if IsDebug() { + fmt.Println("not break") + } + } + }) +} From 8496b0b14efcd467fdba823efefc476ac5ba5e9d Mon Sep 17 00:00:00 2001 From: Noam <69756316+noamsan@users.noreply.github.com> Date: Mon, 5 Dec 2022 19:53:59 +0200 Subject: [PATCH 2/2] Update middleware.go --- middleware.go | 55 +++++---------------------------------------------- 1 file changed, 5 insertions(+), 50 deletions(-) diff --git a/middleware.go b/middleware.go index f8013fed..3d79b639 100644 --- a/middleware.go +++ b/middleware.go @@ -4,7 +4,6 @@ import ( "context" "fmt" "net/http" - "os" ) // ContextKey is the key used in the request @@ -29,11 +28,6 @@ type JWTMiddlewares []*JWTMiddleware // In the default implementation we can add safe defaults for those. type ValidateToken func(context.Context, string) (interface{}, error) -func IsDebug() bool { - _, exists := os.LookupEnv("DEBUG") - return exists -} - // New constructs a new JWTMiddleware instance with the supplied options. // It requires a ValidateToken function to be passed in, so it can // properly validate tokens. @@ -103,18 +97,9 @@ func (m *JWTMiddleware) CheckJWT(next http.Handler) http.Handler { // is passed a http.Handler which will be called if the JWT passes validation for one // of the JWTMiddleware configs in a slice. func (mm JWTMiddlewares) CheckJWTMulti(next http.Handler) http.Handler { - if IsDebug() { - fmt.Println("CheckJWTMulti") - fmt.Println(mm) - } return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - broken := false for i := 0; i < len(mm); i++ { m := mm[i] - if IsDebug() { - fmt.Println("\ncurrent conf:") - fmt.Println(m.validateToken) - } isLast := true if (i + 1) == len(mm) { isLast = true @@ -125,8 +110,7 @@ func (mm JWTMiddlewares) CheckJWTMulti(next http.Handler) http.Handler { // then continue onto next without validating. if !m.validateOnOptions && r.Method == http.MethodOptions { next.ServeHTTP(w, r) - broken = true - break + return } token, err := m.tokenExtractor(r) @@ -134,8 +118,7 @@ func (mm JWTMiddlewares) CheckJWTMulti(next http.Handler) http.Handler { // This is not ErrJWTMissing because an error here means that the // tokenExtractor had an error and _not_ that the token was missing. m.errorHandler(w, r, fmt.Errorf("error extracting token: %w", err)) - broken = true - break + return } if token == "" { @@ -143,60 +126,32 @@ func (mm JWTMiddlewares) CheckJWTMulti(next http.Handler) http.Handler { // onto next without validating. if m.credentialsOptional { next.ServeHTTP(w, r) - broken = true - break + return } if !isLast { - if IsDebug() { - fmt.Println("token empty, but not last m") - } continue - } else { - if IsDebug() { - fmt.Println("token empty, is last m") - } } // Credentials were not optional so we error. m.errorHandler(w, r, ErrJWTMissing) - broken = true - break + return } // Validate the token using the token validator. validToken, err := m.validateToken(r.Context(), token) if err != nil { if !isLast { - if IsDebug() { - fmt.Println("\ntoken not valid, but not last m") - } continue - } else { - if IsDebug() { - fmt.Println("\ntoken not valid, is last m") - } } m.errorHandler(w, r, &invalidError{details: err}) - broken = true - break + return } // No err means we have a valid token, so set // it into the context and continue onto next. r = r.Clone(context.WithValue(r.Context(), ContextKey{}, validToken)) next.ServeHTTP(w, r) - broken = true - break - } - if broken { - if IsDebug() { - fmt.Println("break") - } return - } else { - if IsDebug() { - fmt.Println("not break") - } } }) }