From f7f2ef4f5a97bff32559685c6ce61b290e09dbf8 Mon Sep 17 00:00:00 2001 From: Yuxuan 'fishy' Wang Date: Thu, 9 Jan 2020 12:28:09 -0800 Subject: [PATCH] Native support for key rotation in verifications Add native support for key rotation for ES*, HS*, RS*, and PS* verifications. In those SigningMethod's Verify implementations, also allow the key to be the type of the slice of the supported key type, so that the caller can implement the KeyFunc to return all the accepted keys together to support key rotation. While key rotation verification can be done on the callers' side without this change, this change provides better performance because: - When trying the next key, the steps before actually using the key do not need to be performed again. - If a verification process failed for non-key reasons (for example, because it's already expired), it saves the effort to try the next key. --- ecdsa.go | 46 +++++++++++++++++++++++++++------------------- ecdsa_test.go | 47 +++++++++++++++++++++++++++++++++++++++++++++++ hmac.go | 39 ++++++++++++++++++++++++--------------- hmac_test.go | 29 +++++++++++++++++++++++++++++ rsa.go | 37 +++++++++++++++++++++++++------------ rsa_pss.go | 44 ++++++++++++++++++++++++++++---------------- rsa_pss_test.go | 12 ++++++++++++ rsa_test.go | 13 +++++++++++++ 8 files changed, 205 insertions(+), 62 deletions(-) diff --git a/ecdsa.go b/ecdsa.go index f9773812..1ad69ba2 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -53,8 +53,9 @@ func (m *SigningMethodECDSA) Alg() string { return m.Name } -// Implements the Verify method from SigningMethod -// For this verify method, key must be an ecdsa.PublicKey struct +// Implements the Verify method from SigningMethod. +// For this verify method, key must be in types of either *ecdsa.PublicKey or +// []*ecdsa.PublicKey (for rotation keys). func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error { var err error @@ -64,15 +65,6 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa return err } - // Get the key - var ecdsaKey *ecdsa.PublicKey - switch k := key.(type) { - case *ecdsa.PublicKey: - ecdsaKey = k - default: - return ErrInvalidKeyType - } - if len(sig) != 2*m.KeySize { return ErrECDSAVerification } @@ -80,19 +72,35 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa r := big.NewInt(0).SetBytes(sig[:m.KeySize]) s := big.NewInt(0).SetBytes(sig[m.KeySize:]) - // Create hasher if !m.Hash.Available() { return ErrHashUnavailable } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - // Verify the signature - if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true { - return nil - } else { - return ErrECDSAVerification + // Get the keys + var keys []*ecdsa.PublicKey + switch v := key.(type) { + case *ecdsa.PublicKey: + keys = append(keys, v) + case []*ecdsa.PublicKey: + keys = v + } + if len(keys) == 0 { + return ErrInvalidKeyType + } + + var lastErr error + for _, ecdsaKey := range keys { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true { + return nil + } + lastErr = ErrECDSAVerification } + return lastErr } // Implements the Sign method from SigningMethod diff --git a/ecdsa_test.go b/ecdsa_test.go index 753047b1..f1d16d77 100644 --- a/ecdsa_test.go +++ b/ecdsa_test.go @@ -75,6 +75,53 @@ func TestECDSAVerify(t *testing.T) { } } +func TestECDSAVerifyKeyRotation(t *testing.T) { + targetName := "Basic ES256" + for _, data := range ecdsaTestData { + if data.name != targetName { + continue + } + + var err error + + key, _ := ioutil.ReadFile("test/ec256-public.pem") + var ecdsaKey *ecdsa.PublicKey + if ecdsaKey, err = jwt.ParseECPublicKeyFromPEM(key); err != nil { + t.Errorf("Unable to parse ECDSA public key: %v", err) + } + + key, _ = ioutil.ReadFile("test/ec384-public.pem") + var invalidKey1 *ecdsa.PublicKey + if invalidKey1, err = jwt.ParseECPublicKeyFromPEM(key); err != nil { + t.Errorf("Unable to parse ECDSA public key: %v", err) + } + + key, _ = ioutil.ReadFile("test/ec512-public.pem") + var invalidKey2 *ecdsa.PublicKey + if invalidKey2, err = jwt.ParseECPublicKeyFromPEM(key); err != nil { + t.Errorf("Unable to parse ECDSA public key: %v", err) + } + + parts := strings.Split(data.tokenString, ".") + + method := jwt.GetSigningMethod(data.alg) + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{invalidKey1, ecdsaKey, invalidKey2}) + if err != nil { + t.Errorf("[%v] Error while verifying key: %v", data.name, err) + } + + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{}) + if err == nil { + t.Errorf("[%v] Empty key list passed validation", data.name) + } + + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{invalidKey1, invalidKey2}) + if err == nil { + t.Errorf("[%v] Key list with only invalid keys passed validation", data.name) + } + } +} + func TestECDSASign(t *testing.T) { for _, data := range ecdsaTestData { var err error diff --git a/hmac.go b/hmac.go index addbe5d4..cf0d5511 100644 --- a/hmac.go +++ b/hmac.go @@ -47,12 +47,6 @@ func (m *SigningMethodHMAC) Alg() string { // Verify the signature of HSXXX tokens. Returns nil if the signature is valid. func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error { - // Verify the key is the right type - keyBytes, ok := key.([]byte) - if !ok { - return ErrInvalidKeyType - } - // Decode signature, for comparison sig, err := DecodeSegment(signature) if err != nil { @@ -64,17 +58,32 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac return ErrHashUnavailable } - // This signing method is symmetric, so we validate the signature - // by reproducing the signature from the signing string and key, then - // comparing that against the provided signature. - hasher := hmac.New(m.Hash.New, keyBytes) - hasher.Write([]byte(signingString)) - if !hmac.Equal(sig, hasher.Sum(nil)) { - return ErrSignatureInvalid + // Verify the keys are the right types + var keys [][]byte + switch v := key.(type) { + case []byte: + keys = append(keys, v) + case [][]byte: + keys = v + } + if len(keys) == 0 { + return ErrInvalidKeyType } - // No validation errors. Signature is good. - return nil + var lastErr error + for _, keyBytes := range keys { + // This signing method is symmetric, so we validate the signature + // by reproducing the signature from the signing string and key, then + // comparing that against the provided signature. + hasher := hmac.New(m.Hash.New, keyBytes) + hasher.Write([]byte(signingString)) + if hmac.Equal(sig, hasher.Sum(nil)) { + // No validation errors. Signature is good. + return nil + } + lastErr = ErrSignatureInvalid + } + return lastErr } // Implements the Sign method from SigningMethod for this signing method. diff --git a/hmac_test.go b/hmac_test.go index c7e114f4..b12a724c 100644 --- a/hmac_test.go +++ b/hmac_test.go @@ -62,6 +62,35 @@ func TestHMACVerify(t *testing.T) { } } +func TestHMACVerifyKeyRotation(t *testing.T) { + invalidKey1 := []byte("foo") + invalidKey2 := []byte("bar") + for _, data := range hmacTestData { + parts := strings.Split(data.tokenString, ".") + + method := jwt.GetSigningMethod(data.alg) + err := method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{invalidKey1, hmacTestKey, invalidKey2}) + if data.valid && err != nil { + t.Errorf("[%v] Error while verifying key: %v", data.name, err) + } + if !data.valid && err == nil { + t.Errorf("[%v] Invalid key passed validation", data.name) + } + + if !data.valid { + continue + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{}) + if err == nil { + t.Errorf("[%v] Empty key list passed validation", data.name) + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{invalidKey1, invalidKey2}) + if err == nil { + t.Errorf("[%v] Key list with only invalid keys passed validation", data.name) + } + } +} + func TestHMACSign(t *testing.T) { for _, data := range hmacTestData { if data.valid { diff --git a/rsa.go b/rsa.go index e4caf1ca..7d2c0e0f 100644 --- a/rsa.go +++ b/rsa.go @@ -45,7 +45,8 @@ func (m *SigningMethodRSA) Alg() string { } // Implements the Verify method from SigningMethod -// For this signing method, must be an *rsa.PublicKey structure. +// For this signing method, key must be in types of either *rsa.PublicKey or +// []*rsa.PublicKey (for rotation keys). func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error { var err error @@ -55,22 +56,34 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface return err } - var rsaKey *rsa.PublicKey - var ok bool + if !m.Hash.Available() { + return ErrHashUnavailable + } - if rsaKey, ok = key.(*rsa.PublicKey); !ok { + var keys []*rsa.PublicKey + switch v := key.(type) { + case *rsa.PublicKey: + keys = append(keys, v) + case []*rsa.PublicKey: + keys = v + } + if len(keys) == 0 { return ErrInvalidKeyType } - // Create hasher - if !m.Hash.Available() { - return ErrHashUnavailable + var lastErr error + for _, rsaKey := range keys { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + lastErr = rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) + if lastErr == nil { + return nil + } } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - - // Verify the signature - return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) + return lastErr } // Implements the Sign method from SigningMethod diff --git a/rsa_pss.go b/rsa_pss.go index c0147086..ca99b11a 100644 --- a/rsa_pss.go +++ b/rsa_pss.go @@ -80,7 +80,8 @@ func init() { } // Implements the Verify method from SigningMethod -// For this verify method, key must be an rsa.PublicKey struct +// For this verify method, key must be in the types of either *rsa.PublicKey or +// []*rsa.PublicKey (for rotation keys). func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error { var err error @@ -90,27 +91,38 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf return err } - var rsaKey *rsa.PublicKey - switch k := key.(type) { - case *rsa.PublicKey: - rsaKey = k - default: - return ErrInvalidKey - } - - // Create hasher if !m.Hash.Available() { return ErrHashUnavailable } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - opts := m.Options - if m.VerifyOptions != nil { - opts = m.VerifyOptions + var keys []*rsa.PublicKey + switch v := key.(type) { + case *rsa.PublicKey: + keys = append(keys, v) + case []*rsa.PublicKey: + keys = v + } + if len(keys) == 0 { + return ErrInvalidKeyType } - return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts) + var lastErr error + for _, rsaKey := range keys { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + opts := m.Options + if m.VerifyOptions != nil { + opts = m.VerifyOptions + } + + lastErr = rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts) + if lastErr == nil { + return nil + } + } + return lastErr } // Implements the Sign method from SigningMethod diff --git a/rsa_pss_test.go b/rsa_pss_test.go index e0134d9d..98ed2e16 100644 --- a/rsa_pss_test.go +++ b/rsa_pss_test.go @@ -70,6 +70,18 @@ func TestRSAPSSVerify(t *testing.T) { if !data.valid && err == nil { t.Errorf("[%v] Invalid key passed validation", data.name) } + + if !data.valid { + continue + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{rsaPSSKey}) + if err != nil { + t.Errorf("[%v] Error while verifying key list: %v", data.name, err) + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{}) + if err == nil { + t.Errorf("[%v] Empty key list passed validation", data.name) + } } } diff --git a/rsa_test.go b/rsa_test.go index 7f67c5db..63787884 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -1,6 +1,7 @@ package jwt_test import ( + "crypto/rsa" "github.com/dgrijalva/jwt-go" "io/ioutil" "strings" @@ -59,6 +60,18 @@ func TestRSAVerify(t *testing.T) { if !data.valid && err == nil { t.Errorf("[%v] Invalid key passed validation", data.name) } + + if !data.valid { + continue + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{key}) + if err != nil { + t.Errorf("[%v] Error while verifying key list: %v", data.name, err) + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{}) + if err == nil { + t.Errorf("[%v] Empty key list passed validation", data.name) + } } }