diff --git a/edgecontext/edgecontext.go b/edgecontext/edgecontext.go index 5aeb28a11..3ba1f45f4 100644 --- a/edgecontext/edgecontext.go +++ b/edgecontext/edgecontext.go @@ -4,9 +4,11 @@ import ( "context" "errors" "strings" + "sync/atomic" "time" "github.com/reddit/baseplate.go/internal/gen-go/reddit/baseplate" + "github.com/reddit/baseplate.go/log" "github.com/reddit/baseplate.go/secrets" "github.com/reddit/baseplate.go/thriftbp" "github.com/reddit/baseplate.go/timebp" @@ -24,7 +26,9 @@ var ErrNoHeader = errors.New("edgecontext: no Edge-Request header found") // global vars that will be initialized in Init function. var ( - store *secrets.Store + store *secrets.Store + logger log.Wrapper + keysValue atomic.Value ) var serializerPool = thrift.NewTSerializerPool( @@ -53,7 +57,10 @@ var deserializerPool = thrift.NewTDeserializerPool( // Config for Init function. type Config struct { + // The secret store to get the keys for jwt validation Store *secrets.Store + // The logger to log key decoding errors + Logger log.Wrapper } // Init the global state. @@ -62,6 +69,11 @@ type Config struct { // otherwise they might panic. func Init(cfg Config) error { store = cfg.Store + logger = cfg.Logger + if logger == nil { + logger = log.NopWrapper + } + store.AddMiddlewares(validatorMiddleware) return nil } diff --git a/edgecontext/validator.go b/edgecontext/validator.go index 04af7d2a3..f2cac9772 100644 --- a/edgecontext/validator.go +++ b/edgecontext/validator.go @@ -1,11 +1,17 @@ package edgecontext import ( + "crypto/rsa" "errors" + "fmt" jwt "gopkg.in/dgrijalva/jwt-go.v3" + + "github.com/reddit/baseplate.go/secrets" ) +type keysType []*rsa.PublicKey + const ( authenticationPubKeySecretPath = "secret/authentication/public-key" jwtAlg = "RS256" @@ -43,22 +49,21 @@ func shouldShortCircutError(err error) bool { // ValidateToken parses and validates a jwt token, and return the decoded // AuthenticationToken. func ValidateToken(token string) (*AuthenticationToken, error) { - sec, err := store.GetVersionedSecret(authenticationPubKeySecretPath) - if err != nil { - return nil, err + keys, ok := keysValue.Load().(keysType) + if !ok { + // This would only happen when all previous middleware parsing failed. + return nil, errors.New("no public keys loaded") } - // TODO 1: Patch upstream to support key rotation natively: + // TODO: Patch upstream to support key rotation natively: // https://github.com/dgrijalva/jwt-go/pull/372 - // - // TODO 2: Use secrets middleware to cache parsed pubkeys. var lastErr error - for _, key := range sec.GetAll() { + for _, key := range keys { token, err := jwt.ParseWithClaims( token, &AuthenticationToken{}, func(_ *jwt.Token) (interface{}, error) { - return jwt.ParseRSAPublicKeyFromPEM([]byte(key)) + return key, nil }, ) if err != nil { @@ -78,3 +83,41 @@ func ValidateToken(token string) (*AuthenticationToken, error) { } return nil, lastErr } + +func validatorMiddleware(next secrets.SecretHandlerFunc) secrets.SecretHandlerFunc { + return func(sec *secrets.Secrets) { + defer next(sec) + + versioned, err := sec.GetVersionedSecret(authenticationPubKeySecretPath) + if err != nil { + logger(fmt.Sprintf( + "Failed to get secrets %q: %v", + authenticationPubKeySecretPath, + err, + )) + return + } + + all := versioned.GetAll() + keys := make(keysType, 0, len(all)) + for i, v := range all { + key, err := jwt.ParseRSAPublicKeyFromPEM([]byte(v)) + if err != nil { + logger(fmt.Sprintf( + "Failed to parse key #%d: %v", + i, + err, + )) + } else { + keys = append(keys, key) + } + } + + if len(keys) == 0 { + logger("No valid keys in secrets store.") + return + } + + keysValue.Store(keys) + } +}