Skip to content

Commit

Permalink
Merge branch 'master' into opa-v1.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
Pushpalanka authored Jan 12, 2025
2 parents 90d0e68 + 8d4721f commit f4bcd41
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 132 deletions.
2 changes: 1 addition & 1 deletion .clusterfuzzlite/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
FROM gcr.io/oss-fuzz-base/base-builder-go@sha256:bab77046ede6fae6f9ab931ada4b17a5adaf71842ac93e32300d06d5b9829891
FROM gcr.io/oss-fuzz-base/base-builder-go@sha256:9bf7fad8ca02443224c7518392d80c97a62b8cb0822f03aadf9193a7e27346f0

COPY . $SRC/skipper
COPY ./.clusterfuzzlite/build.sh $SRC/
Expand Down
1 change: 1 addition & 0 deletions filters/auth/authclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ type authClient struct {

type tokeninfoClient interface {
getTokeninfo(token string, ctx filters.FilterContext) (map[string]any, error)
Close()
}

var _ tokeninfoClient = &authClient{}
Expand Down
6 changes: 1 addition & 5 deletions filters/auth/grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,7 @@ func (f *grantFilter) setupToken(token *oauth2.Token, tokeninfo map[string]inter

// By piggy-backing on the OIDC token container,
// we gain downstream compatibility with the oidcClaimsQuery filter.
ctx.StateBag()[oidcClaimsCacheKey] = tokenContainer{
OAuth2Token: token,
Subject: subject,
Claims: tokeninfo,
}
SetOIDCClaims(ctx, tokeninfo)

// Set the tokeninfo also in the tokeninfoCacheKey state bag, so we
// can reuse e.g. the forwardToken() filter.
Expand Down
2 changes: 2 additions & 0 deletions filters/auth/grant_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ func newAuthProxy(t *testing.T, config *auth.OAuthConfig, routes []*eskip.Route,
fr.Register(config.NewGrantCallback())
fr.Register(config.NewGrantClaimsQuery())
fr.Register(config.NewGrantLogout())
fr.Register(auth.NewOIDCQueryClaimsFilter())

pc := proxytest.Config{
RoutingOptions: routing.Options{
Expand Down Expand Up @@ -331,6 +332,7 @@ func TestGrantFlow(t *testing.T) {
config := newGrantTestConfig(tokeninfo.URL, provider.URL)

routes := eskip.MustParse(`* -> oauthGrant()
-> oidcClaimsQuery("/:sub")
-> status(204)
-> setResponseHeader("Backend-Request-Cookie", "${request.header.Cookie}")
-> <shunt>
Expand Down
6 changes: 1 addition & 5 deletions filters/auth/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@ func TestMain(m *testing.M) {

func cleanupAuthClients() {
for _, c := range tokeninfoAuthClient {
if ac, ok := c.(*authClient); ok {
ac.Close()
} else if cc, ok := c.(*tokeninfoCache); ok {
cc.client.(*authClient).Close()
}
c.Close()
}

for _, c := range issuerAuthClient {
Expand Down
8 changes: 8 additions & 0 deletions filters/auth/oidc_introspection.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ func NewOIDCQueryClaimsFilter() filters.Spec {
}
}

// Sets OIDC claims in the state bag.
// Intended for use with the oidcClaimsQuery filter.
func SetOIDCClaims(ctx filters.FilterContext, claims map[string]interface{}) {
ctx.StateBag()[oidcClaimsCacheKey] = tokenContainer{
Claims: claims,
}
}

func (spec *oidcIntrospectionSpec) Name() string {
switch spec.typ {
case checkOIDCQueryClaims:
Expand Down
6 changes: 4 additions & 2 deletions filters/auth/tokeninfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/opentracing/opentracing-go"
"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/filters/annotate"
"github.com/zalando/skipper/metrics"
)

const (
Expand All @@ -32,9 +33,10 @@ type TokeninfoOptions struct {
Timeout time.Duration
MaxIdleConns int
Tracer opentracing.Tracer
Metrics metrics.Metrics

// CacheSize configures the maximum number of cached tokens.
// The cache evicts least recently used items first.
// The cache periodically evicts random items when number of cached tokens exceeds CacheSize.
// Zero value disables tokeninfo cache.
CacheSize int

Expand Down Expand Up @@ -100,7 +102,7 @@ func (o *TokeninfoOptions) newTokeninfoClient() (tokeninfoClient, error) {
}

if o.CacheSize > 0 {
c = newTokeninfoCache(c, o.CacheSize, o.CacheTTL)
c = newTokeninfoCache(c, o.Metrics, o.CacheSize, o.CacheTTL)
}
return c, nil
}
Expand Down
149 changes: 79 additions & 70 deletions filters/auth/tokeninfocache.go
Original file line number Diff line number Diff line change
@@ -1,48 +1,55 @@
package auth

import (
"container/list"
"maps"
"sync"
"sync/atomic"
"time"

"github.com/zalando/skipper/filters"
"github.com/zalando/skipper/metrics"
)

type (
tokeninfoCache struct {
client tokeninfoClient
size int
ttl time.Duration
now func() time.Time

mu sync.Mutex
cache map[string]*entry
// least recently used token at the end
history *list.List
client tokeninfoClient
metrics metrics.Metrics
size int
ttl time.Duration
now func() time.Time

cache sync.Map // map[string]*entry
count atomic.Int64 // estimated number of cached entries, see https://github.com/golang/go/issues/20680
quit chan struct{}
}

entry struct {
cachedAt time.Time
expiresAt time.Time
info map[string]any
// reference in the history
href *list.Element
expiresAt time.Time
info map[string]any
infoExpiresAt time.Time
}
)

var _ tokeninfoClient = &tokeninfoCache{}

const expiresInField = "expires_in"

func newTokeninfoCache(client tokeninfoClient, size int, ttl time.Duration) *tokeninfoCache {
return &tokeninfoCache{
func newTokeninfoCache(client tokeninfoClient, metrics metrics.Metrics, size int, ttl time.Duration) *tokeninfoCache {
c := &tokeninfoCache{
client: client,
metrics: metrics,
size: size,
ttl: ttl,
now: time.Now,
cache: make(map[string]*entry, size),
history: list.New(),
quit: make(chan struct{}),
}
go c.evictLoop()
return c
}

func (c *tokeninfoCache) Close() {
c.client.Close()
close(c.quit)
}

func (c *tokeninfoCache) getTokeninfo(token string, ctx filters.FilterContext) (map[string]any, error) {
Expand All @@ -58,35 +65,21 @@ func (c *tokeninfoCache) getTokeninfo(token string, ctx filters.FilterContext) (
}

func (c *tokeninfoCache) cached(token string) map[string]any {
now := c.now()

c.mu.Lock()

if e, ok := c.cache[token]; ok {
if v, ok := c.cache.Load(token); ok {
now := c.now()
e := v.(*entry)
if now.Before(e.expiresAt) {
c.history.MoveToFront(e.href)
cachedInfo := e.info
c.mu.Unlock()

// It might be ok to return cached value
// without adjusting "expires_in" to avoid copy
// if caller never modifies the result and
// when "expires_in" did not change (same second)
// or for small TTL values
info := shallowCopyOf(cachedInfo)
info := maps.Clone(e.info)

elapsed := now.Sub(e.cachedAt).Truncate(time.Second).Seconds()
info[expiresInField] = info[expiresInField].(float64) - elapsed
info[expiresInField] = e.infoExpiresAt.Sub(now).Truncate(time.Second).Seconds()
return info
} else {
// remove expired
delete(c.cache, token)
c.history.Remove(e.href)
}
}

c.mu.Unlock()

return nil
}

Expand All @@ -95,38 +88,62 @@ func (c *tokeninfoCache) tryCache(token string, info map[string]any) {
if expiresIn <= 0 {
return
}
if c.ttl > 0 && expiresIn > c.ttl {
expiresIn = c.ttl
}

now := c.now()
expiresAt := now.Add(expiresIn)
e := &entry{
info: info,
infoExpiresAt: now.Add(expiresIn),
}

if c.ttl > 0 && expiresIn > c.ttl {
e.expiresAt = now.Add(c.ttl)
} else {
e.expiresAt = e.infoExpiresAt
}

c.mu.Lock()
defer c.mu.Unlock()
if _, loaded := c.cache.Swap(token, e); !loaded {
c.count.Add(1)
}
}

if e, ok := c.cache[token]; ok {
// update
e.cachedAt = now
e.expiresAt = expiresAt
e.info = info
c.history.MoveToFront(e.href)
return
func (c *tokeninfoCache) evictLoop() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-c.quit:
return
case <-ticker.C:
c.evict()
}
}
}

// create
c.cache[token] = &entry{
cachedAt: now,
expiresAt: expiresAt,
info: info,
href: c.history.PushFront(token),
func (c *tokeninfoCache) evict() {
now := c.now()
// Evict expired entries
c.cache.Range(func(key, value any) bool {
e := value.(*entry)
if now.After(e.expiresAt) {
if c.cache.CompareAndDelete(key, value) {
c.count.Add(-1)
}
}
return true
})

// Evict random entries until the cache size is within limits
if c.count.Load() > int64(c.size) {
c.cache.Range(func(key, value any) bool {
if c.cache.CompareAndDelete(key, value) {
c.count.Add(-1)
}
return c.count.Load() > int64(c.size)
})
}

// remove least used
if len(c.cache) > c.size {
leastUsed := c.history.Back()
delete(c.cache, leastUsed.Value.(string))
c.history.Remove(leastUsed)
if c.metrics != nil {
c.metrics.UpdateGauge("tokeninfocache.count", float64(c.count.Load()))
}
}

Expand All @@ -141,11 +158,3 @@ func expiresIn(info map[string]any) time.Duration {
}
return 0
}

func shallowCopyOf(info map[string]any) map[string]any {
m := make(map[string]any, len(info))
for k, v := range info {
m[k] = v
}
return m
}
Loading

0 comments on commit f4bcd41

Please sign in to comment.