diff --git a/jose/validate.go b/jose/validate.go index f4047eea..6a904167 100644 --- a/jose/validate.go +++ b/jose/validate.go @@ -162,6 +162,12 @@ func validateSigJWK(jwk *JSONWebKey) error { return nil } errctx = "kty 'OKP' and crv 'Ed25519'" + case OpaqueSigner: + for _, alg := range k.Algs() { + if jwk.Algorithm == string(alg) { + return nil + } + } } return errors.Errorf("alg '%s' is not compatible with %s", jwk.Algorithm, errctx) diff --git a/jose/validate_test.go b/jose/validate_test.go index 790d68ef..d09710be 100644 --- a/jose/validate_test.go +++ b/jose/validate_test.go @@ -13,6 +13,7 @@ import ( "github.com/pkg/errors" "github.com/smallstep/assert" + "go.step.sm/crypto/keyutil" "go.step.sm/crypto/pemutil" ) @@ -259,3 +260,62 @@ func TestValidateX5C(t *testing.T) { }) } } + +func TestValidateJWK_sig(t *testing.T) { + mustSigner := func(kty, crv string, size int) crypto.Signer { + signer, err := keyutil.GenerateSigner(kty, crv, size) + if err != nil { + t.Fatal(err) + } + return signer + } + + rsaKey := mustSigner("RSA", "", 2048) + p256Key := mustSigner("EC", "P-256", 0) + p384key := mustSigner("EC", "P-384", 0) + p521Key := mustSigner("EC", "P-521", 0) + edKey := mustSigner("OKP", "Ed25519", 0) + + type args struct { + jwk *JSONWebKey + } + tests := []struct { + name string + args args + wantErr bool + }{ + {"ok ES256", args{&JSONWebKey{Use: "sig", Algorithm: ES256, Key: p256Key}}, false}, + {"ok ES384", args{&JSONWebKey{Use: "sig", Algorithm: ES384, Key: p384key}}, false}, + {"ok ES512", args{&JSONWebKey{Use: "sig", Algorithm: ES512, Key: p521Key}}, false}, + {"ok ES256 pub", args{&JSONWebKey{Use: "sig", Algorithm: ES256, Key: p256Key.Public()}}, false}, + {"ok ES384 pub", args{&JSONWebKey{Use: "sig", Algorithm: ES384, Key: p384key.Public()}}, false}, + {"ok ES512 pub", args{&JSONWebKey{Use: "sig", Algorithm: ES512, Key: p521Key.Public()}}, false}, + {"ok RS256", args{&JSONWebKey{Use: "sig", Algorithm: RS256, Key: rsaKey}}, false}, + {"ok RS384", args{&JSONWebKey{Use: "sig", Algorithm: RS384, Key: rsaKey.Public()}}, false}, + {"ok RS512", args{&JSONWebKey{Use: "sig", Algorithm: RS512, Key: rsaKey}}, false}, + {"ok PS256", args{&JSONWebKey{Use: "sig", Algorithm: PS256, Key: rsaKey.Public()}}, false}, + {"ok PS384", args{&JSONWebKey{Use: "sig", Algorithm: PS384, Key: rsaKey}}, false}, + {"ok PS512", args{&JSONWebKey{Use: "sig", Algorithm: PS512, Key: rsaKey.Public()}}, false}, + {"ok EdDSA", args{&JSONWebKey{Use: "sig", Algorithm: EdDSA, Key: edKey}}, false}, + {"ok EdDSA pub", args{&JSONWebKey{Use: "sig", Algorithm: EdDSA, Key: edKey.Public()}}, false}, + {"ok HS256", args{&JSONWebKey{Use: "sig", Algorithm: HS256, Key: []byte("raw-key")}}, false}, + {"ok HS384", args{&JSONWebKey{Use: "sig", Algorithm: HS384, Key: []byte("raw-key")}}, false}, + {"ok HS512", args{&JSONWebKey{Use: "sig", Algorithm: HS512, Key: []byte("raw-key")}}, false}, + {"ok OpaqueSigner", args{&JSONWebKey{Use: "sig", Algorithm: ES256, Key: NewOpaqueSigner(p256Key)}}, false}, + {"fail alg empty", args{&JSONWebKey{Use: "sig", Key: p256Key}}, true}, + {"fail ECDSA", args{&JSONWebKey{Use: "sig", Algorithm: ES384, Key: p256Key}}, true}, + {"fail ECDSA pub", args{&JSONWebKey{Use: "sig", Algorithm: ES384, Key: p256Key.Public()}}, true}, + {"fail RSA", args{&JSONWebKey{Use: "sig", Algorithm: ES256, Key: rsaKey}}, true}, + {"fail Ed25519", args{&JSONWebKey{Use: "sig", Algorithm: ES256, Key: edKey}}, true}, + {"fail bytes", args{&JSONWebKey{Use: "sig", Algorithm: ES256, Key: []byte("raw-key")}}, true}, + {"fail OpaqueSigner", args{&JSONWebKey{Use: "sig", Algorithm: RS256, Key: p256Key}}, true}, + {"fail unknown", args{&JSONWebKey{Use: "sig", Algorithm: HS256, Key: "raw-key"}}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ValidateJWK(tt.args.jwk); (err != nil) != tt.wantErr { + t.Errorf("ValidateJWK() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/kms/apiv1/options.go b/kms/apiv1/options.go index f558b072..7013bb33 100644 --- a/kms/apiv1/options.go +++ b/kms/apiv1/options.go @@ -111,6 +111,46 @@ const ( TPMKMS Type = "tpmkms" ) +// TypeOf returns the type of of the given uri. +func TypeOf(rawuri string) (Type, error) { + u, err := uri.Parse(rawuri) + if err != nil { + return DefaultKMS, err + } + t := Type(u.Scheme).normalize() + if err := t.Validate(); err != nil { + return DefaultKMS, err + } + return t, nil +} + +func (t Type) normalize() Type { + return Type(strings.ToLower(string(t))) +} + +// Validate return an error if the type is not a supported one. +func (t Type) Validate() error { + typ := t.normalize() + + switch typ { + case DefaultKMS, SoftKMS: // Go crypto based kms. + return nil + case CloudKMS, AmazonKMS, AzureKMS: // Cloud based kms. + return nil + case YubiKey, PKCS11, TPMKMS: // Hardware based kms. + return nil + case SSHAgentKMS, CAPIKMS: // Others + return nil + } + + // Check other registered types + if _, ok := registry.Load(typ); ok { + return nil + } + + return fmt.Errorf("unsupported kms type %s", t) +} + // Options are the KMS options. They represent the kms object in the ca.json. type Options struct { // The type of the KMS to use. @@ -155,18 +195,7 @@ func (o *Options) Validate() error { if o == nil { return nil } - - typ := strings.ToLower(string(o.Type)) - switch Type(typ) { - case DefaultKMS, SoftKMS: // Go crypto based kms. - case CloudKMS, AmazonKMS, AzureKMS: // Cloud based kms. - case YubiKey, PKCS11, TPMKMS: // Hardware based kms. - case SSHAgentKMS, CAPIKMS: // Others - default: - return fmt.Errorf("unsupported kms type %s", o.Type) - } - - return nil + return o.Type.Validate() } // GetType returns the type in the type property or the one present in the URI. @@ -175,11 +204,7 @@ func (o *Options) GetType() (Type, error) { return o.Type, nil } if o.URI != "" { - u, err := uri.Parse(o.URI) - if err != nil { - return DefaultKMS, err - } - return Type(strings.ToLower(u.Scheme)), nil + return TypeOf(o.URI) } return SoftKMS, nil } diff --git a/kms/apiv1/options_test.go b/kms/apiv1/options_test.go index ebcb6db8..d1f1aede 100644 --- a/kms/apiv1/options_test.go +++ b/kms/apiv1/options_test.go @@ -1,9 +1,32 @@ package apiv1 import ( + "context" + "crypto" + "os" "testing" ) +type fakeKM struct{} + +func (f *fakeKM) GetPublicKey(req *GetPublicKeyRequest) (crypto.PublicKey, error) { + return nil, NotImplementedError{} +} +func (f *fakeKM) CreateKey(req *CreateKeyRequest) (*CreateKeyResponse, error) { + return nil, NotImplementedError{} +} +func (f *fakeKM) CreateSigner(req *CreateSignerRequest) (crypto.Signer, error) { + return nil, NotImplementedError{} +} +func (f *fakeKM) Close() error { return NotImplementedError{} } + +func TestMain(m *testing.M) { + Register(Type("fake"), func(ctx context.Context, opts Options) (KeyManager, error) { + return &fakeKM{}, nil + }) + os.Exit(m.Run()) +} + func TestOptions_Validate(t *testing.T) { tests := []struct { name string @@ -115,3 +138,41 @@ func TestErrAlreadyExists_Error(t *testing.T) { }) } } + +func TestTypeOf(t *testing.T) { + type args struct { + rawuri string + } + tests := []struct { + name string + args args + want Type + wantErr bool + }{ + {"ok softkms", args{"softkms:foo=bar"}, SoftKMS, false}, + {"ok cloudkms", args{"CLOUDKMS:"}, CloudKMS, false}, + {"ok amazonkms", args{"awskms:foo=bar"}, AmazonKMS, false}, + {"ok pkcs11", args{"PKCS11:foo=bar"}, PKCS11, false}, + {"ok yubikey", args{"yubikey:foo=bar"}, YubiKey, false}, + {"ok sshagentkms", args{"sshagentkms:"}, SSHAgentKMS, false}, + {"ok azurekms", args{"azurekms:foo=bar"}, AzureKMS, false}, + {"ok capi", args{"CAPI:foo-bar"}, CAPIKMS, false}, + {"ok tpmkms", args{"tpmkms:"}, TPMKMS, false}, + {"ok registered", args{"FAKE:"}, Type("fake"), false}, + {"fail empty", args{""}, DefaultKMS, true}, + {"fail parse", args{"softkms"}, DefaultKMS, true}, + {"fail kms", args{"foobar:foo=bar"}, DefaultKMS, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := TypeOf(tt.args.rawuri) + if (err != nil) != tt.wantErr { + t.Errorf("TypeOf() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("TypeOf() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/kms/apiv1/registry.go b/kms/apiv1/registry.go index 5a8cf4db..51811499 100644 --- a/kms/apiv1/registry.go +++ b/kms/apiv1/registry.go @@ -13,12 +13,12 @@ type KeyManagerNewFunc func(ctx context.Context, opts Options) (KeyManager, erro // Register adds to the registry a method to create a KeyManager of type t. func Register(t Type, fn KeyManagerNewFunc) { - registry.Store(t, fn) + registry.Store(t.normalize(), fn) } // LoadKeyManagerNewFunc returns the function initialize a KayManager. func LoadKeyManagerNewFunc(t Type) (KeyManagerNewFunc, bool) { - v, ok := registry.Load(t) + v, ok := registry.Load(t.normalize()) if !ok { return nil, false } diff --git a/kms/kms.go b/kms/kms.go index e06fbdec..c5ad7042 100644 --- a/kms/kms.go +++ b/kms/kms.go @@ -32,6 +32,9 @@ type Options = apiv1.Options // Type represents the KMS type used. type Type = apiv1.Type +// TypeOf returns the KMS type of the given uri. +var TypeOf = apiv1.TypeOf + // Default is the implementation of the default KMS. var Default = &softkms.SoftKMS{}