Skip to content

Commit

Permalink
Allow to set the TPM channel to the tss2.Signer
Browse files Browse the repository at this point in the history
This commit allows to set the TPM channel to the tss2.Signer if this was
already closed.
  • Loading branch information
maraino committed Nov 3, 2023
1 parent eb54143 commit 5e540c7
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 22 deletions.
1 change: 1 addition & 0 deletions tpm/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ func (s *tss2Signer) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts)
return nil, fmt.Errorf("failed opening TPM: %w", err)
}
defer closeTPM(ctx, s.tpm, &err)
s.SetTPM(s.tpm.rwc)
signature, err = s.Signer.Sign(rand, digest, opts)
return
}
Expand Down
19 changes: 19 additions & 0 deletions tpm/tss2/signer.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"fmt"
"io"
"math/big"
"sync"

"github.com/google/go-tpm/legacy/tpm2"
"github.com/google/go-tpm/tpmutil"
Expand Down Expand Up @@ -80,6 +81,7 @@ func (k *TPMKey) Public() (crypto.PublicKey, error) {

// Signer implements [crypto.Signer] using a [TPMKey].
type Signer struct {
m sync.Mutex
rw io.ReadWriter
publicKey crypto.PublicKey
tpmKey *TPMKey
Expand Down Expand Up @@ -134,13 +136,30 @@ func CreateSigner(rw io.ReadWriter, key *TPMKey) (*Signer, error) {
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// release.
func (s *Signer) SetSRKTemplate(p tpm2.Public) {
s.m.Lock()
s.srkTemplate = p
s.m.Unlock()
}

// SetTPM allows to change the TPM channel. This operation is useful if the
// channel set in [CreateSigner] is closed and opened again before calling [Signer.Sign].
//
// # Experimental
//
// Notice: This API is EXPERIMENTAL and may be changed or removed in a later
// release.
func (s *Signer) SetTPM(rw io.ReadWriter) {
s.m.Lock()
s.rw = rw
s.m.Unlock()
}

// Public implements the [crypto.Signer] interface.
func (s *Signer) Public() crypto.PublicKey {
return s.publicKey
}

// Sign implements the [crypto.Signer] interface.
func (s *Signer) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) {
parentHandle := tpmutil.Handle(s.tpmKey.Parent)
if !handleIsPersistent(s.tpmKey.Parent) {
Expand Down
204 changes: 188 additions & 16 deletions tpm/tss2/signer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import (
"bytes"
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/rsa"
"encoding/pem"
"io"
"testing"

"github.com/google/go-tpm/legacy/tpm2"
"github.com/google/go-tpm/tpmutil"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
Expand Down Expand Up @@ -54,6 +55,10 @@ var defaultKeyParamsRSAPSS = tpm2.Public{
},
}

func assertMaybeError(t assert.TestingT, err error, msgAndArgs ...interface{}) bool {
return true
}

func TestSign(t *testing.T) {
rw := openTPM(t)
t.Cleanup(func() {
Expand All @@ -67,15 +72,32 @@ func TestSign(t *testing.T) {
})

tests := []struct {
name string
params tpm2.Public
opts crypto.SignerOpts
name string
params tpm2.Public
opts crypto.SignerOpts
assertion assert.ErrorAssertionFunc
}{
{"ok ECDSA", defaultKeyParamsEC, crypto.SHA256},
{"ok RSA", defaultKeyParamsRSA, crypto.SHA256},
{"ok RSAPSS", defaultKeyParamsRSAPSS, &rsa.PSSOptions{
{"ok ECDSA", defaultKeyParamsEC, crypto.SHA256, assert.NoError},
{"ok RSA", defaultKeyParamsRSA, crypto.SHA256, assert.NoError},
{"ok RSAPSS PSSSaltLengthAuto", defaultKeyParamsRSAPSS, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthAuto, Hash: crypto.SHA256,
}},
}, assert.NoError},
{"ok RSAPSS PSSSaltLengthEqualsHash", defaultKeyParamsRSAPSS, &rsa.PSSOptions{
SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: crypto.SHA256,
}, assert.NoError},
{"ok RSAPSS SaltLength=32", defaultKeyParamsRSAPSS, &rsa.PSSOptions{
SaltLength: 32, Hash: crypto.SHA256,
}, assert.NoError},
// 222 is the largest salt possible when signing with a 2048 bit key. Go
// crypto will use this value when rsa.PSSSaltLengthAuto is set.
//
// TPM 2.0's TPM_ALG_RSAPSS algorithm, uses the maximum possible salt
// length. However, as of TPM revision 1.16, TPMs which follow FIPS
// 186-4 will interpret TPM_ALG_RSAPSS using salt length equal to the
// digest length.
{"RSAPSS SaltLength=222", defaultKeyParamsRSAPSS, &rsa.PSSOptions{
SaltLength: 222, Hash: crypto.SHA256,
}, assertMaybeError},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
Expand All @@ -85,16 +107,20 @@ func TestSign(t *testing.T) {
signer, err := CreateSigner(rw, New(pub, priv))
require.NoError(t, err)

// Set the ECC SRK template used for testing.
// Set the ECC SRK template used for testing
signer.SetSRKTemplate(ECCSRKTemplate)

hash := crypto.SHA256.New()
hash.Write([]byte("rulingly-quailed-cloacal-indifferentist-roughhoused-self-mad"))
sum := hash.Sum(nil)

sig, err := signer.Sign(rand.Reader, sum, tt.opts)
require.NoError(t, err)
tt.assertion(t, err)
if err != nil {
return
}

// Signature validation using Go crypto
switch pub := signer.Public().(type) {
case *ecdsa.PublicKey:
assert.Equal(t, tpm2.AlgECC, tt.params.Type)
Expand All @@ -105,7 +131,9 @@ func TestSign(t *testing.T) {
case tpm2.AlgRSASSA:
assert.NoError(t, rsa.VerifyPKCS1v15(pub, tt.opts.HashFunc(), sum, sig))
case tpm2.AlgRSAPSS:
assert.NoError(t, rsa.VerifyPSS(pub, crypto.SHA256, sum, sig, nil))
opts, ok := tt.opts.(*rsa.PSSOptions)
require.True(t, ok)
assert.NoError(t, rsa.VerifyPSS(pub, opts.Hash, sum, sig, opts))
default:
t.Errorf("unexpected RSAParameters.Sign.Alg %v", tt.params.RSAParameters.Sign.Alg)
}
Expand All @@ -116,12 +144,53 @@ func TestSign(t *testing.T) {
}
}

func TestCreateSigner(t *testing.T) {
parsePEM := func(s string) []byte {
block, _ := pem.Decode([]byte(s))
return block.Bytes
}
func TestSign_SetTPM(t *testing.T) {
var signer *Signer

t.Run("Setup", func(t *testing.T) {
rw := openTPM(t)
t.Cleanup(func() {
assert.NoError(t, rw.Close())
})
keyHnd, _, err := tpm2.CreatePrimary(rw, tpm2.HandleOwner, tpm2.PCRSelection{}, "", "", ECCSRKTemplate)
require.NoError(t, err)
t.Cleanup(func() {
assert.NoError(t, tpm2.FlushContext(rw, keyHnd))
})

priv, pub, _, _, _, err := tpm2.CreateKey(rw, keyHnd, tpm2.PCRSelection{}, "", "", defaultKeyParamsEC)
require.NoError(t, err)

signer, err = CreateSigner(rw, New(pub, priv))
require.NoError(t, err)
})

require.NotNil(t, signer)

rw := openTPM(t)
t.Cleanup(func() {
assert.NoError(t, rw.Close())
})

// Set new tpm channel
signer.SetTPM(rw)

// Set the ECC SRK template used for testing
signer.SetSRKTemplate(ECCSRKTemplate)

hash := crypto.SHA256.New()
hash.Write([]byte("ungymnastic-theirn-cotwin-Summer-pemphigous-propagate"))
sum := hash.Sum(nil)

sig, err := signer.Sign(rand.Reader, sum, crypto.SHA256)
require.NoError(t, err)

publicKey, ok := signer.Public().(*ecdsa.PublicKey)
require.True(t, ok)
assert.True(t, ecdsa.VerifyASN1(publicKey, sum, sig))
}

func TestCreateSigner(t *testing.T) {
var rw bytes.Buffer
key, err := ParsePrivateKey(parsePEM(p256TSS2PEM))
require.NoError(t, err)
Expand Down Expand Up @@ -202,3 +271,106 @@ func TestCreateSigner(t *testing.T) {
})
}
}

func Test_curveSigScheme(t *testing.T) {
type args struct {
curve elliptic.Curve
}
tests := []struct {
name string
args args
want *tpm2.SigScheme
assertion assert.ErrorAssertionFunc
}{
{"ok P-256", args{elliptic.P256()}, &tpm2.SigScheme{
Alg: tpm2.AlgECDSA, Hash: tpm2.AlgSHA256,
}, assert.NoError},
{"ok P-2384", args{elliptic.P384()}, &tpm2.SigScheme{
Alg: tpm2.AlgECDSA, Hash: tpm2.AlgSHA384,
}, assert.NoError},
{"ok P-521", args{elliptic.P521()}, &tpm2.SigScheme{
Alg: tpm2.AlgECDSA, Hash: tpm2.AlgSHA512,
}, assert.NoError},
{"fail P-224", args{elliptic.P224()}, nil, assert.Error},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := curveSigScheme(tt.args.curve)
tt.assertion(t, err)
assert.Equal(t, tt.want, got)
})
}
}

func Test_signECDSA_fail(t *testing.T) {
rw := openTPM(t)
t.Cleanup(func() {
assert.NoError(t, rw.Close())
})

digest := func(h crypto.Hash) []byte {
hh := h.New()
hh.Write([]byte("Subotica-chronique-radiancy-inspirationally-transuming-Melbeta"))
return hh.Sum(nil)
}

type args struct {
rw io.ReadWriter
key tpmutil.Handle
digest []byte
curve elliptic.Curve
}
tests := []struct {
name string
args args
want []byte
assertion assert.ErrorAssertionFunc
}{
{"fail curve", args{rw, handleOwner, digest(crypto.SHA224), elliptic.P224()}, nil, assert.Error},
{"fail sign", args{nil, handleOwner, digest(crypto.SHA256), elliptic.P256()}, nil, assert.Error},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := signECDSA(tt.args.rw, tt.args.key, tt.args.digest, tt.args.curve)
tt.assertion(t, err)
assert.Equal(t, tt.want, got)
})
}
}

func Test_signRSA_fail(t *testing.T) {
rw := openTPM(t)
t.Cleanup(func() {
assert.NoError(t, rw.Close())
})

h := crypto.SHA256.New()
h.Write([]byte("murmur-squinance-hoghide-jubilation-enteraden-samadh"))
digest := h.Sum(nil)

type args struct {
rw io.ReadWriter
key tpmutil.Handle
digest []byte
opts crypto.SignerOpts
}
tests := []struct {
name string
args args
want []byte
assertion assert.ErrorAssertionFunc
}{
{"fail HashToAlgorithm", args{rw, handleOwner, digest, crypto.SHA224}, nil, assert.Error},
{"fail PSSOptions", args{rw, handleOwner, digest, &rsa.PSSOptions{
Hash: crypto.SHA256, SaltLength: 222,
}}, nil, assert.Error},
{"fail sign", args{nil, handleOwner, digest, crypto.SHA256}, nil, assert.Error},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := signRSA(tt.args.rw, tt.args.key, tt.args.digest, tt.args.opts)
tt.assertion(t, err)
assert.Equal(t, tt.want, got)
})
}
}
14 changes: 13 additions & 1 deletion tpm/tss2/simulator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,29 @@
package tss2

