From 35823eb17afa586af6b60f595acd21fee77c2fd8 Mon Sep 17 00:00:00 2001 From: Pavol Ipoth Date: Mon, 27 May 2024 23:29:16 +0200 Subject: [PATCH] Refactor, move common parts to packages --- pkg/keycloak/proxy/handlers.go | 102 +--- pkg/keycloak/proxy/middleware.go | 849 ------------------------------- pkg/keycloak/proxy/misc.go | 98 +--- pkg/keycloak/proxy/oauth.go | 92 ---- pkg/keycloak/proxy/server.go | 45 +- pkg/keycloak/proxy/stores.go | 42 -- pkg/proxy/handlers/handlers.go | 59 +++ pkg/proxy/middleware/base.go | 288 +++++++++++ pkg/proxy/middleware/oauth.go | 321 ++++++++++++ pkg/proxy/middleware/security.go | 201 ++++++++ pkg/proxy/session/token.go | 63 +++ pkg/storage/storage.go | 1 + pkg/storage/store_redis.go | 19 + pkg/testsuite/misc_test.go | 10 +- pkg/utils/token.go | 252 +++++++++ 15 files changed, 1243 insertions(+), 1199 deletions(-) delete mode 100644 pkg/keycloak/proxy/stores.go create mode 100644 pkg/proxy/middleware/base.go create mode 100644 pkg/proxy/middleware/oauth.go create mode 100644 pkg/proxy/middleware/security.go create mode 100644 pkg/utils/token.go diff --git a/pkg/keycloak/proxy/handlers.go b/pkg/keycloak/proxy/handlers.go index 1776ba67..078f93f7 100644 --- a/pkg/keycloak/proxy/handlers.go +++ b/pkg/keycloak/proxy/handlers.go @@ -258,7 +258,7 @@ func oauthCallbackHandler( if enableRefreshTokens && refreshToken != "" { var encrypted string var stdRefreshClaims *jwt.Claims - stdRefreshClaims, err = parseRefreshToken(refreshToken) + stdRefreshClaims, err = utils.ParseRefreshToken(refreshToken) if err != nil { scope.Logger.Error(apperrors.ErrParseRefreshToken.Error(), zap.Error(err)) accessForbidden(writer, req) @@ -384,7 +384,7 @@ func oauthCallbackHandler( func loginHandler( logger *zap.Logger, openIDProviderTimeout time.Duration, - idpClient *gocloak.GoCloak, + httpClient *http.Client, enableLoginHandler bool, newOAuth2Config func(redirectionURL string) *oauth2.Config, getRedirectionURL func(wrt http.ResponseWriter, req *http.Request) string, @@ -416,7 +416,7 @@ func loginHandler( ctx = context.WithValue( ctx, oauth2.HTTPClient, - idpClient.RestyClient().GetClient(), + httpClient, ) if !enableLoginHandler { @@ -507,7 +507,7 @@ func loginHandler( req, writer, accessToken, - GetAccessCookieExpiration(scope.Logger, accessTokenDuration, token.RefreshToken), + session.GetAccessCookieExpiration(scope.Logger, accessTokenDuration, token.RefreshToken), ) if enableIDTokenCookie { @@ -515,7 +515,7 @@ func loginHandler( req, writer, idToken, - GetAccessCookieExpiration(scope.Logger, accessTokenDuration, token.RefreshToken), + session.GetAccessCookieExpiration(scope.Logger, accessTokenDuration, token.RefreshToken), ) } @@ -643,7 +643,7 @@ func logoutHandler( enableLogoutRedirect bool, store storage.Storage, cookManager *cookie.Manager, - idpClient *gocloak.GoCloak, + httpClient *http.Client, accessError func(wrt http.ResponseWriter, req *http.Request) context.Context, GetIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error), ) func(wrt http.ResponseWriter, req *http.Request) { @@ -687,7 +687,7 @@ func logoutHandler( identityToken := user.RawToken //nolint:vetshadow - if refresh, _, err := retrieveRefreshToken( + if refresh, _, err := session.RetrieveRefreshToken( store, cookieRefreshName, encryptionKey, @@ -771,7 +771,6 @@ func logoutHandler( // step: do we have a revocation endpoint? if revocationURL != "" { - client := idpClient.RestyClient().GetClient() // step: add the authentication headers encodedID := url.QueryEscape(clientID) encodedSecret := url.QueryEscape(clientSecret) @@ -795,7 +794,7 @@ func logoutHandler( request.Header.Set("Content-Type", "application/x-www-form-urlencoded") start := time.Now() - response, err := client.Do(request) + response, err := httpClient.Do(request) if err != nil { scope.Logger.Error(apperrors.ErrRevocationReqFailure.Error(), zap.Error(err)) writer.WriteHeader(http.StatusInternalServerError) @@ -834,88 +833,3 @@ func logoutHandler( } } } - -// expirationHandler checks if the token has expired -func expirationHandler( - getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error), - cookieAccessName string, -) func(wrt http.ResponseWriter, req *http.Request) { - return func(wrt http.ResponseWriter, req *http.Request) { - user, err := getIdentity(req, cookieAccessName, "") - if err != nil { - wrt.WriteHeader(http.StatusUnauthorized) - return - } - - if user.IsExpired() { - wrt.WriteHeader(http.StatusUnauthorized) - return - } - - wrt.WriteHeader(http.StatusOK) - } -} - -// tokenHandler display access token to screen -func tokenHandler( - getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error), - cookieAccessName string, - accessError func(wrt http.ResponseWriter, req *http.Request) context.Context, -) func(wrt http.ResponseWriter, req *http.Request) { - return func(wrt http.ResponseWriter, req *http.Request) { - user, err := getIdentity(req, cookieAccessName, "") - if err != nil { - accessError(wrt, req) - return - } - - token, err := jwt.ParseSigned(user.RawToken) - if err != nil { - accessError(wrt, req) - return - } - - jsonMap := make(map[string]interface{}) - err = token.UnsafeClaimsWithoutVerification(&jsonMap) - if err != nil { - accessError(wrt, req) - return - } - - result, err := json.Marshal(jsonMap) - if err != nil { - accessError(wrt, req) - return - } - - wrt.Header().Set("Content-Type", "application/json") - _, _ = wrt.Write(result) - } -} - -// retrieveRefreshToken retrieves the refresh token from store or cookie -func retrieveRefreshToken( - store storage.Storage, - cookieRefreshName string, - encryptionKey string, - req *http.Request, - user *models.UserContext, -) (string, string, error) { - var token string - var err error - - switch store != nil { - case true: - token, err = GetRefreshTokenFromStore(req.Context(), store, user.RawToken) - default: - token, err = session.GetRefreshTokenFromCookie(req, cookieRefreshName) - } - - if err != nil { - return token, "", err - } - - encrypted := token // returns encrypted, avoids encoding twice - token, err = encryption.DecodeText(token, encryptionKey) - return token, encrypted, err -} diff --git a/pkg/keycloak/proxy/middleware.go b/pkg/keycloak/proxy/middleware.go index e14efe6e..1c241325 100644 --- a/pkg/keycloak/proxy/middleware.go +++ b/pkg/keycloak/proxy/middleware.go @@ -18,456 +18,24 @@ package proxy import ( "bytes" "context" - "errors" "fmt" "io" "net/http" "net/url" - "regexp" - "strconv" - "strings" "time" "github.com/Nerzal/gocloak/v12" oidc3 "github.com/coreos/go-oidc/v3/oidc" - "github.com/go-jose/go-jose/v3/jwt" - uuid "github.com/gofrs/uuid" "github.com/gogatekeeper/gatekeeper/pkg/authorization" "github.com/gogatekeeper/gatekeeper/pkg/constant" "github.com/gogatekeeper/gatekeeper/pkg/encryption" "github.com/gogatekeeper/gatekeeper/pkg/proxy/cookie" - "github.com/gogatekeeper/gatekeeper/pkg/proxy/metrics" "github.com/gogatekeeper/gatekeeper/pkg/proxy/models" - "github.com/gogatekeeper/gatekeeper/pkg/proxy/session" - "github.com/gogatekeeper/gatekeeper/pkg/storage" - "github.com/gogatekeeper/gatekeeper/pkg/utils" - "golang.org/x/oauth2" - "github.com/PuerkitoBio/purell" - "github.com/go-chi/chi/v5/middleware" "github.com/gogatekeeper/gatekeeper/pkg/apperrors" - "github.com/unrolled/secure" "go.uber.org/zap" - "go.uber.org/zap/zapcore" ) -const ( - // normalizeFlags is the options to purell - normalizeFlags purell.NormalizationFlags = purell.FlagRemoveDotSegments | purell.FlagRemoveDuplicateSlashes -) - -// entrypointMiddleware is custom filtering for incoming requests -func entrypointMiddleware(logger *zap.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - // @step: create a context for the request - scope := &models.RequestScope{} - // Save the exact formatting of the incoming request so we can use it later - scope.Path = req.URL.Path - scope.RawPath = req.URL.RawPath - scope.Logger = logger - - // We want to Normalize the URL so that we can more easily and accurately - // parse it to apply resource protection rules. - purell.NormalizeURL(req.URL, normalizeFlags) - - // ensure we have a slash in the url - if !strings.HasPrefix(req.URL.Path, "/") { - req.URL.Path = "/" + req.URL.Path - } - req.URL.RawPath = req.URL.EscapedPath() - - resp := middleware.NewWrapResponseWriter(wrt, 1) - start := time.Now() - // All the processing, including forwarding the request upstream and getting the response, - // happens here in this chain. - next.ServeHTTP(resp, req.WithContext(context.WithValue(req.Context(), constant.ContextScopeName, scope))) - - // @metric record the time taken then response code - metrics.LatencyMetric.Observe(time.Since(start).Seconds()) - metrics.StatusMetric.WithLabelValues(strconv.Itoa(resp.Status()), req.Method).Inc() - - // place back the original uri for any later consumers - req.URL.Path = scope.Path - req.URL.RawPath = scope.RawPath - }) - } -} - -// requestIDMiddleware is responsible for adding a request id if none found -func requestIDMiddleware(header string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - if v := req.Header.Get(header); v == "" { - uuid, err := uuid.NewV1() - if err != nil { - wrt.WriteHeader(http.StatusInternalServerError) - } - req.Header.Set(header, uuid.String()) - } - - next.ServeHTTP(wrt, req) - }) - } -} - -// loggingMiddleware is a custom http logger -func loggingMiddleware( - logger *zap.Logger, - verbose bool, -) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - start := time.Now() - resp, assertOk := w.(middleware.WrapResponseWriter) - if !assertOk { - logger.Error(apperrors.ErrAssertionFailed.Error()) - return - } - - scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) - if !assertOk { - logger.Error(apperrors.ErrAssertionFailed.Error()) - return - } - - if verbose { - requestLogger := logger.With( - zap.Any("headers", req.Header), - zap.String("path", req.URL.Path), - zap.String("method", req.Method), - ) - scope.Logger = requestLogger - } - - next.ServeHTTP(resp, req) - - addr := utils.RealIP(req) - - if req.URL.Path == req.URL.RawPath || req.URL.RawPath == "" { - scope.Logger.Info("client request", - zap.Duration("latency", time.Since(start)), - zap.Int("status", resp.Status()), - zap.Int("bytes", resp.BytesWritten()), - zap.String("client_ip", addr), - zap.String("remote_addr", req.RemoteAddr), - zap.String("method", req.Method), - zap.String("path", req.URL.Path)) - } else { - scope.Logger.Info("client request", - zap.Duration("latency", time.Since(start)), - zap.Int("status", resp.Status()), - zap.Int("bytes", resp.BytesWritten()), - zap.String("client_ip", addr), - zap.String("remote_addr", req.RemoteAddr), - zap.String("method", req.Method), - zap.String("path", req.URL.Path), - zap.String("raw path", req.URL.RawPath)) - } - }) - } -} - -/* - authenticationMiddleware is responsible for verifying the access token -*/ -//nolint:funlen,cyclop -func authenticationMiddleware( - logger *zap.Logger, - cookieAccessName string, - cookieRefreshName string, - getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error), - idpClient *gocloak.GoCloak, - enableIDPSessionCheck bool, - provider *oidc3.Provider, - skipTokenVerification bool, - clientID string, - skipAccessTokenClientIDCheck bool, - skipAccessTokenIssuerCheck bool, - accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context, - enableRefreshTokens bool, - redirectionURL string, - cookMgr *cookie.Manager, - enableEncryptedToken bool, - forceEncryptedCookie bool, - encryptionKey string, - redirectToAuthorization func(wrt http.ResponseWriter, req *http.Request) context.Context, - newOAuth2Config func(redirectionURL string) *oauth2.Config, - store storage.Storage, - accessTokenDuration time.Duration, -) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) - if !assertOk { - logger.Error(apperrors.ErrAssertionFailed.Error()) - return - } - - clientIP := utils.RealIP(req) - scope.Logger.Debug("authentication middleware") - - // grab the user identity from the request - user, err := getIdentity(req, cookieAccessName, "") - if err != nil { - scope.Logger.Error(err.Error()) - redirectToAuthorization(wrt, req) - return - } - - scope.Identity = user - ctx := context.WithValue(req.Context(), constant.ContextScopeName, scope) - lLog := scope.Logger.With( - zap.String("client_ip", clientIP), - zap.String("remote_addr", req.RemoteAddr), - zap.String("username", user.Name), - zap.String("sub", user.ID), - zap.String("expired_on", user.ExpiresAt.String()), - ) - - // IMPORTANT: For all calls with go-oidc library be aware - // that calls accept context parameter and you have to pass - // client from provider through this parameter, although - // provider is already configured with client!!! - // https://github.com/coreos/go-oidc/issues/402 - httpClient := idpClient.RestyClient().GetClient() - oidcLibCtx := context.WithValue(ctx, oauth2.HTTPClient, httpClient) - - if enableIDPSessionCheck { - tokenSource := oauth2.StaticTokenSource( - &oauth2.Token{AccessToken: user.RawToken}, - ) - _, err := provider.UserInfo(oidcLibCtx, tokenSource) - if err != nil { - scope.Logger.Error(err.Error()) - redirectToAuthorization(wrt, req) - return - } - } - - // step: skip if we are running skip-token-verification - if skipTokenVerification { - scope.Logger.Warn( - "skip token verification enabled, " + - "skipping verification - TESTING ONLY", - ) - - if user.IsExpired() { - lLog.Error(apperrors.ErrSessionExpiredVerifyOff.Error()) - redirectToAuthorization(wrt, req) - return - } - } else { //nolint:gocritic - _, err := verifyToken( - ctx, - provider, - user.RawToken, - clientID, - skipAccessTokenClientIDCheck, - skipAccessTokenIssuerCheck, - ) - if err != nil { - if errors.Is(err, apperrors.ErrTokenSignature) { - lLog.Error( - apperrors.ErrAccTokenVerifyFailure.Error(), - zap.Error(err), - ) - accessForbidden(wrt, req) - return - } - - if !strings.Contains(err.Error(), "token is expired") { - lLog.Error( - apperrors.ErrAccTokenVerifyFailure.Error(), - zap.Error(err), - ) - accessForbidden(wrt, req) - return - } - - if !enableRefreshTokens { - lLog.Error(apperrors.ErrSessionExpiredRefreshOff.Error()) - redirectToAuthorization(wrt, req) - return - } - - lLog.Info("accces token for user has expired, attemping to refresh the token") - - // step: check if the user has refresh token - refresh, _, err := retrieveRefreshToken( - store, - cookieRefreshName, - encryptionKey, - req.WithContext(ctx), - user, - ) - if err != nil { - scope.Logger.Error( - apperrors.ErrRefreshTokenNotFound.Error(), - zap.Error(err), - ) - redirectToAuthorization(wrt, req) - return - } - - var stdRefreshClaims *jwt.Claims - stdRefreshClaims, err = parseRefreshToken(refresh) - if err != nil { - lLog.Error( - apperrors.ErrParseRefreshToken.Error(), - zap.Error(err), - ) - accessForbidden(wrt, req) - return - } - if user.ID != stdRefreshClaims.Subject { - lLog.Error( - apperrors.ErrAccRefreshTokenMismatch.Error(), - zap.Error(err), - ) - accessForbidden(wrt, req) - return - } - - // attempt to refresh the access token, possibly with a renewed refresh token - // - // NOTE: atm, this does not retrieve explicit refresh token expiry from oauth2, - // and take identity expiry instead: with keycloak, they are the same and equal to - // "SSO session idle" keycloak setting. - // - // exp: expiration of the access token - // expiresIn: expiration of the ID token - conf := newOAuth2Config(redirectionURL) - - lLog.Debug( - "issuing refresh token request", - zap.String("current access token", user.RawToken), - zap.String("refresh token", refresh), - ) - - newAccToken, newRawAccToken, newRefreshToken, accessExpiresAt, refreshExpiresIn, err := getRefreshedToken(ctx, conf, httpClient, refresh) - if err != nil { - switch err { - case apperrors.ErrRefreshTokenExpired: - lLog.Warn("refresh token has expired, cannot retrieve access token") - cookMgr.ClearAllCookies(req.WithContext(ctx), wrt) - default: - lLog.Debug( - apperrors.ErrAccTokenRefreshFailure.Error(), - zap.String("access token", user.RawToken), - ) - lLog.Error( - apperrors.ErrAccTokenRefreshFailure.Error(), - zap.Error(err), - ) - } - - redirectToAuthorization(wrt, req) - return - } - - lLog.Debug( - "info about tokens after refreshing", - zap.String("new access token", newRawAccToken), - zap.String("new refresh token", newRefreshToken), - ) - - accessExpiresIn := time.Until(accessExpiresAt) - - if newRefreshToken != "" { - refresh = newRefreshToken - } - - if refreshExpiresIn == 0 { - // refresh token expiry claims not available: try to parse refresh token - refreshExpiresIn = GetAccessCookieExpiration(lLog, accessTokenDuration, refresh) - } - - lLog.Info( - "injecting the refreshed access token cookie", - zap.Duration("refresh_expires_in", refreshExpiresIn), - zap.Duration("expires_in", accessExpiresIn), - ) - - accessToken := newRawAccToken - - if enableEncryptedToken || forceEncryptedCookie { - if accessToken, err = encryption.EncodeText(accessToken, encryptionKey); err != nil { - lLog.Error( - apperrors.ErrEncryptAccToken.Error(), - zap.Error(err), - ) - accessForbidden(wrt, req) - return - } - } - - // step: inject the refreshed access token - cookMgr.DropAccessTokenCookie(req.WithContext(ctx), wrt, accessToken, accessExpiresIn) - - // update the with the new access token and inject into the context - newUser, err := session.ExtractIdentity(&newAccToken) - if err != nil { - lLog.Error(err.Error()) - accessForbidden(wrt, req) - return - } - - // step: inject the renewed refresh token - if newRefreshToken != "" { - lLog.Debug( - "renew refresh cookie with new refresh token", - zap.Duration("refresh_expires_in", refreshExpiresIn), - ) - var encryptedRefreshToken string - encryptedRefreshToken, err = encryption.EncodeText(newRefreshToken, encryptionKey) - if err != nil { - lLog.Error( - apperrors.ErrEncryptRefreshToken.Error(), - zap.Error(err), - ) - wrt.WriteHeader(http.StatusInternalServerError) - return - } - - if store != nil { - go func(ctx context.Context, old string, newToken string, encrypted string) { - ctxx, cancel := context.WithCancel(ctx) - defer cancel() - if err = store.Delete(ctxx, utils.GetHashKey(old)); err != nil { - lLog.Error( - apperrors.ErrDelTokFromStore.Error(), - zap.Error(err), - ) - } - - if err = store.Set(ctxx, utils.GetHashKey(newToken), encrypted, refreshExpiresIn); err != nil { - lLog.Error( - apperrors.ErrSaveTokToStore.Error(), - zap.Error(err), - ) - return - } - }(ctx, user.RawToken, newRawAccToken, encryptedRefreshToken) - } else { - cookMgr.DropRefreshTokenCookie(req.WithContext(ctx), wrt, encryptedRefreshToken, refreshExpiresIn) - } - } - - // IMPORTANT: on this rely other middlewares, must be refreshed - // with new identity! - newUser.RawToken = newRawAccToken - scope.Identity = newUser - ctx = context.WithValue(req.Context(), constant.ContextScopeName, scope) - } - } - - *req = *(req.WithContext(ctx)) - next.ServeHTTP(wrt, req) - }) - } -} - /* authorizationMiddleware is responsible for verifying permissions in access_token/uma_token */ @@ -682,420 +250,3 @@ func authorizationMiddleware( }) } } - -// checkClaim checks whether claim in userContext matches claimName, match. It can be String or Strings claim. -// -//nolint:cyclop -func checkClaim( - logger *zap.Logger, - user *models.UserContext, - claimName string, - match *regexp.Regexp, - resourceURL string, -) bool { - errFields := []zapcore.Field{ - zap.String("claim", claimName), - zap.String("access", "denied"), - zap.String("email", user.Email), - zap.String("resource", resourceURL), - } - - lLog := logger.With(errFields...) - if _, found := user.Claims[claimName]; !found { - lLog.Warn("the token does not have the claim") - return false - } - - switch user.Claims[claimName].(type) { - case []interface{}: - claims, assertOk := user.Claims[claimName].([]interface{}) - if !assertOk { - logger.Error(apperrors.ErrAssertionFailed.Error()) - return false - } - - for _, v := range claims { - value, ok := v.(string) - if !ok { - lLog.Warn( - "Problem while asserting claim", - zap.String( - "issued", - fmt.Sprintf("%v", user.Claims[claimName]), - ), - zap.String("required", match.String()), - ) - - return false - } - - if match.MatchString(value) { - return true - } - } - - lLog.Warn( - "claim requirement does not match any element claim group in token", - zap.String("issued", fmt.Sprintf("%v", user.Claims[claimName])), - zap.String("required", match.String()), - ) - - return false - case string: - claims, assertOk := user.Claims[claimName].(string) - if !assertOk { - logger.Error(apperrors.ErrAssertionFailed.Error()) - return false - } - if match.MatchString(claims) { - return true - } - - lLog.Warn( - "claim requirement does not match claim in token", - zap.String("issued", claims), - zap.String("required", match.String()), - ) - - return false - default: - logger.Error( - "unable to extract the claim from token not string or array of strings", - ) - } - - lLog.Warn("unexpected error") - return false -} - -// admissionMiddleware is responsible for checking the access token against the protected resource -// -//nolint:cyclop -func admissionMiddleware( - logger *zap.Logger, - resource *authorization.Resource, - matchClaims map[string]string, - accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context, -) func(http.Handler) http.Handler { - claimMatches := make(map[string]*regexp.Regexp) - for k, v := range matchClaims { - claimMatches[k] = regexp.MustCompile(v) - } - - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - // we don't need to continue is a decision has been made - scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) - if !assertOk { - logger.Error(apperrors.ErrAssertionFailed.Error()) - return - } - if scope.AccessDenied { - next.ServeHTTP(wrt, req) - return - } - - user := scope.Identity - lLog := scope.Logger.With( - zap.String("access", "denied"), - zap.String("email", user.Email), - zap.String("resource", resource.URL), - ) - - // @step: we need to check the roles - if !utils.HasAccess(resource.Roles, user.Roles, !resource.RequireAnyRole) { - lLog.Warn("access denied, invalid roles", - zap.String("roles", resource.GetRoles())) - accessForbidden(wrt, req) - return - } - - if len(resource.Headers) > 0 { - var reqHeaders []string - - for _, resVal := range resource.Headers { - resVals := strings.Split(resVal, ":") - name := resVals[0] - canonName := http.CanonicalHeaderKey(name) - values, ok := req.Header[canonName] - if !ok { - lLog.Warn("access denied, invalid headers", - zap.String("headers", resource.GetHeaders())) - accessForbidden(wrt, req) - return - } - - for _, value := range values { - headVal := fmt.Sprintf( - "%s:%s", - strings.ToLower(name), - strings.ToLower(value), - ) - reqHeaders = append(reqHeaders, headVal) - } - } - - // @step: we need to check the headers - if !utils.HasAccess(resource.Headers, reqHeaders, true) { - lLog.Warn("access denied, invalid headers", - zap.String("headers", resource.GetHeaders())) - accessForbidden(wrt, req) - return - } - } - - // @step: check if we have any groups, the groups are there - if !utils.HasAccess(resource.Groups, user.Groups, false) { - lLog.Warn("access denied, invalid groups", - zap.String("groups", strings.Join(resource.Groups, ","))) - accessForbidden(wrt, req) - return - } - - // step: if we have any claim matching, lets validate the tokens has the claims - for claimName, match := range claimMatches { - if !checkClaim(scope.Logger, user, claimName, match, resource.URL) { - accessForbidden(wrt, req) - return - } - } - - scope.Logger.Debug("access permitted to resource", - zap.String("access", "permitted"), - zap.String("email", user.Email), - zap.Duration("expires", time.Until(user.ExpiresAt)), - zap.String("resource", resource.URL)) - - next.ServeHTTP(wrt, req) - }) - } -} - -// responseHeaderMiddleware is responsible for adding response headers -func responseHeaderMiddleware(headers map[string]string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - // @step: inject any custom response headers - for k, v := range headers { - wrt.Header().Set(k, v) - } - - next.ServeHTTP(wrt, req) - }) - } -} - -// identityHeadersMiddleware is responsible for adding the authentication headers to upstream -// -//nolint:cyclop -func identityHeadersMiddleware( - logger *zap.Logger, - custom []string, - cookieAccessName string, - cookieRefreshName string, - noProxy bool, - enableTokenHeader bool, - enableAuthzHeader bool, - enableAuthzCookies bool, -) func(http.Handler) http.Handler { - customClaims := make(map[string]string) - const minSliceLength int = 1 - cookieFilter := []string{cookieAccessName, cookieRefreshName} - - for _, val := range custom { - xslices := strings.Split(val, "|") - val = xslices[0] - if len(xslices) > minSliceLength { - customClaims[val] = utils.ToHeader(xslices[1]) - } else { - customClaims[val] = fmt.Sprintf("X-Auth-%s", utils.ToHeader(val)) - } - } - - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) - if !assertOk { - logger.Error(apperrors.ErrAssertionFailed.Error()) - return - } - - var headers http.Header - if noProxy { - headers = wrt.Header() - } else { - headers = req.Header - } - - if scope.Identity != nil { - user := scope.Identity - headers.Set("X-Auth-Audience", strings.Join(user.Audiences, ",")) - headers.Set("X-Auth-Email", user.Email) - headers.Set("X-Auth-ExpiresIn", user.ExpiresAt.String()) - headers.Set("X-Auth-Groups", strings.Join(user.Groups, ",")) - headers.Set("X-Auth-Roles", strings.Join(user.Roles, ",")) - headers.Set("X-Auth-Subject", user.ID) - headers.Set("X-Auth-Userid", user.Name) - headers.Set("X-Auth-Username", user.Name) - - // should we add the token header? - if enableTokenHeader { - headers.Set("X-Auth-Token", user.RawToken) - } - // add the authorization header if requested - if enableAuthzHeader { - headers.Set("Authorization", fmt.Sprintf("Bearer %s", user.RawToken)) - } - // are we filtering out the cookies - if !enableAuthzCookies { - _ = cookie.FilterCookies(req, cookieFilter) - } - // inject any custom claims - for claim, header := range customClaims { - if claim, found := user.Claims[claim]; found { - headers.Set(header, fmt.Sprintf("%v", claim)) - } - } - } - - next.ServeHTTP(wrt, req) - }) - } -} - -// securityMiddleware performs numerous security checks on the request -func securityMiddleware( - logger *zap.Logger, - allowedHosts []string, - browserXSSFilter bool, - contentSecurityPolicy string, - contentTypeNosniff bool, - frameDeny bool, - sslRedirect bool, - accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context, -) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - logger.Info("enabling the security filter middleware") - - secure := secure.New(secure.Options{ - AllowedHosts: allowedHosts, - BrowserXssFilter: browserXSSFilter, - ContentSecurityPolicy: contentSecurityPolicy, - ContentTypeNosniff: contentTypeNosniff, - FrameDeny: frameDeny, - SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"}, - SSLRedirect: sslRedirect, - }) - - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) - if !assertOk { - logger.Error(apperrors.ErrAssertionFailed.Error()) - return - } - - if err := secure.Process(wrt, req); err != nil { - scope.Logger.Warn("failed security middleware", zap.Error(err)) - accessForbidden(wrt, req) - return - } - - next.ServeHTTP(wrt, req) - }) - } -} - -// methodCheck middleware -func methodCheckMiddleware(logger *zap.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - logger.Info("enabling the method check middleware") - - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - if !utils.IsValidHTTPMethod(req.Method) { - logger.Warn("method not implemented ", zap.String("method", req.Method)) - wrt.WriteHeader(http.StatusNotImplemented) - return - } - - next.ServeHTTP(wrt, req) - }) - } -} - -// proxyDenyMiddleware just block everything -func proxyDenyMiddleware(logger *zap.Logger) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - ctxVal := req.Context().Value(constant.ContextScopeName) - - var scope *models.RequestScope - if ctxVal == nil { - scope = &models.RequestScope{} - } else { - var assertOk bool - scope, assertOk = ctxVal.(*models.RequestScope) - if !assertOk { - logger.Error(apperrors.ErrAssertionFailed.Error()) - return - } - } - - scope.AccessDenied = true - // update the request context - ctx := context.WithValue(req.Context(), constant.ContextScopeName, scope) - - next.ServeHTTP(wrt, req.WithContext(ctx)) - }) - } -} - -// denyMiddleware -func denyMiddleware( - logger *zap.Logger, - accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context, -) func(http.Handler) http.Handler { - return func(_ http.Handler) http.Handler { - logger.Info("enabling the deny middleware") - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - accessForbidden(wrt, req) - }) - } -} - -// hmacMiddleware verifies hmac -func hmacMiddleware(logger *zap.Logger, encKey string) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { - scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) - if !assertOk { - logger.Error(apperrors.ErrAssertionFailed.Error()) - return - } - - if scope.AccessDenied { - next.ServeHTTP(wrt, req) - return - } - - expectedMAC := req.Header.Get(constant.HeaderXHMAC) - if expectedMAC == "" { - logger.Debug(apperrors.ErrHmacHeaderEmpty.Error()) - wrt.WriteHeader(http.StatusBadRequest) - return - } - - reqHmac, err := utils.GenerateHmac(req, encKey) - if err != nil { - logger.Error(err.Error()) - } - - if reqHmac != expectedMAC { - logger.Debug(apperrors.ErrHmacMismatch.Error()) - wrt.WriteHeader(http.StatusBadRequest) - return - } - - next.ServeHTTP(wrt, req) - }) - } -} diff --git a/pkg/keycloak/proxy/misc.go b/pkg/keycloak/proxy/misc.go index fb1589f3..9b40e51a 100644 --- a/pkg/keycloak/proxy/misc.go +++ b/pkg/keycloak/proxy/misc.go @@ -235,40 +235,6 @@ func redirectToAuthorization( } } -// GetAccessCookieExpiration calculates the expiration of the access token cookie -func GetAccessCookieExpiration( - logger *zap.Logger, - accessTokenDuration time.Duration, - refresh string, -) time.Duration { - // notes: by default the duration of the access token will be the configuration option, if - // however we can decode the refresh token, we will set the duration to the duration of the - // refresh token - duration := accessTokenDuration - - webToken, err := jwt.ParseSigned(refresh) - if err != nil { - logger.Error("unable to parse token") - } - - if ident, err := session.ExtractIdentity(webToken); err == nil { - delta := time.Until(ident.ExpiresAt) - - if delta > 0 { - duration = delta - } - - logger.Debug( - "parsed refresh token with new duration", - zap.Duration("new duration", delta), - ) - } else { - logger.Debug("refresh token is opaque and cannot be used to extend calculated duration") - } - - return duration -} - //nolint:cyclop func getPAT( logger *zap.Logger, @@ -411,7 +377,7 @@ func WithUMAIdentity( return authorization.DeniedAuthz, apperrors.ErrAccessMismatchUmaToken } - _, err = verifyToken( + _, err = utils.VerifyToken( req.Context(), provider, umaUser.RawToken, @@ -606,7 +572,7 @@ func verifyOIDCTokens( var oAccToken *oidc3.IDToken var err error - oIDToken, err = verifyToken(ctx, provider, rawIDToken, clientID, false, false) + oIDToken, err = utils.VerifyToken(ctx, provider, rawIDToken, clientID, false, false) if err != nil { return nil, nil, errors.Join(apperrors.ErrVerifyIDToken, err) } @@ -621,7 +587,7 @@ func verifyOIDCTokens( } } - oAccToken, err = verifyToken( + oAccToken, err = utils.VerifyToken( ctx, provider, rawAccessToken, @@ -636,49 +602,6 @@ func verifyOIDCTokens( return oAccToken, oIDToken, nil } -func verifyToken( - ctx context.Context, - provider *oidc3.Provider, - rawToken string, - clientID string, - skipClientIDCheck bool, - skipIssuerCheck bool, -) (*oidc3.IDToken, error) { - // This verifier with this configuration checks only signatures - // we want to know if we are using valid token - // bad is that Verify method doesn't check first signatures, so - // we have to do it like this - verifier := provider.Verifier( - &oidc3.Config{ - ClientID: clientID, - SkipClientIDCheck: true, - SkipIssuerCheck: true, - SkipExpiryCheck: true, - }, - ) - _, err := verifier.Verify(ctx, rawToken) - if err != nil { - return nil, errors.Join(apperrors.ErrTokenSignature, err) - } - - // Now doing expiration check - verifier = provider.Verifier( - &oidc3.Config{ - ClientID: clientID, - SkipClientIDCheck: skipClientIDCheck, - SkipIssuerCheck: skipIssuerCheck, - SkipExpiryCheck: false, - }, - ) - - oToken, err := verifier.Verify(ctx, rawToken) - if err != nil { - return nil, err - } - - return oToken, nil -} - func encryptToken( scope *models.RequestScope, rawToken string, @@ -762,18 +685,3 @@ func refreshUmaToken( umaUser.RawToken = tok.AccessToken return umaUser, nil } - -func parseRefreshToken(rawRefreshToken string) (*jwt.Claims, error) { - refreshToken, err := jwt.ParseSigned(rawRefreshToken) - if err != nil { - return nil, err - } - - stdRefreshClaims := &jwt.Claims{} - err = refreshToken.UnsafeClaimsWithoutVerification(stdRefreshClaims) - if err != nil { - return nil, err - } - - return stdRefreshClaims, nil -} diff --git a/pkg/keycloak/proxy/oauth.go b/pkg/keycloak/proxy/oauth.go index 46fd009d..01053576 100644 --- a/pkg/keycloak/proxy/oauth.go +++ b/pkg/keycloak/proxy/oauth.go @@ -17,7 +17,6 @@ package proxy import ( "net/http" - "strings" "time" "github.com/gogatekeeper/gatekeeper/pkg/apperrors" @@ -25,8 +24,6 @@ import ( "github.com/grokify/go-pkce" "golang.org/x/net/context" "golang.org/x/oauth2" - - "github.com/go-jose/go-jose/v3/jwt" ) // newOAuth2Config returns a oauth2 config @@ -55,95 +52,6 @@ func newOAuth2Config( } } -// getRefreshedToken attempts to refresh the access token, returning the parsed token, optionally with a renewed -// refresh token and the time the access and refresh tokens expire -// -// NOTE: we may be able to extract the specific (non-standard) claim refresh_expires_in and refresh_expires -// from response.RawBody. -// When not available, keycloak provides us with the same (for now) expiry value for ID token. -func getRefreshedToken( - ctx context.Context, - conf *oauth2.Config, - httpClient *http.Client, - oldRefreshToken string, -) (jwt.JSONWebToken, string, string, time.Time, time.Duration, error) { - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - start := time.Now() - - tkn, err := conf.TokenSource(ctx, &oauth2.Token{RefreshToken: oldRefreshToken}).Token() - if err != nil { - if strings.Contains(err.Error(), "invalid_grant") { - return jwt.JSONWebToken{}, - "", - "", - time.Time{}, - time.Duration(0), - apperrors.ErrRefreshTokenExpired - } - return jwt.JSONWebToken{}, - "", - "", - time.Time{}, - time.Duration(0), - err - } - - taken := time.Since(start).Seconds() - metrics.OauthTokensMetric.WithLabelValues("renew").Inc() - metrics.OauthLatencyMetric.WithLabelValues("renew").Observe(taken) - - token, err := jwt.ParseSigned(tkn.AccessToken) - if err != nil { - return jwt.JSONWebToken{}, - "", - "", - time.Time{}, - time.Duration(0), - err - } - - refreshToken, err := jwt.ParseSigned(tkn.RefreshToken) - if err != nil { - return jwt.JSONWebToken{}, - "", - "", - time.Time{}, - time.Duration(0), - err - } - - stdClaims := &jwt.Claims{} - err = token.UnsafeClaimsWithoutVerification(stdClaims) - if err != nil { - return jwt.JSONWebToken{}, - "", - "", - time.Time{}, - time.Duration(0), - err - } - - refreshStdClaims := &jwt.Claims{} - err = refreshToken.UnsafeClaimsWithoutVerification(refreshStdClaims) - if err != nil { - return jwt.JSONWebToken{}, - "", - "", - time.Time{}, - time.Duration(0), - err - } - - refreshExpiresIn := time.Until(refreshStdClaims.Expiry.Time()) - - return *token, - tkn.AccessToken, - tkn.RefreshToken, - stdClaims.Expiry.Time(), - refreshExpiresIn, - nil -} - // exchangeAuthenticationCode exchanges the authentication code with the oauth server for a access token func exchangeAuthenticationCode( ctx context.Context, diff --git a/pkg/keycloak/proxy/server.go b/pkg/keycloak/proxy/server.go index 78c7862d..9df08f69 100644 --- a/pkg/keycloak/proxy/server.go +++ b/pkg/keycloak/proxy/server.go @@ -56,6 +56,7 @@ import ( proxycore "github.com/gogatekeeper/gatekeeper/pkg/proxy/core" "github.com/gogatekeeper/gatekeeper/pkg/proxy/handlers" "github.com/gogatekeeper/gatekeeper/pkg/proxy/metrics" + gmiddleware "github.com/gogatekeeper/gatekeeper/pkg/proxy/middleware" "github.com/gogatekeeper/gatekeeper/pkg/proxy/session" "github.com/gogatekeeper/gatekeeper/pkg/storage" "github.com/gogatekeeper/gatekeeper/pkg/utils" @@ -224,7 +225,7 @@ func (r *OauthProxy) useDefaultStack(engine chi.Router) { engine.NotFound(handlers.EmptyHandler) if r.Config.EnableDefaultDeny || r.Config.EnableDefaultDenyStrict { - engine.Use(methodCheckMiddleware(r.Log)) + engine.Use(gmiddleware.MethodCheckMiddleware(r.Log)) } else { engine.MethodNotAllowed(handlers.EmptyHandler) } @@ -234,7 +235,7 @@ func (r *OauthProxy) useDefaultStack(engine chi.Router) { // @check if the request tracking id middleware is enabled if r.Config.EnableRequestID { r.Log.Info("enabled the correlation request id middleware") - engine.Use(requestIDMiddleware(r.Config.RequestIDHeader)) + engine.Use(gmiddleware.RequestIDMiddleware(r.Config.RequestIDHeader)) } if r.Config.EnableCompression { @@ -242,10 +243,10 @@ func (r *OauthProxy) useDefaultStack(engine chi.Router) { } // @step: enable the entrypoint middleware - engine.Use(entrypointMiddleware(r.Log)) + engine.Use(gmiddleware.EntrypointMiddleware(r.Log)) if r.Config.EnableLogging { - engine.Use(loggingMiddleware(r.Log, r.Config.Verbose)) + engine.Use(gmiddleware.LoggingMiddleware(r.Log, r.Config.Verbose)) } // step: load the templates if any @@ -281,7 +282,7 @@ func (r *OauthProxy) useDefaultStack(engine chi.Router) { if r.Config.EnableSecurityFilter { engine.Use( - securityMiddleware( + gmiddleware.SecurityMiddleware( r.Log, r.Config.Hostnames, r.Config.EnableBrowserXSSFilter, @@ -371,7 +372,7 @@ func (r *OauthProxy) CreateReverseProxy() error { ) if r.Config.EnableHmac { - engine.Use(hmacMiddleware(r.Log, r.Config.EncryptionKey)) + engine.Use(gmiddleware.HmacMiddleware(r.Log, r.Config.EncryptionKey)) } // @step: configure CORS middleware @@ -404,7 +405,7 @@ func (r *OauthProxy) CreateReverseProxy() error { r.Router = engine if len(r.Config.ResponseHeaders) > 0 { - engine.Use(responseHeaderMiddleware(r.Config.ResponseHeaders)) + engine.Use(gmiddleware.ResponseHeaderMiddleware(r.Config.ResponseHeaders)) } // step: define admin subrouter: health and metrics @@ -432,12 +433,12 @@ func (r *OauthProxy) CreateReverseProxy() error { ) } - authMid := authenticationMiddleware( + authMid := gmiddleware.AuthenticationMiddleware( r.Log, r.Config.CookieAccessName, r.Config.CookieRefreshName, r.GetIdentity, - r.IdpClient, + r.IdpClient.RestyClient().GetClient(), r.Config.EnableIDPSessionCheck, r.Provider, r.Config.SkipTokenVerification, @@ -460,7 +461,7 @@ func (r *OauthProxy) CreateReverseProxy() error { loginHand := loginHandler( r.Log, r.Config.OpenIDProviderTimeout, - r.IdpClient, + r.IdpClient.RestyClient().GetClient(), r.Config.EnableLoginHandler, r.newOAuth2Config, r.getRedirectionURL, @@ -491,7 +492,7 @@ func (r *OauthProxy) CreateReverseProxy() error { r.Config.EnableLogoutRedirect, r.Store, r.Cm, - r.IdpClient, + r.IdpClient.RestyClient().GetClient(), r.accessError, r.GetIdentity, ) @@ -540,15 +541,15 @@ func (r *OauthProxy) CreateReverseProxy() error { ) // step: add the routing for oauth - engine.With(proxyDenyMiddleware(r.Log)).Route(r.Config.BaseURI+r.Config.OAuthURI, func(eng chi.Router) { + engine.With(gmiddleware.ProxyDenyMiddleware(r.Log)).Route(r.Config.BaseURI+r.Config.OAuthURI, func(eng chi.Router) { eng.MethodNotAllowed(handlers.MethodNotAllowHandlder) eng.HandleFunc(constant.AuthorizationURL, oauthAuthorizationHand) eng.Get(constant.CallbackURL, oauthCallbackHand) - eng.Get(constant.ExpiredURL, expirationHandler(r.GetIdentity, r.Config.CookieAccessName)) + eng.Get(constant.ExpiredURL, handlers.ExpirationHandler(r.GetIdentity, r.Config.CookieAccessName)) eng.With(authMid).Get(constant.LogoutURL, logoutHand) eng.With(authMid).Get( constant.TokenURL, - tokenHandler(r.GetIdentity, r.Config.CookieAccessName, r.accessError), + handlers.TokenHandler(r.GetIdentity, r.Config.CookieAccessName, r.accessError), ) eng.Post(constant.LoginURL, loginHand) eng.Get(constant.DiscoveryURL, handlers.DiscoveryHandler(r.Log, r.WithOAuthURI)) @@ -579,7 +580,7 @@ func (r *OauthProxy) CreateReverseProxy() error { } if r.Config.ListenAdmin == "" { - engine.With(proxyDenyMiddleware(r.Log)).Mount(constant.DebugURL, debugEngine) + engine.With(gmiddleware.ProxyDenyMiddleware(r.Log)).Mount(constant.DebugURL, debugEngine) } } @@ -591,7 +592,7 @@ func (r *OauthProxy) CreateReverseProxy() error { admin.MethodNotAllowed(handlers.EmptyHandler) admin.NotFound(handlers.EmptyHandler) admin.Use(middleware.Recoverer) - admin.Use(proxyDenyMiddleware(r.Log)) + admin.Use(gmiddleware.ProxyDenyMiddleware(r.Log)) admin.Route("/", func(e chi.Router) { e.Mount(r.Config.OAuthURI, adminEngine) if debugEngine != nil { @@ -655,13 +656,13 @@ func (r *OauthProxy) CreateReverseProxy() error { middlewares := []func(http.Handler) http.Handler{ authMid, - admissionMiddleware( + gmiddleware.AdmissionMiddleware( r.Log, res, r.Config.MatchClaims, r.accessForbidden, ), - identityHeadersMiddleware( + gmiddleware.IdentityHeadersMiddleware( r.Log, r.Config.AddClaims, r.Config.CookieAccessName, @@ -675,8 +676,8 @@ func (r *OauthProxy) CreateReverseProxy() error { if res.URL == constant.AllPath && !res.WhiteListed && enableDefaultDenyStrict { middlewares = []func(http.Handler) http.Handler{ - denyMiddleware(r.Log, r.accessForbidden), - proxyDenyMiddleware(r.Log), + gmiddleware.DenyMiddleware(r.Log, r.accessForbidden), + gmiddleware.ProxyDenyMiddleware(r.Log), } } @@ -708,13 +709,13 @@ func (r *OauthProxy) CreateReverseProxy() error { r.GetIdentity, r.accessForbidden, ), - admissionMiddleware( + gmiddleware.AdmissionMiddleware( r.Log, res, r.Config.MatchClaims, r.accessForbidden, ), - identityHeadersMiddleware( + gmiddleware.IdentityHeadersMiddleware( r.Log, r.Config.AddClaims, r.Config.CookieAccessName, diff --git a/pkg/keycloak/proxy/stores.go b/pkg/keycloak/proxy/stores.go deleted file mode 100644 index 9b3211aa..00000000 --- a/pkg/keycloak/proxy/stores.go +++ /dev/null @@ -1,42 +0,0 @@ -/* -Copyright 2015 All rights reserved. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package proxy - -import ( - "context" - - "github.com/gogatekeeper/gatekeeper/pkg/apperrors" - "github.com/gogatekeeper/gatekeeper/pkg/storage" - "github.com/gogatekeeper/gatekeeper/pkg/utils" -) - -// Get retrieves a token from the store, the key we are using here is the access token -func GetRefreshTokenFromStore( - ctx context.Context, - store storage.Storage, - token string, -) (string, error) { - // step: the key is the access token - val, err := store.Get(ctx, utils.GetHashKey(token)) - if err != nil { - return val, err - } - if val == "" { - return val, apperrors.ErrNoSessionStateFound - } - - return val, nil -} diff --git a/pkg/proxy/handlers/handlers.go b/pkg/proxy/handlers/handlers.go index ceff0938..a4002ed0 100644 --- a/pkg/proxy/handlers/handlers.go +++ b/pkg/proxy/handlers/handlers.go @@ -24,6 +24,7 @@ import ( "net/http/pprof" "github.com/go-chi/chi/v5" + "github.com/go-jose/go-jose/v3/jwt" "github.com/gogatekeeper/gatekeeper/pkg/apperrors" "github.com/gogatekeeper/gatekeeper/pkg/constant" "github.com/gogatekeeper/gatekeeper/pkg/encryption" @@ -223,3 +224,61 @@ func GetRedirectionURL( return fmt.Sprintf("%s%s", redirect, withOAuthURI(constant.CallbackURL)) } } + +// ExpirationHandler checks if the token has expired +func ExpirationHandler( + getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error), + cookieAccessName string, +) func(wrt http.ResponseWriter, req *http.Request) { + return func(wrt http.ResponseWriter, req *http.Request) { + user, err := getIdentity(req, cookieAccessName, "") + if err != nil { + wrt.WriteHeader(http.StatusUnauthorized) + return + } + + if user.IsExpired() { + wrt.WriteHeader(http.StatusUnauthorized) + return + } + + wrt.WriteHeader(http.StatusOK) + } +} + +// TokenHandler display access token to screen +func TokenHandler( + getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error), + cookieAccessName string, + accessError func(wrt http.ResponseWriter, req *http.Request) context.Context, +) func(wrt http.ResponseWriter, req *http.Request) { + return func(wrt http.ResponseWriter, req *http.Request) { + user, err := getIdentity(req, cookieAccessName, "") + if err != nil { + accessError(wrt, req) + return + } + + token, err := jwt.ParseSigned(user.RawToken) + if err != nil { + accessError(wrt, req) + return + } + + jsonMap := make(map[string]interface{}) + err = token.UnsafeClaimsWithoutVerification(&jsonMap) + if err != nil { + accessError(wrt, req) + return + } + + result, err := json.Marshal(jsonMap) + if err != nil { + accessError(wrt, req) + return + } + + wrt.Header().Set("Content-Type", "application/json") + _, _ = wrt.Write(result) + } +} diff --git a/pkg/proxy/middleware/base.go b/pkg/proxy/middleware/base.go new file mode 100644 index 00000000..0334cb09 --- /dev/null +++ b/pkg/proxy/middleware/base.go @@ -0,0 +1,288 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + uuid "github.com/gofrs/uuid" + "github.com/gogatekeeper/gatekeeper/pkg/constant" + "github.com/gogatekeeper/gatekeeper/pkg/proxy/cookie" + "github.com/gogatekeeper/gatekeeper/pkg/proxy/metrics" + "github.com/gogatekeeper/gatekeeper/pkg/proxy/models" + "github.com/gogatekeeper/gatekeeper/pkg/utils" + + "github.com/PuerkitoBio/purell" + "github.com/go-chi/chi/v5/middleware" + "github.com/gogatekeeper/gatekeeper/pkg/apperrors" + "go.uber.org/zap" +) + +const ( + // normalizeFlags is the options to purell + normalizeFlags purell.NormalizationFlags = purell.FlagRemoveDotSegments | purell.FlagRemoveDuplicateSlashes +) + +// entrypointMiddleware is custom filtering for incoming requests +func EntrypointMiddleware(logger *zap.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + // @step: create a context for the request + scope := &models.RequestScope{} + // Save the exact formatting of the incoming request so we can use it later + scope.Path = req.URL.Path + scope.RawPath = req.URL.RawPath + scope.Logger = logger + + // We want to Normalize the URL so that we can more easily and accurately + // parse it to apply resource protection rules. + purell.NormalizeURL(req.URL, normalizeFlags) + + // ensure we have a slash in the url + if !strings.HasPrefix(req.URL.Path, "/") { + req.URL.Path = "/" + req.URL.Path + } + req.URL.RawPath = req.URL.EscapedPath() + + resp := middleware.NewWrapResponseWriter(wrt, 1) + start := time.Now() + // All the processing, including forwarding the request upstream and getting the response, + // happens here in this chain. + next.ServeHTTP(resp, req.WithContext(context.WithValue(req.Context(), constant.ContextScopeName, scope))) + + // @metric record the time taken then response code + metrics.LatencyMetric.Observe(time.Since(start).Seconds()) + metrics.StatusMetric.WithLabelValues(strconv.Itoa(resp.Status()), req.Method).Inc() + + // place back the original uri for any later consumers + req.URL.Path = scope.Path + req.URL.RawPath = scope.RawPath + }) + } +} + +// requestIDMiddleware is responsible for adding a request id if none found +func RequestIDMiddleware(header string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + if v := req.Header.Get(header); v == "" { + uuid, err := uuid.NewV1() + if err != nil { + wrt.WriteHeader(http.StatusInternalServerError) + } + req.Header.Set(header, uuid.String()) + } + + next.ServeHTTP(wrt, req) + }) + } +} + +// loggingMiddleware is a custom http logger +func LoggingMiddleware( + logger *zap.Logger, + verbose bool, +) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { + start := time.Now() + resp, assertOk := w.(middleware.WrapResponseWriter) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return + } + + scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return + } + + if verbose { + requestLogger := logger.With( + zap.Any("headers", req.Header), + zap.String("path", req.URL.Path), + zap.String("method", req.Method), + ) + scope.Logger = requestLogger + } + + next.ServeHTTP(resp, req) + + addr := utils.RealIP(req) + + if req.URL.Path == req.URL.RawPath || req.URL.RawPath == "" { + scope.Logger.Info("client request", + zap.Duration("latency", time.Since(start)), + zap.Int("status", resp.Status()), + zap.Int("bytes", resp.BytesWritten()), + zap.String("client_ip", addr), + zap.String("remote_addr", req.RemoteAddr), + zap.String("method", req.Method), + zap.String("path", req.URL.Path)) + } else { + scope.Logger.Info("client request", + zap.Duration("latency", time.Since(start)), + zap.Int("status", resp.Status()), + zap.Int("bytes", resp.BytesWritten()), + zap.String("client_ip", addr), + zap.String("remote_addr", req.RemoteAddr), + zap.String("method", req.Method), + zap.String("path", req.URL.Path), + zap.String("raw path", req.URL.RawPath)) + } + }) + } +} + +// ResponseHeaderMiddleware is responsible for adding response headers +func ResponseHeaderMiddleware(headers map[string]string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + // @step: inject any custom response headers + for k, v := range headers { + wrt.Header().Set(k, v) + } + + next.ServeHTTP(wrt, req) + }) + } +} + +// DenyMiddleware +func DenyMiddleware( + logger *zap.Logger, + accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context, +) func(http.Handler) http.Handler { + return func(_ http.Handler) http.Handler { + logger.Info("enabling the deny middleware") + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + accessForbidden(wrt, req) + }) + } +} + +// ProxyDenyMiddleware just block everything +func ProxyDenyMiddleware(logger *zap.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + ctxVal := req.Context().Value(constant.ContextScopeName) + + var scope *models.RequestScope + if ctxVal == nil { + scope = &models.RequestScope{} + } else { + var assertOk bool + scope, assertOk = ctxVal.(*models.RequestScope) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return + } + } + + scope.AccessDenied = true + // update the request context + ctx := context.WithValue(req.Context(), constant.ContextScopeName, scope) + + next.ServeHTTP(wrt, req.WithContext(ctx)) + }) + } +} + +// MethodCheck middleware +func MethodCheckMiddleware(logger *zap.Logger) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + logger.Info("enabling the method check middleware") + + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + if !utils.IsValidHTTPMethod(req.Method) { + logger.Warn("method not implemented ", zap.String("method", req.Method)) + wrt.WriteHeader(http.StatusNotImplemented) + return + } + + next.ServeHTTP(wrt, req) + }) + } +} + +// IdentityHeadersMiddleware is responsible for adding the authentication headers to upstream +// +//nolint:cyclop +func IdentityHeadersMiddleware( + logger *zap.Logger, + custom []string, + cookieAccessName string, + cookieRefreshName string, + noProxy bool, + enableTokenHeader bool, + enableAuthzHeader bool, + enableAuthzCookies bool, +) func(http.Handler) http.Handler { + customClaims := make(map[string]string) + const minSliceLength int = 1 + cookieFilter := []string{cookieAccessName, cookieRefreshName} + + for _, val := range custom { + xslices := strings.Split(val, "|") + val = xslices[0] + if len(xslices) > minSliceLength { + customClaims[val] = utils.ToHeader(xslices[1]) + } else { + customClaims[val] = fmt.Sprintf("X-Auth-%s", utils.ToHeader(val)) + } + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return + } + + var headers http.Header + if noProxy { + headers = wrt.Header() + } else { + headers = req.Header + } + + if scope.Identity != nil { + user := scope.Identity + headers.Set("X-Auth-Audience", strings.Join(user.Audiences, ",")) + headers.Set("X-Auth-Email", user.Email) + headers.Set("X-Auth-ExpiresIn", user.ExpiresAt.String()) + headers.Set("X-Auth-Groups", strings.Join(user.Groups, ",")) + headers.Set("X-Auth-Roles", strings.Join(user.Roles, ",")) + headers.Set("X-Auth-Subject", user.ID) + headers.Set("X-Auth-Userid", user.Name) + headers.Set("X-Auth-Username", user.Name) + + // should we add the token header? + if enableTokenHeader { + headers.Set("X-Auth-Token", user.RawToken) + } + // add the authorization header if requested + if enableAuthzHeader { + headers.Set("Authorization", fmt.Sprintf("Bearer %s", user.RawToken)) + } + // are we filtering out the cookies + if !enableAuthzCookies { + _ = cookie.FilterCookies(req, cookieFilter) + } + // inject any custom claims + for claim, header := range customClaims { + if claim, found := user.Claims[claim]; found { + headers.Set(header, fmt.Sprintf("%v", claim)) + } + } + } + + next.ServeHTTP(wrt, req) + }) + } +} diff --git a/pkg/proxy/middleware/oauth.go b/pkg/proxy/middleware/oauth.go new file mode 100644 index 00000000..cc69d6e3 --- /dev/null +++ b/pkg/proxy/middleware/oauth.go @@ -0,0 +1,321 @@ +package middleware + +import ( + "context" + "errors" + "net/http" + "strings" + "time" + + oidc3 "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-jose/go-jose/v3/jwt" + "github.com/gogatekeeper/gatekeeper/pkg/apperrors" + "github.com/gogatekeeper/gatekeeper/pkg/constant" + "github.com/gogatekeeper/gatekeeper/pkg/encryption" + "github.com/gogatekeeper/gatekeeper/pkg/proxy/cookie" + "github.com/gogatekeeper/gatekeeper/pkg/proxy/models" + "github.com/gogatekeeper/gatekeeper/pkg/proxy/session" + "github.com/gogatekeeper/gatekeeper/pkg/storage" + "github.com/gogatekeeper/gatekeeper/pkg/utils" + "go.uber.org/zap" + "golang.org/x/oauth2" +) + +/* + AuthenticationMiddleware is responsible for verifying the access token +*/ +//nolint:funlen,cyclop +func AuthenticationMiddleware( + logger *zap.Logger, + cookieAccessName string, + cookieRefreshName string, + getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error), + httpClient *http.Client, + enableIDPSessionCheck bool, + provider *oidc3.Provider, + skipTokenVerification bool, + clientID string, + skipAccessTokenClientIDCheck bool, + skipAccessTokenIssuerCheck bool, + accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context, + enableRefreshTokens bool, + redirectionURL string, + cookMgr *cookie.Manager, + enableEncryptedToken bool, + forceEncryptedCookie bool, + encryptionKey string, + redirectToAuthorization func(wrt http.ResponseWriter, req *http.Request) context.Context, + newOAuth2Config func(redirectionURL string) *oauth2.Config, + store storage.Storage, + accessTokenDuration time.Duration, +) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return + } + + clientIP := utils.RealIP(req) + scope.Logger.Debug("authentication middleware") + + // grab the user identity from the request + user, err := getIdentity(req, cookieAccessName, "") + if err != nil { + scope.Logger.Error(err.Error()) + redirectToAuthorization(wrt, req) + return + } + + scope.Identity = user + ctx := context.WithValue(req.Context(), constant.ContextScopeName, scope) + lLog := scope.Logger.With( + zap.String("client_ip", clientIP), + zap.String("remote_addr", req.RemoteAddr), + zap.String("username", user.Name), + zap.String("sub", user.ID), + zap.String("expired_on", user.ExpiresAt.String()), + ) + + // IMPORTANT: For all calls with go-oidc library be aware + // that calls accept context parameter and you have to pass + // client from provider through this parameter, although + // provider is already configured with client!!! + // https://github.com/coreos/go-oidc/issues/402 + oidcLibCtx := context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + if enableIDPSessionCheck { + tokenSource := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: user.RawToken}, + ) + _, err := provider.UserInfo(oidcLibCtx, tokenSource) + if err != nil { + scope.Logger.Error(err.Error()) + redirectToAuthorization(wrt, req) + return + } + } + + // step: skip if we are running skip-token-verification + if skipTokenVerification { + scope.Logger.Warn( + "skip token verification enabled, " + + "skipping verification - TESTING ONLY", + ) + + if user.IsExpired() { + lLog.Error(apperrors.ErrSessionExpiredVerifyOff.Error()) + redirectToAuthorization(wrt, req) + return + } + } else { //nolint:gocritic + _, err := utils.VerifyToken( + ctx, + provider, + user.RawToken, + clientID, + skipAccessTokenClientIDCheck, + skipAccessTokenIssuerCheck, + ) + if err != nil { + if errors.Is(err, apperrors.ErrTokenSignature) { + lLog.Error( + apperrors.ErrAccTokenVerifyFailure.Error(), + zap.Error(err), + ) + accessForbidden(wrt, req) + return + } + + if !strings.Contains(err.Error(), "token is expired") { + lLog.Error( + apperrors.ErrAccTokenVerifyFailure.Error(), + zap.Error(err), + ) + accessForbidden(wrt, req) + return + } + + if !enableRefreshTokens { + lLog.Error(apperrors.ErrSessionExpiredRefreshOff.Error()) + redirectToAuthorization(wrt, req) + return + } + + lLog.Info("accces token for user has expired, attemping to refresh the token") + + // step: check if the user has refresh token + refresh, _, err := session.RetrieveRefreshToken( + store, + cookieRefreshName, + encryptionKey, + req.WithContext(ctx), + user, + ) + if err != nil { + scope.Logger.Error( + apperrors.ErrRefreshTokenNotFound.Error(), + zap.Error(err), + ) + redirectToAuthorization(wrt, req) + return + } + + var stdRefreshClaims *jwt.Claims + stdRefreshClaims, err = utils.ParseRefreshToken(refresh) + if err != nil { + lLog.Error( + apperrors.ErrParseRefreshToken.Error(), + zap.Error(err), + ) + accessForbidden(wrt, req) + return + } + if user.ID != stdRefreshClaims.Subject { + lLog.Error( + apperrors.ErrAccRefreshTokenMismatch.Error(), + zap.Error(err), + ) + accessForbidden(wrt, req) + return + } + + // attempt to refresh the access token, possibly with a renewed refresh token + // + // NOTE: atm, this does not retrieve explicit refresh token expiry from oauth2, + // and take identity expiry instead: with keycloak, they are the same and equal to + // "SSO session idle" keycloak setting. + // + // exp: expiration of the access token + // expiresIn: expiration of the ID token + conf := newOAuth2Config(redirectionURL) + + lLog.Debug( + "issuing refresh token request", + zap.String("current access token", user.RawToken), + zap.String("refresh token", refresh), + ) + + newAccToken, newRawAccToken, newRefreshToken, accessExpiresAt, refreshExpiresIn, err := utils.GetRefreshedToken(ctx, conf, httpClient, refresh) + if err != nil { + switch err { + case apperrors.ErrRefreshTokenExpired: + lLog.Warn("refresh token has expired, cannot retrieve access token") + cookMgr.ClearAllCookies(req.WithContext(ctx), wrt) + default: + lLog.Debug( + apperrors.ErrAccTokenRefreshFailure.Error(), + zap.String("access token", user.RawToken), + ) + lLog.Error( + apperrors.ErrAccTokenRefreshFailure.Error(), + zap.Error(err), + ) + } + + redirectToAuthorization(wrt, req) + return + } + + lLog.Debug( + "info about tokens after refreshing", + zap.String("new access token", newRawAccToken), + zap.String("new refresh token", newRefreshToken), + ) + + accessExpiresIn := time.Until(accessExpiresAt) + + if newRefreshToken != "" { + refresh = newRefreshToken + } + + if refreshExpiresIn == 0 { + // refresh token expiry claims not available: try to parse refresh token + refreshExpiresIn = session.GetAccessCookieExpiration(lLog, accessTokenDuration, refresh) + } + + lLog.Info( + "injecting the refreshed access token cookie", + zap.Duration("refresh_expires_in", refreshExpiresIn), + zap.Duration("expires_in", accessExpiresIn), + ) + + accessToken := newRawAccToken + + if enableEncryptedToken || forceEncryptedCookie { + if accessToken, err = encryption.EncodeText(accessToken, encryptionKey); err != nil { + lLog.Error( + apperrors.ErrEncryptAccToken.Error(), + zap.Error(err), + ) + accessForbidden(wrt, req) + return + } + } + + // step: inject the refreshed access token + cookMgr.DropAccessTokenCookie(req.WithContext(ctx), wrt, accessToken, accessExpiresIn) + + // update the with the new access token and inject into the context + newUser, err := session.ExtractIdentity(&newAccToken) + if err != nil { + lLog.Error(err.Error()) + accessForbidden(wrt, req) + return + } + + // step: inject the renewed refresh token + if newRefreshToken != "" { + lLog.Debug( + "renew refresh cookie with new refresh token", + zap.Duration("refresh_expires_in", refreshExpiresIn), + ) + var encryptedRefreshToken string + encryptedRefreshToken, err = encryption.EncodeText(newRefreshToken, encryptionKey) + if err != nil { + lLog.Error( + apperrors.ErrEncryptRefreshToken.Error(), + zap.Error(err), + ) + wrt.WriteHeader(http.StatusInternalServerError) + return + } + + if store != nil { + go func(ctx context.Context, old string, newToken string, encrypted string) { + ctxx, cancel := context.WithCancel(ctx) + defer cancel() + if err = store.Delete(ctxx, utils.GetHashKey(old)); err != nil { + lLog.Error( + apperrors.ErrDelTokFromStore.Error(), + zap.Error(err), + ) + } + + if err = store.Set(ctxx, utils.GetHashKey(newToken), encrypted, refreshExpiresIn); err != nil { + lLog.Error( + apperrors.ErrSaveTokToStore.Error(), + zap.Error(err), + ) + return + } + }(ctx, user.RawToken, newRawAccToken, encryptedRefreshToken) + } else { + cookMgr.DropRefreshTokenCookie(req.WithContext(ctx), wrt, encryptedRefreshToken, refreshExpiresIn) + } + } + + // IMPORTANT: on this rely other middlewares, must be refreshed + // with new identity! + newUser.RawToken = newRawAccToken + scope.Identity = newUser + ctx = context.WithValue(req.Context(), constant.ContextScopeName, scope) + } + } + + *req = *(req.WithContext(ctx)) + next.ServeHTTP(wrt, req) + }) + } +} diff --git a/pkg/proxy/middleware/security.go b/pkg/proxy/middleware/security.go new file mode 100644 index 00000000..447a224e --- /dev/null +++ b/pkg/proxy/middleware/security.go @@ -0,0 +1,201 @@ +package middleware + +import ( + "context" + "fmt" + "net/http" + "regexp" + "strings" + "time" + + "github.com/gogatekeeper/gatekeeper/pkg/apperrors" + "github.com/gogatekeeper/gatekeeper/pkg/authorization" + "github.com/gogatekeeper/gatekeeper/pkg/constant" + "github.com/gogatekeeper/gatekeeper/pkg/proxy/models" + "github.com/gogatekeeper/gatekeeper/pkg/utils" + "github.com/unrolled/secure" + "go.uber.org/zap" +) + +// SecurityMiddleware performs numerous security checks on the request +func SecurityMiddleware( + logger *zap.Logger, + allowedHosts []string, + browserXSSFilter bool, + contentSecurityPolicy string, + contentTypeNosniff bool, + frameDeny bool, + sslRedirect bool, + accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context, +) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + logger.Info("enabling the security filter middleware") + + secure := secure.New(secure.Options{ + AllowedHosts: allowedHosts, + BrowserXssFilter: browserXSSFilter, + ContentSecurityPolicy: contentSecurityPolicy, + ContentTypeNosniff: contentTypeNosniff, + FrameDeny: frameDeny, + SSLProxyHeaders: map[string]string{"X-Forwarded-Proto": "https"}, + SSLRedirect: sslRedirect, + }) + + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return + } + + if err := secure.Process(wrt, req); err != nil { + scope.Logger.Warn("failed security middleware", zap.Error(err)) + accessForbidden(wrt, req) + return + } + + next.ServeHTTP(wrt, req) + }) + } +} + +// HmacMiddleware verifies hmac +func HmacMiddleware(logger *zap.Logger, encKey string) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return + } + + if scope.AccessDenied { + next.ServeHTTP(wrt, req) + return + } + + expectedMAC := req.Header.Get(constant.HeaderXHMAC) + if expectedMAC == "" { + logger.Debug(apperrors.ErrHmacHeaderEmpty.Error()) + wrt.WriteHeader(http.StatusBadRequest) + return + } + + reqHmac, err := utils.GenerateHmac(req, encKey) + if err != nil { + logger.Error(err.Error()) + } + + if reqHmac != expectedMAC { + logger.Debug(apperrors.ErrHmacMismatch.Error()) + wrt.WriteHeader(http.StatusBadRequest) + return + } + + next.ServeHTTP(wrt, req) + }) + } +} + +// AdmissionMiddleware is responsible for checking the access token against the protected resource +// +//nolint:cyclop +func AdmissionMiddleware( + logger *zap.Logger, + resource *authorization.Resource, + matchClaims map[string]string, + accessForbidden func(wrt http.ResponseWriter, req *http.Request) context.Context, +) func(http.Handler) http.Handler { + claimMatches := make(map[string]*regexp.Regexp) + for k, v := range matchClaims { + claimMatches[k] = regexp.MustCompile(v) + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(wrt http.ResponseWriter, req *http.Request) { + // we don't need to continue is a decision has been made + scope, assertOk := req.Context().Value(constant.ContextScopeName).(*models.RequestScope) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return + } + if scope.AccessDenied { + next.ServeHTTP(wrt, req) + return + } + + user := scope.Identity + lLog := scope.Logger.With( + zap.String("access", "denied"), + zap.String("email", user.Email), + zap.String("resource", resource.URL), + ) + + // @step: we need to check the roles + if !utils.HasAccess(resource.Roles, user.Roles, !resource.RequireAnyRole) { + lLog.Warn("access denied, invalid roles", + zap.String("roles", resource.GetRoles())) + accessForbidden(wrt, req) + return + } + + if len(resource.Headers) > 0 { + var reqHeaders []string + + for _, resVal := range resource.Headers { + resVals := strings.Split(resVal, ":") + name := resVals[0] + canonName := http.CanonicalHeaderKey(name) + values, ok := req.Header[canonName] + if !ok { + lLog.Warn("access denied, invalid headers", + zap.String("headers", resource.GetHeaders())) + accessForbidden(wrt, req) + return + } + + for _, value := range values { + headVal := fmt.Sprintf( + "%s:%s", + strings.ToLower(name), + strings.ToLower(value), + ) + reqHeaders = append(reqHeaders, headVal) + } + } + + // @step: we need to check the headers + if !utils.HasAccess(resource.Headers, reqHeaders, true) { + lLog.Warn("access denied, invalid headers", + zap.String("headers", resource.GetHeaders())) + accessForbidden(wrt, req) + return + } + } + + // @step: check if we have any groups, the groups are there + if !utils.HasAccess(resource.Groups, user.Groups, false) { + lLog.Warn("access denied, invalid groups", + zap.String("groups", strings.Join(resource.Groups, ","))) + accessForbidden(wrt, req) + return + } + + // step: if we have any claim matching, lets validate the tokens has the claims + for claimName, match := range claimMatches { + if !utils.CheckClaim(scope.Logger, user, claimName, match, resource.URL) { + accessForbidden(wrt, req) + return + } + } + + scope.Logger.Debug("access permitted to resource", + zap.String("access", "permitted"), + zap.String("email", user.Email), + zap.Duration("expires", time.Until(user.ExpiresAt)), + zap.String("resource", resource.URL)) + + next.ServeHTTP(wrt, req) + }) + } +} diff --git a/pkg/proxy/session/token.go b/pkg/proxy/session/token.go index 935fd74a..65de7713 100644 --- a/pkg/proxy/session/token.go +++ b/pkg/proxy/session/token.go @@ -6,6 +6,7 @@ import ( "net/http" "strconv" "strings" + "time" "github.com/go-jose/go-jose/v3/jwt" "github.com/gogatekeeper/gatekeeper/pkg/apperrors" @@ -13,6 +14,7 @@ import ( "github.com/gogatekeeper/gatekeeper/pkg/encryption" "github.com/gogatekeeper/gatekeeper/pkg/proxy/cookie" "github.com/gogatekeeper/gatekeeper/pkg/proxy/models" + "github.com/gogatekeeper/gatekeeper/pkg/storage" "go.uber.org/zap" ) @@ -231,3 +233,64 @@ func ExtractIdentity(token *jwt.JSONWebToken) (*models.UserContext, error) { Permissions: customClaims.Authorization, }, nil } + +// retrieveRefreshToken retrieves the refresh token from store or cookie +func RetrieveRefreshToken( + store storage.Storage, + cookieRefreshName string, + encryptionKey string, + req *http.Request, + user *models.UserContext, +) (string, string, error) { + var token string + var err error + + switch store != nil { + case true: + token, err = store.GetRefreshTokenFromStore(req.Context(), user.RawToken) + default: + token, err = GetRefreshTokenFromCookie(req, cookieRefreshName) + } + + if err != nil { + return token, "", err + } + + encrypted := token // returns encrypted, avoids encoding twice + token, err = encryption.DecodeText(token, encryptionKey) + return token, encrypted, err +} + +// GetAccessCookieExpiration calculates the expiration of the access token cookie +func GetAccessCookieExpiration( + logger *zap.Logger, + accessTokenDuration time.Duration, + refresh string, +) time.Duration { + // notes: by default the duration of the access token will be the configuration option, if + // however we can decode the refresh token, we will set the duration to the duration of the + // refresh token + duration := accessTokenDuration + + webToken, err := jwt.ParseSigned(refresh) + if err != nil { + logger.Error("unable to parse token") + } + + if ident, err := ExtractIdentity(webToken); err == nil { + delta := time.Until(ident.ExpiresAt) + + if delta > 0 { + duration = delta + } + + logger.Debug( + "parsed refresh token with new duration", + zap.Duration("new duration", delta), + ) + } else { + logger.Debug("refresh token is opaque and cannot be used to extend calculated duration") + } + + return duration +} diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 6624fa2b..99d75c2f 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -20,6 +20,7 @@ type Storage interface { Delete(context.Context, string) error // Close is used to close off any resources Close() error + GetRefreshTokenFromStore(ctx context.Context, token string) (string, error) } // createStorage creates the store client for use diff --git a/pkg/storage/store_redis.go b/pkg/storage/store_redis.go index 9344954a..fa384b10 100644 --- a/pkg/storage/store_redis.go +++ b/pkg/storage/store_redis.go @@ -19,6 +19,8 @@ import ( "context" "time" + "github.com/gogatekeeper/gatekeeper/pkg/apperrors" + "github.com/gogatekeeper/gatekeeper/pkg/utils" redis "github.com/redis/go-redis/v9" ) @@ -80,3 +82,20 @@ func (r RedisStore) Close() error { return nil } + +// Get retrieves a token from the store, the key we are using here is the access token +func (r RedisStore) GetRefreshTokenFromStore( + ctx context.Context, + token string, +) (string, error) { + // step: the key is the access token + val, err := r.Get(ctx, utils.GetHashKey(token)) + if err != nil { + return val, err + } + if val == "" { + return val, apperrors.ErrNoSessionStateFound + } + + return val, nil +} diff --git a/pkg/testsuite/misc_test.go b/pkg/testsuite/misc_test.go index 5dbbb69b..45b98454 100644 --- a/pkg/testsuite/misc_test.go +++ b/pkg/testsuite/misc_test.go @@ -23,7 +23,7 @@ import ( "testing" "time" - keycloakproxy "github.com/gogatekeeper/gatekeeper/pkg/keycloak/proxy" + "github.com/gogatekeeper/gatekeeper/pkg/proxy/session" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -95,7 +95,7 @@ func TestGetAccessCookieExpiration_NoExp(t *testing.T) { c := newFakeKeycloakConfig() c.AccessTokenDuration = time.Duration(1) * time.Hour proxy := newFakeProxy(c, &fakeAuthConfig{}).proxy - duration := keycloakproxy.GetAccessCookieExpiration(proxy.Log, c.AccessTokenDuration, token) + duration := session.GetAccessCookieExpiration(proxy.Log, c.AccessTokenDuration, token) assertAlmostEquals(t, c.AccessTokenDuration, duration) } @@ -107,7 +107,7 @@ func TestGetAccessCookieExpiration_ZeroExp(t *testing.T) { c := newFakeKeycloakConfig() c.AccessTokenDuration = time.Duration(1) * time.Hour proxy := newFakeProxy(c, &fakeAuthConfig{}).proxy - duration := keycloakproxy.GetAccessCookieExpiration(proxy.Log, c.AccessTokenDuration, token) + duration := session.GetAccessCookieExpiration(proxy.Log, c.AccessTokenDuration, token) assert.Greater(t, duration, 0*time.Second, "duration should be positive") assertAlmostEquals(t, c.AccessTokenDuration, duration) } @@ -120,7 +120,7 @@ func TestGetAccessCookieExpiration_PastExp(t *testing.T) { c := newFakeKeycloakConfig() c.AccessTokenDuration = time.Duration(1) * time.Hour proxy := newFakeProxy(c, &fakeAuthConfig{}).proxy - duration := keycloakproxy.GetAccessCookieExpiration(proxy.Log, c.AccessTokenDuration, token) + duration := session.GetAccessCookieExpiration(proxy.Log, c.AccessTokenDuration, token) assertAlmostEquals(t, c.AccessTokenDuration, duration) } @@ -131,7 +131,7 @@ func TestGetAccessCookieExpiration_ValidExp(t *testing.T) { c := newFakeKeycloakConfig() c.AccessTokenDuration = time.Duration(1) * time.Hour proxy := newFakeProxy(c, &fakeAuthConfig{}).proxy - duration := keycloakproxy.GetAccessCookieExpiration(proxy.Log, c.AccessTokenDuration, token) + duration := session.GetAccessCookieExpiration(proxy.Log, c.AccessTokenDuration, token) expectedDuration := time.Until(time.Unix(fToken.Claims.Exp, 0)) assertAlmostEquals(t, expectedDuration, duration) } diff --git a/pkg/utils/token.go b/pkg/utils/token.go new file mode 100644 index 00000000..f180a045 --- /dev/null +++ b/pkg/utils/token.go @@ -0,0 +1,252 @@ +package utils + +import ( + "context" + "errors" + "fmt" + "net/http" + "regexp" + "strings" + "time" + + oidc3 "github.com/coreos/go-oidc/v3/oidc" + "github.com/go-jose/go-jose/v3/jwt" + "github.com/gogatekeeper/gatekeeper/pkg/apperrors" + "github.com/gogatekeeper/gatekeeper/pkg/proxy/metrics" + "github.com/gogatekeeper/gatekeeper/pkg/proxy/models" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + "golang.org/x/oauth2" +) + +func VerifyToken( + ctx context.Context, + provider *oidc3.Provider, + rawToken string, + clientID string, + skipClientIDCheck bool, + skipIssuerCheck bool, +) (*oidc3.IDToken, error) { + // This verifier with this configuration checks only signatures + // we want to know if we are using valid token + // bad is that Verify method doesn't check first signatures, so + // we have to do it like this + verifier := provider.Verifier( + &oidc3.Config{ + ClientID: clientID, + SkipClientIDCheck: true, + SkipIssuerCheck: true, + SkipExpiryCheck: true, + }, + ) + _, err := verifier.Verify(ctx, rawToken) + if err != nil { + return nil, errors.Join(apperrors.ErrTokenSignature, err) + } + + // Now doing expiration check + verifier = provider.Verifier( + &oidc3.Config{ + ClientID: clientID, + SkipClientIDCheck: skipClientIDCheck, + SkipIssuerCheck: skipIssuerCheck, + SkipExpiryCheck: false, + }, + ) + + oToken, err := verifier.Verify(ctx, rawToken) + if err != nil { + return nil, err + } + + return oToken, nil +} + +func ParseRefreshToken(rawRefreshToken string) (*jwt.Claims, error) { + refreshToken, err := jwt.ParseSigned(rawRefreshToken) + if err != nil { + return nil, err + } + + stdRefreshClaims := &jwt.Claims{} + err = refreshToken.UnsafeClaimsWithoutVerification(stdRefreshClaims) + if err != nil { + return nil, err + } + + return stdRefreshClaims, nil +} + +// GetRefreshedToken attempts to refresh the access token, returning the parsed token, optionally with a renewed +// refresh token and the time the access and refresh tokens expire +// +// NOTE: we may be able to extract the specific (non-standard) claim refresh_expires_in and refresh_expires +// from response.RawBody. +// When not available, keycloak provides us with the same (for now) expiry value for ID token. +func GetRefreshedToken( + ctx context.Context, + conf *oauth2.Config, + httpClient *http.Client, + oldRefreshToken string, +) (jwt.JSONWebToken, string, string, time.Time, time.Duration, error) { + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + start := time.Now() + + tkn, err := conf.TokenSource(ctx, &oauth2.Token{RefreshToken: oldRefreshToken}).Token() + if err != nil { + if strings.Contains(err.Error(), "invalid_grant") { + return jwt.JSONWebToken{}, + "", + "", + time.Time{}, + time.Duration(0), + apperrors.ErrRefreshTokenExpired + } + return jwt.JSONWebToken{}, + "", + "", + time.Time{}, + time.Duration(0), + err + } + + taken := time.Since(start).Seconds() + metrics.OauthTokensMetric.WithLabelValues("renew").Inc() + metrics.OauthLatencyMetric.WithLabelValues("renew").Observe(taken) + + token, err := jwt.ParseSigned(tkn.AccessToken) + if err != nil { + return jwt.JSONWebToken{}, + "", + "", + time.Time{}, + time.Duration(0), + err + } + + refreshToken, err := jwt.ParseSigned(tkn.RefreshToken) + if err != nil { + return jwt.JSONWebToken{}, + "", + "", + time.Time{}, + time.Duration(0), + err + } + + stdClaims := &jwt.Claims{} + err = token.UnsafeClaimsWithoutVerification(stdClaims) + if err != nil { + return jwt.JSONWebToken{}, + "", + "", + time.Time{}, + time.Duration(0), + err + } + + refreshStdClaims := &jwt.Claims{} + err = refreshToken.UnsafeClaimsWithoutVerification(refreshStdClaims) + if err != nil { + return jwt.JSONWebToken{}, + "", + "", + time.Time{}, + time.Duration(0), + err + } + + refreshExpiresIn := time.Until(refreshStdClaims.Expiry.Time()) + + return *token, + tkn.AccessToken, + tkn.RefreshToken, + stdClaims.Expiry.Time(), + refreshExpiresIn, + nil +} + +// CheckClaim checks whether claim in userContext matches claimName, match. It can be String or Strings claim. +// +//nolint:cyclop +func CheckClaim( + logger *zap.Logger, + user *models.UserContext, + claimName string, + match *regexp.Regexp, + resourceURL string, +) bool { + errFields := []zapcore.Field{ + zap.String("claim", claimName), + zap.String("access", "denied"), + zap.String("email", user.Email), + zap.String("resource", resourceURL), + } + + lLog := logger.With(errFields...) + if _, found := user.Claims[claimName]; !found { + lLog.Warn("the token does not have the claim") + return false + } + + switch user.Claims[claimName].(type) { + case []interface{}: + claims, assertOk := user.Claims[claimName].([]interface{}) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return false + } + + for _, v := range claims { + value, ok := v.(string) + if !ok { + lLog.Warn( + "Problem while asserting claim", + zap.String( + "issued", + fmt.Sprintf("%v", user.Claims[claimName]), + ), + zap.String("required", match.String()), + ) + + return false + } + + if match.MatchString(value) { + return true + } + } + + lLog.Warn( + "claim requirement does not match any element claim group in token", + zap.String("issued", fmt.Sprintf("%v", user.Claims[claimName])), + zap.String("required", match.String()), + ) + + return false + case string: + claims, assertOk := user.Claims[claimName].(string) + if !assertOk { + logger.Error(apperrors.ErrAssertionFailed.Error()) + return false + } + if match.MatchString(claims) { + return true + } + + lLog.Warn( + "claim requirement does not match claim in token", + zap.String("issued", claims), + zap.String("required", match.String()), + ) + + return false + default: + logger.Error( + "unable to extract the claim from token not string or array of strings", + ) + } + + lLog.Warn("unexpected error") + return false +}