Skip to content
This repository has been archived by the owner on May 21, 2022. It is now read-only.

Commit

Permalink
Native support for key rotation in verifications
Browse files Browse the repository at this point in the history
Add native support for key rotation for ES*, HS*, RS*, and PS*
verifications.

In those SigningMethod's Verify implementations, also allow the key to
be the type of the slice of the supported key type, so that the caller
can implement the KeyFunc to return all the accepted keys together to
support key rotation.

While key rotation verification can be done on the callers' side without
this change, this change provides better performance because:

- When trying the next key, the steps before actually using the key do
  not need to be performed again.

- If a verification process failed for non-key reasons (for example,
  because it's already expired), it saves the effort to try the next
  key.
  • Loading branch information
fishy committed Feb 22, 2020
1 parent dc14462 commit f7f2ef4
Show file tree
Hide file tree
Showing 8 changed files with 205 additions and 62 deletions.
46 changes: 27 additions & 19 deletions ecdsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ func (m *SigningMethodECDSA) Alg() string {
return m.Name
}

// Implements the Verify method from SigningMethod
// For this verify method, key must be an ecdsa.PublicKey struct
// Implements the Verify method from SigningMethod.
// For this verify method, key must be in types of either *ecdsa.PublicKey or
// []*ecdsa.PublicKey (for rotation keys).
func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error {
var err error

Expand All @@ -64,35 +65,42 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa
return err
}

// Get the key
var ecdsaKey *ecdsa.PublicKey
switch k := key.(type) {
case *ecdsa.PublicKey:
ecdsaKey = k
default:
return ErrInvalidKeyType
}

if len(sig) != 2*m.KeySize {
return ErrECDSAVerification
}

r := big.NewInt(0).SetBytes(sig[:m.KeySize])
s := big.NewInt(0).SetBytes(sig[m.KeySize:])

// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

// Verify the signature
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true {
return nil
} else {
return ErrECDSAVerification
// Get the keys
var keys []*ecdsa.PublicKey
switch v := key.(type) {
case *ecdsa.PublicKey:
keys = append(keys, v)
case []*ecdsa.PublicKey:
keys = v
}
if len(keys) == 0 {
return ErrInvalidKeyType
}

var lastErr error
for _, ecdsaKey := range keys {
// Create hasher
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

// Verify the signature
if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true {
return nil
}
lastErr = ErrECDSAVerification
}
return lastErr
}

// Implements the Sign method from SigningMethod
Expand Down
47 changes: 47 additions & 0 deletions ecdsa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,53 @@ func TestECDSAVerify(t *testing.T) {
}
}

func TestECDSAVerifyKeyRotation(t *testing.T) {
targetName := "Basic ES256"
for _, data := range ecdsaTestData {
if data.name != targetName {
continue
}

var err error

key, _ := ioutil.ReadFile("test/ec256-public.pem")
var ecdsaKey *ecdsa.PublicKey
if ecdsaKey, err = jwt.ParseECPublicKeyFromPEM(key); err != nil {
t.Errorf("Unable to parse ECDSA public key: %v", err)
}

key, _ = ioutil.ReadFile("test/ec384-public.pem")
var invalidKey1 *ecdsa.PublicKey
if invalidKey1, err = jwt.ParseECPublicKeyFromPEM(key); err != nil {
t.Errorf("Unable to parse ECDSA public key: %v", err)
}

key, _ = ioutil.ReadFile("test/ec512-public.pem")
var invalidKey2 *ecdsa.PublicKey
if invalidKey2, err = jwt.ParseECPublicKeyFromPEM(key); err != nil {
t.Errorf("Unable to parse ECDSA public key: %v", err)
}

parts := strings.Split(data.tokenString, ".")

method := jwt.GetSigningMethod(data.alg)
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{invalidKey1, ecdsaKey, invalidKey2})
if err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}

err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{})
if err == nil {
t.Errorf("[%v] Empty key list passed validation", data.name)
}

err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{invalidKey1, invalidKey2})
if err == nil {
t.Errorf("[%v] Key list with only invalid keys passed validation", data.name)
}
}
}

