diff --git a/kms/kmsfs_test.go b/kms/kmsfs_test.go index 8062b89e..c9c00f9b 100644 --- a/kms/kmsfs_test.go +++ b/kms/kmsfs_test.go @@ -11,8 +11,11 @@ import ( "reflect" "testing" + "github.com/stretchr/testify/assert" + "go.step.sm/crypto/kms/apiv1" "go.step.sm/crypto/kms/softkms" + "go.step.sm/crypto/kms/tpmkms" ) type fakeCM struct { @@ -67,21 +70,27 @@ func Test_new(t *testing.T) { {"ok softkms", args{ctx, "softkms:"}, &kmsfs{ KeyManager: &softkms.SoftKMS{}, }, false}, + {"ok tpmkms", args{ctx, "tpmkms:"}, &kmsfs{ + KeyManager: &tpmkms.TPMKMS{}, + }, false}, {"fail", args{ctx, "fail:"}, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := newFS(tt.args.ctx, tt.args.kmsuri) - if (err != nil) != tt.wantErr { - t.Errorf("new() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("new() = %v, want %v", got, tt.want) - } - if err := got.Close(); err != nil { - t.Errorf("Close() error = %v, wantErr false", err) + + assert.NoError(t, err) + if assert.NotNil(t, got) { + assert.IsType(t, tt.want, got) } + + err = got.Close() + assert.NoError(t, err) }) } } @@ -103,6 +112,8 @@ func Test_kmsfs_getKMS(t *testing.T) { {"ok empty", fields{nil}, args{""}, &softkms.SoftKMS{}, false}, {"ok softkms", fields{&softkms.SoftKMS{}}, args{""}, &softkms.SoftKMS{}, false}, {"ok softkms with uri", fields{nil}, args{"softkms:"}, &softkms.SoftKMS{}, false}, + {"ok tpmkms", fields{&tpmkms.TPMKMS{}}, args{""}, &tpmkms.TPMKMS{}, false}, + {"ok tpmkms with uri", fields{nil}, args{"tpmkms:"}, &tpmkms.TPMKMS{}, false}, {"fail", fields{nil}, args{"fail:"}, nil, true}, } for _, tt := range tests { @@ -111,12 +122,15 @@ func Test_kmsfs_getKMS(t *testing.T) { KeyManager: tt.fields.KeyManager, } got, err := f.getKMS(tt.args.kmsuri) - if (err != nil) != tt.wantErr { - t.Errorf("kmsfs.getKMS() error = %v, wantErr %v", err, tt.wantErr) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) return } - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("kmsfs.getKMS() = %v, want %v", got, tt.want) + + assert.NoError(t, err) + if assert.NotNil(t, got) { + assert.IsType(t, tt.want, got) } }) } diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index c4786a44..7d38ac29 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -34,8 +34,12 @@ const Scheme = string(apiv1.TPMKMS) const ( // DefaultRSASize is the number of bits of a new RSA key if no size has been - // specified. - DefaultRSASize = 3072 + // specified. Whereas we're generally defaulting to 3072 bits for new RSA keys, + // 2048 is used as the default for the TPMKMS, because we've observed the TPMs + // we're testing with to be supporting this as the maximum RSA key size. We might + // increase the default in the (near) future, but we want to be more confident + // about the supported size for a specific TPM (model) in that case. + DefaultRSASize = 2048 // defaultRSAAKSize is the default number of bits for a new RSA Attestation // Key. It is currently set to 2048, because that's what's mentioned in the // TCG TPM specification and is used by the AK template in `go-attestation`. @@ -202,7 +206,7 @@ func (k *TPMKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyRespons return nil, fmt.Errorf("creating %d bit AKs is not supported; AKs must be RSA 2048 bits", req.Bits) } - size := DefaultRSASize // defaults to 3072 + size := DefaultRSASize // defaults to 2048 if req.Bits > 0 { size = req.Bits } @@ -319,11 +323,19 @@ func (k *TPMKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, return nil, fmt.Errorf("failed parsing %q: %w", req.Name, err) } + ctx := context.Background() if properties.ak { - return nil, fmt.Errorf("retrieving AK public key currently not supported") + ak, err := k.tpm.GetAK(ctx, properties.name) + if err != nil { + return nil, err + } + akPub := ak.Public() + if akPub == nil { + return nil, errors.New("failed getting AK public key") + } + return akPub, nil } - ctx := context.Background() key, err := k.tpm.GetKey(ctx, properties.name) if err != nil { return nil, err @@ -337,6 +349,7 @@ func (k *TPMKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, return signer.Public(), nil } +// LoadCertificate loads the certificate for the key identified by name from the TPMKMS. func (k *TPMKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) { if req.Name == "" { return nil, errors.New("loadCertificateRequest 'name' cannot be empty") @@ -350,6 +363,8 @@ func (k *TPMKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certi return chain[0], nil } +// LoadCertificateCertificate loads the certificate chain for the key identified by +// name from the TPMKMS. func (k *TPMKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) { if req.Name == "" { return nil, errors.New("loadCertificateChainRequest 'name' cannot be empty") @@ -361,7 +376,7 @@ func (k *TPMKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([ } ctx := context.Background() - var chain []*x509.Certificate // TODO(hs): support returning chain? + var chain []*x509.Certificate if properties.ak { ak, err := k.tpm.GetAK(ctx, properties.name) if err != nil { @@ -383,6 +398,7 @@ func (k *TPMKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([ return chain, nil } +// StoreCertificate stores the certificate for the key identified by name to the TPMKMS. func (k *TPMKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { switch { case req.Name == "": @@ -394,6 +410,7 @@ func (k *TPMKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { return k.StoreCertificateChain(&apiv1.StoreCertificateChainRequest{Name: req.Name, CertificateChain: []*x509.Certificate{req.Certificate}}) } +// StoreCertificateChain stores the certificate for the key identified by name to the TPMKMS. func (k *TPMKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { switch { case req.Name == "": @@ -472,6 +489,27 @@ func (ac *attestationClient) Attest(ctx context.Context) ([]*x509.Certificate, e return ac.c.Attest(ctx, ac.t, ac.ek, ac.ak) } +// CreateAttestation implements the [apiv1.Attester] interface for the TPMKMS. It +// can be used to request the required information to verify that an application +// key was created in and by a specific TPM. +// +// It is expected that an application key has been attested at creation time by +// an attestation key (AK) before calling this method. An error will be returned +// otherwise. +// +// The response will include an attestation key (AK) certificate (chain) issued +// to the AK that was used to certify creation of the (application) key, as well +// as the key certification parameters at the time of key creation. Together these +// can be used by a relying party to attest that the key was created by a specific +// TPM. +// +// If no valid AK certificate is available when calling CreateAttestation, an +// enrolment with an instance of the Smallstep Attestation CA is performed. This +// will use the TPM Endorsement Key and the AK as inputs. The Attestation CA will +// return an AK certificate chain on success. +// +// When CreateAttestation is called for an AK, the AK certificate chain will be +// returned. Currently no AK creation parameters are returned. func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.CreateAttestationResponse, error) { if req.Name == "" { return nil, errors.New("createAttestationRequest 'name' cannot be empty") @@ -483,6 +521,48 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1. } ctx := context.Background() + eks, err := k.tpm.GetEKs(ctx) // TODO(hs): control the EK used as the caller of this method? + if err != nil { + return nil, fmt.Errorf("failed getting EKs: %w", err) + } + ek := getPreferredEK(eks) + ekPublic := ek.Public() + ekKeyID, err := generateKeyID(ekPublic) + if err != nil { + return nil, fmt.Errorf("failed getting EK public key ID: %w", err) + } + ekKeyURL := ekURL(ekKeyID) + permanentIdentifier := ekKeyURL.String() + + if properties.ak { + // TODO(hs): decide if we actually want to support this case? TPM attestation + // is about attesting application keys using attestation keys. + ak, err := k.tpm.GetAK(ctx, properties.name) + if err != nil { + return nil, err + } + akPub := ak.Public() + if akPub == nil { + return nil, fmt.Errorf("failed getting AK public key") + } + akChain := ak.CertificateChain() + if len(akChain) == 0 { + return nil, fmt.Errorf("no certificate chain available for AK %q", properties.name) + } + // TODO(hs): decide if we want/need to return these; their purpose is slightly + // different from the key certification parameters. + _, err = ak.AttestationParameters(ctx) + if err != nil { + return nil, fmt.Errorf("failed getting AK attestation parameters: %w", err) + } + return &apiv1.CreateAttestationResponse{ + Certificate: akChain[0], // certificate for the AK + CertificateChain: akChain, // chain for the AK, including the leaf + PublicKey: akPub, // returns the public key of the attestation key + PermanentIdentifier: permanentIdentifier, + }, nil + } + key, err := k.tpm.GetKey(ctx, properties.name) if err != nil { return nil, err @@ -497,18 +577,6 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1. return nil, fmt.Errorf("failed getting AK for key %q: %w", key.Name(), err) } - eks, err := k.tpm.GetEKs(ctx) // TODO(hs): control the EK used as the caller of this method? - if err != nil { - return nil, fmt.Errorf("failed getting EKs: %w", err) - } - ek := getPreferredEK(eks) - ekPublic := ek.Public() - ekKeyID, err := generateKeyID(ekPublic) - if err != nil { - return nil, fmt.Errorf("failed getting EK public key ID: %w", err) - } - ekKeyURL := ekURL(ekKeyID) - // check if the derived EK URI fingerprint representation matches the provided // permanent identifier value. The current implementation requires the EK URI to // be used as the AK identity, so an error is returned if there's no match. This @@ -569,7 +637,6 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1. // prepare the response to return akCert := akChain[0] - permanentIdentifier := ekKeyURL.String() // NOTE: should always match the valid value of the AK identity (for now) return &apiv1.CreateAttestationResponse{ Certificate: akCert, // certificate for the AK that attested the key CertificateChain: akChain, // chain for the AK that attested the key, including the leaf @@ -578,7 +645,7 @@ func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1. Public: params.Public, CreateData: params.CreateData, CreateAttestation: params.CreateAttestation, - CreateSignature: params.CreateSignature, + CreateSignature: params.CreateSignature, // NOTE: should always match the valid value of the AK identity (for now) }, PermanentIdentifier: permanentIdentifier, }, nil diff --git a/kms/tpmkms/tpmkms_simulator_test.go b/kms/tpmkms/tpmkms_simulator_test.go index 2dd30520..2633a8aa 100644 --- a/kms/tpmkms/tpmkms_simulator_test.go +++ b/kms/tpmkms/tpmkms_simulator_test.go @@ -497,6 +497,8 @@ func TestTPMKMS_CreateSigner(t *testing.T) { func TestTPMKMS_GetPublicKey(t *testing.T) { tpmWithKey := newSimulatedTPM(t, withKey("key1")) + _, err := tpmWithKey.CreateAK(context.Background(), "ak1") + require.NoError(t, err) type fields struct { tpm *tpmp.TPM } @@ -511,7 +513,7 @@ func TestTPMKMS_GetPublicKey(t *testing.T) { expErr error }{ { - name: "ok", + name: "ok/key", fields: fields{ tpm: tpmWithKey, }, @@ -522,28 +524,27 @@ func TestTPMKMS_GetPublicKey(t *testing.T) { }, }, { - name: "fail/empty", + name: "ok/ak", fields: fields{ tpm: tpmWithKey, }, args: args{ req: &apiv1.GetPublicKeyRequest{ - Name: "", + Name: "tpmkms:name=ak1;ak=true", }, }, - expErr: errors.New("getPublicKeyRequest 'name' cannot be empty"), }, { - name: "fail/ak", + name: "fail/empty", fields: fields{ tpm: tpmWithKey, }, args: args{ req: &apiv1.GetPublicKeyRequest{ - Name: "tpmkms:name=ak1;ak=true", + Name: "", }, }, - expErr: errors.New("retrieving AK public key currently not supported"), + expErr: errors.New("getPublicKeyRequest 'name' cannot be empty"), }, { name: "fail/unknown-key", @@ -557,6 +558,18 @@ func TestTPMKMS_GetPublicKey(t *testing.T) { }, expErr: fmt.Errorf(`failed getting key "unknown-key": not found`), }, + { + name: "fail/unknown-ak", + fields: fields{ + tpm: tpmWithKey, + }, + args: args{ + req: &apiv1.GetPublicKeyRequest{ + Name: "tpmkms:name=unknown-ak;ak=true", + }, + }, + expErr: fmt.Errorf(`failed getting AK "unknown-ak": not found`), + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -1182,6 +1195,19 @@ func TestTPMKMS_StoreCertificateChain(t *testing.T) { }, expErr: errors.New("storeCertificateChainRequest 'name' cannot be empty"), }, + { + name: "fail/empty-chain", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateChainRequest{ + Name: "tpmkms:name=key1", + CertificateChain: []*x509.Certificate{}, + }, + }, + expErr: errors.New("storeCertificateChainRequest 'certificateChain' cannot be empty"), + }, { name: "fail/unknown-ak", fields: fields{ @@ -1356,6 +1382,37 @@ func TestTPMKMS_CreateAttestation(t *testing.T) { expErr: errors.New(`failed parsing "tpmkms:name=keyx;ak=true;attest-by=ak1": "ak" and "attest-by" are mutually exclusive`), } }, + "fail/unknown-ak": func(t *testing.T) test { + return test{ + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=unknownAK;ak=true", + }, + }, + expErr: errors.New(`failed getting AK "unknownAK": not found`), + } + }, + "fail/ak-withoutCertificate": func(t *testing.T) test { + akWithoutCert, err := tpm.CreateAK(ctx, "anotherAKWithoutCert") + require.NoError(t, err) + akPub := akWithoutCert.Public() + require.Implements(t, (*crypto.PublicKey)(nil), akPub) + return test{ + fields: fields{ + tpm: tpm, + permanentIdentifier: ekKeyURL.String(), + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=anotherAKWithoutCert;ak=true", // key1 was attested by the akWithExistingCert at creation time + }, + }, + expErr: errors.New(`no certificate chain available for AK "anotherAKWithoutCert"`), + } + }, "fail/unknown-key": func(t *testing.T) test { return test{ fields: fields{ @@ -1809,6 +1866,42 @@ func TestTPMKMS_CreateAttestation(t *testing.T) { expErr: nil, } }, + "ok/ak": func(t *testing.T) test { + akWithCert, err := tpm.CreateAK(ctx, "akWithCert") + require.NoError(t, err) + akPub := akWithCert.Public() + require.Implements(t, (*crypto.PublicKey)(nil), akPub) + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testak", + }, + URIs: []*url.URL{ekKeyURL}, + PublicKey: akPub, + } + validAKCert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, validAKCert) + err = akWithCert.SetCertificateChain(ctx, []*x509.Certificate{validAKCert, ca.Intermediate}) + require.NoError(t, err) + return test{ + fields: fields{ + tpm: tpm, + permanentIdentifier: ekKeyURL.String(), + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=akWithCert;ak=true", // key1 was attested by the akWithExistingCert at creation time + }, + }, + want: &apiv1.CreateAttestationResponse{ + Certificate: validAKCert, + CertificateChain: []*x509.Certificate{validAKCert, ca.Intermediate}, + PublicKey: akWithCert.Public(), + PermanentIdentifier: ekKeyURL.String(), + }, + expErr: nil, + } + }, } for name, tt := range tests { tc := tt(t) diff --git a/tpm/ek.go b/tpm/ek.go index 15965900..fce75a44 100644 --- a/tpm/ek.go +++ b/tpm/ek.go @@ -89,16 +89,22 @@ func (ek *EK) MarshalJSON() ([]byte, error) { if err != nil { return nil, fmt.Errorf("failed getting EK fingerprint: %w", err) } + fpURI, err := ek.FingerprintURI() + if err != nil { + return nil, fmt.Errorf("failed getting EK fingerprint URI: %w", err) + } o := struct { - Type string `json:"type"` - Fingerprint string `json:"fingerprint"` - DER []byte `json:"der,omitempty"` - URL string `json:"url,omitempty"` + Type string `json:"type"` + Fingerprint string `json:"fingerprint"` + FingerprintURI string `json:"fingerprintURI"` + DER []byte `json:"der,omitempty"` // TODO: support for EK certificate chain? + URL string `json:"url,omitempty"` }{ - Type: ek.Type(), - Fingerprint: fp, - DER: der, - URL: ek.certificateURL, + Type: ek.Type(), + Fingerprint: fp, + FingerprintURI: fpURI.String(), + DER: der, + URL: ek.certificateURL, } return json.Marshal(o) } diff --git a/tpm/ek_test.go b/tpm/ek_test.go index ac19a1cb..444b49a8 100644 --- a/tpm/ek_test.go +++ b/tpm/ek_test.go @@ -8,8 +8,10 @@ import ( "crypto/elliptic" "crypto/rand" "crypto/rsa" + "crypto/x509" "encoding/base64" "encoding/json" + "encoding/pem" "errors" "fmt" "io" @@ -76,9 +78,11 @@ func TestEK_MarshalJSON(t *testing.T) { keyID, err := generateKeyID(signer.Public()) require.NoError(t, err) fp := "sha256:" + base64.StdEncoding.EncodeToString(keyID) + fpURI := "urn:ek:" + fp require.Equal(t, m["type"], "RSA 2048") require.Equal(t, m["fingerprint"], fp) + require.Equal(t, m["fingerprintURI"], fpURI) require.Equal(t, m["der"], base64.StdEncoding.EncodeToString(cert.Raw)) require.Equal(t, m["url"], "https://certificate.example.com") } @@ -193,3 +197,74 @@ func TestEK_FingerprintURI(t *testing.T) { assert.Error(t, err) assert.Nil(t, u) } + +func TestEK_PEM(t *testing.T) { + ca, err := minica.New( + minica.WithGetSignerFunc( + func() (crypto.Signer, error) { + return keyutil.GenerateSigner("RSA", "", 2048) + }, + ), + ) + require.NoError(t, err) + + signer, err := keyutil.GenerateSigner("RSA", "", 2048) + require.NoError(t, err) + + cr, err := x509util.NewCertificateRequest(signer) + require.NoError(t, err) + cr.Subject.CommonName = "testek" + + csr, err := cr.GetCertificateRequest() + require.NoError(t, err) + + cert, err := ca.SignCSR(csr) + require.NoError(t, err) + + pemString := pem.EncodeToMemory(&pem.Block{ + Type: "CERTIFICATE", + Bytes: cert.Raw, + }) + + type fields struct { + public crypto.PublicKey + certificate *x509.Certificate + certificateURL string + } + tests := []struct { + name string + fields fields + want string + wantErr bool + }{ + { + name: "fail/no-ek-cert", + fields: fields{}, + wantErr: true, + }, + { + name: "ok", + fields: fields{ + certificate: cert, + }, + want: string(pemString), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ek := &EK{ + public: tt.fields.public, + certificate: tt.fields.certificate, + certificateURL: tt.fields.certificateURL, + } + got, err := ek.PEM() + if (err != nil) != tt.wantErr { + t.Errorf("EK.PEM() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("EK.PEM() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tpm/internal/key/key.go b/tpm/internal/key/key.go index ad16c6d6..2a56af1f 100644 --- a/tpm/internal/key/key.go +++ b/tpm/internal/key/key.go @@ -108,6 +108,20 @@ type CreateConfig struct { Size int } +func (c *CreateConfig) Validate() error { + switch c.Algorithm { + case "RSA": + if c.Size > 2048 { + return fmt.Errorf("%d bits RSA keys are (currently) not supported in go.step.sm/crypto; maximum is 2048", c.Size) + } + case "ECDSA": + break + default: + return fmt.Errorf("unsupported algorithm %q", c.Algorithm) + } + return nil +} + var tpmEkTemplate *tpm2.Public func ekTemplate(rwc io.ReadWriteCloser) (tpm2.Public, error) { diff --git a/tpm/key.go b/tpm/key.go index fa52aa14..cb1b08c3 100644 --- a/tpm/key.go +++ b/tpm/key.go @@ -144,7 +144,7 @@ type AttestKeyConfig struct { // a random 10 character name is generated. If a Key with the same name exists, // `ErrExists` is returned. The Key won't be attested by an AK. func (t *TPM) CreateKey(ctx context.Context, name string, config CreateKeyConfig) (key *Key, err error) { - if err = t.open(ctx); err != nil { + if err = t.open(goTPMCall(ctx)); err != nil { return nil, fmt.Errorf("failed opening TPM: %w", err) } defer closeTPM(ctx, t, &err) @@ -162,6 +162,9 @@ func (t *TPM) CreateKey(ctx context.Context, name string, config CreateKeyConfig Algorithm: config.Algorithm, Size: config.Size, } + if err := t.validate(&createConfig); err != nil { + return nil, fmt.Errorf("invalid key creation parameters: %w", err) + } data, err := internalkey.Create(t.rwc, prefixKey(name), createConfig) if err != nil { return nil, fmt.Errorf("failed creating key %q: %w", name, err) @@ -185,6 +188,22 @@ func (t *TPM) CreateKey(ctx context.Context, name string, config CreateKeyConfig return } +type attestValidationWrapper attest.KeyConfig + +func (w attestValidationWrapper) Validate() error { + switch w.Algorithm { + case "RSA": + if w.Size > 2048 { + return fmt.Errorf("%d bits RSA keys are (currently) not supported in go.step.sm/crypto; maximum is 2048", w.Size) + } + case "ECDSA": + break + default: + return fmt.Errorf("unsupported algorithm %q", w.Algorithm) + } + return nil +} + // AttestKey creates a new Key identified by `name` and attested by the AK // identified by `akName`. If no name is provided, a random 10 character // name is generated. If a Key with the same name exists, `ErrExists` is @@ -215,13 +234,16 @@ func (t *TPM) AttestKey(ctx context.Context, akName, name string, config AttestK } defer loadedAK.Close(t.attestTPM) - keyConfig := &attest.KeyConfig{ + keyConfig := attest.KeyConfig{ Algorithm: attest.Algorithm(config.Algorithm), Size: config.Size, QualifyingData: config.QualifyingData, Name: prefixKey(name), } - akey, err := t.attestTPM.NewKey(loadedAK, keyConfig) + if err := t.validate(attestValidationWrapper(keyConfig)); err != nil { + return nil, fmt.Errorf("invalid key attestation parameters: %w", err) + } + akey, err := t.attestTPM.NewKey(loadedAK, &keyConfig) if err != nil { return nil, fmt.Errorf("failed creating key %q: %w", name, err) } diff --git a/tpm/tpm.go b/tpm/tpm.go index ea6fb391..31445867 100644 --- a/tpm/tpm.go +++ b/tpm/tpm.go @@ -203,6 +203,14 @@ func (t *TPM) close(ctx context.Context) error { return nil } +type validatableConfig interface { + Validate() error +} + +func (t *TPM) validate(config validatableConfig) error { + return config.Validate() +} + // closeTPM closes TPM `t`. It must be called as a deferred function // every time TPM `t` is opened. If `ep` is nil and closing the TPM // returned an error, `ep` will be pointed to the latter. In practice diff --git a/tpm/tpm_simulator_test.go b/tpm/tpm_simulator_test.go index 99013886..29370c63 100644 --- a/tpm/tpm_simulator_test.go +++ b/tpm/tpm_simulator_test.go @@ -446,6 +446,29 @@ func TestTPM_CreateKey(t *testing.T) { require.NotEqual(t, 0, len(key.Data())) require.Same(t, tpm, key.tpm) require.False(t, key.WasAttested()) + + config = CreateKeyConfig{ + Algorithm: "RSA", + Size: 1024, + } + key, err = tpm.CreateKey(context.Background(), "1024", config) + require.NoError(t, err) + + config = CreateKeyConfig{ + Algorithm: "RSA", + Size: 3072, + } + key, err = tpm.CreateKey(context.Background(), "3072", config) + assert.EqualError(t, err, "invalid key creation parameters: 3072 bits RSA keys are (currently) not supported in go.step.sm/crypto; maximum is 2048") + assert.Nil(t, key) + + config = CreateKeyConfig{ + Algorithm: "RSA", + Size: 4096, + } + key, err = tpm.CreateKey(context.Background(), "4096", config) + assert.EqualError(t, err, "invalid key creation parameters: 4096 bits RSA keys are (currently) not supported in go.step.sm/crypto; maximum is 2048") + assert.Nil(t, key) } func TestTPM_AttestKey(t *testing.T) { @@ -468,6 +491,14 @@ func TestTPM_AttestKey(t *testing.T) { require.Same(t, tpm, key.tpm) require.True(t, key.WasAttested()) require.True(t, key.WasAttestedBy(ak)) + + config = AttestKeyConfig{ + Algorithm: "RSA", + Size: 3072, + } + key, err = tpm.AttestKey(context.Background(), "first-ak", "3072", config) + assert.EqualError(t, err, "invalid key attestation parameters: 3072 bits RSA keys are (currently) not supported in go.step.sm/crypto; maximum is 2048") + assert.Nil(t, key) } func TestTPM_GetKey(t *testing.T) {