Skip to content

Commit

Permalink
Merge pull request #475 from smallstep/herman/windows-tpm-certificate…
Browse files Browse the repository at this point in the history
…-stores

Support skipping certificate private key check on request
  • Loading branch information
hslatman authored Apr 16, 2024
2 parents ba8d2ce + 262590b commit 1d8dca8
Show file tree
Hide file tree
Showing 7 changed files with 409 additions and 48 deletions.
209 changes: 189 additions & 20 deletions kms/capi/capi.go
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,10 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
return nil, fmt.Errorf("failed to parse URI: %w", err)
}

sha1Hash := u.Get(HashArg)
sha1Hash, err := u.GetHexEncoded(HashArg)
if err != nil {
return nil, fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err)
}
keyID := u.Get(KeyIDArg)
issuerName := u.Get(IssuerNameArg)
serialNumber := u.Get(SerialNumberArg)
Expand All @@ -521,7 +524,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
case "machine":
certStoreLocation = certStoreLocalMachine
default:
return nil, fmt.Errorf("invalid cert store location %v", storeLocation)
return nil, fmt.Errorf("invalid cert store location %q", storeLocation)
}

var storeName string
Expand All @@ -538,24 +541,21 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
certStoreLocation,
uintptr(unsafe.Pointer(wide(storeName))))
if err != nil {
return nil, fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err)
return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
}

var certHandle *windows.CertContext

switch {
case sha1Hash != "":
sha1Hash = strings.TrimPrefix(sha1Hash, "0x") // Support specifying the hash as 0x like with serial

sha1Bytes, err := hex.DecodeString(sha1Hash)
if err != nil {
return nil, fmt.Errorf("%s must be in hex format: %w", HashArg, err)
case len(sha1Hash) > 0:
if len(sha1Hash) != 20 {
return nil, fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash))
}
searchData := CERT_ID_KEYIDORHASH{
idChoice: CERT_ID_SHA1_HASH,
KeyIDOrHash: CRYPTOAPI_BLOB{
len: uint32(len(sha1Bytes)),
data: uintptr(unsafe.Pointer(&sha1Bytes[0])),
len: uint32(len(sha1Hash)),
data: uintptr(unsafe.Pointer(&sha1Hash[0])),
},
}
certHandle, err = findCertificateInStore(st,
Expand All @@ -567,7 +567,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", HashArg, keyID)}
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", HashArg, keyID)}
}
defer windows.CertFreeCertificateContext(certHandle)
return certContextToX509(certHandle)
Expand All @@ -576,7 +576,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert

keyIDBytes, err := hex.DecodeString(keyID)
if err != nil {
return nil, fmt.Errorf("%v must be in hex format: %w", KeyIDArg, err)
return nil, fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err)
}
searchData := CERT_ID_KEYIDORHASH{
idChoice: CERT_ID_KEY_IDENTIFIER,
Expand All @@ -594,7 +594,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
return nil, fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", KeyIDArg, keyID)}
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", KeyIDArg, keyID)}
}
defer windows.CertFreeCertificateContext(certHandle)
return certContextToX509(certHandle)
Expand All @@ -608,13 +608,13 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed
serialBytes, err = hex.DecodeString(serialNumber)
if err != nil {
return nil, fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err)
return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err)
}
} else {
bi := new(big.Int)
bi, ok := bi.SetString(serialNumber, 10)
if !ok {
return nil, fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg)
return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg)
}
serialBytes = bi.Bytes()
}
Expand All @@ -631,7 +631,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
}

if certHandle == nil {
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%v and %v=%v not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)}
return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q and %s=%q not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)}
}

x509Cert, err := certContextToX509(certHandle)
Expand All @@ -648,7 +648,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert
prevCert = certHandle
}
default:
return nil, fmt.Errorf("%s, %s, or %s and %s is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg)
return nil, fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg)
}
}

Expand All @@ -670,7 +670,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
case "machine":
certStoreLocation = certStoreLocalMachine
default:
return fmt.Errorf("invalid cert store location %v", storeLocation)
return fmt.Errorf("invalid cert store location %q", storeLocation)
}

var storeName string
Expand Down Expand Up @@ -703,7 +703,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
certStoreLocation,
uintptr(unsafe.Pointer(wide(storeName))))
if err != nil {
return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err)
return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
}

// Add the cert context to the system certificate store
Expand All @@ -714,6 +714,175 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
return nil
}

