Skip to content

Commit

Permalink
Tt 13184 Upstream OAuth2 updates to fix TTL issue (#6643)
Browse files Browse the repository at this point in the history
<details open>
<summary><a href="https://tyktech.atlassian.net/browse/TT-13184"
title="TT-13184" target="_blank">TT-13184</a></summary>
  <br />
  <table>
    <tr>
      <th>Summary</th>
<td>Implement OAuth 2.0 Client Credentials for API Gateway
Authentication with Upstream Server</td>
    </tr>
    <tr>
      <th>Type</th>
      <td>
<img alt="Story"
src="https://tyktech.atlassian.net/rest/api/2/universal_avatar/view/type/issuetype/avatar/10315?size=medium"
/>
        Story
      </td>
    </tr>
    <tr>
      <th>Status</th>
      <td>Ready for Testing</td>
    </tr>
    <tr>
      <th>Points</th>
      <td>N/A</td>
    </tr>
    <tr>
      <th>Labels</th>
      <td>-</td>
    </tr>
  </table>
</details>
<!--
  do not remove this marker as it will break jira-lint's functionality.
  added_by_jira_lint
-->

---

<!-- Provide a general summary of your changes in the Title above -->

## Description

<!-- Describe your changes in detail -->

## Related Issue

<!-- This project only accepts pull requests related to open issues. -->
<!-- If suggesting a new feature or change, please discuss it in an
issue first. -->
<!-- If fixing a bug, there should be an issue describing it with steps
to reproduce. -->
<!-- OSS: Please link to the issue here. Tyk: please create/link the
JIRA ticket. -->

## Motivation and Context

<!-- Why is this change required? What problem does it solve? -->

## How This Has Been Tested

<!-- Please describe in detail how you tested your changes -->
<!-- Include details of your testing environment, and the tests -->
<!-- you ran to see how your change affects other areas of the code,
etc. -->
<!-- This information is helpful for reviewers and QA. -->

## Screenshots (if appropriate)

## Types of changes

<!-- What types of changes does your code introduce? Put an `x` in all
the boxes that apply: -->

- [ ] Bug fix (non-breaking change which fixes an issue)
- [ ] New feature (non-breaking change which adds functionality)
- [ ] Breaking change (fix or feature that would cause existing
functionality to change)
- [ ] Refactoring or add test (improvements in base code or adds test
coverage to functionality)

## Checklist

<!-- Go over all the following points, and put an `x` in all the boxes
that apply -->
<!-- If there are no documentation updates required, mark the item as
checked. -->
<!-- Raise up any additional concerns not covered by the checklist. -->

- [ ] I ensured that the documentation is up to date
- [ ] I explained why this PR updates go.mod in detail with reasoning
why it's required
- [ ] I would like a code coverage CI quality gate exception and have
explained why
  • Loading branch information
andrei-tyk authored Oct 17, 2024
1 parent 8f37f4b commit 33db0e2
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 54 deletions.
10 changes: 8 additions & 2 deletions apidef/api_definitions.go
Original file line number Diff line number Diff line change
Expand Up @@ -813,18 +813,24 @@ type UpstreamOAuth struct {
HeaderName string `bson:"header_name" json:"header_name,omitempty"`
}

// ClientCredentials holds the client credentials for upstream OAuth2 authentication.
type ClientCredentials struct {
// ClientAuthData holds the client ID and secret for upstream OAuth2 authentication.
type ClientAuthData struct {
// ClientID is the application's ID.
ClientID string `bson:"client_id" json:"client_id"`
// ClientSecret is the application's secret.
ClientSecret string `bson:"client_secret" json:"client_secret"`
}

// ClientCredentials holds the client credentials for upstream OAuth2 authentication.
type ClientCredentials struct {
ClientAuthData
// TokenURL is the resource server's token endpoint
// URL. This is a constant specific to each server.
TokenURL string `bson:"token_url" json:"token_url"`
// Scopes specifies optional requested permissions.
Scopes []string `bson:"scopes" json:"scopes,omitempty"`

// TokenProvider is the OAuth2 token provider for internal use.
TokenProvider oauth2.TokenSource `bson:"-" json:"-"`
}

Expand Down
6 changes: 0 additions & 6 deletions apidef/oas/schema/x-tyk-api-gateway.json
Original file line number Diff line number Diff line change
Expand Up @@ -2054,17 +2054,11 @@
},
"scopes":{
"type": ["array", "null"]
},
"endpointParams": {
"type": ["object", "null"]
}
}
},
"headerName": {
"type": "string"
},
"distributedToken": {
"type": "boolean"
}
}
}
Expand Down
6 changes: 0 additions & 6 deletions apidef/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -805,17 +805,11 @@ const Schema = `{
},
"scopes":{
"type": ["array", "null"]
},
"endpoint_params": {
"type": ["object", "null"]
}
}
},
"header_name": {
"type": "string"
},
"distributed_token": {
"type": "boolean"
}
}
}
Expand Down
84 changes: 49 additions & 35 deletions gateway/mw_oauth2_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"strings"
"time"

"golang.org/x/oauth2"

"github.com/sirupsen/logrus"
oauth2clientcredentials "golang.org/x/oauth2/clientcredentials"

Expand All @@ -24,12 +26,28 @@ const (
)

type OAuthHeaderProvider interface {
getOAuthHeaderValue(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error)
// getOAuthToken returns the OAuth token for the request.
getOAuthToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error)
}

type DistributedCacheOAuthProvider struct{}
type ClientCredentialsOAuthProvider struct{}

type PerAPIClientCredentialsOAuthProvider struct{}

type PerAPIOAuthProvider struct{}
func newUpstreamOAuthClientCredentialsCache(connectionHandler *storage.ConnectionHandler) *upstreamOAuthClientCredentialsCache {
return &upstreamOAuthClientCredentialsCache{RedisCluster: storage.RedisCluster{KeyPrefix: "upstreamOAuthCC-", ConnectionHandler: connectionHandler}}
}

type upstreamOAuthClientCredentialsCache struct {
storage.RedisCluster
}

type UpstreamOAuthCache interface {
// getToken returns the token from cache or issues a request to obtain it from the OAuth provider.
getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error)
// obtainToken issues a request to obtain the token from the OAuth provider.
obtainToken(ctx context.Context, OAuthSpec *UpstreamOAuth) (*oauth2.Token, error)
}

// UpstreamOAuth is a middleware that will do basic authentication for upstream connections.
// UpstreamOAuth middleware is only supported in Tyk OAS API definitions.
Expand Down Expand Up @@ -67,11 +85,14 @@ func (OAuthSpec *UpstreamOAuth) ProcessRequest(_ http.ResponseWriter, r *http.Re
upstreamOAuthProvider.HeaderName = oauthConfig.HeaderName
}

provider := getOAuthHeaderProvider(oauthConfig)
provider, err := getOAuthHeaderProvider(oauthConfig)
if err != nil {
return fmt.Errorf("failed to get OAuth header provider: %w", err), http.StatusInternalServerError
}

payload, err := provider.getOAuthHeaderValue(r, OAuthSpec)
payload, err := provider.getOAuthToken(r, OAuthSpec)
if err != nil {
return fmt.Errorf("failed to get OAuth token: %v", err), http.StatusInternalServerError
return fmt.Errorf("failed to get OAuth token: %w", err), http.StatusInternalServerError
}

upstreamOAuthProvider.AuthValue = payload
Expand All @@ -80,11 +101,12 @@ func (OAuthSpec *UpstreamOAuth) ProcessRequest(_ http.ResponseWriter, r *http.Re
return nil, http.StatusOK
}

func getOAuthHeaderProvider(oauthConfig apidef.UpstreamOAuth) OAuthHeaderProvider {
return &DistributedCacheOAuthProvider{}
func getOAuthHeaderProvider(oauthConfig apidef.UpstreamOAuth) (OAuthHeaderProvider, error) {
// to be extended when PasswordAuth is implemented
return &ClientCredentialsOAuthProvider{}, nil
}

func (p *PerAPIOAuthProvider) getOAuthHeaderValue(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
func (p *PerAPIClientCredentialsOAuthProvider) getOAuthHeaderValue(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
oauthConfig := OAuthSpec.Spec.UpstreamAuth.OAuth

if oauthConfig.ClientCredentials.TokenProvider == nil {
Expand All @@ -108,26 +130,17 @@ func handleOAuthError(r *http.Request, OAuthSpec *UpstreamOAuth, err error) (str
return "", err
}

func (p *DistributedCacheOAuthProvider) getOAuthHeaderValue(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
func (p *ClientCredentialsOAuthProvider) getOAuthToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
if OAuthSpec.Gw.UpstreamOAuthCache == nil {
OAuthSpec.Gw.UpstreamOAuthCache = newUpstreamOAuthCache(OAuthSpec.Gw.StorageConnectionHandler)
OAuthSpec.Gw.UpstreamOAuthCache = newUpstreamOAuthClientCredentialsCache(OAuthSpec.Gw.StorageConnectionHandler)
}

token, err := OAuthSpec.Gw.UpstreamOAuthCache.getToken(r, OAuthSpec)
if err != nil {
return handleOAuthError(r, OAuthSpec, err)
}

payload := fmt.Sprintf("Bearer %s", token)
return payload, nil
}

func newUpstreamOAuthCache(connectionHandler *storage.ConnectionHandler) *upstreamOAuthCache {
return &upstreamOAuthCache{RedisCluster: storage.RedisCluster{KeyPrefix: "upstreamOAuth-", ConnectionHandler: connectionHandler}}
}

type upstreamOAuthCache struct {
storage.RedisCluster
return fmt.Sprintf("Bearer %s", token), nil
}

func generateCacheKey(config apidef.UpstreamOAuth, apiId string) string {
Expand All @@ -143,34 +156,35 @@ func generateCacheKey(config apidef.UpstreamOAuth, apiId string) string {
return hex.EncodeToString(hash.Sum(nil))
}

func (cache *upstreamOAuthCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
func (cache *upstreamOAuthClientCredentialsCache) getToken(r *http.Request, OAuthSpec *UpstreamOAuth) (string, error) {
cacheKey := generateCacheKey(OAuthSpec.Spec.UpstreamAuth.OAuth, OAuthSpec.Spec.APIID)

token, err := cache.retryGetKeyAndLock(cacheKey)
tokenString, err := retryGetKeyAndLock(cacheKey, &cache.RedisCluster)
if err != nil {
return "", err
}

if token != "" {
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw), token)
if tokenString != "" {
decryptedToken := decrypt(getPaddedSecret(OAuthSpec.Gw), tokenString)
return decryptedToken, nil
}

token, err = cache.obtainToken(r.Context(), OAuthSpec)
token, err := cache.obtainToken(r.Context(), OAuthSpec)
if err != nil {
return "", err
}

encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw), token)
encryptedToken := encrypt(getPaddedSecret(OAuthSpec.Gw), token.AccessToken)

if err := cache.setTokenInCache(cacheKey, encryptedToken); err != nil {
ttl := time.Until(token.Expiry)
if err := setTokenInCache(cacheKey, encryptedToken, ttl, &cache.RedisCluster); err != nil {
return "", err
}

return token, nil
return token.AccessToken, nil
}

func (cache *upstreamOAuthCache) retryGetKeyAndLock(cacheKey string) (string, error) {
func retryGetKeyAndLock(cacheKey string, cache *storage.RedisCluster) (string, error) {
const maxRetries = 10
const retryDelay = 100 * time.Millisecond

Expand Down Expand Up @@ -204,20 +218,20 @@ func newOAuth2ClientCredentialsConfig(OAuthSpec *UpstreamOAuth) oauth2clientcred
}
}

func (cache *upstreamOAuthCache) obtainToken(ctx context.Context, OAuthSpec *UpstreamOAuth) (string, error) {
func (cache *upstreamOAuthClientCredentialsCache) obtainToken(ctx context.Context, OAuthSpec *UpstreamOAuth) (*oauth2.Token, error) {
cfg := newOAuth2ClientCredentialsConfig(OAuthSpec)

tokenSource := cfg.TokenSource(ctx)
oauthToken, err := tokenSource.Token()
if err != nil {
return "", err
return &oauth2.Token{}, err
}

return oauthToken.AccessToken, nil
return oauthToken, nil
}

func (cache *upstreamOAuthCache) setTokenInCache(cacheKey, token string) error {
oauthTokenExpiry := time.Now().Add(time.Hour)
func setTokenInCache(cacheKey string, token string, ttl time.Duration, cache *storage.RedisCluster) error {
oauthTokenExpiry := time.Now().Add(ttl)
return cache.SetKey(cacheKey, token, int64(oauthTokenExpiry.Sub(time.Now()).Seconds()))
}

Expand Down
10 changes: 6 additions & 4 deletions gateway/mw_oauth2_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ func TestUpstreamOauth2(t *testing.T) {
defer ts.Close()

cfg := apidef.ClientCredentials{
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
TokenURL: ts.URL + "/token",
Scopes: []string{"scope1", "scope2"},
ClientAuthData: apidef.ClientAuthData{
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
},
TokenURL: ts.URL + "/token",
Scopes: []string{"scope1", "scope2"},
}

tst.Gw.BuildAndLoadAPI(
Expand Down
2 changes: 1 addition & 1 deletion gateway/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ type Gateway struct {
HostCheckerClient *http.Client
TracerProvider otel.TracerProvider
// UpstreamOAuthCache is used to cache upstream OAuth tokens
UpstreamOAuthCache *upstreamOAuthCache
UpstreamOAuthCache UpstreamOAuthCache

keyGen DefaultKeyGenerator

Expand Down

0 comments on commit 33db0e2

Please sign in to comment.