import (
"crypto/rand"
"encoding/hex"
"io"
"testing"

"github.com/stretchr/testify/require"
"go.step.sm/crypto/tpm/simulator"
)

var seed string

func init() {
b := make([]byte, 8)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
panic(err)
}
seed = hex.EncodeToString(b)
}

func openTPM(t *testing.T) io.ReadWriteCloser {
t.Helper()

sim, err := simulator.New()
sim, err := simulator.New(simulator.WithSeed(seed))
require.NoError(t, err)
require.NoError(t, sim.Open())
return sim
Expand Down
10 changes: 5 additions & 5 deletions tpm/tss2/tss2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ ePVypgEUeJGw68er7UZb4ZSVfoGId6KLX9JE7IwyBkRWLhBU3sLANdgjTqlXUhAD
mnYo
-----END TSS2 PRIVATE KEY-----`

func TestParsePrivateKey(t *testing.T) {
parsePEM := func(s string) []byte {
block, _ := pem.Decode([]byte(s))
return block.Bytes
}
func parsePEM(s string) []byte {
block, _ := pem.Decode([]byte(s))
return block.Bytes
}

func TestParsePrivateKey(t *testing.T) {
type args struct {
derBytes []byte
}
Expand Down

0 comments on commit 5e540c7

Please sign in to comment.