Skip to content

Commit

Permalink
Merge pull request #261 from smallstep/herman/more-tpm-improvements
Browse files Browse the repository at this point in the history
TPM improvements
  • Loading branch information
hslatman authored Jun 12, 2023
2 parents 1e0726a + faa54c3 commit 271865e
Show file tree
Hide file tree
Showing 9 changed files with 379 additions and 49 deletions.
36 changes: 25 additions & 11 deletions kms/kmsfs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ import (
"reflect"
"testing"

"github.com/stretchr/testify/assert"

"go.step.sm/crypto/kms/apiv1"
"go.step.sm/crypto/kms/softkms"
"go.step.sm/crypto/kms/tpmkms"
)

type fakeCM struct {
Expand Down Expand Up @@ -67,21 +70,27 @@ func Test_new(t *testing.T) {
{"ok softkms", args{ctx, "softkms:"}, &kmsfs{
KeyManager: &softkms.SoftKMS{},
}, false},
{"ok tpmkms", args{ctx, "tpmkms:"}, &kmsfs{
KeyManager: &tpmkms.TPMKMS{},
}, false},
{"fail", args{ctx, "fail:"}, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := newFS(tt.args.ctx, tt.args.kmsuri)
if (err != nil) != tt.wantErr {
t.Errorf("new() error = %v, wantErr %v", err, tt.wantErr)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("new() = %v, want %v", got, tt.want)
}
if err := got.Close(); err != nil {
t.Errorf("Close() error = %v, wantErr false", err)

assert.NoError(t, err)
if assert.NotNil(t, got) {
assert.IsType(t, tt.want, got)
}

err = got.Close()
assert.NoError(t, err)
})
}
}
Expand All @@ -103,6 +112,8 @@ func Test_kmsfs_getKMS(t *testing.T) {
{"ok empty", fields{nil}, args{""}, &softkms.SoftKMS{}, false},
{"ok softkms", fields{&softkms.SoftKMS{}}, args{""}, &softkms.SoftKMS{}, false},
{"ok softkms with uri", fields{nil}, args{"softkms:"}, &softkms.SoftKMS{}, false},
{"ok tpmkms", fields{&tpmkms.TPMKMS{}}, args{""}, &tpmkms.TPMKMS{}, false},
{"ok tpmkms with uri", fields{nil}, args{"tpmkms:"}, &tpmkms.TPMKMS{}, false},
{"fail", fields{nil}, args{"fail:"}, nil, true},
}
for _, tt := range tests {
Expand All @@ -111,12 +122,15 @@ func Test_kmsfs_getKMS(t *testing.T) {
KeyManager: tt.fields.KeyManager,
}
got, err := f.getKMS(tt.args.kmsuri)
if (err != nil) != tt.wantErr {
t.Errorf("kmsfs.getKMS() error = %v, wantErr %v", err, tt.wantErr)
if tt.wantErr {
assert.Error(t, err)
assert.Nil(t, got)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("kmsfs.getKMS() = %v, want %v", got, tt.want)

assert.NoError(t, err)
if assert.NotNil(t, got) {
assert.IsType(t, tt.want, got)
}
})
}
Expand Down
107 changes: 87 additions & 20 deletions kms/tpmkms/tpmkms.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,12 @@ const Scheme = string(apiv1.TPMKMS)

