diff --git a/tpm/signer.go b/tpm/signer.go index 6036f2c2..9aa896fa 100644 --- a/tpm/signer.go +++ b/tpm/signer.go @@ -105,6 +105,7 @@ func (s *tss2Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) return nil, fmt.Errorf("failed opening TPM: %w", err) } defer closeTPM(ctx, s.tpm, &err) + s.SetTPM(s.tpm.rwc) signature, err = s.Signer.Sign(rand, digest, opts) return } diff --git a/tpm/tss2/signer.go b/tpm/tss2/signer.go index 15ce04ac..59861d82 100644 --- a/tpm/tss2/signer.go +++ b/tpm/tss2/signer.go @@ -26,6 +26,7 @@ import ( "fmt" "io" "math/big" + "sync" "github.com/google/go-tpm/legacy/tpm2" "github.com/google/go-tpm/tpmutil" @@ -80,6 +81,7 @@ func (k *TPMKey) Public() (crypto.PublicKey, error) { // Signer implements [crypto.Signer] using a [TPMKey]. type Signer struct { + m sync.Mutex rw io.ReadWriter publicKey crypto.PublicKey tpmKey *TPMKey @@ -134,13 +136,30 @@ func CreateSigner(rw io.ReadWriter, key *TPMKey) (*Signer, error) { // Notice: This API is EXPERIMENTAL and may be changed or removed in a later // release. func (s *Signer) SetSRKTemplate(p tpm2.Public) { + s.m.Lock() s.srkTemplate = p + s.m.Unlock() } +// SetTPM allows to change the TPM channel. This operation is useful if the +// channel set in [CreateSigner] is closed and opened again before calling [Signer.Sign]. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +func (s *Signer) SetTPM(rw io.ReadWriter) { + s.m.Lock() + s.rw = rw + s.m.Unlock() +} + +// Public implements the [crypto.Signer] interface. func (s *Signer) Public() crypto.PublicKey { return s.publicKey } +// Sign implements the [crypto.Signer] interface. func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { parentHandle := tpmutil.Handle(s.tpmKey.Parent) if !handleIsPersistent(s.tpmKey.Parent) { diff --git a/tpm/tss2/signer_test.go b/tpm/tss2/signer_test.go index 197636ef..18523958 100644 --- a/tpm/tss2/signer_test.go +++ b/tpm/tss2/signer_test.go @@ -4,13 +4,14 @@ import ( "bytes" "crypto" "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" "crypto/rsa" - "encoding/pem" "io" "testing" "github.com/google/go-tpm/legacy/tpm2" + "github.com/google/go-tpm/tpmutil" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -54,6 +55,10 @@ var defaultKeyParamsRSAPSS = tpm2.Public{ }, } +func assertMaybeError(t assert.TestingT, err error, msgAndArgs ...interface{}) bool { + return true +} + func TestSign(t *testing.T) { rw := openTPM(t) t.Cleanup(func() { @@ -67,15 +72,32 @@ func TestSign(t *testing.T) { }) tests := []struct { - name string - params tpm2.Public - opts crypto.SignerOpts + name string + params tpm2.Public + opts crypto.SignerOpts + assertion assert.ErrorAssertionFunc }{ - {"ok ECDSA", defaultKeyParamsEC, crypto.SHA256}, - {"ok RSA", defaultKeyParamsRSA, crypto.SHA256}, - {"ok RSAPSS", defaultKeyParamsRSAPSS, &rsa.PSSOptions{ + {"ok ECDSA", defaultKeyParamsEC, crypto.SHA256, assert.NoError}, + {"ok RSA", defaultKeyParamsRSA, crypto.SHA256, assert.NoError}, + {"ok RSAPSS PSSSaltLengthAuto", defaultKeyParamsRSAPSS, &rsa.PSSOptions{ SaltLength: rsa.PSSSaltLengthAuto, Hash: crypto.SHA256, - }}, + }, assert.NoError}, + {"ok RSAPSS PSSSaltLengthEqualsHash", defaultKeyParamsRSAPSS, &rsa.PSSOptions{ + SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA256, + }, assert.NoError}, + {"ok RSAPSS SaltLength=32", defaultKeyParamsRSAPSS, &rsa.PSSOptions{ + SaltLength: 32, Hash: crypto.SHA256, + }, assert.NoError}, + // 222 is the largest salt possible when signing with a 2048 bit key. Go + // crypto will use this value when rsa.PSSSaltLengthAuto is set. + // + // TPM 2.0's TPM_ALG_RSAPSS algorithm, uses the maximum possible salt + // length. However, as of TPM revision 1.16, TPMs which follow FIPS + // 186-4 will interpret TPM_ALG_RSAPSS using salt length equal to the + // digest length. + {"RSAPSS SaltLength=222", defaultKeyParamsRSAPSS, &rsa.PSSOptions{ + SaltLength: 222, Hash: crypto.SHA256, + }, assertMaybeError}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -85,7 +107,7 @@ func TestSign(t *testing.T) { signer, err := CreateSigner(rw, New(pub, priv)) require.NoError(t, err) - // Set the ECC SRK template used for testing. + // Set the ECC SRK template used for testing signer.SetSRKTemplate(ECCSRKTemplate) hash := crypto.SHA256.New() @@ -93,8 +115,12 @@ func TestSign(t *testing.T) { sum := hash.Sum(nil) sig, err := signer.Sign(rand.Reader, sum, tt.opts) - require.NoError(t, err) + tt.assertion(t, err) + if err != nil { + return + } + // Signature validation using Go crypto switch pub := signer.Public().(type) { case *ecdsa.PublicKey: assert.Equal(t, tpm2.AlgECC, tt.params.Type) @@ -105,7 +131,9 @@ func TestSign(t *testing.T) { case tpm2.AlgRSASSA: assert.NoError(t, rsa.VerifyPKCS1v15(pub, tt.opts.HashFunc(), sum, sig)) case tpm2.AlgRSAPSS: - assert.NoError(t, rsa.VerifyPSS(pub, crypto.SHA256, sum, sig, nil)) + opts, ok := tt.opts.(*rsa.PSSOptions) + require.True(t, ok) + assert.NoError(t, rsa.VerifyPSS(pub, opts.Hash, sum, sig, opts)) default: t.Errorf("unexpected RSAParameters.Sign.Alg %v", tt.params.RSAParameters.Sign.Alg) } @@ -116,12 +144,53 @@ func TestSign(t *testing.T) { } } -func TestCreateSigner(t *testing.T) { - parsePEM := func(s string) []byte { - block, _ := pem.Decode([]byte(s)) - return block.Bytes - } +func TestSign_SetTPM(t *testing.T) { + var signer *Signer + + t.Run("Setup", func(t *testing.T) { + rw := openTPM(t) + t.Cleanup(func() { + assert.NoError(t, rw.Close()) + }) + keyHnd, _, err := tpm2.CreatePrimary(rw, tpm2.HandleOwner, tpm2.PCRSelection{}, "", "", ECCSRKTemplate) + require.NoError(t, err) + t.Cleanup(func() { + assert.NoError(t, tpm2.FlushContext(rw, keyHnd)) + }) + + priv, pub, _, _, _, err := tpm2.CreateKey(rw, keyHnd, tpm2.PCRSelection{}, "", "", defaultKeyParamsEC) + require.NoError(t, err) + + signer, err = CreateSigner(rw, New(pub, priv)) + require.NoError(t, err) + }) + + require.NotNil(t, signer) + + rw := openTPM(t) + t.Cleanup(func() { + assert.NoError(t, rw.Close()) + }) + // Set new tpm channel + signer.SetTPM(rw) + + // Set the ECC SRK template used for testing + signer.SetSRKTemplate(ECCSRKTemplate) + + hash := crypto.SHA256.New() + hash.Write([]byte("ungymnastic-theirn-cotwin-Summer-pemphigous-propagate")) + sum := hash.Sum(nil) + + sig, err := signer.Sign(rand.Reader, sum, crypto.SHA256) + require.NoError(t, err) + + publicKey, ok := signer.Public().(*ecdsa.PublicKey) + require.True(t, ok) + assert.True(t, ecdsa.VerifyASN1(publicKey, sum, sig)) +} + +func TestCreateSigner(t *testing.T) { var rw bytes.Buffer key, err := ParsePrivateKey(parsePEM(p256TSS2PEM)) require.NoError(t, err) @@ -202,3 +271,106 @@ func TestCreateSigner(t *testing.T) { }) } } + +func Test_curveSigScheme(t *testing.T) { + type args struct { + curve elliptic.Curve + } + tests := []struct { + name string + args args + want *tpm2.SigScheme + assertion assert.ErrorAssertionFunc + }{ + {"ok P-256", args{elliptic.P256()}, &tpm2.SigScheme{ + Alg: tpm2.AlgECDSA, Hash: tpm2.AlgSHA256, + }, assert.NoError}, + {"ok P-2384", args{elliptic.P384()}, &tpm2.SigScheme{ + Alg: tpm2.AlgECDSA, Hash: tpm2.AlgSHA384, + }, assert.NoError}, + {"ok P-521", args{elliptic.P521()}, &tpm2.SigScheme{ + Alg: tpm2.AlgECDSA, Hash: tpm2.AlgSHA512, + }, assert.NoError}, + {"fail P-224", args{elliptic.P224()}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := curveSigScheme(tt.args.curve) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_signECDSA_fail(t *testing.T) { + rw := openTPM(t) + t.Cleanup(func() { + assert.NoError(t, rw.Close()) + }) + + digest := func(h crypto.Hash) []byte { + hh := h.New() + hh.Write([]byte("Subotica-chronique-radiancy-inspirationally-transuming-Melbeta")) + return hh.Sum(nil) + } + + type args struct { + rw io.ReadWriter + key tpmutil.Handle + digest []byte + curve elliptic.Curve + } + tests := []struct { + name string + args args + want []byte + assertion assert.ErrorAssertionFunc + }{ + {"fail curve", args{rw, handleOwner, digest(crypto.SHA224), elliptic.P224()}, nil, assert.Error}, + {"fail sign", args{nil, handleOwner, digest(crypto.SHA256), elliptic.P256()}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := signECDSA(tt.args.rw, tt.args.key, tt.args.digest, tt.args.curve) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_signRSA_fail(t *testing.T) { + rw := openTPM(t) + t.Cleanup(func() { + assert.NoError(t, rw.Close()) + }) + + h := crypto.SHA256.New() + h.Write([]byte("murmur-squinance-hoghide-jubilation-enteraden-samadh")) + digest := h.Sum(nil) + + type args struct { + rw io.ReadWriter + key tpmutil.Handle + digest []byte + opts crypto.SignerOpts + } + tests := []struct { + name string + args args + want []byte + assertion assert.ErrorAssertionFunc + }{ + {"fail HashToAlgorithm", args{rw, handleOwner, digest, crypto.SHA224}, nil, assert.Error}, + {"fail PSSOptions", args{rw, handleOwner, digest, &rsa.PSSOptions{ + Hash: crypto.SHA256, SaltLength: 222, + }}, nil, assert.Error}, + {"fail sign", args{nil, handleOwner, digest, crypto.SHA256}, nil, assert.Error}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := signRSA(tt.args.rw, tt.args.key, tt.args.digest, tt.args.opts) + tt.assertion(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/tpm/tss2/simulator_test.go b/tpm/tss2/simulator_test.go index 72fb697d..6028934b 100644 --- a/tpm/tss2/simulator_test.go +++ b/tpm/tss2/simulator_test.go @@ -3,6 +3,8 @@ package tss2 import ( + "crypto/rand" + "encoding/hex" "io" "testing" @@ -10,10 +12,20 @@ import ( "go.step.sm/crypto/tpm/simulator" ) +var seed string + +func init() { + b := make([]byte, 8) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + panic(err) + } + seed = hex.EncodeToString(b) +} + func openTPM(t *testing.T) io.ReadWriteCloser { t.Helper() - sim, err := simulator.New() + sim, err := simulator.New(simulator.WithSeed(seed)) require.NoError(t, err) require.NoError(t, sim.Open()) return sim diff --git a/tpm/tss2/tss2_test.go b/tpm/tss2/tss2_test.go index 185b46f4..9a8cf1b6 100644 --- a/tpm/tss2/tss2_test.go +++ b/tpm/tss2/tss2_test.go @@ -53,12 +53,12 @@ ePVypgEUeJGw68er7UZb4ZSVfoGId6KLX9JE7IwyBkRWLhBU3sLANdgjTqlXUhAD mnYo -----END TSS2 PRIVATE KEY-----` -func TestParsePrivateKey(t *testing.T) { - parsePEM := func(s string) []byte { - block, _ := pem.Decode([]byte(s)) - return block.Bytes - } +func parsePEM(s string) []byte { + block, _ := pem.Decode([]byte(s)) + return block.Bytes +} +func TestParsePrivateKey(t *testing.T) { type args struct { derBytes []byte }