Skip to content

Commit

Permalink
Refactor, move common parts to packages
Browse files Browse the repository at this point in the history
  • Loading branch information
p53 committed May 27, 2024
1 parent 3b095de commit 35823eb
Show file tree
Hide file tree
Showing 15 changed files with 1,243 additions and 1,199 deletions.
102 changes: 8 additions & 94 deletions pkg/keycloak/proxy/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ func oauthCallbackHandler(
if enableRefreshTokens && refreshToken != "" {
var encrypted string
var stdRefreshClaims *jwt.Claims
stdRefreshClaims, err = parseRefreshToken(refreshToken)
stdRefreshClaims, err = utils.ParseRefreshToken(refreshToken)
if err != nil {
scope.Logger.Error(apperrors.ErrParseRefreshToken.Error(), zap.Error(err))
accessForbidden(writer, req)
Expand Down Expand Up @@ -384,7 +384,7 @@ func oauthCallbackHandler(
func loginHandler(
logger *zap.Logger,
openIDProviderTimeout time.Duration,
idpClient *gocloak.GoCloak,
httpClient *http.Client,
enableLoginHandler bool,
newOAuth2Config func(redirectionURL string) *oauth2.Config,
getRedirectionURL func(wrt http.ResponseWriter, req *http.Request) string,
Expand Down Expand Up @@ -416,7 +416,7 @@ func loginHandler(
ctx = context.WithValue(
ctx,
oauth2.HTTPClient,
idpClient.RestyClient().GetClient(),
httpClient,
)

if !enableLoginHandler {
Expand Down Expand Up @@ -507,15 +507,15 @@ func loginHandler(
req,
writer,
accessToken,
GetAccessCookieExpiration(scope.Logger, accessTokenDuration, token.RefreshToken),
session.GetAccessCookieExpiration(scope.Logger, accessTokenDuration, token.RefreshToken),
)

if enableIDTokenCookie {
cookManager.DropIDTokenCookie(
req,
writer,
idToken,
GetAccessCookieExpiration(scope.Logger, accessTokenDuration, token.RefreshToken),
session.GetAccessCookieExpiration(scope.Logger, accessTokenDuration, token.RefreshToken),
)
}

Expand Down Expand Up @@ -643,7 +643,7 @@ func logoutHandler(
enableLogoutRedirect bool,
store storage.Storage,
cookManager *cookie.Manager,
idpClient *gocloak.GoCloak,
httpClient *http.Client,
accessError func(wrt http.ResponseWriter, req *http.Request) context.Context,
GetIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error),
) func(wrt http.ResponseWriter, req *http.Request) {
Expand Down Expand Up @@ -687,7 +687,7 @@ func logoutHandler(
identityToken := user.RawToken

//nolint:vetshadow
if refresh, _, err := retrieveRefreshToken(
if refresh, _, err := session.RetrieveRefreshToken(
store,
cookieRefreshName,
encryptionKey,
Expand Down Expand Up @@ -771,7 +771,6 @@ func logoutHandler(

// step: do we have a revocation endpoint?
if revocationURL != "" {
client := idpClient.RestyClient().GetClient()
// step: add the authentication headers
encodedID := url.QueryEscape(clientID)
encodedSecret := url.QueryEscape(clientSecret)
Expand All @@ -795,7 +794,7 @@ func logoutHandler(
request.Header.Set("Content-Type", "application/x-www-form-urlencoded")

start := time.Now()
response, err := client.Do(request)
response, err := httpClient.Do(request)
if err != nil {
scope.Logger.Error(apperrors.ErrRevocationReqFailure.Error(), zap.Error(err))
writer.WriteHeader(http.StatusInternalServerError)
Expand Down Expand Up @@ -834,88 +833,3 @@ func logoutHandler(
}
}
}

// expirationHandler checks if the token has expired
func expirationHandler(
getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error),
cookieAccessName string,
) func(wrt http.ResponseWriter, req *http.Request) {
return func(wrt http.ResponseWriter, req *http.Request) {
user, err := getIdentity(req, cookieAccessName, "")
if err != nil {
wrt.WriteHeader(http.StatusUnauthorized)
return
}

if user.IsExpired() {
wrt.WriteHeader(http.StatusUnauthorized)
return
}

wrt.WriteHeader(http.StatusOK)
}
}

// tokenHandler display access token to screen
func tokenHandler(
getIdentity func(req *http.Request, tokenCookie string, tokenHeader string) (*models.UserContext, error),
cookieAccessName string,
accessError func(wrt http.ResponseWriter, req *http.Request) context.Context,
) func(wrt http.ResponseWriter, req *http.Request) {
return func(wrt http.ResponseWriter, req *http.Request) {
user, err := getIdentity(req, cookieAccessName, "")
if err != nil {
accessError(wrt, req)
return
}

token, err := jwt.ParseSigned(user.RawToken)
if err != nil {
accessError(wrt, req)
return
}

jsonMap := make(map[string]interface{})
err = token.UnsafeClaimsWithoutVerification(&jsonMap)
if err != nil {
accessError(wrt, req)
return
}

result, err := json.Marshal(jsonMap)
if err != nil {
accessError(wrt, req)
return
}

wrt.Header().Set("Content-Type", "application/json")
_, _ = wrt.Write(result)
}
}

// retrieveRefreshToken retrieves the refresh token from store or cookie
func retrieveRefreshToken(
store storage.Storage,
cookieRefreshName string,
encryptionKey string,
req *http.Request,
user *models.UserContext,
) (string, string, error) {
var token string
var err error

switch store != nil {
case true:
token, err = GetRefreshTokenFromStore(req.Context(), store, user.RawToken)
default:
token, err = session.GetRefreshTokenFromCookie(req, cookieRefreshName)
}

if err != nil {
return token, "", err
}

encrypted := token // returns encrypted, avoids encoding twice
token, err = encryption.DecodeText(token, encryptionKey)
return token, encrypted, err
}
Loading

0 comments on commit 35823eb

Please sign in to comment.