// DeleteCertificate deletes a certificate from the Windows certificate store. It uses
// largely the same logic for searching for the certificate as [LoadCertificate], but
// deletes it as soon as it's found.
//
// # Experimental
//
// Notice: This method is EXPERIMENTAL and may be changed or removed in a later
// release.
func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error {
u, err := uri.ParseWithScheme(Scheme, req.Name)
if err != nil {
return fmt.Errorf("failed to parse URI: %w", err)
}

sha1Hash, err := u.GetHexEncoded(HashArg)
if err != nil {
return fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err)
}
keyID := u.Get(KeyIDArg)
issuerName := u.Get(IssuerNameArg)
serialNumber := u.Get(SerialNumberArg)

var storeLocation string
if storeLocation = u.Get(StoreLocationArg); storeLocation == "" {
storeLocation = "user"
}

var certStoreLocation uint32
switch storeLocation {
case "user":
certStoreLocation = certStoreCurrentUser
case "machine":
certStoreLocation = certStoreLocalMachine
default:
return fmt.Errorf("invalid cert store location %q", storeLocation)
}

var storeName string
if storeName = u.Get(StoreNameArg); storeName == "" {
storeName = "My"
}

st, err := windows.CertOpenStore(
certStoreProvSystem,
0,
0,
certStoreLocation,
uintptr(unsafe.Pointer(wide(storeName))))
if err != nil {
return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err)
}

var certHandle *windows.CertContext

switch {
case len(sha1Hash) > 0:
if len(sha1Hash) != 20 {
return fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash))
}
searchData := CERT_ID_KEYIDORHASH{
idChoice: CERT_ID_SHA1_HASH,
KeyIDOrHash: CRYPTOAPI_BLOB{
len: uint32(len(sha1Hash)),
data: uintptr(unsafe.Pointer(&sha1Hash[0])),
},
}
certHandle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findCertID,
uintptr(unsafe.Pointer(&searchData)), nil)
if err != nil {
return fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
return nil
}
defer windows.CertFreeCertificateContext(certHandle)

if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil {
return fmt.Errorf("failed removing certificate: %w", err)
}
return nil
case keyID != "":
keyID = strings.TrimPrefix(keyID, "0x") // Support specifying the hash as 0x like with serial

keyIDBytes, err := hex.DecodeString(keyID)
if err != nil {
return fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err)
}
searchData := CERT_ID_KEYIDORHASH{
idChoice: CERT_ID_KEY_IDENTIFIER,
KeyIDOrHash: CRYPTOAPI_BLOB{
len: uint32(len(keyIDBytes)),
data: uintptr(unsafe.Pointer(&keyIDBytes[0])),
},
}
certHandle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findCertID,
uintptr(unsafe.Pointer(&searchData)), nil)
if err != nil {
return fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
return nil
}
defer windows.CertFreeCertificateContext(certHandle)

if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil {
return fmt.Errorf("failed removing certificate: %w", err)
}
return nil
case issuerName != "" && serialNumber != "":
//TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id
// https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number
var serialBytes []byte
if strings.HasPrefix(serialNumber, "0x") {
serialNumber = strings.TrimPrefix(serialNumber, "0x")
serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed
serialBytes, err = hex.DecodeString(serialNumber)
if err != nil {
return fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err)
}
} else {
bi := new(big.Int)
bi, ok := bi.SetString(serialNumber, 10)
if !ok {
return fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg)
}
serialBytes = bi.Bytes()
}
var prevCert *windows.CertContext
for {
certHandle, err = findCertificateInStore(st,
encodingX509ASN|encodingPKCS7,
0,
findIssuerStr,
uintptr(unsafe.Pointer(wide(issuerName))), prevCert)

if err != nil {
return fmt.Errorf("findCertificateInStore failed: %w", err)
}
if certHandle == nil {
return nil
}
defer windows.CertFreeCertificateContext(certHandle)

x509Cert, err := certContextToX509(certHandle)
if err != nil {
return fmt.Errorf("could not unmarshal certificate to DER: %w", err)
}

if bytes.Equal(x509Cert.SerialNumber.Bytes(), serialBytes) {
if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil {
return fmt.Errorf("failed removing certificate: %w", err)
}

return nil
}
prevCert = certHandle
}
default:
return fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg)
}
}

type CAPISigner struct {
algorithmGroup string
keyHandle uintptr
Expand Down
Loading

0 comments on commit 1d8dca8

Please sign in to comment.