func TestECDSASign(t *testing.T) {
for _, data := range ecdsaTestData {
var err error
Expand Down
39 changes: 24 additions & 15 deletions hmac.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,6 @@ func (m *SigningMethodHMAC) Alg() string {

// Verify the signature of HSXXX tokens. Returns nil if the signature is valid.
func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error {
// Verify the key is the right type
keyBytes, ok := key.([]byte)
if !ok {
return ErrInvalidKeyType
}

// Decode signature, for comparison
sig, err := DecodeSegment(signature)
if err != nil {
Expand All @@ -64,17 +58,32 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac
return ErrHashUnavailable
}

// This signing method is symmetric, so we validate the signature
// by reproducing the signature from the signing string and key, then
// comparing that against the provided signature.
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
if !hmac.Equal(sig, hasher.Sum(nil)) {
return ErrSignatureInvalid
// Verify the keys are the right types
var keys [][]byte
switch v := key.(type) {
case []byte:
keys = append(keys, v)
case [][]byte:
keys = v
}
if len(keys) == 0 {
return ErrInvalidKeyType
}

// No validation errors. Signature is good.
return nil
var lastErr error
for _, keyBytes := range keys {
// This signing method is symmetric, so we validate the signature
// by reproducing the signature from the signing string and key, then
// comparing that against the provided signature.
hasher := hmac.New(m.Hash.New, keyBytes)
hasher.Write([]byte(signingString))
if hmac.Equal(sig, hasher.Sum(nil)) {
// No validation errors. Signature is good.
return nil
}
lastErr = ErrSignatureInvalid
}
return lastErr
}

// Implements the Sign method from SigningMethod for this signing method.
Expand Down
29 changes: 29 additions & 0 deletions hmac_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,35 @@ func TestHMACVerify(t *testing.T) {
}
}

func TestHMACVerifyKeyRotation(t *testing.T) {
invalidKey1 := []byte("foo")
invalidKey2 := []byte("bar")
for _, data := range hmacTestData {
parts := strings.Split(data.tokenString, ".")

method := jwt.GetSigningMethod(data.alg)
err := method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{invalidKey1, hmacTestKey, invalidKey2})
if data.valid && err != nil {
t.Errorf("[%v] Error while verifying key: %v", data.name, err)
}
if !data.valid && err == nil {
t.Errorf("[%v] Invalid key passed validation", data.name)
}

if !data.valid {
continue
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{})
if err == nil {
t.Errorf("[%v] Empty key list passed validation", data.name)
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{invalidKey1, invalidKey2})
if err == nil {
t.Errorf("[%v] Key list with only invalid keys passed validation", data.name)
}
}
}

func TestHMACSign(t *testing.T) {
for _, data := range hmacTestData {
if data.valid {
Expand Down
37 changes: 25 additions & 12 deletions rsa.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func (m *SigningMethodRSA) Alg() string {
}

// Implements the Verify method from SigningMethod
// For this signing method, must be an *rsa.PublicKey structure.
// For this signing method, key must be in types of either *rsa.PublicKey or
// []*rsa.PublicKey (for rotation keys).
func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error {
var err error

Expand All @@ -55,22 +56,34 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface
return err
}

var rsaKey *rsa.PublicKey
var ok bool
if !m.Hash.Available() {
return ErrHashUnavailable
}

if rsaKey, ok = key.(*rsa.PublicKey); !ok {
var keys []*rsa.PublicKey
switch v := key.(type) {
case *rsa.PublicKey:
keys = append(keys, v)
case []*rsa.PublicKey:
keys = v
}
if len(keys) == 0 {
return ErrInvalidKeyType
}

// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
var lastErr error
for _, rsaKey := range keys {
// Create hasher
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

// Verify the signature
lastErr = rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
if lastErr == nil {
return nil
}
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

// Verify the signature
return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig)
return lastErr
}

// Implements the Sign method from SigningMethod
Expand Down
44 changes: 28 additions & 16 deletions rsa_pss.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ func init() {
}

// Implements the Verify method from SigningMethod
// For this verify method, key must be an rsa.PublicKey struct
// For this verify method, key must be in the types of either *rsa.PublicKey or
// []*rsa.PublicKey (for rotation keys).
func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error {
var err error

Expand All @@ -90,27 +91,38 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf
return err
}

var rsaKey *rsa.PublicKey
switch k := key.(type) {
case *rsa.PublicKey:
rsaKey = k
default:
return ErrInvalidKey
}

// Create hasher
if !m.Hash.Available() {
return ErrHashUnavailable
}
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

opts := m.Options
if m.VerifyOptions != nil {
opts = m.VerifyOptions
var keys []*rsa.PublicKey
switch v := key.(type) {
case *rsa.PublicKey:
keys = append(keys, v)
case []*rsa.PublicKey:
keys = v
}
if len(keys) == 0 {
return ErrInvalidKeyType
}

return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts)
var lastErr error
for _, rsaKey := range keys {
// Create hasher
hasher := m.Hash.New()
hasher.Write([]byte(signingString))

opts := m.Options
if m.VerifyOptions != nil {
opts = m.VerifyOptions
}

lastErr = rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts)
if lastErr == nil {
return nil
}
}
return lastErr
}

// Implements the Sign method from SigningMethod
Expand Down
12 changes: 12 additions & 0 deletions rsa_pss_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,18 @@ func TestRSAPSSVerify(t *testing.T) {
if !data.valid && err == nil {
t.Errorf("[%v] Invalid key passed validation", data.name)
}

if !data.valid {
continue
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{rsaPSSKey})
if err != nil {
t.Errorf("[%v] Error while verifying key list: %v", data.name, err)
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{})
if err == nil {
t.Errorf("[%v] Empty key list passed validation", data.name)
}
}
}

Expand Down
13 changes: 13 additions & 0 deletions rsa_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package jwt_test

import (
"crypto/rsa"
"github.com/dgrijalva/jwt-go"
"io/ioutil"
"strings"
Expand Down Expand Up @@ -59,6 +60,18 @@ func TestRSAVerify(t *testing.T) {
if !data.valid && err == nil {
t.Errorf("[%v] Invalid key passed validation", data.name)
}

if !data.valid {
continue
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{key})
if err != nil {
t.Errorf("[%v] Error while verifying key list: %v", data.name, err)
}
err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{})
if err == nil {
t.Errorf("[%v] Empty key list passed validation", data.name)
}
}
}

Expand Down

0 comments on commit f7f2ef4

Please sign in to comment.