Skip to content

Commit

Permalink
Implement FindCertificateChain function to find certificate chain
Browse files Browse the repository at this point in the history
Certificate chain is found by following algorithm:

* find first certificate either by id or/and label or/and serial (same as
  existing FindCertificate does);
* if issuer is not nil, find next certificate by CKA_SUBJECT (issuer should
  be equal subject);
* if certificate with required subject not found then read all certificates
  and try to find next certificate by AuthorityKeyId (AuthorityKeyId should
  be equal to SubjectKeyId);
* finding stops if last found certificate is selfsigned (issuer is nil or
  equals to subject).

Signed-off-by: Oleksandr Grytsov <[email protected]>
  • Loading branch information
al1img committed Feb 4, 2022
1 parent 1e36df6 commit c72fb44
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 9 deletions.
98 changes: 98 additions & 0 deletions certificates.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
package crypto11

import (
"bytes"
"crypto/tls"
"crypto/x509"
"encoding/asn1"
Expand Down Expand Up @@ -111,6 +112,67 @@ func findCertificatesWithAttributes(session *pkcs11Session, template []*pkcs11.A
return handles, nil
}

func findCertificateByKeyID(session *pkcs11Session, keyID []byte) (cert *x509.Certificate, err error) {
handles, err := findCertificatesWithAttributes(session, nil)
if err != nil {
return nil, err
}

for _, handle := range handles {
if cert, err = getX509Certificate(session, handle); err != nil {
return nil, err
}

if bytes.Equal(cert.SubjectKeyId, keyID) {
return cert, nil
}
}

return nil, errors.New("no certificate with required subject key ID found")
}

func findCertificateChain(session *pkcs11Session, cert *x509.Certificate) (certs []*x509.Certificate, err error) {
if len(cert.RawIssuer) == 0 || bytes.Equal(cert.RawIssuer, cert.RawSubject) {
return nil, nil
}

template := []*pkcs11.Attribute{pkcs11.NewAttribute(pkcs11.CKA_SUBJECT, cert.RawIssuer)}

handles, err := findCertificatesWithAttributes(session, template)
if err != nil {
return nil, err
}

if len(handles) == 0 {
if cert, err = findCertificateByKeyID(session, cert.AuthorityKeyId); err != nil {
return nil, err
}
} else {
if cert, err = getX509Certificate(session, handles[0]); err != nil {
return nil, err
}
}

for _, foundCert := range certs {
if bytes.Equal(cert.RawSubject, foundCert.RawSubject) {
return certs, nil
}
}

certs = append(certs, cert)

certChain, err := findCertificateChain(session, cert)
if err != nil {
return nil, err
}

if len(certChain) != 0 {
certs = append(certs, certChain...)
}

return certs, nil
}

// FindCertificate retrieves a previously imported certificate. Any combination of id, label
// and serial can be provided. An error is return if all are nil.
func (c *Context) FindCertificate(id []byte, label []byte, serial *big.Int) (*x509.Certificate, error) {
Expand All @@ -128,6 +190,42 @@ func (c *Context) FindCertificate(id []byte, label []byte, serial *big.Int) (*x5
return cert, err
}

// FindCertificateChain retrieves a previously imported certificate chain. Any combination of id, label
// and serial can be provided. An error is return if all are nil.
func (c *Context) FindCertificateChain(id []byte, label []byte, serial *big.Int) (certs []*x509.Certificate, err error) {
if c.closed.Get() {
return nil, errClosed
}

err = c.withSession(func(session *pkcs11Session) (err error) {
cert, err := findCertificate(session, id, label, serial)
if err != nil {
return err
}

if cert == nil {
return nil
}

certs = append(certs, cert)

certChain, err := findCertificateChain(session, cert)
if err != nil {
return err
}

if len(certChain) == 0 {
return nil
}

certs = append(certs, certChain...)

return nil
})

return certs, err
}

func (c *Context) FindAllPairedCertificates() (certificates []tls.Certificate, err error) {
if c.closed.Get() {
return nil, errClosed
Expand Down
104 changes: 96 additions & 8 deletions certificates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func TestCertificate(t *testing.T) {
id := randomBytes()
label := randomBytes()

cert := generateRandomCert(t)
cert := generateRandomCert(t, nil, "Foo", nil, nil)

err = ctx.ImportCertificateWithLabel(id, label, cert)
require.NoError(t, err)
Expand Down Expand Up @@ -81,7 +81,7 @@ func TestCertificateAttributes(t *testing.T) {
require.NoError(t, ctx.Close())
}()

cert := generateRandomCert(t)
cert := generateRandomCert(t, nil, "Foo", nil, nil)

// We import this with a different serial number, to test this is obeyed
ourSerial := new(big.Int)
Expand Down Expand Up @@ -116,7 +116,7 @@ func TestCertificateRequiredArgs(t *testing.T) {
require.NoError(t, ctx.Close())
}()

cert := generateRandomCert(t)
cert := generateRandomCert(t, nil, "Foo", nil, nil)

val := randomBytes()

Expand All @@ -143,7 +143,7 @@ func TestDeleteCertificate(t *testing.T) {
randomCert := func() ([]byte, []byte, *x509.Certificate) {
id := randomBytes()
label := randomBytes()
cert := generateRandomCert(t)
cert := generateRandomCert(t, nil, "Foo", nil, nil)
return id, label, cert
}
importCertificate := func() ([]byte, []byte, *big.Int) {
Expand Down Expand Up @@ -207,14 +207,98 @@ func TestDeleteCertificate(t *testing.T) {
require.Nil(t, cert)
}

func generateRandomCert(t *testing.T) *x509.Certificate {
func TestCertificateChain(t *testing.T) {
skipTest(t, skipTestCert)

ctx, err := ConfigureFromFile("config")
require.NoError(t, err)

defer func() {
require.NoError(t, ctx.Close())
}()

certNames := []string{"Cert0", "Cert1", "Cert2"}

var (
parent *x509.Certificate
originCertChain []*x509.Certificate
authorityKeyId, subjectKeyID []byte
ids [][]byte
)

for _, name := range certNames {
subjectKeyID = randomBytes()

cert := generateRandomCert(t, parent, name, authorityKeyId, subjectKeyID)

id := randomBytes()
ids = append([][]byte{id}, ids...)

err = ctx.ImportCertificate(id, cert)
require.NoError(t, err)

originCertChain = append([]*x509.Certificate{cert}, originCertChain...)

parent = cert
authorityKeyId = subjectKeyID
}

foundCertChain, err := ctx.FindCertificateChain(ids[0], nil, nil)
require.NoError(t, err)
require.NotNil(t, foundCertChain)

assert.Equal(t, len(foundCertChain), len(originCertChain))

for i := 0; i < len(foundCertChain); i++ {
assert.Equal(t, foundCertChain[i].Signature, originCertChain[i].Signature)
}

err = ctx.DeleteCertificate(ids[len(ids)-1], nil, nil)
require.NoError(t, err)

oldCert := originCertChain[len(originCertChain)-1]
newCert := generateRandomCert(t, nil, "NewCert", oldCert.AuthorityKeyId, oldCert.SubjectKeyId)

originCertChain[len(originCertChain)-1] = newCert

id := randomBytes()

err = ctx.ImportCertificate(id, newCert)
require.NoError(t, err)

ids[len(ids)-1] = id

foundCertChain, err = ctx.FindCertificateChain(ids[0], nil, nil)
require.NoError(t, err)
require.NotNil(t, foundCertChain)

assert.Equal(t, len(foundCertChain), len(originCertChain))

for i := 0; i < len(foundCertChain); i++ {
assert.Equal(t, foundCertChain[i].Signature, originCertChain[i].Signature)
}

for _, id := range ids {
err = ctx.DeleteCertificate(id, nil, nil)
require.NoError(t, err)
}

foundCertChain, err = ctx.FindCertificateChain([]byte("test2"), nil, nil)
require.NoError(t, err)
assert.Nil(t, foundCertChain)
}

func generateRandomCert(t *testing.T, parent *x509.Certificate, commonName string,
authorityKeyId, subjectKeyID []byte) *x509.Certificate {
serial, err := rand.Int(rand.Reader, big.NewInt(20000))
require.NoError(t, err)

ca := &x509.Certificate{
template := &x509.Certificate{
Subject: pkix.Name{
CommonName: "Foo",
CommonName: commonName,
},
AuthorityKeyId: authorityKeyId,
SubjectKeyId: subjectKeyID,
SerialNumber: serial,
NotAfter: time.Now().Add(365 * 24 * time.Hour),
IsCA: true,
Expand All @@ -223,11 +307,15 @@ func generateRandomCert(t *testing.T) *x509.Certificate {
BasicConstraintsValid: true,
}

if parent == nil {
parent = template
}

key, err := rsa.GenerateKey(rand.Reader, 4096)
require.NoError(t, err)

csr := &key.PublicKey
certBytes, err := x509.CreateCertificate(rand.Reader, ca, ca, csr, key)
certBytes, err := x509.CreateCertificate(rand.Reader, template, parent, csr, key)
require.NoError(t, err)

cert, err := x509.ParseCertificate(certBytes)
Expand Down
2 changes: 1 addition & 1 deletion close_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func TestErrorAfterClosed(t *testing.T) {
_, err = ctx.NewRandomReader()
assert.Equal(t, errClosed, err)

cert := generateRandomCert(t)
cert := generateRandomCert(t, nil, "Foo", nil, nil)

err = ctx.ImportCertificate(bytes, cert)
assert.Equal(t, errClosed, err)
Expand Down

0 comments on commit c72fb44

Please sign in to comment.