diff --git a/certificates.go b/certificates.go index b9c8225..54a1cb0 100644 --- a/certificates.go +++ b/certificates.go @@ -22,6 +22,7 @@ package crypto11 import ( + "bytes" "crypto/tls" "crypto/x509" "encoding/asn1" @@ -111,6 +112,67 @@ func findCertificatesWithAttributes(session *pkcs11Session, template []*pkcs11.A return handles, nil } +func findCertificateByKeyID(session *pkcs11Session, keyID []byte) (cert *x509.Certificate, err error) { + handles, err := findCertificatesWithAttributes(session, nil) + if err != nil { + return nil, err + } + + for _, handle := range handles { + if cert, err = getX509Certificate(session, handle); err != nil { + return nil, err + } + + if bytes.Equal(cert.SubjectKeyId, keyID) { + return cert, nil + } + } + + return nil, errors.New("no certificate with required subject key ID found") +} + +func findCertificateChain(session *pkcs11Session, cert *x509.Certificate) (certs []*x509.Certificate, err error) { + if len(cert.RawIssuer) == 0 || bytes.Equal(cert.RawIssuer, cert.RawSubject) { + return nil, nil + } + + template := []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_SUBJECT, cert.RawIssuer)} + + handles, err := findCertificatesWithAttributes(session, template) + if err != nil { + return nil, err + } + + if len(handles) == 0 { + if cert, err = findCertificateByKeyID(session, cert.AuthorityKeyId); err != nil { + return nil, err + } + } else { + if cert, err = getX509Certificate(session, handles[0]); err != nil { + return nil, err + } + } + + for _, foundCert := range certs { + if bytes.Equal(cert.RawSubject, foundCert.RawSubject) { + return certs, nil + } + } + + certs = append(certs, cert) + + certChain, err := findCertificateChain(session, cert) + if err != nil { + return nil, err + } + + if len(certChain) != 0 { + certs = append(certs, certChain...) + } + + return certs, nil +} + // FindCertificate retrieves a previously imported certificate. Any combination of id, label // and serial can be provided. An error is return if all are nil. func (c *Context) FindCertificate(id []byte, label []byte, serial *big.Int) (*x509.Certificate, error) { @@ -128,6 +190,42 @@ func (c *Context) FindCertificate(id []byte, label []byte, serial *big.Int) (*x5 return cert, err } +// FindCertificateChain retrieves a previously imported certificate chain. Any combination of id, label +// and serial can be provided. An error is return if all are nil. +func (c *Context) FindCertificateChain(id []byte, label []byte, serial *big.Int) (certs []*x509.Certificate, err error) { + if c.closed.Get() { + return nil, errClosed + } + + err = c.withSession(func(session *pkcs11Session) (err error) { + cert, err := findCertificate(session, id, label, serial) + if err != nil { + return err + } + + if cert == nil { + return nil + } + + certs = append(certs, cert) + + certChain, err := findCertificateChain(session, cert) + if err != nil { + return err + } + + if len(certChain) == 0 { + return nil + } + + certs = append(certs, certChain...) + + return nil + }) + + return certs, err +} + func (c *Context) FindAllPairedCertificates() (certificates []tls.Certificate, err error) { if c.closed.Get() { return nil, errClosed diff --git a/certificates_test.go b/certificates_test.go index 31aa856..69c43dc 100644 --- a/certificates_test.go +++ b/certificates_test.go @@ -48,7 +48,7 @@ func TestCertificate(t *testing.T) { id := randomBytes() label := randomBytes() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) err = ctx.ImportCertificateWithLabel(id, label, cert) require.NoError(t, err) @@ -81,7 +81,7 @@ func TestCertificateAttributes(t *testing.T) { require.NoError(t, ctx.Close()) }() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) // We import this with a different serial number, to test this is obeyed ourSerial := new(big.Int) @@ -116,7 +116,7 @@ func TestCertificateRequiredArgs(t *testing.T) { require.NoError(t, ctx.Close()) }() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) val := randomBytes() @@ -143,7 +143,7 @@ func TestDeleteCertificate(t *testing.T) { randomCert := func() ([]byte, []byte, *x509.Certificate) { id := randomBytes() label := randomBytes() - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) return id, label, cert } importCertificate := func() ([]byte, []byte, *big.Int) { @@ -207,14 +207,98 @@ func TestDeleteCertificate(t *testing.T) { require.Nil(t, cert) } -func generateRandomCert(t *testing.T) *x509.Certificate { +func TestCertificateChain(t *testing.T) { + skipTest(t, skipTestCert) + + ctx, err := ConfigureFromFile("config") + require.NoError(t, err) + + defer func() { + require.NoError(t, ctx.Close()) + }() + + certNames := []string{"Cert0", "Cert1", "Cert2"} + + var ( + parent *x509.Certificate + originCertChain []*x509.Certificate + authorityKeyId, subjectKeyID []byte + ids [][]byte + ) + + for _, name := range certNames { + subjectKeyID = randomBytes() + + cert := generateRandomCert(t, parent, name, authorityKeyId, subjectKeyID) + + id := randomBytes() + ids = append([][]byte{id}, ids...) + + err = ctx.ImportCertificate(id, cert) + require.NoError(t, err) + + originCertChain = append([]*x509.Certificate{cert}, originCertChain...) + + parent = cert + authorityKeyId = subjectKeyID + } + + foundCertChain, err := ctx.FindCertificateChain(ids[0], nil, nil) + require.NoError(t, err) + require.NotNil(t, foundCertChain) + + assert.Equal(t, len(foundCertChain), len(originCertChain)) + + for i := 0; i < len(foundCertChain); i++ { + assert.Equal(t, foundCertChain[i].Signature, originCertChain[i].Signature) + } + + err = ctx.DeleteCertificate(ids[len(ids)-1], nil, nil) + require.NoError(t, err) + + oldCert := originCertChain[len(originCertChain)-1] + newCert := generateRandomCert(t, nil, "NewCert", oldCert.AuthorityKeyId, oldCert.SubjectKeyId) + + originCertChain[len(originCertChain)-1] = newCert + + id := randomBytes() + + err = ctx.ImportCertificate(id, newCert) + require.NoError(t, err) + + ids[len(ids)-1] = id + + foundCertChain, err = ctx.FindCertificateChain(ids[0], nil, nil) + require.NoError(t, err) + require.NotNil(t, foundCertChain) + + assert.Equal(t, len(foundCertChain), len(originCertChain)) + + for i := 0; i < len(foundCertChain); i++ { + assert.Equal(t, foundCertChain[i].Signature, originCertChain[i].Signature) + } + + for _, id := range ids { + err = ctx.DeleteCertificate(id, nil, nil) + require.NoError(t, err) + } + + foundCertChain, err = ctx.FindCertificateChain([]byte("test2"), nil, nil) + require.NoError(t, err) + assert.Nil(t, foundCertChain) +} + +func generateRandomCert(t *testing.T, parent *x509.Certificate, commonName string, + authorityKeyId, subjectKeyID []byte) *x509.Certificate { serial, err := rand.Int(rand.Reader, big.NewInt(20000)) require.NoError(t, err) - ca := &x509.Certificate{ + template := &x509.Certificate{ Subject: pkix.Name{ - CommonName: "Foo", + CommonName: commonName, }, + AuthorityKeyId: authorityKeyId, + SubjectKeyId: subjectKeyID, SerialNumber: serial, NotAfter: time.Now().Add(365 * 24 * time.Hour), IsCA: true, @@ -223,11 +307,15 @@ func generateRandomCert(t *testing.T) *x509.Certificate { BasicConstraintsValid: true, } + if parent == nil { + parent = template + } + key, err := rsa.GenerateKey(rand.Reader, 4096) require.NoError(t, err) csr := &key.PublicKey - certBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, csr, key) + certBytes, err := x509.CreateCertificate(rand.Reader, template, parent, csr, key) require.NoError(t, err) cert, err := x509.ParseCertificate(certBytes) diff --git a/close_test.go b/close_test.go index c1acfc9..27761b8 100644 --- a/close_test.go +++ b/close_test.go @@ -85,7 +85,7 @@ func TestErrorAfterClosed(t *testing.T) { _, err = ctx.NewRandomReader() assert.Equal(t, errClosed, err) - cert := generateRandomCert(t) + cert := generateRandomCert(t, nil, "Foo", nil, nil) err = ctx.ImportCertificate(bytes, cert) assert.Equal(t, errClosed, err)