diff --git a/ecdsa.go b/ecdsa.go index f9773812..1ad69ba2 100644 --- a/ecdsa.go +++ b/ecdsa.go @@ -53,8 +53,9 @@ func (m *SigningMethodECDSA) Alg() string { return m.Name } -// Implements the Verify method from SigningMethod -// For this verify method, key must be an ecdsa.PublicKey struct +// Implements the Verify method from SigningMethod. +// For this verify method, key must be in types of either *ecdsa.PublicKey or +// []*ecdsa.PublicKey (for rotation keys). func (m *SigningMethodECDSA) Verify(signingString, signature string, key interface{}) error { var err error @@ -64,15 +65,6 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa return err } - // Get the key - var ecdsaKey *ecdsa.PublicKey - switch k := key.(type) { - case *ecdsa.PublicKey: - ecdsaKey = k - default: - return ErrInvalidKeyType - } - if len(sig) != 2*m.KeySize { return ErrECDSAVerification } @@ -80,19 +72,35 @@ func (m *SigningMethodECDSA) Verify(signingString, signature string, key interfa r := big.NewInt(0).SetBytes(sig[:m.KeySize]) s := big.NewInt(0).SetBytes(sig[m.KeySize:]) - // Create hasher if !m.Hash.Available() { return ErrHashUnavailable } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - // Verify the signature - if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true { - return nil - } else { - return ErrECDSAVerification + // Get the keys + var keys []*ecdsa.PublicKey + switch v := key.(type) { + case *ecdsa.PublicKey: + keys = append(keys, v) + case []*ecdsa.PublicKey: + keys = v + } + if len(keys) == 0 { + return ErrInvalidKeyType + } + + var lastErr error + for _, ecdsaKey := range keys { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + if verifystatus := ecdsa.Verify(ecdsaKey, hasher.Sum(nil), r, s); verifystatus == true { + return nil + } + lastErr = ErrECDSAVerification } + return lastErr } // Implements the Sign method from SigningMethod diff --git a/ecdsa_test.go b/ecdsa_test.go index 753047b1..f1d16d77 100644 --- a/ecdsa_test.go +++ b/ecdsa_test.go @@ -75,6 +75,53 @@ func TestECDSAVerify(t *testing.T) { } } +func TestECDSAVerifyKeyRotation(t *testing.T) { + targetName := "Basic ES256" + for _, data := range ecdsaTestData { + if data.name != targetName { + continue + } + + var err error + + key, _ := ioutil.ReadFile("test/ec256-public.pem") + var ecdsaKey *ecdsa.PublicKey + if ecdsaKey, err = jwt.ParseECPublicKeyFromPEM(key); err != nil { + t.Errorf("Unable to parse ECDSA public key: %v", err) + } + + key, _ = ioutil.ReadFile("test/ec384-public.pem") + var invalidKey1 *ecdsa.PublicKey + if invalidKey1, err = jwt.ParseECPublicKeyFromPEM(key); err != nil { + t.Errorf("Unable to parse ECDSA public key: %v", err) + } + + key, _ = ioutil.ReadFile("test/ec512-public.pem") + var invalidKey2 *ecdsa.PublicKey + if invalidKey2, err = jwt.ParseECPublicKeyFromPEM(key); err != nil { + t.Errorf("Unable to parse ECDSA public key: %v", err) + } + + parts := strings.Split(data.tokenString, ".") + + method := jwt.GetSigningMethod(data.alg) + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{invalidKey1, ecdsaKey, invalidKey2}) + if err != nil { + t.Errorf("[%v] Error while verifying key: %v", data.name, err) + } + + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{}) + if err == nil { + t.Errorf("[%v] Empty key list passed validation", data.name) + } + + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*ecdsa.PublicKey{invalidKey1, invalidKey2}) + if err == nil { + t.Errorf("[%v] Key list with only invalid keys passed validation", data.name) + } + } +} + func TestECDSASign(t *testing.T) { for _, data := range ecdsaTestData { var err error diff --git a/hmac.go b/hmac.go index addbe5d4..cf0d5511 100644 --- a/hmac.go +++ b/hmac.go @@ -47,12 +47,6 @@ func (m *SigningMethodHMAC) Alg() string { // Verify the signature of HSXXX tokens. Returns nil if the signature is valid. func (m *SigningMethodHMAC) Verify(signingString, signature string, key interface{}) error { - // Verify the key is the right type - keyBytes, ok := key.([]byte) - if !ok { - return ErrInvalidKeyType - } - // Decode signature, for comparison sig, err := DecodeSegment(signature) if err != nil { @@ -64,17 +58,32 @@ func (m *SigningMethodHMAC) Verify(signingString, signature string, key interfac return ErrHashUnavailable } - // This signing method is symmetric, so we validate the signature - // by reproducing the signature from the signing string and key, then - // comparing that against the provided signature. - hasher := hmac.New(m.Hash.New, keyBytes) - hasher.Write([]byte(signingString)) - if !hmac.Equal(sig, hasher.Sum(nil)) { - return ErrSignatureInvalid + // Verify the keys are the right types + var keys [][]byte + switch v := key.(type) { + case []byte: + keys = append(keys, v) + case [][]byte: + keys = v + } + if len(keys) == 0 { + return ErrInvalidKeyType } - // No validation errors. Signature is good. - return nil + var lastErr error + for _, keyBytes := range keys { + // This signing method is symmetric, so we validate the signature + // by reproducing the signature from the signing string and key, then + // comparing that against the provided signature. + hasher := hmac.New(m.Hash.New, keyBytes) + hasher.Write([]byte(signingString)) + if hmac.Equal(sig, hasher.Sum(nil)) { + // No validation errors. Signature is good. + return nil + } + lastErr = ErrSignatureInvalid + } + return lastErr } // Implements the Sign method from SigningMethod for this signing method. diff --git a/hmac_test.go b/hmac_test.go index c7e114f4..b12a724c 100644 --- a/hmac_test.go +++ b/hmac_test.go @@ -62,6 +62,35 @@ func TestHMACVerify(t *testing.T) { } } +func TestHMACVerifyKeyRotation(t *testing.T) { + invalidKey1 := []byte("foo") + invalidKey2 := []byte("bar") + for _, data := range hmacTestData { + parts := strings.Split(data.tokenString, ".") + + method := jwt.GetSigningMethod(data.alg) + err := method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{invalidKey1, hmacTestKey, invalidKey2}) + if data.valid && err != nil { + t.Errorf("[%v] Error while verifying key: %v", data.name, err) + } + if !data.valid && err == nil { + t.Errorf("[%v] Invalid key passed validation", data.name) + } + + if !data.valid { + continue + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{}) + if err == nil { + t.Errorf("[%v] Empty key list passed validation", data.name) + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], [][]byte{invalidKey1, invalidKey2}) + if err == nil { + t.Errorf("[%v] Key list with only invalid keys passed validation", data.name) + } + } +} + func TestHMACSign(t *testing.T) { for _, data := range hmacTestData { if data.valid { diff --git a/rsa.go b/rsa.go index e4caf1ca..7d2c0e0f 100644 --- a/rsa.go +++ b/rsa.go @@ -45,7 +45,8 @@ func (m *SigningMethodRSA) Alg() string { } // Implements the Verify method from SigningMethod -// For this signing method, must be an *rsa.PublicKey structure. +// For this signing method, key must be in types of either *rsa.PublicKey or +// []*rsa.PublicKey (for rotation keys). func (m *SigningMethodRSA) Verify(signingString, signature string, key interface{}) error { var err error @@ -55,22 +56,34 @@ func (m *SigningMethodRSA) Verify(signingString, signature string, key interface return err } - var rsaKey *rsa.PublicKey - var ok bool + if !m.Hash.Available() { + return ErrHashUnavailable + } - if rsaKey, ok = key.(*rsa.PublicKey); !ok { + var keys []*rsa.PublicKey + switch v := key.(type) { + case *rsa.PublicKey: + keys = append(keys, v) + case []*rsa.PublicKey: + keys = v + } + if len(keys) == 0 { return ErrInvalidKeyType } - // Create hasher - if !m.Hash.Available() { - return ErrHashUnavailable + var lastErr error + for _, rsaKey := range keys { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + // Verify the signature + lastErr = rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) + if lastErr == nil { + return nil + } } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - - // Verify the signature - return rsa.VerifyPKCS1v15(rsaKey, m.Hash, hasher.Sum(nil), sig) + return lastErr } // Implements the Sign method from SigningMethod diff --git a/rsa_pss.go b/rsa_pss.go index c0147086..ca99b11a 100644 --- a/rsa_pss.go +++ b/rsa_pss.go @@ -80,7 +80,8 @@ func init() { } // Implements the Verify method from SigningMethod -// For this verify method, key must be an rsa.PublicKey struct +// For this verify method, key must be in the types of either *rsa.PublicKey or +// []*rsa.PublicKey (for rotation keys). func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interface{}) error { var err error @@ -90,27 +91,38 @@ func (m *SigningMethodRSAPSS) Verify(signingString, signature string, key interf return err } - var rsaKey *rsa.PublicKey - switch k := key.(type) { - case *rsa.PublicKey: - rsaKey = k - default: - return ErrInvalidKey - } - - // Create hasher if !m.Hash.Available() { return ErrHashUnavailable } - hasher := m.Hash.New() - hasher.Write([]byte(signingString)) - opts := m.Options - if m.VerifyOptions != nil { - opts = m.VerifyOptions + var keys []*rsa.PublicKey + switch v := key.(type) { + case *rsa.PublicKey: + keys = append(keys, v) + case []*rsa.PublicKey: + keys = v + } + if len(keys) == 0 { + return ErrInvalidKeyType } - return rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts) + var lastErr error + for _, rsaKey := range keys { + // Create hasher + hasher := m.Hash.New() + hasher.Write([]byte(signingString)) + + opts := m.Options + if m.VerifyOptions != nil { + opts = m.VerifyOptions + } + + lastErr = rsa.VerifyPSS(rsaKey, m.Hash, hasher.Sum(nil), sig, opts) + if lastErr == nil { + return nil + } + } + return lastErr } // Implements the Sign method from SigningMethod diff --git a/rsa_pss_test.go b/rsa_pss_test.go index e0134d9d..98ed2e16 100644 --- a/rsa_pss_test.go +++ b/rsa_pss_test.go @@ -70,6 +70,18 @@ func TestRSAPSSVerify(t *testing.T) { if !data.valid && err == nil { t.Errorf("[%v] Invalid key passed validation", data.name) } + + if !data.valid { + continue + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{rsaPSSKey}) + if err != nil { + t.Errorf("[%v] Error while verifying key list: %v", data.name, err) + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{}) + if err == nil { + t.Errorf("[%v] Empty key list passed validation", data.name) + } } } diff --git a/rsa_test.go b/rsa_test.go index 7f67c5db..63787884 100644 --- a/rsa_test.go +++ b/rsa_test.go @@ -1,6 +1,7 @@ package jwt_test import ( + "crypto/rsa" "github.com/dgrijalva/jwt-go" "io/ioutil" "strings" @@ -59,6 +60,18 @@ func TestRSAVerify(t *testing.T) { if !data.valid && err == nil { t.Errorf("[%v] Invalid key passed validation", data.name) } + + if !data.valid { + continue + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{key}) + if err != nil { + t.Errorf("[%v] Error while verifying key list: %v", data.name, err) + } + err = method.Verify(strings.Join(parts[0:2], "."), parts[2], []*rsa.PublicKey{}) + if err == nil { + t.Errorf("[%v] Empty key list passed validation", data.name) + } } }