From d6fe013f63638807a27d5fb1004d3e03e68d2b0c Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Fri, 5 Apr 2024 11:00:49 +0200 Subject: [PATCH 01/14] Support skipping certificate private key check on request --- kms/tpmkms/tpmkms.go | 6 +++++- kms/tpmkms/uri.go | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index 05f7871e..60a4d9c3 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -778,6 +778,10 @@ func (k *TPMKMS) storeCertificateChainToWindowsCertificateStore(req *apiv1.Store if o.store != "" { store = o.store } + skipFindCertificatKey := "false" + if o.skipFindCertificateKey { + skipFindCertificatKey = "true" + } leaf := req.CertificateChain[0] fp, err := fingerprint.New(leaf.Raw, crypto.SHA1, fingerprint.HexFingerprint) @@ -786,7 +790,7 @@ func (k *TPMKMS) storeCertificateChainToWindowsCertificateStore(req *apiv1.Store } if err := k.windowsCertificateManager.StoreCertificate(&apiv1.StoreCertificateRequest{ - Name: fmt.Sprintf("capi:sha1=%s;store-location=%s;store=%s;", fp, location, store), + Name: fmt.Sprintf("capi:sha1=%s;store-location=%s;store=%s;skip-find-certificate-key=%s", fp, location, store, skipFindCertificatKey), Certificate: leaf, }); err != nil { return fmt.Errorf("failed storing certificate using Windows platform cryptography provider: %w", err) diff --git a/kms/tpmkms/uri.go b/kms/tpmkms/uri.go index 993c0021..86429848 100644 --- a/kms/tpmkms/uri.go +++ b/kms/tpmkms/uri.go @@ -21,6 +21,7 @@ type objectProperties struct { store string intermediateStoreLocation string intermediateStore string + skipFindCertificateKey bool } func parseNameURI(nameURI string) (o objectProperties, err error) { @@ -59,6 +60,7 @@ func parseNameURI(nameURI string) (o objectProperties, err error) { o.store = u.Get("store") o.intermediateStoreLocation = u.Get("intermediate-store-location") o.intermediateStore = u.Get("intermediate-store") + o.skipFindCertificateKey = u.GetBool("skip-find-certificate-key") // validation if o.ak && o.attestBy != "" { From 958326e75499d3ef8b7dfe7a16a57fc6e0d38266 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Fri, 5 Apr 2024 11:05:15 +0200 Subject: [PATCH 02/14] Upgrade `go get golang.org/x/net` to `v0.23.0` --- go.mod | 2 +- go.sum | 2 ++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 4243eaf2..d09ecdf9 100644 --- a/go.mod +++ b/go.mod @@ -25,7 +25,7 @@ require ( github.com/smallstep/go-attestation v0.4.4-0.20240109183208-413678f90935 github.com/stretchr/testify v1.9.0 golang.org/x/crypto v0.21.0 - golang.org/x/net v0.22.0 + golang.org/x/net v0.23.0 golang.org/x/sys v0.18.0 google.golang.org/api v0.172.0 google.golang.org/grpc v1.62.1 diff --git a/go.sum b/go.sum index ff500adc..423e8f5a 100644 --- a/go.sum +++ b/go.sum @@ -997,6 +997,8 @@ golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc= golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= +golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs= +golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= From 4d88201cc7a12b915cb355995878ae77be457e37 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Mon, 8 Apr 2024 20:20:31 +0200 Subject: [PATCH 03/14] Use `uri.Values` when storing certificate chains on Windows --- kms/tpmkms/tpmkms.go | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index 60a4d9c3..e0007e6f 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -778,9 +778,9 @@ func (k *TPMKMS) storeCertificateChainToWindowsCertificateStore(req *apiv1.Store if o.store != "" { store = o.store } - skipFindCertificatKey := "false" + skipFindCertificateKey := "false" if o.skipFindCertificateKey { - skipFindCertificatKey = "true" + skipFindCertificateKey = "true" } leaf := req.CertificateChain[0] @@ -789,8 +789,14 @@ func (k *TPMKMS) storeCertificateChainToWindowsCertificateStore(req *apiv1.Store return fmt.Errorf("failed calculating certificate SHA1 fingerprint: %w", err) } + uv := url.Values{} + uv.Set("sha1", fp) + uv.Set("store-location", location) + uv.Set("store", store) + uv.Set("skip-find-certificate-key", skipFindCertificateKey) + if err := k.windowsCertificateManager.StoreCertificate(&apiv1.StoreCertificateRequest{ - Name: fmt.Sprintf("capi:sha1=%s;store-location=%s;store=%s;skip-find-certificate-key=%s", fp, location, store, skipFindCertificatKey), + Name: uri.New("capi", uv).String(), Certificate: leaf, }); err != nil { return fmt.Errorf("failed storing certificate using Windows platform cryptography provider: %w", err) From 7935f172dd202e81c4338d079c4357f92bf743d4 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 9 Apr 2024 15:45:25 +0200 Subject: [PATCH 04/14] Add support for removing certificates to TPM and CAPI KMS --- kms/capi/capi.go | 170 +++++++++++++++++++++++++++++++++++++++++++ kms/tpmkms/tpmkms.go | 103 ++++++++++++++++++++++++++ kms/tpmkms/uri.go | 8 ++ tpm/ak.go | 28 ++++--- tpm/key.go | 24 +++--- 5 files changed, 307 insertions(+), 26 deletions(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 4dd92503..a9871fd7 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -714,6 +714,176 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { return nil } +func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { + u, err := uri.ParseWithScheme(Scheme, req.Name) + if err != nil { + return fmt.Errorf("failed to parse URI: %w", err) + } + + sha1Hash := u.Get(HashArg) + keyID := u.Get(KeyIDArg) + issuerName := u.Get(IssuerNameArg) + serialNumber := u.Get(SerialNumberArg) + + var storeLocation string + if storeLocation = u.Get(StoreLocationArg); storeLocation == "" { + storeLocation = "user" + } + + var certStoreLocation uint32 + switch storeLocation { + case "user": + certStoreLocation = certStoreCurrentUser + case "machine": + certStoreLocation = certStoreLocalMachine + default: + return fmt.Errorf("invalid cert store location %v", storeLocation) + } + + var storeName string + if storeName = u.Get(StoreNameArg); storeName == "" { + storeName = "My" + } + + st, err := windows.CertOpenStore( + certStoreProvSystem, + 0, + 0, + certStoreLocation, + uintptr(unsafe.Pointer(wide(storeName)))) + if err != nil { + return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err) + } + + var certHandle *windows.CertContext + + switch { + case sha1Hash != "": + sha1Hash = strings.TrimPrefix(sha1Hash, "0x") // Support specifying the hash as 0x like with serial + + sha1Bytes, err := hex.DecodeString(sha1Hash) + if err != nil { + return fmt.Errorf("%s must be in hex format: %w", HashArg, err) + } + searchData := CERT_ID_KEYIDORHASH{ + idChoice: CERT_ID_SHA1_HASH, + KeyIDOrHash: CRYPTOAPI_BLOB{ + len: uint32(len(sha1Bytes)), + data: uintptr(unsafe.Pointer(&sha1Bytes[0])), + }, + } + certHandle, err = findCertificateInStore(st, + encodingX509ASN|encodingPKCS7, + 0, + findCertID, + uintptr(unsafe.Pointer(&searchData)), nil) + if err != nil { + return fmt.Errorf("findCertificateInStore failed: %w", err) + } + if certHandle == nil { + return apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", HashArg, keyID)} + } + defer windows.CertFreeCertificateContext(certHandle) + + if err := removeCertificateUsingWindowsContext(certHandle); err != nil { + return fmt.Errorf("failed removing certificate: %w", err) + } + return nil + case keyID != "": + keyID = strings.TrimPrefix(keyID, "0x") // Support specifying the hash as 0x like with serial + + keyIDBytes, err := hex.DecodeString(keyID) + if err != nil { + return fmt.Errorf("%v must be in hex format: %w", KeyIDArg, err) + } + searchData := CERT_ID_KEYIDORHASH{ + idChoice: CERT_ID_KEY_IDENTIFIER, + KeyIDOrHash: CRYPTOAPI_BLOB{ + len: uint32(len(keyIDBytes)), + data: uintptr(unsafe.Pointer(&keyIDBytes[0])), + }, + } + certHandle, err = findCertificateInStore(st, + encodingX509ASN|encodingPKCS7, + 0, + findCertID, + uintptr(unsafe.Pointer(&searchData)), nil) + if err != nil { + return fmt.Errorf("findCertificateInStore failed: %w", err) + } + if certHandle == nil { + return apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", KeyIDArg, keyID)} + } + defer windows.CertFreeCertificateContext(certHandle) + + if err := removeCertificateUsingWindowsContext(certHandle); err != nil { + return fmt.Errorf("failed removing certificate: %w", err) + } + return nil + case issuerName != "" && serialNumber != "": + //TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead + // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id + // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number + var serialBytes []byte + if strings.HasPrefix(serialNumber, "0x") { + serialNumber = strings.TrimPrefix(serialNumber, "0x") + serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed + serialBytes, err = hex.DecodeString(serialNumber) + if err != nil { + return fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err) + } + } else { + bi := new(big.Int) + bi, ok := bi.SetString(serialNumber, 10) + if !ok { + return fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg) + } + serialBytes = bi.Bytes() + } + var prevCert *windows.CertContext + for { + certHandle, err = findCertificateInStore(st, + encodingX509ASN|encodingPKCS7, + 0, + findIssuerStr, + uintptr(unsafe.Pointer(wide(issuerName))), prevCert) + + if err != nil { + return fmt.Errorf("findCertificateInStore failed: %w", err) + } + + if certHandle == nil { + return apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%v and %v=%v not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)} + } + defer windows.CertFreeCertificateContext(certHandle) + + x509Cert, err := certContextToX509(certHandle) + if err != nil { + return fmt.Errorf("could not unmarshal certificate to DER: %w", err) + } + + if bytes.Equal(x509Cert.SerialNumber.Bytes(), serialBytes) { + if err := removeCertificateUsingWindowsContext(certHandle); err != nil { + return fmt.Errorf("failed removing certificate: %w", err) + } + return nil + } + + prevCert = certHandle + } + default: + return fmt.Errorf("%s, %s, or %s and %s is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) + } +} + +func removeCertificateUsingWindowsContext(certContext *windows.CertContext) error { + r, _, err := procCertDeleteCertificateFromStore.Call(uintptr(unsafe.Pointer(certContext))) + if r != 1 { + return fmt.Errorf("procCertDeleteCertificateFromStore failed with %X: %v", r, err) + } + return nil +} + type CAPISigner struct { algorithmGroup string keyHandle uintptr diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index e0007e6f..902323f3 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -860,6 +860,98 @@ func (k *TPMKMS) storeIntermediateToWindowsCertificateStore(c *x509.Certificate, return nil } +func (k *TPMKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { + switch { + case req.Name == "": + return errors.New("deleteCertificateRequest 'name' cannot be empty") + } + + if k.usesWindowsCertificateStore() { + if err := k.deleteCertificateFromWindowsCertificateStore(&apiv1.DeleteCertificateRequest{ + Name: req.Name, + }); err != nil { + return fmt.Errorf("failed deleting certificate from Windows platform cryptography provider: %w", err) + } + + return nil + } + + // TODO(hs): support delete by serial? If not, the behavior for TPM storage and Windows + // certificate store storage will be different, and may need different behavior when + // implementing certificate management. + + properties, err := parseNameURI(req.Name) + if err != nil { + return fmt.Errorf("failed parsing %q: %w", req.Name, err) + } + + ctx := context.Background() + if properties.ak { + ak, err := k.tpm.GetAK(ctx, properties.name) + if err != nil { + return err + } + if err := ak.SetCertificateChain(ctx, nil); err != nil { + return fmt.Errorf("failed storing certificate for AK %q: %w", properties.name, err) + } + } else { + key, err := k.tpm.GetKey(ctx, properties.name) + if err != nil { + return err + } + if err := key.SetCertificateChain(ctx, nil); err != nil { + return fmt.Errorf("failed storing certificate for key %q: %w", properties.name, err) + } + } + + return nil +} + +func (k *TPMKMS) deleteCertificateFromWindowsCertificateStore(req *apiv1.DeleteCertificateRequest) error { + o, err := parseNameURI(req.Name) + if err != nil { + return fmt.Errorf("failed parsing %q: %w", req.Name, err) + } + + location := k.windowsCertificateStoreLocation + if o.storeLocation != "" { + location = o.storeLocation + } + store := k.windowsCertificateStore + if o.store != "" { + store = o.store + } + + uv := url.Values{} + uv.Set("store-location", location) + uv.Set("store", store) + + switch { + case o.serial != "": + uv.Set("serial", o.serial) + uv.Set("issuer", o.issuer) + case o.keyID != "": + uv.Set("key-id", o.keyID) + case o.sha1 != "": + uv.Set("sha1", o.sha1) + default: + return errors.New(`at least one of "serial", "key-id" or "sha1" is expected to be set`) + } + + dk, ok := k.windowsCertificateManager.(deletingCertificateManager) + if !ok { + return fmt.Errorf("expected Windows certificate manager to implement DeleteCertificate") + } + + if err := dk.DeleteCertificate(&apiv1.DeleteCertificateRequest{ // TODO(hs): handle specific error cases? + Name: uri.New("capi", uv).String(), + }); err != nil { + return fmt.Errorf("failed deleting certificate using Windows platform cryptography provider: %w", err) + } + + return nil +} + // attestationClient is a wrapper for [attestation.Client], containing // all of the required references to perform attestation against the // Smallstep Attestation CA. @@ -1183,8 +1275,19 @@ func generateWindowsSubjectKeyID(pub crypto.PublicKey) (string, error) { return hex.EncodeToString(hash[:]), nil } +type deletingCertificateManager interface { + apiv1.CertificateManager + DeleteCertificate(req *apiv1.DeleteCertificateRequest) error +} + +type deletingCertificateChainManager interface { + apiv1.CertificateChainManager + DeleteCertificate(req *apiv1.DeleteCertificateRequest) error +} + var _ apiv1.KeyManager = (*TPMKMS)(nil) var _ apiv1.Attester = (*TPMKMS)(nil) var _ apiv1.CertificateManager = (*TPMKMS)(nil) var _ apiv1.CertificateChainManager = (*TPMKMS)(nil) +var _ deletingCertificateChainManager = (*TPMKMS)(nil) var _ apiv1.AttestationClient = (*attestationClient)(nil) diff --git a/kms/tpmkms/uri.go b/kms/tpmkms/uri.go index 86429848..2614d6b7 100644 --- a/kms/tpmkms/uri.go +++ b/kms/tpmkms/uri.go @@ -22,6 +22,10 @@ type objectProperties struct { intermediateStoreLocation string intermediateStore string skipFindCertificateKey bool + keyID string + sha1 string + serial string + issuer string } func parseNameURI(nameURI string) (o objectProperties, err error) { @@ -61,6 +65,10 @@ func parseNameURI(nameURI string) (o objectProperties, err error) { o.intermediateStoreLocation = u.Get("intermediate-store-location") o.intermediateStore = u.Get("intermediate-store") o.skipFindCertificateKey = u.GetBool("skip-find-certificate-key") + o.keyID = u.Get("key-id") + o.sha1 = u.Get("sha1") + o.serial = u.Get("serial") + o.issuer = u.Get("issuer") // validation if o.ak && o.attestBy != "" { diff --git a/tpm/ak.go b/tpm/ak.go index a3d2c176..4a0f9baf 100644 --- a/tpm/ak.go +++ b/tpm/ak.go @@ -384,28 +384,26 @@ func (ak *AK) SetCertificateChain(ctx context.Context, chain []*x509.Certificate } defer closeTPM(ctx, ak.tpm, &err) - if len(chain) == 0 { - return errors.New("certificate chain must contain at least one certificate") - } - akPublic, err := ak.public(internalCall(ctx)) if err != nil { return fmt.Errorf("failed getting AK public key: %w", err) } - leaf := chain[0] - leafPK, ok := leaf.PublicKey.(crypto.PublicKey) - if !ok { - return fmt.Errorf("unexpected type for AK certificate public key: %T", leaf.PublicKey) - } + if len(chain) > 0 { + leaf := chain[0] + leafPK, ok := leaf.PublicKey.(crypto.PublicKey) + if !ok { + return fmt.Errorf("unexpected type for AK certificate public key: %T", leaf.PublicKey) + } - publicKey, ok := leafPK.(comparablePublicKey) - if !ok { - return errors.New("certificate public key can't be compared to a crypto.PublicKey") - } + publicKey, ok := leafPK.(comparablePublicKey) + if !ok { + return errors.New("certificate public key can't be compared to a crypto.PublicKey") + } - if !publicKey.Equal(akPublic) { - return errors.New("AK public key does not match the leaf certificate public key") + if !publicKey.Equal(akPublic) { + return errors.New("AK public key does not match the leaf certificate public key") + } } ak.chain = chain // TODO(hs): deep copy, so that certs can't be changed by pointer? diff --git a/tpm/key.go b/tpm/key.go index cb1b08c3..87a91b76 100644 --- a/tpm/key.go +++ b/tpm/key.go @@ -435,19 +435,21 @@ func (k *Key) SetCertificateChain(ctx context.Context, chain []*x509.Certificate return fmt.Errorf("failed getting signer for key: %w", err) } - leaf := chain[0] - leafPK, ok := leaf.PublicKey.(crypto.PublicKey) - if !ok { - return fmt.Errorf("unexpected type for certificate public key: %T", leaf.PublicKey) - } + if len(chain) > 0 { + leaf := chain[0] + leafPK, ok := leaf.PublicKey.(crypto.PublicKey) + if !ok { + return fmt.Errorf("unexpected type for certificate public key: %T", leaf.PublicKey) + } - publicKey, ok := leafPK.(comparablePublicKey) - if !ok { - return errors.New("certificate public key can't be compared to a crypto.PublicKey") - } + publicKey, ok := leafPK.(comparablePublicKey) + if !ok { + return errors.New("certificate public key can't be compared to a crypto.PublicKey") + } - if !publicKey.Equal(signer.Public()) { - return errors.New("public key does not match the leaf certificate public key") + if !publicKey.Equal(signer.Public()) { + return errors.New("public key does not match the leaf certificate public key") + } } k.chain = chain // TODO(hs): deep copy, so that certs can't be changed by pointer? From cb5c13b6b50baec28c904b9c1f4f61eb1e418afb Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 9 Apr 2024 16:05:25 +0200 Subject: [PATCH 05/14] Ignore non existing certificates and fix linting issue --- kms/capi/capi.go | 8 ++++---- kms/tpmkms/tpmkms.go | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index a9871fd7..7db61aa4 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -781,7 +781,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return fmt.Errorf("findCertificateInStore failed: %w", err) } if certHandle == nil { - return apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", HashArg, keyID)} + return nil } defer windows.CertFreeCertificateContext(certHandle) @@ -812,7 +812,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return fmt.Errorf("findCertificateInStore failed: %w", err) } if certHandle == nil { - return apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", KeyIDArg, keyID)} + return nil } defer windows.CertFreeCertificateContext(certHandle) @@ -853,9 +853,8 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } if certHandle == nil { - return apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%v and %v=%v not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)} + return nil } - defer windows.CertFreeCertificateContext(certHandle) x509Cert, err := certContextToX509(certHandle) if err != nil { @@ -866,6 +865,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { if err := removeCertificateUsingWindowsContext(certHandle); err != nil { return fmt.Errorf("failed removing certificate: %w", err) } + defer windows.CertFreeCertificateContext(certHandle) return nil } diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index 902323f3..7c3d9c0d 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -861,8 +861,7 @@ func (k *TPMKMS) storeIntermediateToWindowsCertificateStore(c *x509.Certificate, } func (k *TPMKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { - switch { - case req.Name == "": + if req.Name == "" { return errors.New("deleteCertificateRequest 'name' cannot be empty") } From 6c1a983ae5a9a5b4adcfd6a7dcbde1d9fa758178 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 9 Apr 2024 16:11:01 +0200 Subject: [PATCH 06/14] Fix missing `procCertDeleteCertificateFromStore` definition --- kms/capi/ncrypt_windows.go | 1 + 1 file changed, 1 insertion(+) diff --git a/kms/capi/ncrypt_windows.go b/kms/capi/ncrypt_windows.go index 49d4dc5c..44ce2561 100644 --- a/kms/capi/ncrypt_windows.go +++ b/kms/capi/ncrypt_windows.go @@ -145,6 +145,7 @@ var ( crypt32 = windows.MustLoadDLL("crypt32.dll") procCertFindCertificateInStore = crypt32.MustFindProc("CertFindCertificateInStore") + procCertDeleteCertificateFromStore = crypt32.MustFindProc("CertDeleteCertificateFromStore") procCryptFindCertificateKeyProvInfo = crypt32.MustFindProc("CryptFindCertificateKeyProvInfo") procCertStrToName = crypt32.MustFindProc("CertStrToNameW") ) From 6f2eab144b2e1878fed75a6ed604dccd8fdbf82c Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 9 Apr 2024 17:18:09 +0200 Subject: [PATCH 07/14] Add docs to the new TPM and CAPI KMS methods --- kms/capi/capi.go | 25 ++++++++++++------------- kms/tpmkms/tpmkms.go | 12 ++++++++++++ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 7db61aa4..4fbeb6b7 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -714,6 +714,14 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { return nil } +// DeleteCertificate deletes a certificate from the Windows certificate store. It uses +// largely the same logic for searching for the certificate as [LoadCertificate], but +// deletes it as soon as it it's found. +// +// # Experimental +// +// Notice: This method is EXPERIMENTAL and may be changed or removed in a later +// release. func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { u, err := uri.ParseWithScheme(Scheme, req.Name) if err != nil { @@ -785,7 +793,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } defer windows.CertFreeCertificateContext(certHandle) - if err := removeCertificateUsingWindowsContext(certHandle); err != nil { + if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { return fmt.Errorf("failed removing certificate: %w", err) } return nil @@ -816,7 +824,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } defer windows.CertFreeCertificateContext(certHandle) - if err := removeCertificateUsingWindowsContext(certHandle); err != nil { + if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { return fmt.Errorf("failed removing certificate: %w", err) } return nil @@ -862,28 +870,19 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } if bytes.Equal(x509Cert.SerialNumber.Bytes(), serialBytes) { - if err := removeCertificateUsingWindowsContext(certHandle); err != nil { + if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { return fmt.Errorf("failed removing certificate: %w", err) } defer windows.CertFreeCertificateContext(certHandle) return nil } - prevCert = certHandle } default: - return fmt.Errorf("%s, %s, or %s and %s is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) + return fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) } } -func removeCertificateUsingWindowsContext(certContext *windows.CertContext) error { - r, _, err := procCertDeleteCertificateFromStore.Call(uintptr(unsafe.Pointer(certContext))) - if r != 1 { - return fmt.Errorf("procCertDeleteCertificateFromStore failed with %X: %v", r, err) - } - return nil -} - type CAPISigner struct { algorithmGroup string keyHandle uintptr diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index 7c3d9c0d..54136e58 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -860,6 +860,18 @@ func (k *TPMKMS) storeIntermediateToWindowsCertificateStore(c *x509.Certificate, return nil } +// DeleteCertificate deletes a certificate for the key identified by name from the +// TPMKMS. If the instance is configured to use the Windows certificate store, it'll +// delete the certificate from the certificate store, backed by a CAPIKMS instance. +// +// It's possible to delete a specific certificate for a key by specifying it's SHA1 +// or serial. This is only supported if the instance is configured to use the Windows +// certificate store. +// +// # Experimental +// +// Notice: This method is EXPERIMENTAL and may be changed or removed in a later +// release. func (k *TPMKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { if req.Name == "" { return errors.New("deleteCertificateRequest 'name' cannot be empty") From b5b8a3438f1045393510224a5b5514ed80c66541 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 9 Apr 2024 17:39:02 +0200 Subject: [PATCH 08/14] Remove `CertDeleteCertificateFromStore` and use `windows` instead --- kms/capi/ncrypt_windows.go | 1 - 1 file changed, 1 deletion(-) diff --git a/kms/capi/ncrypt_windows.go b/kms/capi/ncrypt_windows.go index 44ce2561..49d4dc5c 100644 --- a/kms/capi/ncrypt_windows.go +++ b/kms/capi/ncrypt_windows.go @@ -145,7 +145,6 @@ var ( crypt32 = windows.MustLoadDLL("crypt32.dll") procCertFindCertificateInStore = crypt32.MustFindProc("CertFindCertificateInStore") - procCertDeleteCertificateFromStore = crypt32.MustFindProc("CertDeleteCertificateFromStore") procCryptFindCertificateKeyProvInfo = crypt32.MustFindProc("CryptFindCertificateKeyProvInfo") procCertStrToName = crypt32.MustFindProc("CertStrToNameW") ) From 8e154290809188a6273aafa038c994641e199d0b Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 9 Apr 2024 17:53:17 +0200 Subject: [PATCH 09/14] Fix comment spelling error --- kms/capi/capi.go | 2 +- kms/tpmkms/tpmkms.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 4fbeb6b7..600f34ae 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -716,7 +716,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { // DeleteCertificate deletes a certificate from the Windows certificate store. It uses // largely the same logic for searching for the certificate as [LoadCertificate], but -// deletes it as soon as it it's found. +// deletes it as soon as it's found. // // # Experimental // diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index 54136e58..2f0b425d 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -954,7 +954,7 @@ func (k *TPMKMS) deleteCertificateFromWindowsCertificateStore(req *apiv1.DeleteC return fmt.Errorf("expected Windows certificate manager to implement DeleteCertificate") } - if err := dk.DeleteCertificate(&apiv1.DeleteCertificateRequest{ // TODO(hs): handle specific error cases? + if err := dk.DeleteCertificate(&apiv1.DeleteCertificateRequest{ Name: uri.New("capi", uv).String(), }); err != nil { return fmt.Errorf("failed deleting certificate using Windows platform cryptography provider: %w", err) From 7f950f1ecb59117bb736c8d88e646a00ed4dadc6 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Fri, 12 Apr 2024 13:37:17 +0200 Subject: [PATCH 10/14] Remove unused `defer` while freeing up Windows certificate context --- kms/capi/capi.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 600f34ae..a493c3a2 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -873,7 +873,8 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { return fmt.Errorf("failed removing certificate: %w", err) } - defer windows.CertFreeCertificateContext(certHandle) + + windows.CertFreeCertificateContext(certHandle) return nil } prevCert = certHandle From 75696c3872fd754d97a7eddcdde0f1a32bf12802 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Mon, 15 Apr 2024 14:48:58 +0200 Subject: [PATCH 11/14] Free certificate handles while looking for one to delete --- kms/capi/capi.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index a493c3a2..9ea4cda2 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -859,10 +859,10 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { if err != nil { return fmt.Errorf("findCertificateInStore failed: %w", err) } - if certHandle == nil { return nil } + defer windows.CertFreeCertificateContext(certHandle) x509Cert, err := certContextToX509(certHandle) if err != nil { @@ -874,7 +874,6 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return fmt.Errorf("failed removing certificate: %w", err) } - windows.CertFreeCertificateContext(certHandle) return nil } prevCert = certHandle From fac70361ef13b70e4c9b54f21af1ce36c2945350 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Mon, 15 Apr 2024 22:41:34 +0200 Subject: [PATCH 12/14] Add `GetHexEncoded` method to `uri` package --- kms/capi/capi.go | 30 ++++++++++++------------------ kms/uri/uri.go | 17 ++++++++++++++++- kms/uri/uri_test.go | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+), 19 deletions(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 9ea4cda2..8d5705c5 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -503,7 +503,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert return nil, fmt.Errorf("failed to parse URI: %w", err) } - sha1Hash := u.Get(HashArg) + sha1Hash := u.GetHexEncoded(HashArg) keyID := u.Get(KeyIDArg) issuerName := u.Get(IssuerNameArg) serialNumber := u.Get(SerialNumberArg) @@ -544,18 +544,15 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert var certHandle *windows.CertContext switch { - case sha1Hash != "": - sha1Hash = strings.TrimPrefix(sha1Hash, "0x") // Support specifying the hash as 0x like with serial - - sha1Bytes, err := hex.DecodeString(sha1Hash) - if err != nil { - return nil, fmt.Errorf("%s must be in hex format: %w", HashArg, err) + case len(sha1Hash) > 0: + if len(sha1Hash) != 20 { + return nil, fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash)) } searchData := CERT_ID_KEYIDORHASH{ idChoice: CERT_ID_SHA1_HASH, KeyIDOrHash: CRYPTOAPI_BLOB{ - len: uint32(len(sha1Bytes)), - data: uintptr(unsafe.Pointer(&sha1Bytes[0])), + len: uint32(len(sha1Hash)), + data: uintptr(unsafe.Pointer(&sha1Hash[0])), }, } certHandle, err = findCertificateInStore(st, @@ -728,7 +725,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return fmt.Errorf("failed to parse URI: %w", err) } - sha1Hash := u.Get(HashArg) + sha1Hash := u.GetHexEncoded(HashArg) keyID := u.Get(KeyIDArg) issuerName := u.Get(IssuerNameArg) serialNumber := u.Get(SerialNumberArg) @@ -766,18 +763,15 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { var certHandle *windows.CertContext switch { - case sha1Hash != "": - sha1Hash = strings.TrimPrefix(sha1Hash, "0x") // Support specifying the hash as 0x like with serial - - sha1Bytes, err := hex.DecodeString(sha1Hash) - if err != nil { - return fmt.Errorf("%s must be in hex format: %w", HashArg, err) + case len(sha1Hash) > 0: + if len(sha1Hash) != 20 { + return fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash)) } searchData := CERT_ID_KEYIDORHASH{ idChoice: CERT_ID_SHA1_HASH, KeyIDOrHash: CRYPTOAPI_BLOB{ - len: uint32(len(sha1Bytes)), - data: uintptr(unsafe.Pointer(&sha1Bytes[0])), + len: uint32(len(sha1Hash)), + data: uintptr(unsafe.Pointer(&sha1Hash[0])), }, } certHandle, err = findCertificateInStore(st, diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 9953407c..df4a1f94 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -150,13 +150,28 @@ func (u *URI) GetEncoded(key string) []byte { return nil } if len(v)%2 == 0 { - if b, err := hex.DecodeString(v); err == nil { + if b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")); err == nil { return b } } return []byte(v) } +// GetHexEncoded returns the first value in the uri with the given key, it will +// return nil if the field is not present, is empty, or is not hex encoded. +func (u *URI) GetHexEncoded(key string) []byte { + v := u.Get(key) + if v == "" { + return nil + } + if len(v)%2 == 0 { + if b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")); err == nil { + return b + } + } + return nil +} + // Pin returns the pin encoded in the url. It will read the pin from the // pin-value or the pin-source attributes. func (u *URI) Pin() string { diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index ae724481..f691516c 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNew(t *testing.T) { @@ -237,6 +238,7 @@ func TestURI_GetEncoded(t *testing.T) { want []byte }{ {"ok", mustParse("yubikey:slot-id=9a"), args{"slot-id"}, []byte{0x9a}}, + {"ok prefix", mustParse("yubikey:slot-id=0x9a"), args{"slot-id"}, []byte{0x9a}}, {"ok first", mustParse("yubikey:slot-id=9a9b;slot-id=9b"), args{"slot-id"}, []byte{0x9a, 0x9b}}, {"ok percent", mustParse("yubikey:slot-id=9a;foo=%9a%9b%9c"), args{"foo"}, []byte{0x9a, 0x9b, 0x9c}}, {"ok in query", mustParse("yubikey:slot-id=9a?foo=9a"), args{"foo"}, []byte{0x9a}}, @@ -342,3 +344,33 @@ func TestURI_GetInt(t *testing.T) { }) } } + +func TestURI_GetHexEncoded(t *testing.T) { + mustParse := func(t *testing.T, s string) *URI { + t.Helper() + u, err := Parse(s) + require.NoError(t, err) + return u + } + type args struct { + key string + } + tests := []struct { + name string + uri *URI + args args + want []byte + }{ + {"ok", mustParse(t, "capi:sha1=9a"), args{"sha1"}, []byte{0x9a}}, + {"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}}, + {"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}}, + {"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil}, + {"ok odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.uri.GetHexEncoded(tt.args.key) + assert.Equal(t, tt.want, got) + }) + } +} From 2fd7f84ea0917b1492ebb5f2f8a7387b3b1e8856 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 16 Apr 2024 17:11:13 +0200 Subject: [PATCH 13/14] Make `GetHexEncoded` return an error and improve error message formatting --- kms/capi/capi.go | 40 +++++++++++++++++++++++----------------- kms/uri/uri.go | 21 ++++++++++++--------- kms/uri/uri_test.go | 28 ++++++++++++++++++---------- 3 files changed, 53 insertions(+), 36 deletions(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 8d5705c5..f266d8ef 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -503,7 +503,10 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert return nil, fmt.Errorf("failed to parse URI: %w", err) } - sha1Hash := u.GetHexEncoded(HashArg) + sha1Hash, err := u.GetHexEncoded(HashArg) + if err != nil { + return nil, fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err) + } keyID := u.Get(KeyIDArg) issuerName := u.Get(IssuerNameArg) serialNumber := u.Get(SerialNumberArg) @@ -521,7 +524,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert case "machine": certStoreLocation = certStoreLocalMachine default: - return nil, fmt.Errorf("invalid cert store location %v", storeLocation) + return nil, fmt.Errorf("invalid cert store location %q", storeLocation) } var storeName string @@ -538,7 +541,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert certStoreLocation, uintptr(unsafe.Pointer(wide(storeName)))) if err != nil { - return nil, fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err) + return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err) } var certHandle *windows.CertContext @@ -564,7 +567,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert return nil, fmt.Errorf("findCertificateInStore failed: %w", err) } if certHandle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", HashArg, keyID)} + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", HashArg, keyID)} } defer windows.CertFreeCertificateContext(certHandle) return certContextToX509(certHandle) @@ -573,7 +576,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert keyIDBytes, err := hex.DecodeString(keyID) if err != nil { - return nil, fmt.Errorf("%v must be in hex format: %w", KeyIDArg, err) + return nil, fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err) } searchData := CERT_ID_KEYIDORHASH{ idChoice: CERT_ID_KEY_IDENTIFIER, @@ -591,7 +594,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert return nil, fmt.Errorf("findCertificateInStore failed: %w", err) } if certHandle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", KeyIDArg, keyID)} + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", KeyIDArg, keyID)} } defer windows.CertFreeCertificateContext(certHandle) return certContextToX509(certHandle) @@ -605,13 +608,13 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed serialBytes, err = hex.DecodeString(serialNumber) if err != nil { - return nil, fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err) + return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) } } else { bi := new(big.Int) bi, ok := bi.SetString(serialNumber, 10) if !ok { - return nil, fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg) + return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) } serialBytes = bi.Bytes() } @@ -628,7 +631,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert } if certHandle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%v and %v=%v not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)} + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q and %s=%q not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)} } x509Cert, err := certContextToX509(certHandle) @@ -645,7 +648,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert prevCert = certHandle } default: - return nil, fmt.Errorf("%s, %s, or %s and %s is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) + return nil, fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) } } @@ -667,7 +670,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { case "machine": certStoreLocation = certStoreLocalMachine default: - return fmt.Errorf("invalid cert store location %v", storeLocation) + return fmt.Errorf("invalid cert store location %q", storeLocation) } var storeName string @@ -700,7 +703,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { certStoreLocation, uintptr(unsafe.Pointer(wide(storeName)))) if err != nil { - return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err) + return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err) } // Add the cert context to the system certificate store @@ -725,7 +728,10 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return fmt.Errorf("failed to parse URI: %w", err) } - sha1Hash := u.GetHexEncoded(HashArg) + sha1Hash, err := u.GetHexEncoded(HashArg) + if err != nil { + return fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err) + } keyID := u.Get(KeyIDArg) issuerName := u.Get(IssuerNameArg) serialNumber := u.Get(SerialNumberArg) @@ -742,7 +748,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { case "machine": certStoreLocation = certStoreLocalMachine default: - return fmt.Errorf("invalid cert store location %v", storeLocation) + return fmt.Errorf("invalid cert store location %q", storeLocation) } var storeName string @@ -757,7 +763,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { certStoreLocation, uintptr(unsafe.Pointer(wide(storeName)))) if err != nil { - return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err) + return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err) } var certHandle *windows.CertContext @@ -832,13 +838,13 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed serialBytes, err = hex.DecodeString(serialNumber) if err != nil { - return fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err) + return fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) } } else { bi := new(big.Int) bi, ok := bi.SetString(serialNumber, 10) if !ok { - return fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg) + return fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) } serialBytes = bi.Bytes() } diff --git a/kms/uri/uri.go b/kms/uri/uri.go index df4a1f94..a3e325b8 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -3,6 +3,7 @@ package uri import ( "bytes" "encoding/hex" + "fmt" "net/url" "os" "strconv" @@ -157,19 +158,21 @@ func (u *URI) GetEncoded(key string) []byte { return []byte(v) } -// GetHexEncoded returns the first value in the uri with the given key, it will -// return nil if the field is not present, is empty, or is not hex encoded. -func (u *URI) GetHexEncoded(key string) []byte { +// GetHexEncoded returns the first value in the uri with the given key. It +// returns nil if the field is not present or is empty. It will return an +// error if the the value is not properly hex encoded. +func (u *URI) GetHexEncoded(key string) ([]byte, error) { v := u.Get(key) if v == "" { - return nil + return nil, nil } - if len(v)%2 == 0 { - if b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")); err == nil { - return b - } + + b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")) + if err != nil { + return nil, fmt.Errorf("failed decoding %q: %w", v, err) } - return nil + + return b, nil } // Pin returns the pin encoded in the url. It will read the pin from the diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index f691516c..6688ee73 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -356,20 +356,28 @@ func TestURI_GetHexEncoded(t *testing.T) { key string } tests := []struct { - name string - uri *URI - args args - want []byte + name string + uri *URI + args args + want []byte + wantErr bool }{ - {"ok", mustParse(t, "capi:sha1=9a"), args{"sha1"}, []byte{0x9a}}, - {"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}}, - {"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}}, - {"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil}, - {"ok odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil}, + {"ok", mustParse(t, "capi:sha1=9a"), args{"sha1"}, []byte{0x9a}, false}, + {"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false}, + {"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false}, + {"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil, false}, + {"fail odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil, true}, + {"fail invalid hex", mustParse(t, "capi:sha1=9z?bar=zar"), args{"sha1"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.uri.GetHexEncoded(tt.args.key) + got, err := tt.uri.GetHexEncoded(tt.args.key) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + return + } + assert.Equal(t, tt.want, got) }) } From 262590beae0ddab9000d63a3fab6caef90b042a7 Mon Sep 17 00:00:00 2001 From: Herman Slatman Date: Tue, 16 Apr 2024 17:12:35 +0200 Subject: [PATCH 14/14] Fix leftover `%v` format --- kms/capi/capi.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/kms/capi/capi.go b/kms/capi/capi.go index f266d8ef..3acb47dc 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -802,7 +802,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { keyIDBytes, err := hex.DecodeString(keyID) if err != nil { - return fmt.Errorf("%v must be in hex format: %w", KeyIDArg, err) + return fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err) } searchData := CERT_ID_KEYIDORHASH{ idChoice: CERT_ID_KEY_IDENTIFIER,