const (
// DefaultRSASize is the number of bits of a new RSA key if no size has been
// specified.
DefaultRSASize = 3072
// specified. Whereas we're generally defaulting to 3072 bits for new RSA keys,
// 2048 is used as the default for the TPMKMS, because we've observed the TPMs
// we're testing with to be supporting this as the maximum RSA key size. We might
// increase the default in the (near) future, but we want to be more confident
// about the supported size for a specific TPM (model) in that case.
DefaultRSASize = 2048
// defaultRSAAKSize is the default number of bits for a new RSA Attestation
// Key. It is currently set to 2048, because that's what's mentioned in the
// TCG TPM specification and is used by the AK template in `go-attestation`.
Expand Down Expand Up @@ -202,7 +206,7 @@ func (k *TPMKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons
return nil, fmt.Errorf("creating %d bit AKs is not supported; AKs must be RSA 2048 bits", req.Bits)
}

size := DefaultRSASize // defaults to 3072
size := DefaultRSASize // defaults to 2048
if req.Bits > 0 {
size = req.Bits
}
Expand Down Expand Up @@ -319,11 +323,19 @@ func (k *TPMKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey,
return nil, fmt.Errorf("failed parsing %q: %w", req.Name, err)
}

ctx := context.Background()
if properties.ak {
return nil, fmt.Errorf("retrieving AK public key currently not supported")
ak, err := k.tpm.GetAK(ctx, properties.name)
if err != nil {
return nil, err
}
akPub := ak.Public()
if akPub == nil {
return nil, errors.New("failed getting AK public key")
}
return akPub, nil
}

ctx := context.Background()
key, err := k.tpm.GetKey(ctx, properties.name)
if err != nil {
return nil, err
Expand All @@ -337,6 +349,7 @@ func (k *TPMKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey,
return signer.Public(), nil
}

// LoadCertificate loads the certificate for the key identified by name from the TPMKMS.
func (k *TPMKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) {
if req.Name == "" {
return nil, errors.New("loadCertificateRequest 'name' cannot be empty")
Expand All @@ -350,6 +363,8 @@ func (k *TPMKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certi
return chain[0], nil
}

// LoadCertificateCertificate loads the certificate chain for the key identified by
// name from the TPMKMS.
func (k *TPMKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) {
if req.Name == "" {
return nil, errors.New("loadCertificateChainRequest 'name' cannot be empty")
Expand All @@ -361,7 +376,7 @@ func (k *TPMKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([
}

ctx := context.Background()
var chain []*x509.Certificate // TODO(hs): support returning chain?
var chain []*x509.Certificate
if properties.ak {
ak, err := k.tpm.GetAK(ctx, properties.name)
if err != nil {
Expand All @@ -383,6 +398,7 @@ func (k *TPMKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([
return chain, nil
}

// StoreCertificate stores the certificate for the key identified by name to the TPMKMS.
func (k *TPMKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
switch {
case req.Name == "":
Expand All @@ -394,6 +410,7 @@ func (k *TPMKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error {
return k.StoreCertificateChain(&apiv1.StoreCertificateChainRequest{Name: req.Name, CertificateChain: []*x509.Certificate{req.Certificate}})
}

// StoreCertificateChain stores the certificate for the key identified by name to the TPMKMS.
func (k *TPMKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error {
switch {
case req.Name == "":
Expand Down Expand Up @@ -472,6 +489,27 @@ func (ac *attestationClient) Attest(ctx context.Context) ([]*x509.Certificate, e
return ac.c.Attest(ctx, ac.t, ac.ek, ac.ak)
}

// CreateAttestation implements the [apiv1.Attester] interface for the TPMKMS. It
// can be used to request the required information to verify that an application
// key was created in and by a specific TPM.
//
// It is expected that an application key has been attested at creation time by
// an attestation key (AK) before calling this method. An error will be returned
// otherwise.
//
// The response will include an attestation key (AK) certificate (chain) issued
// to the AK that was used to certify creation of the (application) key, as well
// as the key certification parameters at the time of key creation. Together these
// can be used by a relying party to attest that the key was created by a specific
// TPM.
//
// If no valid AK certificate is available when calling CreateAttestation, an
// enrolment with an instance of the Smallstep Attestation CA is performed. This
// will use the TPM Endorsement Key and the AK as inputs. The Attestation CA will
// return an AK certificate chain on success.
//
// When CreateAttestation is called for an AK, the AK certificate chain will be
// returned. Currently no AK creation parameters are returned.
func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.CreateAttestationResponse, error) {
if req.Name == "" {
return nil, errors.New("createAttestationRequest 'name' cannot be empty")
Expand All @@ -483,6 +521,48 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.
}

ctx := context.Background()
eks, err := k.tpm.GetEKs(ctx) // TODO(hs): control the EK used as the caller of this method?
if err != nil {
return nil, fmt.Errorf("failed getting EKs: %w", err)
}
ek := getPreferredEK(eks)
ekPublic := ek.Public()
ekKeyID, err := generateKeyID(ekPublic)
if err != nil {
return nil, fmt.Errorf("failed getting EK public key ID: %w", err)
}
ekKeyURL := ekURL(ekKeyID)
permanentIdentifier := ekKeyURL.String()

if properties.ak {
// TODO(hs): decide if we actually want to support this case? TPM attestation
// is about attesting application keys using attestation keys.
ak, err := k.tpm.GetAK(ctx, properties.name)
if err != nil {
return nil, err
}
akPub := ak.Public()
if akPub == nil {
return nil, fmt.Errorf("failed getting AK public key")
}
akChain := ak.CertificateChain()
if len(akChain) == 0 {
return nil, fmt.Errorf("no certificate chain available for AK %q", properties.name)
}
// TODO(hs): decide if we want/need to return these; their purpose is slightly
// different from the key certification parameters.
_, err = ak.AttestationParameters(ctx)
if err != nil {
return nil, fmt.Errorf("failed getting AK attestation parameters: %w", err)
}
return &apiv1.CreateAttestationResponse{
Certificate: akChain[0], // certificate for the AK
CertificateChain: akChain, // chain for the AK, including the leaf
PublicKey: akPub, // returns the public key of the attestation key
PermanentIdentifier: permanentIdentifier,
}, nil
}

key, err := k.tpm.GetKey(ctx, properties.name)
if err != nil {
return nil, err
Expand All @@ -497,18 +577,6 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.
return nil, fmt.Errorf("failed getting AK for key %q: %w", key.Name(), err)
}

eks, err := k.tpm.GetEKs(ctx) // TODO(hs): control the EK used as the caller of this method?
if err != nil {
return nil, fmt.Errorf("failed getting EKs: %w", err)
}
ek := getPreferredEK(eks)
ekPublic := ek.Public()
ekKeyID, err := generateKeyID(ekPublic)
if err != nil {
return nil, fmt.Errorf("failed getting EK public key ID: %w", err)
}
ekKeyURL := ekURL(ekKeyID)

// check if the derived EK URI fingerprint representation matches the provided
// permanent identifier value. The current implementation requires the EK URI to
// be used as the AK identity, so an error is returned if there's no match. This
Expand Down Expand Up @@ -569,7 +637,6 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.

// prepare the response to return
akCert := akChain[0]
permanentIdentifier := ekKeyURL.String() // NOTE: should always match the valid value of the AK identity (for now)
return &apiv1.CreateAttestationResponse{
Certificate: akCert, // certificate for the AK that attested the key
CertificateChain: akChain, // chain for the AK that attested the key, including the leaf
Expand All @@ -578,7 +645,7 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.
Public: params.Public,
CreateData: params.CreateData,
CreateAttestation: params.CreateAttestation,
CreateSignature: params.CreateSignature,
CreateSignature: params.CreateSignature, // NOTE: should always match the valid value of the AK identity (for now)
},
PermanentIdentifier: permanentIdentifier,
}, nil
Expand Down
Loading

0 comments on commit 271865e

Please sign in to comment.