diff --git a/go.mod b/go.mod index 590f75a1..3edf4f53 100644 --- a/go.mod +++ b/go.mod @@ -10,7 +10,8 @@ require ( github.com/Azure/azure-sdk-for-go/sdk/keyvault/azkeys v0.10.0 github.com/Masterminds/sprig/v3 v3.2.3 github.com/ThalesIgnite/crypto11 v1.2.5 - github.com/aws/aws-sdk-go v1.49.21 + github.com/aws/aws-sdk-go-v2/config v1.26.5 + github.com/aws/aws-sdk-go-v2/service/kms v1.27.9 github.com/go-jose/go-jose/v3 v3.0.1 github.com/go-piv/piv-go v1.11.0 github.com/golang/mock v1.6.0 @@ -40,6 +41,18 @@ require ( github.com/AzureAD/microsoft-authentication-library-for-go v1.1.1 // indirect github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.0 // indirect + github.com/aws/aws-sdk-go-v2 v1.24.1 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.16.16 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 // indirect + github.com/aws/smithy-go v1.19.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/go-logr/logr v1.3.0 // indirect @@ -55,7 +68,6 @@ require ( github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect github.com/huandu/xstrings v1.3.3 // indirect github.com/imdario/mergo v0.3.12 // indirect - github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/kr/pretty v0.3.1 // indirect github.com/kylelemons/godebug v1.1.0 // indirect github.com/miekg/pkcs11 v1.0.3 // indirect @@ -71,6 +83,7 @@ require ( go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 // indirect go.opentelemetry.io/otel v1.21.0 // indirect go.opentelemetry.io/otel/metric v1.21.0 // indirect + go.opentelemetry.io/otel/sdk v1.21.0 // indirect go.opentelemetry.io/otel/trace v1.21.0 // indirect golang.org/x/oauth2 v0.16.0 // indirect golang.org/x/sync v0.6.0 // indirect diff --git a/go.sum b/go.sum index 15e19a4a..6c211c90 100644 --- a/go.sum +++ b/go.sum @@ -149,9 +149,35 @@ github.com/aws/aws-sdk-go v1.23.20/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpi github.com/aws/aws-sdk-go v1.25.11/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.27.0/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= github.com/aws/aws-sdk-go v1.37.0/go.mod h1:hcU610XS61/+aQV88ixoOzUoG7v3b31pl2zKMmprdro= -github.com/aws/aws-sdk-go v1.49.21 h1:Rl8KW6HqkwzhATwvXhyr7vD4JFUMi7oXGAw9SrxxIFY= -github.com/aws/aws-sdk-go v1.49.21/go.mod h1:LF8svs817+Nz+DmiMQKTO3ubZ/6IaTpq3TjupRn3Eqk= github.com/aws/aws-sdk-go-v2 v0.18.0/go.mod h1:JWVYvqSMppoMJC0x5wdwiImzgXTI9FuZwxzkQq9wy+g= +github.com/aws/aws-sdk-go-v2 v1.24.1 h1:xAojnj+ktS95YZlDf0zxWBkbFtymPeDP+rvUQIH3uAU= +github.com/aws/aws-sdk-go-v2 v1.24.1/go.mod h1:LNh45Br1YAkEKaAqvmE1m8FUx6a5b/V0oAKV7of29b4= +github.com/aws/aws-sdk-go-v2/config v1.26.5 h1:lodGSevz7d+kkFJodfauThRxK9mdJbyutUxGq1NNhvw= +github.com/aws/aws-sdk-go-v2/config v1.26.5/go.mod h1:DxHrz6diQJOc9EwDslVRh84VjjrE17g+pVZXUeSxaDU= +github.com/aws/aws-sdk-go-v2/credentials v1.16.16 h1:8q6Rliyv0aUFAVtzaldUEcS+T5gbadPbWdV1WcAddK8= +github.com/aws/aws-sdk-go-v2/credentials v1.16.16/go.mod h1:UHVZrdUsv63hPXFo1H7c5fEneoVo9UXiz36QG1GEPi0= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10 h1:vF+Zgd9s+H4vOXd5BMaPWykta2a6Ih0AKLq/X6NYKn4= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.2.10/go.mod h1:6BkRjejp/GR4411UGqkX8+wFMbFbqsUIimfK4XjOKR4= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10 h1:nYPe006ktcqUji8S2mqXf9c/7NdiKriOwMvWQHgYztw= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.5.10/go.mod h1:6UV4SZkVvmODfXKql4LCbaZUpF7HO2BX38FgBf9ZOLw= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2 h1:GrSw8s0Gs/5zZ0SX+gX4zQjRnRsMJDJ2sLur1gRBhEM= +github.com/aws/aws-sdk-go-v2/internal/ini v1.7.2/go.mod h1:6fQQgfuGmw8Al/3M2IgIllycxV7ZW7WCdVSqfBeUiCY= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino= +github.com/aws/aws-sdk-go-v2/service/kms v1.27.9 h1:W9PbZAZAEcelhhjb7KuwUtf+Lbc+i7ByYJRuWLlnxyQ= +github.com/aws/aws-sdk-go-v2/service/kms v1.27.9/go.mod h1:2tFmR7fQnOdQlM2ZCEPpFnBIQD1U8wmXmduBgZbOag0= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 h1:eajuO3nykDPdYicLlP3AGgOyVN3MOlFmZv7WGTuJPow= +github.com/aws/aws-sdk-go-v2/service/sso v1.18.7/go.mod h1:+mJNDdF+qiUlNKNC3fxn74WWNN+sOiGOEImje+3ScPM= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7 h1:QPMJf+Jw8E1l7zqhZmMlFw6w1NmfkfiSK8mS4zOx3BA= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.21.7/go.mod h1:ykf3COxYI0UJmxcfcxcVuz7b6uADi1FkiUz6Eb7AgM8= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.7 h1:NzO4Vrau795RkUdSHKEwiR01FaGzGOH1EETJ+5QHnm0= +github.com/aws/aws-sdk-go-v2/service/sts v1.26.7/go.mod h1:6h2YuIoxaMSCFf5fi1EgZAwdfkGMgDY+DVfa61uLe4U= +github.com/aws/smithy-go v1.19.0 h1:KWFKQV80DpP3vJrrA9sVAHQ5gc2z8i4EzrLhLlWXcBM= +github.com/aws/smithy-go v1.19.0/go.mod h1:NukqUGpCZIILqqiV0NIjeFh24kd/FAa4beRb6nbIUPE= github.com/aybabtme/rgbterm v0.0.0-20170906152045-cc83f3b3ce59/go.mod h1:q/89r3U2H7sSsE2t6Kca0lfwTK8JdoNGS/yzM/4iH5I= github.com/benbjohnson/clock v1.0.3/go.mod h1:bGMdMPoPVvcYyt1gHDf4J2KE153Yf9BuiUKYMaxlTDM= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= @@ -480,9 +506,7 @@ github.com/jhump/protoreflect v1.8.2/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8 github.com/jhump/protoreflect v1.9.0/go.mod h1:7GcYQDdMU/O/BBrl/cX6PNHpXh6cenjd8pneu5yW7Tg= github.com/jmespath/go-jmespath v0.0.0-20160202185014-0b12d6b521d8/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= github.com/jmespath/go-jmespath v0.0.0-20180206201540-c2b33e8439af/go.mod h1:Nht3zPeWKUH0NzdCt2Blrr5ys8VGpn0CEB0cQHVjt7k= -github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= -github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8= github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U= github.com/joho/godotenv v1.3.0/go.mod h1:7hK45KPybAkOC6peb+G5yklZfMxEjkZhHbwpqxOKXbg= github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22gdxWY5EU2bo= @@ -831,8 +855,9 @@ go.opentelemetry.io/otel/metric v0.20.0/go.mod h1:598I5tYlH1vzBjn+BTuhzTCSb/9deb go.opentelemetry.io/otel/metric v1.21.0 h1:tlYWfeo+Bocx5kLEloTjbcDwBuELRrIFxwdQ36PlJu4= go.opentelemetry.io/otel/metric v1.21.0/go.mod h1:o1p3CA8nNHW8j5yuQLdc1eeqEaPfzug24uvsyIEJRWM= go.opentelemetry.io/otel/oteltest v0.20.0/go.mod h1:L7bgKf9ZB7qCwT9Up7i9/pn0PWIa9FqQ2IQ8LoxiGnw= -go.opentelemetry.io/otel/sdk v0.20.0 h1:JsxtGXd06J8jrnya7fdI/U/MR6yXA5DtbZy+qoHQlr8= go.opentelemetry.io/otel/sdk v0.20.0/go.mod h1:g/IcepuwNsoiX5Byy2nNV0ySUF1em498m7hBWC279Yc= +go.opentelemetry.io/otel/sdk v1.21.0 h1:FTt8qirL1EysG6sTQRZ5TokkU8d0ugCj8htOgThZXQ8= +go.opentelemetry.io/otel/sdk v1.21.0/go.mod h1:Nna6Yv7PWTdgJHVRD9hIYywQBRx7pbox6nwBnZIxl/E= go.opentelemetry.io/otel/sdk/export/metric v0.20.0/go.mod h1:h7RBNMsDJ5pmI1zExLi+bJK+Dr8NQCh0qGhm1KDnNlE= go.opentelemetry.io/otel/sdk/metric v0.20.0/go.mod h1:knxiS8Xd4E/N+ZqKmUPf3gTTZ4/0TjTXukfxjzSTpHE= go.opentelemetry.io/otel/trace v0.20.0/go.mod h1:6GjCW8zgDjwGHGa6GkyeB8+/5vjT16gUEi0Nf1iBdgw= diff --git a/kms/awskms/awskms.go b/kms/awskms/awskms.go index a85dc71a..82665429 100644 --- a/kms/awskms/awskms.go +++ b/kms/awskms/awskms.go @@ -10,10 +10,9 @@ import ( "strings" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/pkg/errors" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" @@ -25,92 +24,97 @@ const Scheme = string(apiv1.AmazonKMS) // KMS implements a KMS using AWS Key Management Service. type KMS struct { - session *session.Session - service KeyManagementClient + client KeyManagementClient } // KeyManagementClient defines the methods on KeyManagementClient that this // package will use. This interface will be used for unit testing. type KeyManagementClient interface { - GetPublicKeyWithContext(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) - CreateKeyWithContext(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) - CreateAliasWithContext(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) - SignWithContext(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) + GetPublicKey(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) + CreateKey(ctx context.Context, input *kms.CreateKeyInput, opts ...func(*kms.Options)) (*kms.CreateKeyOutput, error) + CreateAlias(ctx context.Context, input *kms.CreateAliasInput, opts ...func(*kms.Options)) (*kms.CreateAliasOutput, error) + Sign(ctx context.Context, input *kms.SignInput, opts ...func(*kms.Options)) (*kms.SignOutput, error) } // customerMasterKeySpecMapping is a mapping between the step signature algorithm, // and bits for RSA keys, with awskms CustomerMasterKeySpec. var customerMasterKeySpecMapping = map[apiv1.SignatureAlgorithm]interface{}{ - apiv1.UnspecifiedSignAlgorithm: kms.CustomerMasterKeySpecEccNistP256, - apiv1.SHA256WithRSA: map[int]string{ - 0: kms.CustomerMasterKeySpecRsa3072, - 2048: kms.CustomerMasterKeySpecRsa2048, - 3072: kms.CustomerMasterKeySpecRsa3072, - 4096: kms.CustomerMasterKeySpecRsa4096, + apiv1.UnspecifiedSignAlgorithm: types.KeySpecEccNistP256, + apiv1.SHA256WithRSA: map[int]types.KeySpec{ + 0: types.KeySpecRsa3072, + 2048: types.KeySpecRsa2048, + 3072: types.KeySpecRsa3072, + 4096: types.KeySpecRsa4096, }, - apiv1.SHA512WithRSA: map[int]string{ - 0: kms.CustomerMasterKeySpecRsa4096, - 4096: kms.CustomerMasterKeySpecRsa4096, + apiv1.SHA512WithRSA: map[int]types.KeySpec{ + 0: types.KeySpecRsa3072, + 2048: types.KeySpecRsa2048, + 3072: types.KeySpecRsa3072, + 4096: types.KeySpecRsa4096, }, - apiv1.SHA256WithRSAPSS: map[int]string{ - 0: kms.CustomerMasterKeySpecRsa3072, - 2048: kms.CustomerMasterKeySpecRsa2048, - 3072: kms.CustomerMasterKeySpecRsa3072, - 4096: kms.CustomerMasterKeySpecRsa4096, + apiv1.SHA256WithRSAPSS: map[int]types.KeySpec{ + 0: types.KeySpecRsa3072, + 2048: types.KeySpecRsa2048, + 3072: types.KeySpecRsa3072, + 4096: types.KeySpecRsa4096, }, - apiv1.SHA512WithRSAPSS: map[int]string{ - 0: kms.CustomerMasterKeySpecRsa4096, - 4096: kms.CustomerMasterKeySpecRsa4096, + apiv1.SHA512WithRSAPSS: map[int]types.KeySpec{ + 0: types.KeySpecRsa3072, + 2048: types.KeySpecRsa2048, + 3072: types.KeySpecRsa3072, + 4096: types.KeySpecRsa4096, }, - apiv1.ECDSAWithSHA256: kms.CustomerMasterKeySpecEccNistP256, - apiv1.ECDSAWithSHA384: kms.CustomerMasterKeySpecEccNistP384, - apiv1.ECDSAWithSHA512: kms.CustomerMasterKeySpecEccNistP521, + apiv1.ECDSAWithSHA256: types.KeySpecEccNistP256, + apiv1.ECDSAWithSHA384: types.KeySpecEccNistP384, + apiv1.ECDSAWithSHA512: types.KeySpecEccNistP521, } -// New creates a new AWSKMS. By default, sessions will be created using the +// New creates a new AWSKMS. By default, clients will be created using the // credentials in `~/.aws/credentials`, but this can be overridden using the // CredentialsFile option, the Region and Profile can also be configured as // options. // -// AWS sessions can also be configured with environment variables, see docs at -// https://docs.aws.amazon.com/sdk-for-go/api/aws/session/ for all the options. -func New(_ context.Context, opts apiv1.Options) (*KMS, error) { - var o session.Options +// AWS clients can also be configured with environment variables, see docs at +// https://aws.github.io/aws-sdk-go-v2/docs/configuring-sdk/ for all the +// options. +func New(ctx context.Context, opts apiv1.Options) (*KMS, error) { + var configOptions []func(*config.LoadOptions) error if opts.URI != "" { u, err := uri.ParseWithScheme(Scheme, opts.URI) if err != nil { return nil, err } - o.Profile = u.Get("profile") + + if v := u.Get("profile"); v != "" { + configOptions = append(configOptions, config.WithSharedConfigProfile(v)) + } if v := u.Get("region"); v != "" { - o.Config.Region = new(string) - *o.Config.Region = v + configOptions = append(configOptions, config.WithRegion(v)) } - if f := u.Get("credentials-file"); f != "" { - o.SharedConfigFiles = []string{f} + if v := u.Get("credentials-file"); v != "" { + configOptions = append(configOptions, config.WithSharedConfigFiles([]string{v})) } } // Deprecated way to set configuration parameters. if opts.Region != "" { - o.Config.Region = &opts.Region + configOptions = append(configOptions, config.WithRegion(opts.Region)) } if opts.Profile != "" { - o.Profile = opts.Profile + configOptions = append(configOptions, config.WithSharedConfigProfile(opts.Profile)) } if opts.CredentialsFile != "" { - o.SharedConfigFiles = []string{opts.CredentialsFile} + configOptions = append(configOptions, config.WithSharedConfigFiles([]string{opts.CredentialsFile})) } - sess, err := session.NewSessionWithOptions(o) + cfg, err := config.LoadDefaultConfig(ctx, configOptions...) if err != nil { - return nil, errors.Wrap(err, "error creating AWS session") + return nil, errors.Wrap(err, "error loading AWS config") } return &KMS{ - session: sess, - service: kms.New(sess), + client: kms.NewFromConfig(cfg), }, nil } @@ -125,6 +129,7 @@ func (k *KMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, er if req.Name == "" { return nil, errors.New("getPublicKey 'name' cannot be empty") } + keyID, err := parseKeyID(req.Name) if err != nil { return nil, err @@ -133,11 +138,11 @@ func (k *KMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, er ctx, cancel := defaultContext() defer cancel() - resp, err := k.service.GetPublicKeyWithContext(ctx, &kms.GetPublicKeyInput{ + resp, err := k.client.GetPublicKey(ctx, &kms.GetPublicKeyInput{ KeyId: &keyID, }) if err != nil { - return nil, errors.Wrap(err, "awskms GetPublicKeyWithContext failed") + return nil, errors.Wrap(err, "awskms GetPublicKey failed") } return pemutil.ParseDER(resp.PublicKey) @@ -150,30 +155,36 @@ func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, return nil, errors.New("createKeyRequest 'name' cannot be empty") } + keyName, err := parseName(req.Name) + if err != nil { + return nil, err + } + keySpec, err := getCustomerMasterKeySpecMapping(req.SignatureAlgorithm, req.Bits) if err != nil { return nil, err } - tag := new(kms.Tag) - tag.SetTagKey("name") - tag.SetTagValue(req.Name) + tag := types.Tag{ + TagKey: pointer("name"), + TagValue: pointer(keyName), + } input := &kms.CreateKeyInput{ - Description: &req.Name, - CustomerMasterKeySpec: &keySpec, - Tags: []*kms.Tag{tag}, + Description: pointer(keyName), + KeySpec: keySpec, + Tags: []types.Tag{tag}, + KeyUsage: types.KeyUsageTypeSignVerify, } - input.SetKeyUsage(kms.KeyUsageTypeSignVerify) ctx, cancel := defaultContext() defer cancel() - resp, err := k.service.CreateKeyWithContext(ctx, input) + resp, err := k.client.CreateKey(ctx, input) if err != nil { - return nil, errors.Wrap(err, "awskms CreateKeyWithContext failed") + return nil, errors.Wrap(err, "awskms CreateKey failed") } - if err := k.createKeyAlias(*resp.KeyMetadata.KeyId, req.Name); err != nil { + if err := k.createKeyAlias(*resp.KeyMetadata.KeyId, keyName); err != nil { return nil, err } @@ -201,17 +212,15 @@ func (k *KMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, } func (k *KMS) createKeyAlias(keyID, alias string) error { - alias = "alias/" + alias + "-" + keyID[:8] - ctx, cancel := defaultContext() defer cancel() - _, err := k.service.CreateAliasWithContext(ctx, &kms.CreateAliasInput{ - AliasName: &alias, - TargetKeyId: &keyID, + _, err := k.client.CreateAlias(ctx, &kms.CreateAliasInput{ + AliasName: pointer("alias/" + alias + "-" + keyID[:8]), + TargetKeyId: pointer(keyID), }) if err != nil { - return errors.Wrap(err, "awskms CreateAliasWithContext failed") + return errors.Wrap(err, "awskms CreateAlias failed") } return nil } @@ -221,7 +230,7 @@ func (k *KMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error if req.SigningKey == "" { return nil, errors.New("createSigner 'signingKey' cannot be empty") } - return NewSigner(k.service, req.SigningKey) + return NewSigner(k.client, req.SigningKey) } // Close closes the connection of the KMS client. @@ -229,6 +238,10 @@ func (k *KMS) Close() error { return nil } +func pointer[T any](v T) *T { + return &v +} + func defaultContext() (context.Context, context.CancelFunc) { return context.WithTimeout(context.Background(), 15*time.Second) } @@ -249,21 +262,36 @@ func parseKeyID(name string) (string, error) { return name, nil } -func getCustomerMasterKeySpecMapping(alg apiv1.SignatureAlgorithm, bits int) (string, error) { +// parseName extracts the name from an uri. +func parseName(rawuri string) (string, error) { + if strings.HasPrefix(rawuri, "awskms:") || strings.HasPrefix(rawuri, "aws:") { + u, err := uri.Parse(rawuri) + if err != nil { + return "", err + } + if k := u.Get("name"); k != "" { + return k, nil + } + return "", errors.Errorf("failed to get name from %s", rawuri) + } + return rawuri, nil +} + +func getCustomerMasterKeySpecMapping(alg apiv1.SignatureAlgorithm, bits int) (types.KeySpec, error) { v, ok := customerMasterKeySpecMapping[alg] if !ok { return "", errors.Errorf("awskms does not support signature algorithm '%s'", alg) } switch v := v.(type) { - case string: + case types.KeySpec: return v, nil - case map[int]string: - s, ok := v[bits] + case map[int]types.KeySpec: + ks, ok := v[bits] if !ok { return "", errors.Errorf("awskms does not support signature algorithm '%s' with '%d' bits", alg, bits) } - return s, nil + return ks, nil default: return "", errors.Errorf("unexpected error: this should not happen") } diff --git a/kms/awskms/awskms_test.go b/kms/awskms/awskms_test.go index 0b2196f6..cd821346 100644 --- a/kms/awskms/awskms_test.go +++ b/kms/awskms/awskms_test.go @@ -4,44 +4,34 @@ import ( "context" "crypto" "fmt" - "os" - "path/filepath" "reflect" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/aws/session" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/pemutil" ) +func TestRegister(t *testing.T) { + fn, ok := apiv1.LoadKeyManagerNewFunc(apiv1.AmazonKMS) + require.True(t, ok) + _, err := fn(context.Background(), apiv1.Options{}) + require.NoError(t, err) +} + func TestNew(t *testing.T) { ctx := context.Background() - sess, err := session.NewSessionWithOptions(session.Options{}) + cfg, err := config.LoadDefaultConfig(context.Background()) if err != nil { t.Fatal(err) } expected := &KMS{ - session: sess, - service: kms.New(sess), - } - - // This will force an error in the session creation. - // It does not fail with missing credentials. - forceError := func(t *testing.T) { - key := "AWS_CA_BUNDLE" - value := os.Getenv(key) - t.Setenv(key, filepath.Join(os.TempDir(), "missing-ca.crt")) - t.Cleanup(func() { - if value == "" { - os.Unsetenv(key) - } else { - t.Setenv(key, value) - } - }) + client: kms.NewFromConfig(cfg), } type args struct { @@ -55,26 +45,20 @@ func TestNew(t *testing.T) { wantErr bool }{ {"ok", args{ctx, apiv1.Options{}}, expected, false}, - {"ok with options", args{ctx, apiv1.Options{ + {"fail with options", args{ctx, apiv1.Options{ Region: "us-east-1", Profile: "smallstep", - CredentialsFile: "~/aws/credentials", - }}, expected, false}, - {"ok with uri", args{ctx, apiv1.Options{ - URI: "awskms:region=us-east-1;profile=smallstep;credentials-file=/var/run/aws/credentials", - }}, expected, false}, - {"fail", args{ctx, apiv1.Options{}}, nil, true}, - {"fail uri", args{ctx, apiv1.Options{ + CredentialsFile: "~/aws/missing", + }}, nil, true}, + {"fail with uri", args{ctx, apiv1.Options{ + URI: "awskms:region=us-east-1;profile=smallstep;credentials-file=/var/run/aws/missing", + }}, nil, true}, + {"fail bad uri", args{ctx, apiv1.Options{ URI: "pkcs11:region=us-east-1;profile=smallstep;credentials-file=/var/run/aws/credentials", }}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - // Force an error in the session loading - if tt.wantErr { - forceError(t) - } - got, err := New(tt.args.ctx, tt.args.opts) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) @@ -85,7 +69,7 @@ func TestNew(t *testing.T) { t.Errorf("New() = %#v, want %#v", got, tt.want) } } else { - if got.session == nil || got.service == nil { + if got.client == nil { t.Errorf("New() = %#v, want %#v", got, tt.want) } } @@ -101,8 +85,7 @@ func TestKMS_GetPublicKey(t *testing.T) { } type fields struct { - session *session.Session - service KeyManagementClient + client KeyManagementClient } type args struct { req *apiv1.GetPublicKeyRequest @@ -114,22 +97,22 @@ func TestKMS_GetPublicKey(t *testing.T) { want crypto.PublicKey wantErr bool }{ - {"ok", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{ + {"ok", fields{okClient}, args{&apiv1.GetPublicKeyRequest{ Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }}, key, false}, - {"fail empty", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{}}, nil, true}, - {"fail name", fields{nil, okClient}, args{&apiv1.GetPublicKeyRequest{ + {"fail empty", fields{okClient}, args{&apiv1.GetPublicKeyRequest{}}, nil, true}, + {"fail name", fields{okClient}, args{&apiv1.GetPublicKeyRequest{ Name: "awskms:key-id=", }}, nil, true}, - {"fail getPublicKey", fields{nil, &MockClient{ - getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + {"fail getPublicKey", fields{&MockClient{ + getPublicKey: func(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { return nil, fmt.Errorf("an error") }, }}, args{&apiv1.GetPublicKeyRequest{ Name: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }}, nil, true}, - {"fail not der", fields{nil, &MockClient{ - getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + {"fail not der", fields{&MockClient{ + getPublicKey: func(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { return &kms.GetPublicKeyOutput{ KeyId: input.KeyId, PublicKey: []byte(publicKey), @@ -142,8 +125,7 @@ func TestKMS_GetPublicKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &KMS{ - session: tt.fields.session, - service: tt.fields.service, + client: tt.fields.client, } got, err := k.GetPublicKey(tt.args.req) if (err != nil) != tt.wantErr { @@ -165,8 +147,7 @@ func TestKMS_CreateKey(t *testing.T) { } type fields struct { - session *session.Session - service KeyManagementClient + client KeyManagementClient } type args struct { req *apiv1.CreateKeyRequest @@ -178,7 +159,7 @@ func TestKMS_CreateKey(t *testing.T) { want *apiv1.CreateKeyResponse wantErr bool }{ - {"ok", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ + {"ok", fields{okClient}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.ECDSAWithSHA256, }}, &apiv1.CreateKeyResponse{ @@ -188,8 +169,8 @@ func TestKMS_CreateKey(t *testing.T) { SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }, }, false}, - {"ok rsa", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ - Name: "root", + {"ok rsa with uri", fields{okClient}, args{&apiv1.CreateKeyRequest{ + Name: "awskms:name=root", SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 2048, }}, &apiv1.CreateKeyResponse{ @@ -199,40 +180,48 @@ func TestKMS_CreateKey(t *testing.T) { SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }, }, false}, - {"fail empty", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{}}, nil, true}, - {"fail unsupported alg", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ + {"fail empty", fields{okClient}, args{&apiv1.CreateKeyRequest{}}, nil, true}, + {"fail unsupported alg", fields{okClient}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.PureEd25519, }}, nil, true}, - {"fail unsupported bits", fields{nil, okClient}, args{&apiv1.CreateKeyRequest{ + {"fail unsupported bits", fields{okClient}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.SHA256WithRSA, Bits: 1234, }}, nil, true}, - {"fail createKey", fields{nil, &MockClient{ - createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) { + {"fail uri parse", fields{okClient}, args{&apiv1.CreateKeyRequest{ + Name: "awskms:%name=root", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, nil, true}, + {"fail uri no name", fields{okClient}, args{&apiv1.CreateKeyRequest{ + Name: "awskms:name", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }}, nil, true}, + {"fail createKey", fields{&MockClient{ + createKey: func(ctx context.Context, input *kms.CreateKeyInput, opts ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { return nil, fmt.Errorf("an error") }, - createAliasWithContext: okClient.createAliasWithContext, - getPublicKeyWithContext: okClient.getPublicKeyWithContext, + createAlias: okClient.createAlias, + getPublicKey: okClient.getPublicKey, }}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.ECDSAWithSHA256, }}, nil, true}, - {"fail createAlias", fields{nil, &MockClient{ - createKeyWithContext: okClient.createKeyWithContext, - createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) { + {"fail createAlias", fields{&MockClient{ + createKey: okClient.createKey, + createAlias: func(ctx context.Context, input *kms.CreateAliasInput, opts ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { return nil, fmt.Errorf("an error") }, - getPublicKeyWithContext: okClient.getPublicKeyWithContext, + getPublicKey: okClient.getPublicKey, }}, args{&apiv1.CreateKeyRequest{ Name: "root", SignatureAlgorithm: apiv1.ECDSAWithSHA256, }}, nil, true}, - {"fail getPublicKey", fields{nil, &MockClient{ - createKeyWithContext: okClient.createKeyWithContext, - createAliasWithContext: okClient.createAliasWithContext, - getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + {"fail getPublicKey", fields{&MockClient{ + createKey: okClient.createKey, + createAlias: okClient.createAlias, + getPublicKey: func(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { return nil, fmt.Errorf("an error") }, }}, args{&apiv1.CreateKeyRequest{ @@ -243,8 +232,7 @@ func TestKMS_CreateKey(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &KMS{ - session: tt.fields.session, - service: tt.fields.service, + client: tt.fields.client, } got, err := k.CreateKey(tt.args.req) if (err != nil) != tt.wantErr { @@ -266,8 +254,7 @@ func TestKMS_CreateSigner(t *testing.T) { } type fields struct { - session *session.Session - service KeyManagementClient + client KeyManagementClient } type args struct { req *apiv1.CreateSignerRequest @@ -279,21 +266,20 @@ func TestKMS_CreateSigner(t *testing.T) { want crypto.Signer wantErr bool }{ - {"ok", fields{nil, client}, args{&apiv1.CreateSignerRequest{ + {"ok", fields{client}, args{&apiv1.CreateSignerRequest{ SigningKey: "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936", }}, &Signer{ - service: client, + client: client, keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936", publicKey: key, }, false}, - {"fail empty", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true}, - {"fail preload", fields{nil, client}, args{&apiv1.CreateSignerRequest{}}, nil, true}, + {"fail empty", fields{client}, args{&apiv1.CreateSignerRequest{}}, nil, true}, + {"fail preload", fields{client}, args{&apiv1.CreateSignerRequest{}}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &KMS{ - session: tt.fields.session, - service: tt.fields.service, + client: tt.fields.client, } got, err := k.CreateSigner(tt.args.req) if (err != nil) != tt.wantErr { @@ -309,21 +295,19 @@ func TestKMS_CreateSigner(t *testing.T) { func TestKMS_Close(t *testing.T) { type fields struct { - session *session.Session - service KeyManagementClient + client KeyManagementClient } tests := []struct { name string fields fields wantErr bool }{ - {"ok", fields{nil, getOKClient()}, false}, + {"ok", fields{getOKClient()}, false}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { k := &KMS{ - session: tt.fields.session, - service: tt.fields.service, + client: tt.fields.client, } if err := k.Close(); (err != nil) != tt.wantErr { t.Errorf("KMS.Close() error = %v, wantErr %v", err, tt.wantErr) @@ -362,3 +346,54 @@ func Test_parseKeyID(t *testing.T) { }) } } + +func Test_getCustomerMasterKeySpecMapping(t *testing.T) { + tmp := customerMasterKeySpecMapping + t.Cleanup(func() { + customerMasterKeySpecMapping = tmp + }) + + // Fail type switch + customerMasterKeySpecMapping[apiv1.SignatureAlgorithm(100)] = "string" + + type args struct { + alg apiv1.SignatureAlgorithm + bits int + } + tests := []struct { + name string + args args + want types.KeySpec + assertion assert.ErrorAssertionFunc + }{ + {"UnspecifiedSignAlgorithm", args{apiv1.UnspecifiedSignAlgorithm, 0}, types.KeySpecEccNistP256, assert.NoError}, + {"SHA256WithRSA", args{apiv1.SHA256WithRSA, 0}, types.KeySpecRsa3072, assert.NoError}, + {"SHA256WithRSA+2048", args{apiv1.SHA256WithRSA, 2048}, types.KeySpecRsa2048, assert.NoError}, + {"SHA256WithRSA+3072", args{apiv1.SHA256WithRSA, 3072}, types.KeySpecRsa3072, assert.NoError}, + {"SHA256WithRSA+4096", args{apiv1.SHA256WithRSA, 4096}, types.KeySpecRsa4096, assert.NoError}, + {"SHA512WithRSA", args{apiv1.SHA512WithRSA, 0}, types.KeySpecRsa3072, assert.NoError}, + {"SHA512WithRSA+2048", args{apiv1.SHA512WithRSA, 2048}, types.KeySpecRsa2048, assert.NoError}, + {"SHA512WithRSA+3072", args{apiv1.SHA512WithRSA, 3072}, types.KeySpecRsa3072, assert.NoError}, + {"SHA512WithRSA+4096", args{apiv1.SHA512WithRSA, 4096}, types.KeySpecRsa4096, assert.NoError}, + {"SHA256WithRSAPSS", args{apiv1.SHA256WithRSAPSS, 0}, types.KeySpecRsa3072, assert.NoError}, + {"SHA256WithRSAPSS+2048", args{apiv1.SHA256WithRSAPSS, 2048}, types.KeySpecRsa2048, assert.NoError}, + {"SHA256WithRSAPSS+3072", args{apiv1.SHA256WithRSAPSS, 3072}, types.KeySpecRsa3072, assert.NoError}, + {"SHA256WithRSAPSS+4096", args{apiv1.SHA256WithRSAPSS, 4096}, types.KeySpecRsa4096, assert.NoError}, + {"SHA512WithRSAPSS", args{apiv1.SHA512WithRSAPSS, 0}, types.KeySpecRsa3072, assert.NoError}, + {"SHA512WithRSAPSS+2048", args{apiv1.SHA512WithRSAPSS, 2048}, types.KeySpecRsa2048, assert.NoError}, + {"SHA512WithRSAPSS+3072", args{apiv1.SHA512WithRSAPSS, 3072}, types.KeySpecRsa3072, assert.NoError}, + {"SHA512WithRSAPSS+4096", args{apiv1.SHA512WithRSAPSS, 4096}, types.KeySpecRsa4096, assert.NoError}, + {"ECDSAWithSHA256", args{apiv1.ECDSAWithSHA256, 0}, types.KeySpecEccNistP256, assert.NoError}, + {"ECDSAWithSHA384", args{apiv1.ECDSAWithSHA384, 0}, types.KeySpecEccNistP384, assert.NoError}, + {"ECDSAWithSHA512", args{apiv1.ECDSAWithSHA512, 0}, types.KeySpecEccNistP521, assert.NoError}, + {"fail Ed25519", args{apiv1.PureEd25519, 0}, "", assert.Error}, + {"fail type switch", args{apiv1.SignatureAlgorithm(100), 0}, "", assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getCustomerMasterKeySpecMapping(tt.args.alg, tt.args.bits) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/kms/awskms/mock_test.go b/kms/awskms/mock_test.go index 5a7d5bd4..4e4e405e 100644 --- a/kms/awskms/mock_test.go +++ b/kms/awskms/mock_test.go @@ -1,34 +1,34 @@ package awskms import ( + "context" "encoding/pem" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" ) type MockClient struct { - getPublicKeyWithContext func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) - createKeyWithContext func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) - createAliasWithContext func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) - signWithContext func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) + getPublicKey func(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) + createKey func(ctx context.Context, input *kms.CreateKeyInput, opts ...func(*kms.Options)) (*kms.CreateKeyOutput, error) + createAlias func(ctx context.Context, input *kms.CreateAliasInput, opts ...func(*kms.Options)) (*kms.CreateAliasOutput, error) + sign func(ctx context.Context, input *kms.SignInput, opts ...func(*kms.Options)) (*kms.SignOutput, error) } -func (m *MockClient) GetPublicKeyWithContext(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { - return m.getPublicKeyWithContext(ctx, input, opts...) +func (m *MockClient) GetPublicKey(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { + return m.getPublicKey(ctx, input, opts...) } -func (m *MockClient) CreateKeyWithContext(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) { - return m.createKeyWithContext(ctx, input, opts...) +func (m *MockClient) CreateKey(ctx context.Context, input *kms.CreateKeyInput, opts ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { + return m.createKey(ctx, input, opts...) } -func (m *MockClient) CreateAliasWithContext(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) { - return m.createAliasWithContext(ctx, input, opts...) +func (m *MockClient) CreateAlias(ctx context.Context, input *kms.CreateAliasInput, opts ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { + return m.createAlias(ctx, input, opts...) } -func (m *MockClient) SignWithContext(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) { - return m.signWithContext(ctx, input, opts...) +func (m *MockClient) Sign(ctx context.Context, input *kms.SignInput, opts ...func(*kms.Options)) (*kms.SignOutput, error) { + return m.sign(ctx, input, opts...) } const ( @@ -46,24 +46,24 @@ var signature = []byte{ func getOKClient() *MockClient { return &MockClient{ - getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + getPublicKey: func(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { block, _ := pem.Decode([]byte(publicKey)) return &kms.GetPublicKeyOutput{ KeyId: input.KeyId, PublicKey: block.Bytes, }, nil }, - createKeyWithContext: func(ctx aws.Context, input *kms.CreateKeyInput, opts ...request.Option) (*kms.CreateKeyOutput, error) { - md := new(kms.KeyMetadata) - md.SetKeyId(keyID) + createKey: func(ctx context.Context, input *kms.CreateKeyInput, opts ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { return &kms.CreateKeyOutput{ - KeyMetadata: md, + KeyMetadata: &types.KeyMetadata{ + KeyId: pointer(keyID), + }, }, nil }, - createAliasWithContext: func(ctx aws.Context, input *kms.CreateAliasInput, opts ...request.Option) (*kms.CreateAliasOutput, error) { + createAlias: func(ctx context.Context, input *kms.CreateAliasInput, opts ...func(*kms.Options)) (*kms.CreateAliasOutput, error) { return &kms.CreateAliasOutput{}, nil }, - signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) { + sign: func(ctx context.Context, input *kms.SignInput, opts ...func(*kms.Options)) (*kms.SignOutput, error) { return &kms.SignOutput{ Signature: signature, }, nil diff --git a/kms/awskms/signer.go b/kms/awskms/signer.go index 47418f3a..e31535a3 100644 --- a/kms/awskms/signer.go +++ b/kms/awskms/signer.go @@ -9,20 +9,21 @@ import ( "crypto/rsa" "io" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" "github.com/pkg/errors" "go.step.sm/crypto/pemutil" ) // Signer implements a crypto.Signer using the AWS KMS. type Signer struct { - service KeyManagementClient + client KeyManagementClient keyID string publicKey crypto.PublicKey } // NewSigner creates a new signer using a key in the AWS KMS. -func NewSigner(svc KeyManagementClient, signingKey string) (*Signer, error) { +func NewSigner(client KeyManagementClient, signingKey string) (*Signer, error) { keyID, err := parseKeyID(signingKey) if err != nil { return nil, err @@ -30,8 +31,8 @@ func NewSigner(svc KeyManagementClient, signingKey string) (*Signer, error) { // Make sure that the key exists. signer := &Signer{ - service: svc, - keyID: keyID, + client: client, + keyID: keyID, } if err := signer.preloadKey(keyID); err != nil { return nil, err @@ -44,11 +45,11 @@ func (s *Signer) preloadKey(keyID string) error { ctx, cancel := defaultContext() defer cancel() - resp, err := s.service.GetPublicKeyWithContext(ctx, &kms.GetPublicKeyInput{ - KeyId: &keyID, + resp, err := s.client.GetPublicKey(ctx, &kms.GetPublicKeyInput{ + KeyId: pointer(keyID), }) if err != nil { - return errors.Wrap(err, "awskms GetPublicKeyWithContext failed") + return errors.Wrap(err, "awskms GetPublicKey failed") } s.publicKey, err = pemutil.ParseDER(resp.PublicKey) @@ -68,54 +69,54 @@ func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([]byt } req := &kms.SignInput{ - KeyId: &s.keyID, - SigningAlgorithm: &alg, + KeyId: pointer(s.keyID), + SigningAlgorithm: alg, Message: digest, + MessageType: types.MessageTypeDigest, } - req.SetMessageType("DIGEST") ctx, cancel := defaultContext() defer cancel() - resp, err := s.service.SignWithContext(ctx, req) + resp, err := s.client.Sign(ctx, req) if err != nil { - return nil, errors.Wrap(err, "awsKMS SignWithContext failed") + return nil, errors.Wrap(err, "awskms Sign failed") } return resp.Signature, nil } -func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (string, error) { +func getSigningAlgorithm(key crypto.PublicKey, opts crypto.SignerOpts) (types.SigningAlgorithmSpec, error) { switch key.(type) { case *rsa.PublicKey: _, isPSS := opts.(*rsa.PSSOptions) switch h := opts.HashFunc(); h { case crypto.SHA256: if isPSS { - return kms.SigningAlgorithmSpecRsassaPssSha256, nil + return types.SigningAlgorithmSpecRsassaPssSha256, nil } - return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil + return types.SigningAlgorithmSpecRsassaPkcs1V15Sha256, nil case crypto.SHA384: if isPSS { - return kms.SigningAlgorithmSpecRsassaPssSha384, nil + return types.SigningAlgorithmSpecRsassaPssSha384, nil } - return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil + return types.SigningAlgorithmSpecRsassaPkcs1V15Sha384, nil case crypto.SHA512: if isPSS { - return kms.SigningAlgorithmSpecRsassaPssSha512, nil + return types.SigningAlgorithmSpecRsassaPssSha512, nil } - return kms.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil + return types.SigningAlgorithmSpecRsassaPkcs1V15Sha512, nil default: return "", errors.Errorf("unsupported hash function %v", h) } case *ecdsa.PublicKey: switch h := opts.HashFunc(); h { case crypto.SHA256: - return kms.SigningAlgorithmSpecEcdsaSha256, nil + return types.SigningAlgorithmSpecEcdsaSha256, nil case crypto.SHA384: - return kms.SigningAlgorithmSpecEcdsaSha384, nil + return types.SigningAlgorithmSpecEcdsaSha384, nil case crypto.SHA512: - return kms.SigningAlgorithmSpecEcdsaSha512, nil + return types.SigningAlgorithmSpecEcdsaSha512, nil default: return "", errors.Errorf("unsupported hash function %v", h) } diff --git a/kms/awskms/signer_test.go b/kms/awskms/signer_test.go index 9694c62a..6d7edea0 100644 --- a/kms/awskms/signer_test.go +++ b/kms/awskms/signer_test.go @@ -1,6 +1,7 @@ package awskms import ( + "context" "crypto" "crypto/ecdsa" "crypto/rand" @@ -10,9 +11,8 @@ import ( "reflect" "testing" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms" + "github.com/aws/aws-sdk-go-v2/service/kms/types" "go.step.sm/crypto/pemutil" ) @@ -34,18 +34,18 @@ func TestNewSigner(t *testing.T) { wantErr bool }{ {"ok", args{okClient, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, &Signer{ - service: okClient, + client: okClient, keyID: "be468355-ca7a-40d9-a28b-8ae1c4c7f936", publicKey: key, }, false}, {"fail parse", args{okClient, "awskms:key-id="}, nil, true}, {"fail preload", args{&MockClient{ - getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + getPublicKey: func(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { return nil, fmt.Errorf("an error") }, }, "awskms:key-id=be468355-ca7a-40d9-a28b-8ae1c4c7f936"}, nil, true}, {"fail preload not der", args{&MockClient{ - getPublicKeyWithContext: func(ctx aws.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { + getPublicKey: func(ctx context.Context, input *kms.GetPublicKeyInput, opts ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { return &kms.GetPublicKeyOutput{ KeyId: input.KeyId, PublicKey: []byte(publicKey), @@ -75,7 +75,7 @@ func TestSigner_Public(t *testing.T) { } type fields struct { - service KeyManagementClient + client KeyManagementClient keyID string publicKey crypto.PublicKey } @@ -89,7 +89,7 @@ func TestSigner_Public(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &Signer{ - service: tt.fields.service, + client: tt.fields.client, keyID: tt.fields.keyID, publicKey: tt.fields.publicKey, } @@ -108,7 +108,7 @@ func TestSigner_Sign(t *testing.T) { } type fields struct { - service KeyManagementClient + client KeyManagementClient keyID string publicKey crypto.PublicKey } @@ -128,7 +128,7 @@ func TestSigner_Sign(t *testing.T) { {"fail alg", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.MD5}, nil, true}, {"fail key", fields{okClient, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", []byte("key")}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true}, {"fail sign", fields{&MockClient{ - signWithContext: func(ctx aws.Context, input *kms.SignInput, opts ...request.Option) (*kms.SignOutput, error) { + sign: func(ctx context.Context, input *kms.SignInput, opts ...func(*kms.Options)) (*kms.SignOutput, error) { return nil, fmt.Errorf("an error") }, }, "be468355-ca7a-40d9-a28b-8ae1c4c7f936", key}, args{rand.Reader, []byte("digest"), crypto.SHA256}, nil, true}, @@ -136,7 +136,7 @@ func TestSigner_Sign(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &Signer{ - service: tt.fields.service, + client: tt.fields.client, keyID: tt.fields.keyID, publicKey: tt.fields.publicKey, } @@ -160,7 +160,7 @@ func Test_getSigningAlgorithm(t *testing.T) { tests := []struct { name string args args - want string + want types.SigningAlgorithmSpec wantErr bool }{ {"rsa+sha256", args{&rsa.PublicKey{}, crypto.SHA256}, "RSASSA_PKCS1_V1_5_SHA_256", false},