diff --git a/kms/pkcs11/pkcs11.go b/kms/pkcs11/pkcs11.go index da80be63..250d26b0 100644 --- a/kms/pkcs11/pkcs11.go +++ b/kms/pkcs11/pkcs11.go @@ -11,13 +11,16 @@ import ( "crypto/x509" "encoding/hex" "fmt" + "io" "math/big" "runtime" "strconv" "sync" "github.com/ThalesIgnite/crypto11" + "github.com/miekg/pkcs11" "github.com/pkg/errors" + "golang.org/x/sync/singleflight" "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/uri" @@ -50,6 +53,7 @@ var p11Configure = func(config *crypto11.Config) (P11, error) { type PKCS11 struct { p11 P11 closed sync.Once + config crypto11.Config } // New returns a new PKCS#11 KMS. To initialize it, you need to provide a URI @@ -139,7 +143,8 @@ func New(_ context.Context, opts apiv1.Options) (*PKCS11, error) { } return &PKCS11{ - p11: p11, + p11: p11, + config: config, }, nil } @@ -207,7 +212,95 @@ func (k *PKCS11) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, er return nil, errors.Wrap(err, "createSigner failed") } - return signer, nil + return &reconnectSigner{ + Signer: signer, + kms: k, + signingKey: req.SigningKey, + }, nil +} + +// reconnectSigner is a crypto.Signer that reconfigures the PKCS#11 session on +// specific errors. The locking mechanism does not avoid that concurrent signs +// might reconnect the connection several times. +type reconnectSigner struct { + crypto.Signer + rw sync.RWMutex + sf singleflight.Group + kms *PKCS11 + signingKey string +} + +// Sign calls the crypto.Signer Sign method and attempts to reconfigure the +// connection if it the PKCS#11 session fails. +func (s *reconnectSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + signature, err := s.sign(rand, digest, opts) + if err == nil { + return signature, nil + } + + // Reconnect with errors from functions that use a session handle: + // - CKR_DEVICE_REMOVED: The token was removed from its slot during the + // execution of the function. Utimaco simulator fails with this code if + // we keep more than 256 connections open. + // - CKR_SESSION_HANDLE_INVALID: The specified session handle was invalid + // at the time that the function was invoked. + // - CKR_SESSION_CLOSED: The session was closed during the execution of the + // function. + var p11Err pkcs11.Error + if !errors.As(err, &p11Err) { + return nil, err + } + switch p11Err { + case pkcs11.CKR_DEVICE_REMOVED, pkcs11.CKR_SESSION_HANDLE_INVALID, pkcs11.CKR_SESSION_CLOSED: + default: + return nil, err + } + + // Reconnect and prepare signer + if err := s.reconnect(); err != nil { + return nil, err + } + + // Sign again + return s.sign(rand, digest, opts) +} + +// sign signs the digest using the PKCS#11 module. The sign is protected by a +// RWMutex that will be locked if there is we need to reconnect. +func (s *reconnectSigner) sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) ([]byte, error) { + s.rw.RLock() + defer s.rw.RUnlock() + return s.Signer.Sign(rand, digest, opts) +} + +// reconnect closes the current connection to the module and reconnects again. +// It uses singleflight to avoid simultaneous attempts, locks the signer RWMutex +// to avoid new signs while reconnecting and reconfigures the signer. +func (s *reconnectSigner) reconnect() error { + _, err, _ := s.sf.Do("all", func() (interface{}, error) { + s.rw.Lock() + defer s.rw.Unlock() + if err := s.kms.unsafeReconnect(); err != nil { + return nil, err + } + signer, err := findSigner(s.kms.p11, s.signingKey) + if err != nil { + return nil, err + } + s.Signer = signer + return nil, nil //nolint:nilnil // there's nothing to return + }) + return err +} + +func (k *PKCS11) unsafeReconnect() error { + _ = k.p11.Close() + p11, err := p11Configure(&k.config) + if err != nil { + return err + } + k.p11 = p11 + return nil } // CreateDecrypter creates a decrypter using a key present in the PKCS#11 diff --git a/kms/pkcs11/pkcs11_test.go b/kms/pkcs11/pkcs11_test.go index 7684e96b..b89cffe9 100644 --- a/kms/pkcs11/pkcs11_test.go +++ b/kms/pkcs11/pkcs11_test.go @@ -12,13 +12,18 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509" + "fmt" + "io" "math/big" "reflect" "strings" "testing" "github.com/ThalesIgnite/crypto11" + "github.com/miekg/pkcs11" "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "go.step.sm/crypto/kms/apiv1" "golang.org/x/crypto/cryptobyte" "golang.org/x/crypto/cryptobyte/asn1" @@ -42,6 +47,15 @@ func TestNew(t *testing.T) { return k.p11, nil } + withConfig := func(c crypto11.Config) *PKCS11 { + return &PKCS11{ + p11: k.p11, + config: c, + } + } + + var slot int + type args struct { ctx context.Context opts apiv1.Options @@ -55,60 +69,88 @@ func TestNew(t *testing.T) { {"ok", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test?pin-value=password", - }}, k, false}, + }}, withConfig(crypto11.Config{ + Path: "/usr/local/lib/softhsm/libsofthsm2.so", + TokenLabel: "pkcs11-test", + Pin: "password", + }), false}, {"ok with serial", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;serial=0123456789?pin-value=password", - }}, k, false}, + }}, withConfig(crypto11.Config{ + Path: "/usr/local/lib/softhsm/libsofthsm2.so", + TokenSerial: "0123456789", + Pin: "password", + }), false}, {"ok with slot-id", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;slot-id=0?pin-value=password", - }}, k, false}, + }}, withConfig(crypto11.Config{ + Path: "/usr/local/lib/softhsm/libsofthsm2.so", + SlotNumber: &slot, + Pin: "password", + }), false}, {"ok with max-sessions", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;max-sessions=100?pin-value=password", - }}, k, false}, + }}, withConfig(crypto11.Config{ + Path: "/usr/local/lib/softhsm/libsofthsm2.so", + TokenLabel: "pkcs11-test", + Pin: "password", + MaxSessions: 100, + }), false}, {"ok with pin", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test", - Pin: "passowrd", - }}, k, false}, + Pin: "password", + }}, withConfig(crypto11.Config{ + Path: "/usr/local/lib/softhsm/libsofthsm2.so", + TokenLabel: "pkcs11-test", + Pin: "password", + }), false}, {"ok no pin", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test", - }}, k, false}, + }}, withConfig(crypto11.Config{ + Path: "/usr/local/lib/softhsm/libsofthsm2.so", + TokenLabel: "pkcs11-test", + }), false}, {"ok with missing module", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:token=pkcs11-test", - Pin: "passowrd", - }}, k, false}, + Pin: "password", + }}, withConfig(crypto11.Config{ + Path: defaultModule, + TokenLabel: "pkcs11-test", + Pin: "password", + }), false}, {"fail missing uri", args{context.Background(), apiv1.Options{ Type: "pkcs11", }}, nil, true}, {"fail missing token/serial/slot-id", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so", - Pin: "passowrd", + Pin: "password", }}, nil, true}, {"fail token+serial+slot-id", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;serial=0123456789;slot-id=0", - Pin: "passowrd", + Pin: "password", }}, nil, true}, {"fail token+serial", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;serial=0123456789", - Pin: "passowrd", + Pin: "password", }}, nil, true}, {"fail token+slot-id", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;token=pkcs11-test;slot-id=0", - Pin: "passowrd", + Pin: "password", }}, nil, true}, {"fail serial+slot-id", args{context.Background(), apiv1.Options{ Type: "pkcs11", URI: "pkcs11:module-path=/usr/local/lib/softhsm/libsofthsm2.so;serial=0123456789;slot-id=0", - Pin: "passowrd", + Pin: "password", }}, nil, true}, {"fail slot-id", args{context.Background(), apiv1.Options{ Type: "pkcs11", @@ -134,9 +176,7 @@ func TestNew(t *testing.T) { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("New() = %v, want %v", got, tt.want) - } + assert.Equal(t, tt.want, got) }) } } @@ -506,13 +546,102 @@ func TestPKCS11_CreateSigner(t *testing.T) { default: t.Errorf("signature algorithm %s is not supported", tt.algorithm) } - } - }) } } +type badSigner struct { + err error +} + +func (badSigner) Public() crypto.PublicKey { + return []byte("foo") +} + +func (p badSigner) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { + return nil, p.err +} + +func Test_reconnectSigner(t *testing.T) { + if testModule != "Golang crypto" { + t.Skipf("skipping test with %s", testModule) + } + + tmp0 := p11Configure + t.Cleanup(func() { + p11Configure = tmp0 + }) + + k := setupPKCS11(t) + + p11Configure = func(config *crypto11.Config) (P11, error) { + if strings.Contains(config.Path, "fail") { + return nil, errors.New("an error") + } + return k.p11, nil + } + + signer, err := k.CreateSigner(&apiv1.CreateSignerRequest{ + SigningKey: "pkcs11:id=7373;object=ecdsa-p256-key", + }) + require.NoError(t, err) + require.IsType(t, &ecdsa.PublicKey{}, signer.Public()) + require.IsType(t, &reconnectSigner{}, signer) + + h := crypto.SHA256.New() + h.Write([]byte("buggy-coheir-RUBRIC-rabbet-liberal-eaglet-khartoum-stagger")) + digest := h.Sum(nil) + + signature, err := signer.Sign(rand.Reader, digest, crypto.SHA256) + assert.NoError(t, err) + assert.True(t, ecdsa.VerifyASN1(signer.Public().(*ecdsa.PublicKey), digest, signature)) + + rcs := signer.(*reconnectSigner) + t.Run("pkcs11 error and reconnect", func(t *testing.T) { + for _, v := range []int{ + pkcs11.CKR_DEVICE_REMOVED, + pkcs11.CKR_SESSION_HANDLE_INVALID, + pkcs11.CKR_SESSION_CLOSED, + } { + rcs.Signer = badSigner{err: pkcs11.Error(v)} + signature, err = rcs.Sign(rand.Reader, digest, crypto.SHA256) + fmt.Printf("%#v\n", rcs.kms.config) + assert.NoError(t, err) + assert.True(t, ecdsa.VerifyASN1(signer.Public().(*ecdsa.PublicKey), digest, signature)) + } + }) + + t.Run("no reconnect errors", func(t *testing.T) { + for _, err := range []error{ + errors.New("some error"), + pkcs11.Error(pkcs11.CKR_DATA_INVALID), + } { + rcs.Signer = badSigner{err: err} + signature, err = rcs.Sign(rand.Reader, digest, crypto.SHA256) + assert.Error(t, err) + assert.Nil(t, signature) + } + }) + + t.Run("find signer error", func(t *testing.T) { + rcs.Signer = badSigner{err: pkcs11.Error(pkcs11.CKR_DEVICE_REMOVED)} + rcs.signingKey = "pkcs11:id=fail" + signature, err = rcs.Sign(rand.Reader, digest, crypto.SHA256) + assert.Error(t, err) + assert.Nil(t, signature) + }) + + t.Run("reconnect error", func(t *testing.T) { + rcs.Signer = badSigner{err: pkcs11.Error(pkcs11.CKR_DEVICE_REMOVED)} + rcs.signingKey = "pkcs11:id=7373;object=ecdsa-p256-key" + rcs.kms.config.Path = "/usr/local/lib/fail.so" + signature, err = rcs.Sign(rand.Reader, digest, crypto.SHA256) + assert.Error(t, err) + assert.Nil(t, signature) + }) +} + func TestPKCS11_CreateDecrypter(t *testing.T) { k := setupPKCS11(t) data := []byte("buggy-coheir-RUBRIC-rabbet-liberal-eaglet-khartoum-stagger")