diff --git a/pkg/server/plugin/keymanager/hashicorpvault/vault_client.go b/pkg/server/plugin/keymanager/hashicorpvault/vault_client.go index 6bcf7a0cc9..06039ebb18 100644 --- a/pkg/server/plugin/keymanager/hashicorpvault/vault_client.go +++ b/pkg/server/plugin/keymanager/hashicorpvault/vault_client.go @@ -383,10 +383,9 @@ func (c *Client) CreateKey(ctx context.Context, spireKeyID string, keyType Trans // GetKey gets the transit engine key with the specified spire key id. // See: https://developer.hashicorp.com/vault/api-docs/secret/transit#read-key func (c *Client) GetKey(ctx context.Context, spireKeyID string) (string, error) { - // TODO: Handle errors here res, err := c.vaultClient.Logical().ReadWithContext(ctx, fmt.Sprintf("/%s/keys/%s", c.clientParams.TransitEnginePath, spireKeyID)) if err != nil { - return "", err + return "", status.Errorf(codes.Internal, "failed to get transit engine key: %v", err) } keys, ok := res.Data["keys"] diff --git a/pkg/server/plugin/keymanager/hashicorpvault/vault_client_test.go b/pkg/server/plugin/keymanager/hashicorpvault/vault_client_test.go index b4d64e6d6d..c1d83ad169 100644 --- a/pkg/server/plugin/keymanager/hashicorpvault/vault_client_test.go +++ b/pkg/server/plugin/keymanager/hashicorpvault/vault_client_test.go @@ -671,7 +671,73 @@ func TestCreateKeyErrorFromEndpoint(t *testing.T) { spiretest.RequireGRPCStatusHasPrefix(t, err, codes.Internal, "failed to create transit engine key: Error making API request.") } -// TODO: Test GetKey +func TestGetKey(t *testing.T) { + fakeVaultServer := newFakeVaultServer() + fakeVaultServer.CertAuthResponseCode = 200 + fakeVaultServer.CertAuthResponse = []byte(testCertAuthResponse) + fakeVaultServer.GetKeyResponseCode = 200 + fakeVaultServer.GetKeyResponse = []byte(testGetKeyResponse) + + s, addr, err := fakeVaultServer.NewTLSServer() + require.NoError(t, err) + + s.Start() + defer s.Close() + + cp := &ClientParams{ + VaultAddr: fmt.Sprintf("https://%v/", addr), + CACertPath: testRootCert, + ClientCertPath: testClientCert, + ClientKeyPath: testClientKey, + } + + cc, err := NewClientConfig(cp, hclog.Default()) + require.NoError(t, err) + + renewCh := make(chan struct{}) + client, err := cc.NewAuthenticatedClient(CERT, renewCh) + require.NoError(t, err) + + resp, err := client.GetKey(context.Background(), "x509-CA-A") + require.NoError(t, err) + + require.Equal(t, "-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEV57LFbIQZzyZ2YcKZfB9mGWkUhJv\niRzIZOqV4wRHoUOZjMuhBMR2WviEsy65TYpcBjreAc6pbneiyhlTwPvgmw==\n-----END PUBLIC KEY-----\n", resp) +} + +func TestGetKeyErrorFromEndpoint(t *testing.T) { + fakeVaultServer := newFakeVaultServer() + fakeVaultServer.CertAuthResponseCode = 200 + fakeVaultServer.CertAuthResponse = []byte(testCertAuthResponse) + fakeVaultServer.GetKeyResponseCode = 500 + fakeVaultServer.GetKeyResponse = []byte("test error") + + s, addr, err := fakeVaultServer.NewTLSServer() + require.NoError(t, err) + + s.Start() + defer s.Close() + + retry := 0 // Disable retry + cp := &ClientParams{ + MaxRetries: &retry, + VaultAddr: fmt.Sprintf("https://%v/", addr), + CACertPath: testRootCert, + ClientCertPath: testClientCert, + ClientKeyPath: testClientKey, + } + + cc, err := NewClientConfig(cp, hclog.Default()) + require.NoError(t, err) + + renewCh := make(chan struct{}) + client, err := cc.NewAuthenticatedClient(CERT, renewCh) + require.NoError(t, err) + + resp, err := client.GetKey(context.Background(), "x509-CA-A") + spiretest.RequireGRPCStatusHasPrefix(t, err, codes.Internal, "failed to get transit engine key: Error making API request.") + require.Empty(t, resp) +} + // TODO: Test SignData func newFakeVaultServer() *FakeVaultServerConfig { diff --git a/pkg/server/plugin/keymanager/hashicorpvault/vault_fake_test.go b/pkg/server/plugin/keymanager/hashicorpvault/vault_fake_test.go index 804a7a2002..990cc71e56 100644 --- a/pkg/server/plugin/keymanager/hashicorpvault/vault_fake_test.go +++ b/pkg/server/plugin/keymanager/hashicorpvault/vault_fake_test.go @@ -8,12 +8,13 @@ import ( ) const ( - defaultTLSAuthEndpoint = "/v1/auth/cert/login" - defaultAppRoleAuthEndpoint = "/v1/auth/approle/login" - defaultK8sAuthEndpoint = "/v1/auth/kubernetes/login" - defaultRenewEndpoint = "/v1/auth/token/renew-self" - defaultLookupSelfEndpoint = "/v1/auth/token/lookup-self" - defaultCreateKeyEndpoint = "/v1/transit/keys/x509-CA-A" + defaultTLSAuthEndpoint = "PUT /v1/auth/cert/login" + defaultAppRoleAuthEndpoint = "PUT /v1/auth/approle/login" + defaultK8sAuthEndpoint = "PUT /v1/auth/kubernetes/login" + defaultRenewEndpoint = "POST /v1/auth/token/renew-self" + defaultLookupSelfEndpoint = "GET /v1/auth/token/lookup-self" + defaultCreateKeyEndpoint = "PUT /v1/transit/keys/{id}" + defaultGetKeyEndpoint = "GET /v1/transit/keys/{id}" listenAddr = "127.0.0.1:0" ) @@ -267,6 +268,41 @@ var ( "orphan": true } }` + + testGetKeyResponse = `{ + "request_id": "646eddbd-83fd-0cc1-387b-f1a17fa88c3d", + "lease_id": "", + "renewable": false, + "lease_duration": 0, + "data": { + "allow_plaintext_backup": false, + "auto_rotate_period": 0, + "deletion_allowed": false, + "derived": false, + "exportable": false, + "imported_key": false, + "keys": { + "1": { + "creation_time": "2024-09-16T18:18:54.284635756Z", + "name": "P-256", + "public_key": "-----BEGIN PUBLIC KEY-----\nMFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEV57LFbIQZzyZ2YcKZfB9mGWkUhJv\niRzIZOqV4wRHoUOZjMuhBMR2WviEsy65TYpcBjreAc6pbneiyhlTwPvgmw==\n-----END PUBLIC KEY-----\n" + } + }, + "latest_version": 1, + "min_available_version": 0, + "min_decryption_version": 1, + "min_encryption_version": 0, + "name": "x509-CA-A", + "supports_decryption": false, + "supports_derivation": false, + "supports_encryption": false, + "supports_signing": true, + "type": "ecdsa-p256" + }, + "wrap_info": null, + "warnings": null, + "auth": null +}` ) type FakeVaultServerConfig struct { @@ -297,6 +333,10 @@ type FakeVaultServerConfig struct { CreateKeyReqHandler func(code int, resp []byte) func(http.ResponseWriter, *http.Request) CreateKeyResponseCode int CreateKeyResponse []byte + GetKeyReqEndpoint string + GetKeyReqHandler func(code int, resp []byte) func(http.ResponseWriter, *http.Request) + GetKeyResponseCode int + GetKeyResponse []byte } // NewFakeVaultServerConfig returns VaultServerConfig with default values @@ -315,6 +355,8 @@ func NewFakeVaultServerConfig() *FakeVaultServerConfig { LookupSelfReqHandler: defaultReqHandler, CreateKeyReqEndpoint: defaultCreateKeyEndpoint, CreateKeyReqHandler: defaultReqHandler, + GetKeyReqEndpoint: defaultGetKeyEndpoint, + GetKeyReqHandler: defaultReqHandler, } } @@ -347,6 +389,7 @@ func (v *FakeVaultServerConfig) NewTLSServer() (srv *httptest.Server, addr strin mux.HandleFunc(v.RenewReqEndpoint, v.RenewReqHandler(v.RenewResponseCode, v.RenewResponse)) mux.HandleFunc(v.LookupSelfReqEndpoint, v.LookupSelfReqHandler(v.LookupSelfResponseCode, v.LookupSelfResponse)) mux.HandleFunc(v.CreateKeyReqEndpoint, v.CreateKeyReqHandler(v.CreateKeyResponseCode, v.CreateKeyResponse)) + mux.HandleFunc(v.GetKeyReqEndpoint, v.GetKeyReqHandler(v.GetKeyResponseCode, v.GetKeyResponse)) srv = httptest.NewUnstartedServer(mux) srv.Listener = l