diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 4dd92503..3acb47dc 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -503,7 +503,10 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert return nil, fmt.Errorf("failed to parse URI: %w", err) } - sha1Hash := u.Get(HashArg) + sha1Hash, err := u.GetHexEncoded(HashArg) + if err != nil { + return nil, fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err) + } keyID := u.Get(KeyIDArg) issuerName := u.Get(IssuerNameArg) serialNumber := u.Get(SerialNumberArg) @@ -521,7 +524,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert case "machine": certStoreLocation = certStoreLocalMachine default: - return nil, fmt.Errorf("invalid cert store location %v", storeLocation) + return nil, fmt.Errorf("invalid cert store location %q", storeLocation) } var storeName string @@ -538,24 +541,21 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert certStoreLocation, uintptr(unsafe.Pointer(wide(storeName)))) if err != nil { - return nil, fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err) + return nil, fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err) } var certHandle *windows.CertContext switch { - case sha1Hash != "": - sha1Hash = strings.TrimPrefix(sha1Hash, "0x") // Support specifying the hash as 0x like with serial - - sha1Bytes, err := hex.DecodeString(sha1Hash) - if err != nil { - return nil, fmt.Errorf("%s must be in hex format: %w", HashArg, err) + case len(sha1Hash) > 0: + if len(sha1Hash) != 20 { + return nil, fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash)) } searchData := CERT_ID_KEYIDORHASH{ idChoice: CERT_ID_SHA1_HASH, KeyIDOrHash: CRYPTOAPI_BLOB{ - len: uint32(len(sha1Bytes)), - data: uintptr(unsafe.Pointer(&sha1Bytes[0])), + len: uint32(len(sha1Hash)), + data: uintptr(unsafe.Pointer(&sha1Hash[0])), }, } certHandle, err = findCertificateInStore(st, @@ -567,7 +567,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert return nil, fmt.Errorf("findCertificateInStore failed: %w", err) } if certHandle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", HashArg, keyID)} + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", HashArg, keyID)} } defer windows.CertFreeCertificateContext(certHandle) return certContextToX509(certHandle) @@ -576,7 +576,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert keyIDBytes, err := hex.DecodeString(keyID) if err != nil { - return nil, fmt.Errorf("%v must be in hex format: %w", KeyIDArg, err) + return nil, fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err) } searchData := CERT_ID_KEYIDORHASH{ idChoice: CERT_ID_KEY_IDENTIFIER, @@ -594,7 +594,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert return nil, fmt.Errorf("findCertificateInStore failed: %w", err) } if certHandle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%s not found", KeyIDArg, keyID)} + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%s not found", KeyIDArg, keyID)} } defer windows.CertFreeCertificateContext(certHandle) return certContextToX509(certHandle) @@ -608,13 +608,13 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed serialBytes, err = hex.DecodeString(serialNumber) if err != nil { - return nil, fmt.Errorf("invalid hex format for %v: %w", SerialNumberArg, err) + return nil, fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) } } else { bi := new(big.Int) bi, ok := bi.SetString(serialNumber, 10) if !ok { - return nil, fmt.Errorf("invalid %v - must be in hex or integer format", SerialNumberArg) + return nil, fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) } serialBytes = bi.Bytes() } @@ -631,7 +631,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert } if certHandle == nil { - return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %v=%v and %v=%v not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)} + return nil, apiv1.NotFoundError{Message: fmt.Sprintf("certificate with %s=%q and %s=%q not found", IssuerNameArg, issuerName, SerialNumberArg, serialNumber)} } x509Cert, err := certContextToX509(certHandle) @@ -648,7 +648,7 @@ func (k *CAPIKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Cert prevCert = certHandle } default: - return nil, fmt.Errorf("%s, %s, or %s and %s is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) + return nil, fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) } } @@ -670,7 +670,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { case "machine": certStoreLocation = certStoreLocalMachine default: - return fmt.Errorf("invalid cert store location %v", storeLocation) + return fmt.Errorf("invalid cert store location %q", storeLocation) } var storeName string @@ -703,7 +703,7 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { certStoreLocation, uintptr(unsafe.Pointer(wide(storeName)))) if err != nil { - return fmt.Errorf("CertOpenStore for the %v store %v returned: %w", storeLocation, storeName, err) + return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err) } // Add the cert context to the system certificate store @@ -714,6 +714,175 @@ func (k *CAPIKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { return nil } +// DeleteCertificate deletes a certificate from the Windows certificate store. It uses +// largely the same logic for searching for the certificate as [LoadCertificate], but +// deletes it as soon as it's found. +// +// # Experimental +// +// Notice: This method is EXPERIMENTAL and may be changed or removed in a later +// release. +func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { + u, err := uri.ParseWithScheme(Scheme, req.Name) + if err != nil { + return fmt.Errorf("failed to parse URI: %w", err) + } + + sha1Hash, err := u.GetHexEncoded(HashArg) + if err != nil { + return fmt.Errorf("failed getting %s from URI %q: %w", HashArg, req.Name, err) + } + keyID := u.Get(KeyIDArg) + issuerName := u.Get(IssuerNameArg) + serialNumber := u.Get(SerialNumberArg) + + var storeLocation string + if storeLocation = u.Get(StoreLocationArg); storeLocation == "" { + storeLocation = "user" + } + + var certStoreLocation uint32 + switch storeLocation { + case "user": + certStoreLocation = certStoreCurrentUser + case "machine": + certStoreLocation = certStoreLocalMachine + default: + return fmt.Errorf("invalid cert store location %q", storeLocation) + } + + var storeName string + if storeName = u.Get(StoreNameArg); storeName == "" { + storeName = "My" + } + + st, err := windows.CertOpenStore( + certStoreProvSystem, + 0, + 0, + certStoreLocation, + uintptr(unsafe.Pointer(wide(storeName)))) + if err != nil { + return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err) + } + + var certHandle *windows.CertContext + + switch { + case len(sha1Hash) > 0: + if len(sha1Hash) != 20 { + return fmt.Errorf("decoded %s has length %d; expected 20 bytes for SHA-1", HashArg, len(sha1Hash)) + } + searchData := CERT_ID_KEYIDORHASH{ + idChoice: CERT_ID_SHA1_HASH, + KeyIDOrHash: CRYPTOAPI_BLOB{ + len: uint32(len(sha1Hash)), + data: uintptr(unsafe.Pointer(&sha1Hash[0])), + }, + } + certHandle, err = findCertificateInStore(st, + encodingX509ASN|encodingPKCS7, + 0, + findCertID, + uintptr(unsafe.Pointer(&searchData)), nil) + if err != nil { + return fmt.Errorf("findCertificateInStore failed: %w", err) + } + if certHandle == nil { + return nil + } + defer windows.CertFreeCertificateContext(certHandle) + + if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { + return fmt.Errorf("failed removing certificate: %w", err) + } + return nil + case keyID != "": + keyID = strings.TrimPrefix(keyID, "0x") // Support specifying the hash as 0x like with serial + + keyIDBytes, err := hex.DecodeString(keyID) + if err != nil { + return fmt.Errorf("%s must be in hex format: %w", KeyIDArg, err) + } + searchData := CERT_ID_KEYIDORHASH{ + idChoice: CERT_ID_KEY_IDENTIFIER, + KeyIDOrHash: CRYPTOAPI_BLOB{ + len: uint32(len(keyIDBytes)), + data: uintptr(unsafe.Pointer(&keyIDBytes[0])), + }, + } + certHandle, err = findCertificateInStore(st, + encodingX509ASN|encodingPKCS7, + 0, + findCertID, + uintptr(unsafe.Pointer(&searchData)), nil) + if err != nil { + return fmt.Errorf("findCertificateInStore failed: %w", err) + } + if certHandle == nil { + return nil + } + defer windows.CertFreeCertificateContext(certHandle) + + if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { + return fmt.Errorf("failed removing certificate: %w", err) + } + return nil + case issuerName != "" && serialNumber != "": + //TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead + // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id + // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_issuer_serial_number + var serialBytes []byte + if strings.HasPrefix(serialNumber, "0x") { + serialNumber = strings.TrimPrefix(serialNumber, "0x") + serialNumber = strings.TrimPrefix(serialNumber, "00") // Comparison fails if leading 00 is not removed + serialBytes, err = hex.DecodeString(serialNumber) + if err != nil { + return fmt.Errorf("invalid hex format for %s: %w", SerialNumberArg, err) + } + } else { + bi := new(big.Int) + bi, ok := bi.SetString(serialNumber, 10) + if !ok { + return fmt.Errorf("invalid %s - must be in hex or integer format", SerialNumberArg) + } + serialBytes = bi.Bytes() + } + var prevCert *windows.CertContext + for { + certHandle, err = findCertificateInStore(st, + encodingX509ASN|encodingPKCS7, + 0, + findIssuerStr, + uintptr(unsafe.Pointer(wide(issuerName))), prevCert) + + if err != nil { + return fmt.Errorf("findCertificateInStore failed: %w", err) + } + if certHandle == nil { + return nil + } + defer windows.CertFreeCertificateContext(certHandle) + + x509Cert, err := certContextToX509(certHandle) + if err != nil { + return fmt.Errorf("could not unmarshal certificate to DER: %w", err) + } + + if bytes.Equal(x509Cert.SerialNumber.Bytes(), serialBytes) { + if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { + return fmt.Errorf("failed removing certificate: %w", err) + } + + return nil + } + prevCert = certHandle + } + default: + return fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) + } +} + type CAPISigner struct { algorithmGroup string keyHandle uintptr diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index 05f7871e..2f0b425d 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -778,6 +778,10 @@ func (k *TPMKMS) storeCertificateChainToWindowsCertificateStore(req *apiv1.Store if o.store != "" { store = o.store } + skipFindCertificateKey := "false" + if o.skipFindCertificateKey { + skipFindCertificateKey = "true" + } leaf := req.CertificateChain[0] fp, err := fingerprint.New(leaf.Raw, crypto.SHA1, fingerprint.HexFingerprint) @@ -785,8 +789,14 @@ func (k *TPMKMS) storeCertificateChainToWindowsCertificateStore(req *apiv1.Store return fmt.Errorf("failed calculating certificate SHA1 fingerprint: %w", err) } + uv := url.Values{} + uv.Set("sha1", fp) + uv.Set("store-location", location) + uv.Set("store", store) + uv.Set("skip-find-certificate-key", skipFindCertificateKey) + if err := k.windowsCertificateManager.StoreCertificate(&apiv1.StoreCertificateRequest{ - Name: fmt.Sprintf("capi:sha1=%s;store-location=%s;store=%s;", fp, location, store), + Name: uri.New("capi", uv).String(), Certificate: leaf, }); err != nil { return fmt.Errorf("failed storing certificate using Windows platform cryptography provider: %w", err) @@ -850,6 +860,109 @@ func (k *TPMKMS) storeIntermediateToWindowsCertificateStore(c *x509.Certificate, return nil } +// DeleteCertificate deletes a certificate for the key identified by name from the +// TPMKMS. If the instance is configured to use the Windows certificate store, it'll +// delete the certificate from the certificate store, backed by a CAPIKMS instance. +// +// It's possible to delete a specific certificate for a key by specifying it's SHA1 +// or serial. This is only supported if the instance is configured to use the Windows +// certificate store. +// +// # Experimental +// +// Notice: This method is EXPERIMENTAL and may be changed or removed in a later +// release. +func (k *TPMKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { + if req.Name == "" { + return errors.New("deleteCertificateRequest 'name' cannot be empty") + } + + if k.usesWindowsCertificateStore() { + if err := k.deleteCertificateFromWindowsCertificateStore(&apiv1.DeleteCertificateRequest{ + Name: req.Name, + }); err != nil { + return fmt.Errorf("failed deleting certificate from Windows platform cryptography provider: %w", err) + } + + return nil + } + + // TODO(hs): support delete by serial? If not, the behavior for TPM storage and Windows + // certificate store storage will be different, and may need different behavior when + // implementing certificate management. + + properties, err := parseNameURI(req.Name) + if err != nil { + return fmt.Errorf("failed parsing %q: %w", req.Name, err) + } + + ctx := context.Background() + if properties.ak { + ak, err := k.tpm.GetAK(ctx, properties.name) + if err != nil { + return err + } + if err := ak.SetCertificateChain(ctx, nil); err != nil { + return fmt.Errorf("failed storing certificate for AK %q: %w", properties.name, err) + } + } else { + key, err := k.tpm.GetKey(ctx, properties.name) + if err != nil { + return err + } + if err := key.SetCertificateChain(ctx, nil); err != nil { + return fmt.Errorf("failed storing certificate for key %q: %w", properties.name, err) + } + } + + return nil +} + +func (k *TPMKMS) deleteCertificateFromWindowsCertificateStore(req *apiv1.DeleteCertificateRequest) error { + o, err := parseNameURI(req.Name) + if err != nil { + return fmt.Errorf("failed parsing %q: %w", req.Name, err) + } + + location := k.windowsCertificateStoreLocation + if o.storeLocation != "" { + location = o.storeLocation + } + store := k.windowsCertificateStore + if o.store != "" { + store = o.store + } + + uv := url.Values{} + uv.Set("store-location", location) + uv.Set("store", store) + + switch { + case o.serial != "": + uv.Set("serial", o.serial) + uv.Set("issuer", o.issuer) + case o.keyID != "": + uv.Set("key-id", o.keyID) + case o.sha1 != "": + uv.Set("sha1", o.sha1) + default: + return errors.New(`at least one of "serial", "key-id" or "sha1" is expected to be set`) + } + + dk, ok := k.windowsCertificateManager.(deletingCertificateManager) + if !ok { + return fmt.Errorf("expected Windows certificate manager to implement DeleteCertificate") + } + + if err := dk.DeleteCertificate(&apiv1.DeleteCertificateRequest{ + Name: uri.New("capi", uv).String(), + }); err != nil { + return fmt.Errorf("failed deleting certificate using Windows platform cryptography provider: %w", err) + } + + return nil +} + // attestationClient is a wrapper for [attestation.Client], containing // all of the required references to perform attestation against the // Smallstep Attestation CA. @@ -1173,8 +1286,19 @@ func generateWindowsSubjectKeyID(pub crypto.PublicKey) (string, error) { return hex.EncodeToString(hash[:]), nil } +type deletingCertificateManager interface { + apiv1.CertificateManager + DeleteCertificate(req *apiv1.DeleteCertificateRequest) error +} + +type deletingCertificateChainManager interface { + apiv1.CertificateChainManager + DeleteCertificate(req *apiv1.DeleteCertificateRequest) error +} + var _ apiv1.KeyManager = (*TPMKMS)(nil) var _ apiv1.Attester = (*TPMKMS)(nil) var _ apiv1.CertificateManager = (*TPMKMS)(nil) var _ apiv1.CertificateChainManager = (*TPMKMS)(nil) +var _ deletingCertificateChainManager = (*TPMKMS)(nil) var _ apiv1.AttestationClient = (*attestationClient)(nil) diff --git a/kms/tpmkms/uri.go b/kms/tpmkms/uri.go index 993c0021..2614d6b7 100644 --- a/kms/tpmkms/uri.go +++ b/kms/tpmkms/uri.go @@ -21,6 +21,11 @@ type objectProperties struct { store string intermediateStoreLocation string intermediateStore string + skipFindCertificateKey bool + keyID string + sha1 string + serial string + issuer string } func parseNameURI(nameURI string) (o objectProperties, err error) { @@ -59,6 +64,11 @@ func parseNameURI(nameURI string) (o objectProperties, err error) { o.store = u.Get("store") o.intermediateStoreLocation = u.Get("intermediate-store-location") o.intermediateStore = u.Get("intermediate-store") + o.skipFindCertificateKey = u.GetBool("skip-find-certificate-key") + o.keyID = u.Get("key-id") + o.sha1 = u.Get("sha1") + o.serial = u.Get("serial") + o.issuer = u.Get("issuer") // validation if o.ak && o.attestBy != "" { diff --git a/kms/uri/uri.go b/kms/uri/uri.go index 9953407c..a3e325b8 100644 --- a/kms/uri/uri.go +++ b/kms/uri/uri.go @@ -3,6 +3,7 @@ package uri import ( "bytes" "encoding/hex" + "fmt" "net/url" "os" "strconv" @@ -150,13 +151,30 @@ func (u *URI) GetEncoded(key string) []byte { return nil } if len(v)%2 == 0 { - if b, err := hex.DecodeString(v); err == nil { + if b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")); err == nil { return b } } return []byte(v) } +// GetHexEncoded returns the first value in the uri with the given key. It +// returns nil if the field is not present or is empty. It will return an +// error if the the value is not properly hex encoded. +func (u *URI) GetHexEncoded(key string) ([]byte, error) { + v := u.Get(key) + if v == "" { + return nil, nil + } + + b, err := hex.DecodeString(strings.TrimPrefix(v, "0x")) + if err != nil { + return nil, fmt.Errorf("failed decoding %q: %w", v, err) + } + + return b, nil +} + // Pin returns the pin encoded in the url. It will read the pin from the // pin-value or the pin-source attributes. func (u *URI) Pin() string { diff --git a/kms/uri/uri_test.go b/kms/uri/uri_test.go index ae724481..6688ee73 100644 --- a/kms/uri/uri_test.go +++ b/kms/uri/uri_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNew(t *testing.T) { @@ -237,6 +238,7 @@ func TestURI_GetEncoded(t *testing.T) { want []byte }{ {"ok", mustParse("yubikey:slot-id=9a"), args{"slot-id"}, []byte{0x9a}}, + {"ok prefix", mustParse("yubikey:slot-id=0x9a"), args{"slot-id"}, []byte{0x9a}}, {"ok first", mustParse("yubikey:slot-id=9a9b;slot-id=9b"), args{"slot-id"}, []byte{0x9a, 0x9b}}, {"ok percent", mustParse("yubikey:slot-id=9a;foo=%9a%9b%9c"), args{"foo"}, []byte{0x9a, 0x9b, 0x9c}}, {"ok in query", mustParse("yubikey:slot-id=9a?foo=9a"), args{"foo"}, []byte{0x9a}}, @@ -342,3 +344,41 @@ func TestURI_GetInt(t *testing.T) { }) } } + +func TestURI_GetHexEncoded(t *testing.T) { + mustParse := func(t *testing.T, s string) *URI { + t.Helper() + u, err := Parse(s) + require.NoError(t, err) + return u + } + type args struct { + key string + } + tests := []struct { + name string + uri *URI + args args + want []byte + wantErr bool + }{ + {"ok", mustParse(t, "capi:sha1=9a"), args{"sha1"}, []byte{0x9a}, false}, + {"ok first", mustParse(t, "capi:sha1=9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false}, + {"ok prefix", mustParse(t, "capi:sha1=0x9a9b;sha1=9b"), args{"sha1"}, []byte{0x9a, 0x9b}, false}, + {"ok missing", mustParse(t, "capi:foo=9a"), args{"sha1"}, nil, false}, + {"fail odd hex", mustParse(t, "capi:sha1=09a?bar=zar"), args{"sha1"}, nil, true}, + {"fail invalid hex", mustParse(t, "capi:sha1=9z?bar=zar"), args{"sha1"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.uri.GetHexEncoded(tt.args.key) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + return + } + + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/tpm/ak.go b/tpm/ak.go index a3d2c176..4a0f9baf 100644 --- a/tpm/ak.go +++ b/tpm/ak.go @@ -384,28 +384,26 @@ func (ak *AK) SetCertificateChain(ctx context.Context, chain []*x509.Certificate } defer closeTPM(ctx, ak.tpm, &err) - if len(chain) == 0 { - return errors.New("certificate chain must contain at least one certificate") - } - akPublic, err := ak.public(internalCall(ctx)) if err != nil { return fmt.Errorf("failed getting AK public key: %w", err) } - leaf := chain[0] - leafPK, ok := leaf.PublicKey.(crypto.PublicKey) - if !ok { - return fmt.Errorf("unexpected type for AK certificate public key: %T", leaf.PublicKey) - } + if len(chain) > 0 { + leaf := chain[0] + leafPK, ok := leaf.PublicKey.(crypto.PublicKey) + if !ok { + return fmt.Errorf("unexpected type for AK certificate public key: %T", leaf.PublicKey) + } - publicKey, ok := leafPK.(comparablePublicKey) - if !ok { - return errors.New("certificate public key can't be compared to a crypto.PublicKey") - } + publicKey, ok := leafPK.(comparablePublicKey) + if !ok { + return errors.New("certificate public key can't be compared to a crypto.PublicKey") + } - if !publicKey.Equal(akPublic) { - return errors.New("AK public key does not match the leaf certificate public key") + if !publicKey.Equal(akPublic) { + return errors.New("AK public key does not match the leaf certificate public key") + } } ak.chain = chain // TODO(hs): deep copy, so that certs can't be changed by pointer? diff --git a/tpm/key.go b/tpm/key.go index cb1b08c3..87a91b76 100644 --- a/tpm/key.go +++ b/tpm/key.go @@ -435,19 +435,21 @@ func (k *Key) SetCertificateChain(ctx context.Context, chain []*x509.Certificate return fmt.Errorf("failed getting signer for key: %w", err) } - leaf := chain[0] - leafPK, ok := leaf.PublicKey.(crypto.PublicKey) - if !ok { - return fmt.Errorf("unexpected type for certificate public key: %T", leaf.PublicKey) - } + if len(chain) > 0 { + leaf := chain[0] + leafPK, ok := leaf.PublicKey.(crypto.PublicKey) + if !ok { + return fmt.Errorf("unexpected type for certificate public key: %T", leaf.PublicKey) + } - publicKey, ok := leafPK.(comparablePublicKey) - if !ok { - return errors.New("certificate public key can't be compared to a crypto.PublicKey") - } + publicKey, ok := leafPK.(comparablePublicKey) + if !ok { + return errors.New("certificate public key can't be compared to a crypto.PublicKey") + } - if !publicKey.Equal(signer.Public()) { - return errors.New("public key does not match the leaf certificate public key") + if !publicKey.Equal(signer.Public()) { + return errors.New("public key does not match the leaf certificate public key") + } } k.chain = chain // TODO(hs): deep copy, so that certs can't be changed by pointer?