Skip to content

Commit

Permalink
Merge pull request #484 from smallstep/herman/fix-cloudkms-resource-uris
Browse files Browse the repository at this point in the history
Fix GCP CloudKMS resource URIs sometimes starting with `cloudkms:`
  • Loading branch information
hslatman authored Apr 15, 2024
2 parents 2980706 + 927f094 commit cb64f07
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 76 deletions.
14 changes: 7 additions & 7 deletions kms/cloudkms/cloudkms.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func (k *CloudKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespo
return nil, err
}

var crytoKeyName string
var cryptoKeyName string

ctx, cancel := defaultContext()
defer cancel()
Expand Down Expand Up @@ -240,13 +240,13 @@ func (k *CloudKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespo
if err != nil {
return nil, errors.Wrap(err, "cloudKMS CreateCryptoKeyVersion failed")
}
crytoKeyName = response.Name
cryptoKeyName = response.Name
} else {
crytoKeyName = response.Name + "/cryptoKeyVersions/1"
cryptoKeyName = response.Name + "/cryptoKeyVersions/1"
}

// Use uri format for the keys
crytoKeyName = uri.NewOpaque(Scheme, crytoKeyName).String()
cryptoKeyName = uri.NewOpaque(Scheme, cryptoKeyName).String()

// Sleep deterministically to avoid retries because of PENDING_GENERATING.
// One second is often enough.
Expand All @@ -256,17 +256,17 @@ func (k *CloudKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespo

// Retrieve public key to add it to the response.
pk, err := k.GetPublicKey(&apiv1.GetPublicKeyRequest{
Name: crytoKeyName,
Name: cryptoKeyName,
})
if err != nil {
return nil, errors.Wrap(err, "cloudKMS GetPublicKey failed")
}

return &apiv1.CreateKeyResponse{
Name: crytoKeyName,
Name: cryptoKeyName,
PublicKey: pk,
CreateSignerRequest: apiv1.CreateSignerRequest{
SigningKey: crytoKeyName,
SigningKey: cryptoKeyName,
},
}, nil
}
Expand Down
70 changes: 37 additions & 33 deletions kms/cloudkms/cloudkms_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"cloud.google.com/go/kms/apiv1/kmspb"
gax "github.com/googleapis/gax-go/v2"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.step.sm/crypto/kms/apiv1"
"go.step.sm/crypto/kms/uri"
"go.step.sm/crypto/pemutil"
Expand Down Expand Up @@ -174,13 +175,9 @@ func TestCloudKMS_CreateSigner(t *testing.T) {
keyURI := uri.NewOpaque(Scheme, keyName).String()

pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
pk, err := pemutil.ParseKey(pemBytes)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

type fields struct {
client KeyManagementClient
Expand All @@ -196,17 +193,20 @@ func TestCloudKMS_CreateSigner(t *testing.T) {
wantErr bool
}{
{"ok", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}}, args{&apiv1.CreateSignerRequest{SigningKey: keyName}}, &Signer{client: &MockClient{}, signingKey: keyName, publicKey: pk}, false},
{"ok with uri", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}}, args{&apiv1.CreateSignerRequest{SigningKey: keyURI}}, &Signer{client: &MockClient{}, signingKey: keyName, publicKey: pk}, false},
{"fail", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return nil, fmt.Errorf("test error")
},
}}, args{&apiv1.CreateSignerRequest{SigningKey: ""}}, nil, true},
Expand Down Expand Up @@ -238,13 +238,9 @@ func TestCloudKMS_CreateKey(t *testing.T) {
alreadyExists := status.Error(codes.AlreadyExists, "already exists")

pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
pk, err := pemutil.ParseKey(pemBytes)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

var retries int
type fields struct {
Expand All @@ -269,7 +265,8 @@ func TestCloudKMS_CreateKey(t *testing.T) {
assert.Nil(t, req.CryptoKey.DestroyScheduledDuration)
return &kmspb.CryptoKey{Name: keyName}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
Expand All @@ -284,7 +281,8 @@ func TestCloudKMS_CreateKey(t *testing.T) {
assert.Equal(t, req.CryptoKey.DestroyScheduledDuration, durationpb.New(24*time.Hour))
return &kmspb.CryptoKey{Name: keyName}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
Expand All @@ -301,7 +299,8 @@ func TestCloudKMS_CreateKey(t *testing.T) {
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
return &kmspb.CryptoKey{Name: keyName}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
Expand All @@ -318,7 +317,8 @@ func TestCloudKMS_CreateKey(t *testing.T) {
createCryptoKeyVersion: func(_ context.Context, _ *kmspb.CreateCryptoKeyVersionRequest, _ ...gax.CallOption) (*kmspb.CryptoKeyVersion, error) {
return &kmspb.CryptoKeyVersion{Name: keyName + "/cryptoKeyVersions/2"}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
Expand All @@ -332,7 +332,8 @@ func TestCloudKMS_CreateKey(t *testing.T) {
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
return &kmspb.CryptoKey{Name: keyName}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
if retries != 2 {
retries++
return nil, status.Error(codes.FailedPrecondition, "key is not enabled, current state is: PENDING_GENERATION")
Expand Down Expand Up @@ -391,7 +392,8 @@ func TestCloudKMS_CreateKey(t *testing.T) {
createCryptoKey: func(_ context.Context, _ *kmspb.CreateCryptoKeyRequest, _ ...gax.CallOption) (*kmspb.CryptoKey, error) {
return &kmspb.CryptoKey{Name: keyName}, nil
},
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return nil, testError
},
}},
Expand Down Expand Up @@ -424,13 +426,9 @@ func TestCloudKMS_GetPublicKey(t *testing.T) {
testError := fmt.Errorf("an error")

pemBytes, err := os.ReadFile("testdata/pub.pem")
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
pk, err := pemutil.ParseKey(pemBytes)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)

var retries int
type fields struct {
Expand All @@ -448,28 +446,32 @@ func TestCloudKMS_GetPublicKey(t *testing.T) {
}{
{"ok", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, pk, false},
{"ok with uri", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
args{&apiv1.GetPublicKeyRequest{Name: keyURI}}, pk, false},
{"ok with resource uri", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}},
args{&apiv1.GetPublicKeyRequest{Name: keyResource}}, pk, false},
{"ok with retries", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
if retries != 2 {
retries++
return nil, status.Error(codes.FailedPrecondition, "key is not enabled, current state is: PENDING_GENERATION")
Expand All @@ -481,14 +483,16 @@ func TestCloudKMS_GetPublicKey(t *testing.T) {
{"fail name", fields{&MockClient{}}, args{&apiv1.GetPublicKeyRequest{}}, nil, true},
{"fail get public key", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return nil, testError
},
}},
args{&apiv1.GetPublicKeyRequest{Name: keyName}}, nil, true},
{"fail parse pem", fields{
&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string("bad pem")}, nil
},
}},
Expand Down
6 changes: 3 additions & 3 deletions kms/cloudkms/decrypter.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,19 +39,19 @@ func NewDecrypter(client KeyManagementClient, decryptionKey string) (*Decrypter,
client: client,
decryptionKey: resourceName(decryptionKey),
}
if err := decrypter.preloadKey(decryptionKey); err != nil { // TODO(hs): (option for) lazy load instead?
if err := decrypter.preloadKey(); err != nil { // TODO(hs): (option for) lazy load instead?
return nil, err
}

return decrypter, nil
}

func (d *Decrypter) preloadKey(signingKey string) error {
func (d *Decrypter) preloadKey() error {
ctx, cancel := defaultContext()
defer cancel()

response, err := d.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{
Name: signingKey,
Name: d.decryptionKey,
})
if err != nil {
return fmt.Errorf("cloudKMS GetPublicKey failed: %w", err)
Expand Down
25 changes: 18 additions & 7 deletions kms/cloudkms/decrypter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,26 @@ func TestCloudKMS_CreateDecrypter(t *testing.T) {
wantErr bool
}{
{"ok", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}}, args{&apiv1.CreateDecrypterRequest{DecryptionKey: keyName}}, &Decrypter{client: &MockClient{}, decryptionKey: keyName, publicKey: pk}, false},
{"ok with uri", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}}, args{&apiv1.CreateDecrypterRequest{DecryptionKey: "cloudkms:resource=" + keyName}}, &Decrypter{client: &MockClient{}, decryptionKey: keyName, publicKey: pk}, false},
{"ok with opaque uri", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}}, args{&apiv1.CreateDecrypterRequest{DecryptionKey: "cloudkms:" + keyName}}, &Decrypter{client: &MockClient{}, decryptionKey: keyName, publicKey: pk}, false},
{"fail", fields{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return nil, fmt.Errorf("test error")
},
}}, args{&apiv1.CreateDecrypterRequest{DecryptionKey: ""}}, nil, true},
Expand Down Expand Up @@ -92,17 +96,20 @@ func TestNewDecrypter(t *testing.T) {
wantErr bool
}{
{"ok", args{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string(pemBytes)}, nil
},
}, "decryptionKey"}, &Decrypter{client: &MockClient{}, decryptionKey: "decryptionKey", publicKey: pk}, false},
{"fail get public key", args{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return nil, fmt.Errorf("an error")
},
}, "decryptionKey"}, nil, true},
{"fail parse pem", args{&MockClient{
getPublicKey: func(_ context.Context, _ *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
getPublicKey: func(_ context.Context, r *kmspb.GetPublicKeyRequest, _ ...gax.CallOption) (*kmspb.PublicKey, error) {
assert.NotContains(t, r.Name, "cloudkms:")
return &kmspb.PublicKey{Pem: string("bad pem")}, nil
},
}, "decryptionKey"}, nil, true},
Expand Down Expand Up @@ -160,21 +167,25 @@ func TestDecrypter_Decrypt(t *testing.T) {
keyName := "projects/p/locations/l/keyRings/k/cryptoKeys/c/cryptoKeyVersions/1"
okClient := &MockClient{
asymmetricDecrypt: func(ctx context.Context, adr *kmspb.AsymmetricDecryptRequest, co ...gax.CallOption) (*kmspb.AsymmetricDecryptResponse, error) {
assert.NotContains(t, adr.Name, "cloudkms:")
return &kmspb.AsymmetricDecryptResponse{Plaintext: []byte("decrypted"), PlaintextCrc32C: wrapperspb.Int64(crc32c([]byte("decrypted"))), VerifiedCiphertextCrc32C: true}, nil
},
}
failClient := &MockClient{
asymmetricDecrypt: func(ctx context.Context, adr *kmspb.AsymmetricDecryptRequest, co ...gax.CallOption) (*kmspb.AsymmetricDecryptResponse, error) {
assert.NotContains(t, adr.Name, "cloudkms:")
return nil, fmt.Errorf("an error")
},
}
requestCRC32Client := &MockClient{
asymmetricDecrypt: func(ctx context.Context, adr *kmspb.AsymmetricDecryptRequest, co ...gax.CallOption) (*kmspb.AsymmetricDecryptResponse, error) {
assert.NotContains(t, adr.Name, "cloudkms:")
return &kmspb.AsymmetricDecryptResponse{Plaintext: []byte("decrypted"), PlaintextCrc32C: wrapperspb.Int64(crc32c([]byte("decrypted"))), VerifiedCiphertextCrc32C: false}, nil
},
}
responseCRC32Client := &MockClient{
asymmetricDecrypt: func(ctx context.Context, adr *kmspb.AsymmetricDecryptRequest, co ...gax.CallOption) (*kmspb.AsymmetricDecryptResponse, error) {
assert.NotContains(t, adr.Name, "cloudkms:")
return &kmspb.AsymmetricDecryptResponse{Plaintext: []byte("decrypted"), PlaintextCrc32C: wrapperspb.Int64(crc32c([]byte("wrong"))), VerifiedCiphertextCrc32C: true}, nil
},
}
Expand Down
6 changes: 3 additions & 3 deletions kms/cloudkms/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ func NewSigner(c KeyManagementClient, signingKey string) (*Signer, error) {
client: c,
signingKey: resourceName(signingKey),
}
if err := signer.preloadKey(signingKey); err != nil {
if err := signer.preloadKey(); err != nil {
return nil, err
}

return signer, nil
}

func (s *Signer) preloadKey(signingKey string) error {
func (s *Signer) preloadKey() error {
ctx, cancel := defaultContext()
defer cancel()

response, err := s.client.GetPublicKey(ctx, &kmspb.GetPublicKeyRequest{
Name: signingKey,
Name: s.signingKey,
})
if err != nil {
return errors.Wrap(err, "cloudKMS GetPublicKey failed")
Expand Down
Loading

0 comments on commit cb64f07

Please sign in to comment.