diff --git a/Makefile b/Makefile index 30aa5250..f6c5cbcc 100644 --- a/Makefile +++ b/Makefile @@ -29,10 +29,11 @@ defaulttest: $Q $(GOFLAGS) gotestsum -- -coverpkg=./... -coverprofile=defaultcoverage.out -covermode=atomic ./... simulatortest: - $Q $(GOFLAGS) CGO_ENALBED=1 gotestsum -- -coverpkg=./tpm -coverprofile=simulatorcoverage.out -covermode=atomic -tags tpmsimulator ./tpm + $Q $(GOFLAGS) CGO_ENABLED=1 gotestsum -- -coverpkg=./tpm/...,./kms/tpmkms -coverprofile=simulatorcoverage.out -covermode=atomic -tags tpmsimulator ./tpm ./kms/tpmkms combinecoverage: - cat defaultcoverage.out simulatorcoverage.out > coverage.out + cat defaultcoverage.out > coverage.out + tail -n +2 simulatorcoverage.out >> coverage.out race: $Q $(GOFLAGS) gotestsum -- -race ./... diff --git a/kms/apiv1/options.go b/kms/apiv1/options.go index f075d3f1..439fa0d9 100644 --- a/kms/apiv1/options.go +++ b/kms/apiv1/options.go @@ -30,6 +30,17 @@ type CertificateManager interface { StoreCertificate(req *StoreCertificateRequest) error } +// CertificateChainManager is the interface implemented by KMS implementations +// that can load certificate chains. The LoadCertificateChain method uses the +// same request object as the LoadCertificate method of the CertificateManager +// interfaces. When the LoadCertificateChain method is called, the certificate +// chain stored through the CertificateChain property in the StoreCertificateRequest +// will be returned, partially reusing the StoreCertificateRequest object. +type CertificateChainManager interface { + LoadCertificateChain(req *LoadCertificateChainRequest) ([]*x509.Certificate, error) + StoreCertificateChain(req *StoreCertificateChainRequest) error +} + // NameValidator is an interface that KeyManager can implement to validate a // given name or URI. type NameValidator interface { @@ -61,7 +72,7 @@ func (e NotImplementedError) Error() string { } // AlreadyExistsError is the type of error returned if a key already exists. This -// is currently only implmented on pkcs11. +// is currently only implmented for pkcs11 and tpmkms. type AlreadyExistsError struct { Message string } @@ -95,6 +106,8 @@ const ( AzureKMS Type = "azurekms" // CAPIKMS CAPIKMS Type = "capi" + // TPMKMS + TPMKMS Type = "tpmkms" ) // Options are the KMS options. They represent the kms object in the ca.json. @@ -109,7 +122,7 @@ type Options struct { // https://tools.ietf.org/html/rfc7512 and represents the configuration used // to connect to the KMS. // - // Used by: pkcs11 + // Used by: pkcs11, tpmkms URI string `json:"uri,omitempty"` // Pin used to access the PKCS11 module. It can be defined in the URI using @@ -130,6 +143,10 @@ type Options struct { // Profile to use in AmazonKMS. Profile string `json:"profile,omitempty"` + + // StorageDirectory is the path to a directory to + // store serialized TPM objects. Only used by the TPMKMS. + StorageDirectory string `json:"storageDirectory,omitempty"` } // Validate checks the fields in Options. @@ -142,7 +159,7 @@ func (o *Options) Validate() error { switch Type(typ) { case DefaultKMS, SoftKMS: // Go crypto based kms. case CloudKMS, AmazonKMS, AzureKMS: // Cloud based kms. - case YubiKey, PKCS11: // Hardware based kms. + case YubiKey, PKCS11, TPMKMS: // Hardware based kms. case SSHAgentKMS, CAPIKMS: // Others default: return fmt.Errorf("unsupported kms type %s", o.Type) diff --git a/kms/apiv1/requests.go b/kms/apiv1/requests.go index 0862cab8..fb804f54 100644 --- a/kms/apiv1/requests.go +++ b/kms/apiv1/requests.go @@ -1,6 +1,7 @@ package apiv1 import ( + "context" "crypto" "crypto/x509" "fmt" @@ -129,7 +130,7 @@ type GetPublicKeyRequest struct { type CreateKeyRequest struct { // Name represents the key name or label used to identify a key. // - // Used by: awskms, cloudkms, azurekms, pkcs11, yubikey. + // Used by: awskms, cloudkms, azurekms, pkcs11, yubikey, tpmkms. Name string // SignatureAlgorithm represents the type of key to create. @@ -163,8 +164,9 @@ type CreateKeyRequest struct { // CreateKeyResponse is the response value of the kms.CreateKey method. type CreateKeyResponse struct { - Name string - PublicKey crypto.PublicKey + Name string + PublicKey crypto.PublicKey + // PrivateKey is only used by softkms PrivateKey crypto.PrivateKey CreateSignerRequest CreateSignerRequest } @@ -194,6 +196,10 @@ type LoadCertificateRequest struct { Name string } +// LoadCertificateChainRequest is the parameter used in the LoadCertificateChain method of +// a CertificateChainManager. It's an alias for LoadCertificateRequest. +type LoadCertificateChainRequest LoadCertificateRequest + // StoreCertificateRequest is the parameter used in the StoreCertificate method // of a CertificateManager. type StoreCertificateRequest struct { @@ -207,6 +213,13 @@ type StoreCertificateRequest struct { Extractable bool } +// StoreCertificateChainRequest is the parameter used in the StoreCertificateChain method +// of a CertificateChainManager. +type StoreCertificateChainRequest struct { + Name string + CertificateChain []*x509.Certificate +} + // CreateAttestationRequest is the parameter used in the kms.CreateAttestation // method. // @@ -215,19 +228,54 @@ type StoreCertificateRequest struct { // Notice: This API is EXPERIMENTAL and may be changed or removed in a later // release. type CreateAttestationRequest struct { - Name string + Name string + AttestationClient AttestationClient // TODO(hs): a better name; Attestor perhaps, but that's already taken +} + +// AttestationClient is an interface that provides a pluggable method for +// attesting Attestation Keys (AKs). +type AttestationClient interface { + Attest(context.Context) ([]*x509.Certificate, error) +} + +// CertificationParameters encapsulates the inputs for certifying an application key. +// Only TPM 2.0 is supported at this point. +// +// This struct was copied from github.com/google/go-attestation, preventing an +// additional dependency in this package. +type CertificationParameters struct { + // Public represents the key's canonical encoding (a TPMT_PUBLIC structure). + // It includes the public key and signing parameters. + Public []byte + // CreateData represents the properties of a TPM 2.0 key. It is encoded + // as a TPMS_CREATION_DATA structure. + CreateData []byte + // CreateAttestation represents an assertion as to the details of the key. + // It is encoded as a TPMS_ATTEST structure. + CreateAttestation []byte + // CreateSignature represents a signature of the CreateAttestation structure. + // It is encoded as a TPMT_SIGNATURE structure. + CreateSignature []byte } // CreateAttestationResponse is the response value of the kms.CreateAttestation // method. // +// If a non-empty CertificateChain is returned, the first x509.Certificate is +// the same as the one in the Certificate property. +// +// When an attestation is created for a TPM key, the CertificationParameters +// property will have a record of the certification parameters at the time of +// key attestation. +// // # Experimental // // Notice: This API is EXPERIMENTAL and may be changed or removed in a later // release. type CreateAttestationResponse struct { - Certificate *x509.Certificate - CertificateChain []*x509.Certificate - PublicKey crypto.PublicKey - PermanentIdentifier string + Certificate *x509.Certificate + CertificateChain []*x509.Certificate + PublicKey crypto.PublicKey + CertificationParameters *CertificationParameters + PermanentIdentifier string } diff --git a/kms/capi/capi.go b/kms/capi/capi.go index d0d7d013..ebf37906 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -755,3 +755,5 @@ func (s *CAPISigner) Sign(_ io.Reader, digest []byte, opts crypto.SignerOpts) ([ func (s *CAPISigner) Public() crypto.PublicKey { return s.PublicKey } + +var _ apiv1.CertificateManager = (*CAPIKMS)(nil) diff --git a/kms/kmsfs_test.go b/kms/kmsfs_test.go index 2f4d5ebd..8062b89e 100644 --- a/kms/kmsfs_test.go +++ b/kms/kmsfs_test.go @@ -42,6 +42,8 @@ func (f *fakeCM) StoreCertificate(req *apiv1.StoreCertificateRequest) error { return nil } +var _ apiv1.CertificateManager = (*fakeCM)(nil) + func TestMain(m *testing.M) { apiv1.Register(apiv1.Type("fake"), func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) { return &fakeCM{}, nil diff --git a/kms/pkcs11/pkcs11.go b/kms/pkcs11/pkcs11.go index e839a093..46d0f962 100644 --- a/kms/pkcs11/pkcs11.go +++ b/kms/pkcs11/pkcs11.go @@ -401,3 +401,5 @@ func findCertificate(ctx P11, rawuri string) (*x509.Certificate, error) { } return cert, nil } + +var _ apiv1.CertificateManager = (*PKCS11)(nil) diff --git a/kms/tpmkms/no_tpmkms.go b/kms/tpmkms/no_tpmkms.go new file mode 100644 index 00000000..2a84569a --- /dev/null +++ b/kms/tpmkms/no_tpmkms.go @@ -0,0 +1,11 @@ +//go:build notpmkms +// +build notpmkms + +package tpmkms + +func init() { + apiv1.Register(apiv1.TPMKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) { + name := filepath.Base(os.Args[0]) + return nil, errors.Errorf("unsupported KMS type 'tpmkms': %s is compiled without TPM KMS support", name) + }) +} diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go new file mode 100644 index 00000000..c4786a44 --- /dev/null +++ b/kms/tpmkms/tpmkms.go @@ -0,0 +1,658 @@ +//go:build !notpmkms +// +build !notpmkms + +package tpmkms + +import ( + "context" + "crypto" + "crypto/rsa" + "crypto/sha256" + "crypto/x509" + "encoding/base64" + "errors" + "fmt" + "net/url" + "path/filepath" + + "go.step.sm/crypto/internal/step" + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/kms/uri" + "go.step.sm/crypto/tpm" + "go.step.sm/crypto/tpm/attestation" + "go.step.sm/crypto/tpm/storage" +) + +func init() { + apiv1.Register(apiv1.TPMKMS, func(ctx context.Context, opts apiv1.Options) (apiv1.KeyManager, error) { + return New(ctx, opts) + }) +} + +// Scheme is the scheme used in TPM KMS URIs, the string "tpmkms". +const Scheme = string(apiv1.TPMKMS) + +const ( + // DefaultRSASize is the number of bits of a new RSA key if no size has been + // specified. + DefaultRSASize = 3072 + // defaultRSAAKSize is the default number of bits for a new RSA Attestation + // Key. It is currently set to 2048, because that's what's mentioned in the + // TCG TPM specification and is used by the AK template in `go-attestation`. + defaultRSAAKSize = 2048 +) + +// TPMKMS is a KMS implementation backed by a TPM. +type TPMKMS struct { + tpm *tpm.TPM + attestationCABaseURL string + attestationCARootFile string + attestationCAInsecure bool + permanentIdentifier string +} + +type algorithmAttributes struct { + Type string + Curve int +} + +var signatureAlgorithmMapping = map[apiv1.SignatureAlgorithm]algorithmAttributes{ + apiv1.UnspecifiedSignAlgorithm: {"RSA", -1}, + apiv1.SHA256WithRSA: {"RSA", -1}, + apiv1.SHA384WithRSA: {"RSA", -1}, + apiv1.SHA512WithRSA: {"RSA", -1}, + apiv1.SHA256WithRSAPSS: {"RSA", -1}, + apiv1.SHA384WithRSAPSS: {"RSA", -1}, + apiv1.SHA512WithRSAPSS: {"RSA", -1}, + apiv1.ECDSAWithSHA256: {"ECDSA", 256}, + apiv1.ECDSAWithSHA384: {"ECDSA", 384}, + apiv1.ECDSAWithSHA512: {"ECDSA", 521}, +} + +// New initializes a new KMS backed by a TPM. +// +// A new TPMKMS can be initialized with a configuration by providing +// a URI in the options: +// +// New(ctx, &apiv1.Options{ +// URI: tpmkms:device=/dev/tpmrm0;storage-directory=/path/to/tpmstorage/directory +// }) +// +// It's also possible to set the storage directory as follows: +// +// New(ctx, &apiv1.Options{ +// URI: tpmkms:device=/dev/tpmrm0 +// StorageDirectory: /path/to/tpmstorage/directory +// }) +// +// The default storage location for serialized TPM objects when +// an instance of TPMKMS is created, is $STEPPATH/tpm. +// +// The system default TPM device will be used when not configured. A +// specific TPM device can be selected by setting the device: +// +// tpmkms:device=/dev/tpmrm0 +// +// By default newly created TPM objects won't be persisted, so can't +// be readily used. The location for storage can be set using +// storage-directory: +// +// tpmkms:storage-directory=/path/to/tpmstorage/directory +// +// For attestation use cases that involve the Smallstep Attestation CA +// or a compatible one, several properties can be set. The following +// specify the Attestation CA base URL, the path to a bundle of root CAs +// to trust when setting up a TLS connection to the Attestation CA and +// disable TLS certificate validation, respectively. +// +// tpmkms:attestation-ca-url=https://my.attestation.ca +// tpmkms:attestation-ca-root=/path/to/trusted/roots.pem +// tpmkms:attestation-ca-insecure=true +// +// The system may not always have a PermanentIdentifier assigned, so +// when initializing the TPMKMS, it's possible to set this value: +// +// tpmkms:permanent-identifier= +// +// Attestation support in the TPMKMS is considered EXPERIMENTAL. It +// is expected that there will be changes to the configuration that +// be provided and the attestation flow. +// +// The TPMKMS implementation is backed by an instance of the TPM from +// the `tpm` package. If the TPMKMS operations aren't sufficient for +// your use case, use a tpm.TPM instance instead. +func New(ctx context.Context, opts apiv1.Options) (kms *TPMKMS, err error) { + kms = &TPMKMS{} + storageDirectory := filepath.Join(step.Path(), "tpm") // store TPM objects in $STEPPATH/tpm by default + if opts.StorageDirectory != "" { + storageDirectory = opts.StorageDirectory + } + tpmOpts := []tpm.NewTPMOption{tpm.WithStore(storage.NewDirstore(storageDirectory))} + if opts.URI != "" { + u, err := uri.ParseWithScheme(Scheme, opts.URI) + if err != nil { + return nil, fmt.Errorf("failed parsing %q as URI: %w", opts.URI, err) + } + if device := u.Get("device"); device != "" { + tpmOpts = append(tpmOpts, tpm.WithDeviceName(device)) + } + if storageDirectory := u.Get("storage-directory"); storageDirectory != "" { + tpmOpts = append(tpmOpts, tpm.WithStore(storage.NewDirstore(storageDirectory))) + } + kms.attestationCABaseURL = u.Get("attestation-ca-url") + kms.attestationCARootFile = u.Get("attestation-ca-root") + kms.attestationCAInsecure = u.GetBool("attestation-ca-insecure") + kms.permanentIdentifier = u.Get("permanent-identifier") // TODO(hs): determine if this is needed + } + + kms.tpm, err = tpm.New(tpmOpts...) + if err != nil { + return nil, fmt.Errorf("failed creating new TPM: %w", err) + } + + return +} + +// CreateKey generates a new key in the TPM KMS and returns the public key. +// +// The `name` in the [apiv1.CreateKeyRequest] can be used to specify +// some key properties. These are as follows: +// +// - name=: specify the name to identify the key with +// - ak=true: if set to true, an Attestation Key (AK) will be created instead of an application key +// - attest-by=: attest an application key at creation time with the AK identified by `akName` +// - qualifying-data=: hexadecimal coded binary data that can be used to guarantee freshness when attesting creation of a key +// +// Some examples usages: +// +// Create an application key, without attesting it: +// +// tpmkms:name=my-key +// +// Create an Attestation Key (AK): +// +// tpmkms:name=my-ak;ak=true +// +// Create an application key, attested by `my-ak` with "1234" as the Qualifying Data: +// +// tpmkms:name=my-attested-key;attest-by=my-ak;qualifying-data=61626364 +func (k *TPMKMS) CreateKey(req *apiv1.CreateKeyRequest) (*apiv1.CreateKeyResponse, error) { + switch { + case req.Name == "": + return nil, errors.New("createKeyRequest 'name' cannot be empty") + case req.Bits < 0: + return nil, errors.New("createKeyRequest 'bits' cannot be negative") + } + + properties, err := parseNameURI(req.Name) + if err != nil { + return nil, fmt.Errorf("failed parsing %q: %w", req.Name, err) + } + + v, ok := signatureAlgorithmMapping[req.SignatureAlgorithm] + if !ok { + return nil, fmt.Errorf("TPMKMS does not support signature algorithm %q", req.SignatureAlgorithm) + } + + if properties.ak && v.Type == "ECDSA" { + return nil, errors.New("AKs must be RSA keys") + } + + if properties.ak && req.Bits != 0 && req.Bits != defaultRSAAKSize { // 2048 + return nil, fmt.Errorf("creating %d bit AKs is not supported; AKs must be RSA 2048 bits", req.Bits) + } + + size := DefaultRSASize // defaults to 3072 + if req.Bits > 0 { + size = req.Bits + } + + if v.Type == "ECDSA" { + size = v.Curve + } + + ctx := context.Background() + if properties.ak { + ak, err := k.tpm.CreateAK(ctx, properties.name) // NOTE: size is never passed for AKs; it's hardcoded to 2048 in lower levels. + if err != nil { + if errors.Is(err, tpm.ErrExists) { + return nil, apiv1.AlreadyExistsError{Message: err.Error()} + } + return nil, fmt.Errorf("failed creating AK: %w", err) + } + createdAKURI := fmt.Sprintf("tpmkms:name=%s;ak=true", ak.Name()) + return &apiv1.CreateKeyResponse{ + Name: createdAKURI, + PublicKey: ak.Public(), + }, nil + } + + var key *tpm.Key + if properties.attestBy != "" { + config := tpm.AttestKeyConfig{ + Algorithm: v.Type, + Size: size, + QualifyingData: properties.qualifyingData, + } + key, err = k.tpm.AttestKey(ctx, properties.attestBy, properties.name, config) + if err != nil { + if errors.Is(err, tpm.ErrExists) { + return nil, apiv1.AlreadyExistsError{Message: err.Error()} + } + return nil, fmt.Errorf("failed creating attested key: %w", err) + } + } else { + config := tpm.CreateKeyConfig{ + Algorithm: v.Type, + Size: size, + } + key, err = k.tpm.CreateKey(ctx, properties.name, config) + if err != nil { + if errors.Is(err, tpm.ErrExists) { + return nil, apiv1.AlreadyExistsError{Message: err.Error()} + } + return nil, fmt.Errorf("failed creating key: %w", err) + } + } + + signer, err := key.Signer(ctx) + if err != nil { + return nil, fmt.Errorf("failed getting signer for key: %w", err) + } + + createdKeyURI := fmt.Sprintf("tpmkms:name=%s", key.Name()) + if properties.attestBy != "" { + createdKeyURI = fmt.Sprintf("%s;attest-by=%s", createdKeyURI, key.AttestedBy()) + } + + return &apiv1.CreateKeyResponse{ + Name: createdKeyURI, + PublicKey: signer.Public(), + CreateSignerRequest: apiv1.CreateSignerRequest{ + SigningKey: createdKeyURI, + Signer: signer, + }, + }, nil +} + +// CreateSigner creates a signer using a key present in the TPM KMS. +func (k *TPMKMS) CreateSigner(req *apiv1.CreateSignerRequest) (crypto.Signer, error) { + if req.Signer != nil { + return req.Signer, nil + } + + if req.SigningKey == "" { + return nil, errors.New("createSignerRequest 'signingKey' cannot be empty") + } + + properties, err := parseNameURI(req.SigningKey) + if err != nil { + return nil, fmt.Errorf("failed parsing %q: %w", req.SigningKey, err) + } + + if properties.ak { + return nil, fmt.Errorf("signing with an AK currently not supported") + } + + ctx := context.Background() + key, err := k.tpm.GetKey(ctx, properties.name) + if err != nil { + return nil, err + } + + signer, err := key.Signer(ctx) + if err != nil { + return nil, fmt.Errorf("failed getting signer for key %q: %w", properties.name, err) + } + + return signer, nil +} + +// GetPublicKey returns the public key .... +func (k *TPMKMS) GetPublicKey(req *apiv1.GetPublicKeyRequest) (crypto.PublicKey, error) { + if req.Name == "" { + return nil, errors.New("getPublicKeyRequest 'name' cannot be empty") + } + + properties, err := parseNameURI(req.Name) + if err != nil { + return nil, fmt.Errorf("failed parsing %q: %w", req.Name, err) + } + + if properties.ak { + return nil, fmt.Errorf("retrieving AK public key currently not supported") + } + + ctx := context.Background() + key, err := k.tpm.GetKey(ctx, properties.name) + if err != nil { + return nil, err + } + + signer, err := key.Signer(ctx) + if err != nil { + return nil, fmt.Errorf("failed getting signer for key %q: %w", properties.name, err) + } + + return signer.Public(), nil +} + +func (k *TPMKMS) LoadCertificate(req *apiv1.LoadCertificateRequest) (*x509.Certificate, error) { + if req.Name == "" { + return nil, errors.New("loadCertificateRequest 'name' cannot be empty") + } + + chain, err := k.LoadCertificateChain(&apiv1.LoadCertificateChainRequest{Name: req.Name}) + if err != nil { + return nil, err + } + + return chain[0], nil +} + +func (k *TPMKMS) LoadCertificateChain(req *apiv1.LoadCertificateChainRequest) ([]*x509.Certificate, error) { + if req.Name == "" { + return nil, errors.New("loadCertificateChainRequest 'name' cannot be empty") + } + + properties, err := parseNameURI(req.Name) + if err != nil { + return nil, fmt.Errorf("failed parsing %q: %w", req.Name, err) + } + + ctx := context.Background() + var chain []*x509.Certificate // TODO(hs): support returning chain? + if properties.ak { + ak, err := k.tpm.GetAK(ctx, properties.name) + if err != nil { + return nil, err + } + chain = ak.CertificateChain() + } else { + key, err := k.tpm.GetKey(ctx, properties.name) + if err != nil { + return nil, err + } + chain = key.CertificateChain() + } + + if len(chain) == 0 { + return nil, fmt.Errorf("failed getting certificate chain for %q: no certificate chain stored", properties.name) + } + + return chain, nil +} + +func (k *TPMKMS) StoreCertificate(req *apiv1.StoreCertificateRequest) error { + switch { + case req.Name == "": + return errors.New("storeCertificateRequest 'name' cannot be empty") + case req.Certificate == nil: + return errors.New("storeCertificateRequest 'certificate' cannot be empty") + } + + return k.StoreCertificateChain(&apiv1.StoreCertificateChainRequest{Name: req.Name, CertificateChain: []*x509.Certificate{req.Certificate}}) +} + +func (k *TPMKMS) StoreCertificateChain(req *apiv1.StoreCertificateChainRequest) error { + switch { + case req.Name == "": + return errors.New("storeCertificateChainRequest 'name' cannot be empty") + case len(req.CertificateChain) == 0: + return errors.New("storeCertificateChainRequest 'certificateChain' cannot be empty") + } + + properties, err := parseNameURI(req.Name) + if err != nil { + return fmt.Errorf("failed parsing %q: %w", req.Name, err) + } + + ctx := context.Background() + if properties.ak { + ak, err := k.tpm.GetAK(ctx, properties.name) + if err != nil { + return err + } + err = ak.SetCertificateChain(ctx, req.CertificateChain) + if err != nil { + return fmt.Errorf("failed storing certificate for AK %q: %w", properties.name, err) + } + } else { + key, err := k.tpm.GetKey(ctx, properties.name) + if err != nil { + return err + } + + err = key.SetCertificateChain(ctx, req.CertificateChain) + if err != nil { + return fmt.Errorf("failed storing certificate for key %q: %w", properties.name, err) + } + } + + return nil +} + +// attestationClient is a wrapper for [attestation.Client], containing +// all of the required references to perform attestation against the +// Smallstep Attestation CA. +type attestationClient struct { + c *attestation.Client + t *tpm.TPM + ek *tpm.EK + ak *tpm.AK +} + +// newAttestorClient creates a new [attestationClient], wrapping references +// to the [tpm.TPM] instance, the EK and the AK to use when attesting. +func (k *TPMKMS) newAttestorClient(ek *tpm.EK, ak *tpm.AK) (*attestationClient, error) { + if k.attestationCABaseURL == "" { + return nil, errors.New("failed creating attestation client: attestation CA base URL must not be empty") + } + // prepare a client to perform attestation with an Attestation CA + attestationClientOptions := []attestation.Option{attestation.WithRootsFile(k.attestationCARootFile)} + if k.attestationCAInsecure { + attestationClientOptions = append(attestationClientOptions, attestation.WithInsecure()) + } + client, err := attestation.NewClient(k.attestationCABaseURL, attestationClientOptions...) + if err != nil { + return nil, fmt.Errorf("failed creating attestation client: %w", err) + } + return &attestationClient{ + c: client, + t: k.tpm, + ek: ek, + ak: ak, + }, nil +} + +// Attest implements the [apiv1.AttestationClient] interface, calling into the +// underlying [attestation.Client] to perform an attestation flow with the +// Smallstep Attestation CA. +func (ac *attestationClient) Attest(ctx context.Context) ([]*x509.Certificate, error) { + return ac.c.Attest(ctx, ac.t, ac.ek, ac.ak) +} + +func (k *TPMKMS) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1.CreateAttestationResponse, error) { + if req.Name == "" { + return nil, errors.New("createAttestationRequest 'name' cannot be empty") + } + + properties, err := parseNameURI(req.Name) + if err != nil { + return nil, fmt.Errorf("failed parsing %q: %w", req.Name, err) + } + + ctx := context.Background() + key, err := k.tpm.GetKey(ctx, properties.name) + if err != nil { + return nil, err + } + + if !key.WasAttested() { + return nil, fmt.Errorf("key %q was not attested", key.Name()) + } + + ak, err := k.tpm.GetAK(ctx, key.AttestedBy()) + if err != nil { + return nil, fmt.Errorf("failed getting AK for key %q: %w", key.Name(), err) + } + + eks, err := k.tpm.GetEKs(ctx) // TODO(hs): control the EK used as the caller of this method? + if err != nil { + return nil, fmt.Errorf("failed getting EKs: %w", err) + } + ek := getPreferredEK(eks) + ekPublic := ek.Public() + ekKeyID, err := generateKeyID(ekPublic) + if err != nil { + return nil, fmt.Errorf("failed getting EK public key ID: %w", err) + } + ekKeyURL := ekURL(ekKeyID) + + // check if the derived EK URI fingerprint representation matches the provided + // permanent identifier value. The current implementation requires the EK URI to + // be used as the AK identity, so an error is returned if there's no match. This + // could be changed in the future, so that another attestation flow takes place, + // instead, for example. + if k.permanentIdentifier != "" && ekKeyURL.String() != k.permanentIdentifier { + return nil, fmt.Errorf("the provided permanent identifier %q does not match the EK URL %q", k.permanentIdentifier, ekKeyURL.String()) + } + + // check if a (valid) AK certificate (chain) is available. Perform attestation flow + // otherwise. If an AK certificate is available, but not considered valid, e.g. due + // to it not having the right identity, a new attestation flow will be performed and + // the old certificate (chain) will be overwritten with the result of that flow. + akChain := ak.CertificateChain() + if len(akChain) == 0 || !hasValidIdentity(ak, ekKeyURL) { + var ac apiv1.AttestationClient + if req.AttestationClient != nil { + // TODO(hs): check if it makes sense to have this; it doesn't capture all + // behavior of the built-in attestorClient, but at least it does provide + // a basic extension point for other ways of performing attestation that + // might be useful for testing or attestation flows against other systems. + // For it to be truly useful, the logic for determining the AK identity + // would have to be updated too, though. + ac = req.AttestationClient + } else { + ac, err = k.newAttestorClient(ek, ak) + if err != nil { + return nil, fmt.Errorf("failed creating attestor client: %w", err) + } + } + // perform the attestation flow with a (remote) attestation CA + if akChain, err = ac.Attest(ctx); err != nil { + return nil, fmt.Errorf("failed performing AK attestation: %w", err) + } + // store the result with the AK, so that it can be reused for future + // attestations. + if err := ak.SetCertificateChain(ctx, akChain); err != nil { + return nil, fmt.Errorf("failed storing AK certificate chain: %w", err) + } + } + + // when a new certificate was issued for the AK, it is possible the + // certificate that was issued doesn't include the expected and/or required + // identity, so this is checked before continuing. + if !hasValidIdentity(ak, ekKeyURL) { + return nil, fmt.Errorf("AK certificate (chain) not valid for EK %q", ekKeyURL) + } + + signer, err := key.Signer(ctx) + if err != nil { + return nil, fmt.Errorf("failed getting signer for key %q: %w", properties.name, err) + } + + params, err := key.CertificationParameters(ctx) + if err != nil { + return nil, fmt.Errorf("failed getting key certification parameters for %q: %w", key.Name(), err) + } + + // prepare the response to return + akCert := akChain[0] + permanentIdentifier := ekKeyURL.String() // NOTE: should always match the valid value of the AK identity (for now) + return &apiv1.CreateAttestationResponse{ + Certificate: akCert, // certificate for the AK that attested the key + CertificateChain: akChain, // chain for the AK that attested the key, including the leaf + PublicKey: signer.Public(), // returns the public key of the attested key + CertificationParameters: &apiv1.CertificationParameters{ // key certification parameters + Public: params.Public, + CreateData: params.CreateData, + CreateAttestation: params.CreateAttestation, + CreateSignature: params.CreateSignature, + }, + PermanentIdentifier: permanentIdentifier, + }, nil +} + +// Close releases the connection to the TPM. +func (k *TPMKMS) Close() (err error) { + return +} + +// getPreferredEK returns the first RSA TPM EK found. If no RSA +// EK exists, it returns the first ECDSA EK found. +func getPreferredEK(eks []*tpm.EK) (ek *tpm.EK) { + var fallback *tpm.EK + for _, ek = range eks { + if _, isRSA := ek.Public().(*rsa.PublicKey); isRSA { + return + } + if fallback == nil { + fallback = ek + } + } + return fallback +} + +// hasValidIdentity indicates if the AK has an associated certificate +// that includes a valid identity. Currently we only consider certificates +// that encode the TPM EK public key ID as one of its URI SANs, which is +// the default behavior of the Smallstep Attestation CA. +func hasValidIdentity(ak *tpm.AK, ekURL *url.URL) bool { + chain := ak.CertificateChain() + if len(chain) == 0 { + return false + } + akCert := chain[0] + + // TODO(hs): before continuing, add check if the cert is still valid? + + // the Smallstep Attestation CA will issue AK certifiates that + // contain the EK public key ID encoded as an URN by default. + for _, u := range akCert.URIs { + if ekURL.String() == u.String() { + return true + } + } + + // TODO(hs): we could consider checking other values to contain + // a usable identity too. + + return false +} + +// generateKeyID generates a key identifier from the +// SHA256 hash of the public key. +func generateKeyID(pub crypto.PublicKey) ([]byte, error) { + b, err := x509.MarshalPKIXPublicKey(pub) + if err != nil { + return nil, fmt.Errorf("error marshaling public key: %w", err) + } + hash := sha256.Sum256(b) + return hash[:], nil +} + +// ekURL generates an EK URI containing the encoded key identifier +// for the EK. +func ekURL(keyID []byte) *url.URL { + return &url.URL{ + Scheme: "urn", + Opaque: "ek:sha256:" + base64.StdEncoding.EncodeToString(keyID), + } +} + +var _ apiv1.KeyManager = (*TPMKMS)(nil) +var _ apiv1.Attester = (*TPMKMS)(nil) +var _ apiv1.CertificateManager = (*TPMKMS)(nil) +var _ apiv1.CertificateChainManager = (*TPMKMS)(nil) +var _ apiv1.AttestationClient = (*attestationClient)(nil) diff --git a/kms/tpmkms/tpmkms_simulator_test.go b/kms/tpmkms/tpmkms_simulator_test.go new file mode 100644 index 00000000..2dd30520 --- /dev/null +++ b/kms/tpmkms/tpmkms_simulator_test.go @@ -0,0 +1,1836 @@ +//go:build tpmsimulator +// +build tpmsimulator + +package tpmkms + +import ( + "context" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smallstep/go-attestation/attest" + + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/kms/apiv1" + "go.step.sm/crypto/minica" + tpmp "go.step.sm/crypto/tpm" + "go.step.sm/crypto/tpm/simulator" + "go.step.sm/crypto/tpm/storage" +) + +type newSimulatedTPMOption func(t *testing.T, tpm *tpmp.TPM) + +func withAK(name string) newSimulatedTPMOption { + return func(t *testing.T, tpm *tpmp.TPM) { + t.Helper() + _, err := tpm.CreateAK(context.Background(), name) + require.NoError(t, err) + } +} + +func withKey(name string) newSimulatedTPMOption { + return func(t *testing.T, tpm *tpmp.TPM) { + t.Helper() + config := tpmp.CreateKeyConfig{ + Algorithm: "RSA", + Size: 1024, + } + _, err := tpm.CreateKey(context.Background(), name, config) + require.NoError(t, err) + } +} + +func newSimulatedTPM(t *testing.T, opts ...newSimulatedTPMOption) *tpmp.TPM { + t.Helper() + tmpDir := t.TempDir() + tpm, err := tpmp.New(withSimulator(t), tpmp.WithStore(storage.NewDirstore(tmpDir))) + require.NoError(t, err) + for _, applyTo := range opts { + applyTo(t, tpm) + } + return tpm +} + +func withSimulator(t *testing.T) tpmp.NewTPMOption { + t.Helper() + var sim simulator.Simulator + t.Cleanup(func() { + if sim == nil { + return + } + err := sim.Close() + require.NoError(t, err) + }) + sim = simulator.New() + err := sim.Open() + require.NoError(t, err) + return tpmp.WithSimulator(sim) +} + +func TestTPMKMS_CreateKey(t *testing.T) { + tpmWithAK := newSimulatedTPM(t, withAK("ak1")) + type fields struct { + tpm *tpmp.TPM + } + type args struct { + req *apiv1.CreateKeyRequest + } + tests := []struct { + name string + fields fields + args args + assertFunc assert.ValueAssertionFunc + expErr error + }{ + { + name: "ok/key", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=key1", + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 1024, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + if assert.NotNil(t, r) { + assert.Equal(t, "tpmkms:name=key1", r.Name) + assert.Equal(t, "tpmkms:name=key1", r.CreateSignerRequest.SigningKey) + if assert.NotNil(t, r.CreateSignerRequest.Signer) { + assert.Implements(t, (*crypto.Signer)(nil), r.CreateSignerRequest.Signer) + } + return true + } + } + return false + }, + }, + { + name: "ok/attested-key", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=key2;attest-by=ak1", + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 1024, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + if assert.NotNil(t, r) { + assert.Equal(t, "tpmkms:name=key2;attest-by=ak1", r.Name) + assert.Equal(t, "tpmkms:name=key2;attest-by=ak1", r.CreateSignerRequest.SigningKey) + if assert.NotNil(t, r.CreateSignerRequest.Signer) { + assert.Implements(t, (*crypto.Signer)(nil), r.CreateSignerRequest.Signer) + } + return true + } + } + return false + }, + }, + { + name: "ok/ak2", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=ak2;ak=true", + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + if assert.NotNil(t, r) { + assert.Equal(t, "tpmkms:name=ak2;ak=true", r.Name) + assert.Equal(t, apiv1.CreateSignerRequest{}, r.CreateSignerRequest) + return true + } + } + return false + }, + }, + { + name: "ok/ecdsa-key", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=ecdsa-key", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + if assert.NotNil(t, r) { + assert.Equal(t, "tpmkms:name=ecdsa-key", r.Name) + assert.Equal(t, "tpmkms:name=ecdsa-key", r.CreateSignerRequest.SigningKey) + if assert.NotNil(t, r.CreateSignerRequest.Signer) { + assert.Implements(t, (*crypto.Signer)(nil), r.CreateSignerRequest.Signer) + } + return true + } + } + return false + }, + }, + { + name: "fail/empty-name", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "", + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + return assert.Nil(t, r) + } + return false + }, + expErr: errors.New("createKeyRequest 'name' cannot be empty"), + }, + { + name: "fail/negative-bits", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=key1", + Bits: -1, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + return assert.Nil(t, r) + } + return false + }, + expErr: errors.New("createKeyRequest 'bits' cannot be negative"), + }, + { + name: "fail/ak-cannot-be-attested", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=akx;ak=true;attest-by=ak1", + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + return assert.Nil(t, r) + } + return false + }, + expErr: fmt.Errorf(`failed parsing "tpmkms:name=akx;ak=true;attest-by=ak1": "ak" and "attest-by" are mutually exclusive`), + }, + { + name: "fail/invalid-algorithm", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=key1", + SignatureAlgorithm: apiv1.SignatureAlgorithm(-1), + Bits: 1024, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + return assert.Nil(t, r) + } + return false + }, + expErr: errors.New(`TPMKMS does not support signature algorithm "unknown(-1)"`), + }, + { + name: "fail/ecdsa-ak", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=invalidAK;ak=true", + SignatureAlgorithm: apiv1.ECDSAWithSHA256, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + return assert.Nil(t, r) + } + return false + }, + expErr: errors.New(`AKs must be RSA keys`), + }, + { + name: "fail/ak-3072-bits", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=invalidAK;ak=true", + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 3072, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + return assert.Nil(t, r) + } + return false + }, + expErr: errors.New(`creating 3072 bit AKs is not supported; AKs must be RSA 2048 bits`), + }, + { + name: "fail/key-exists", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=key1", + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 1024, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + return assert.Nil(t, r) + } + return false + }, + expErr: errors.New(`failed creating key "key1": already exists`), + }, + { + name: "fail/attested-key-exists", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=key2;attest-by=ak1", + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 1024, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + return assert.Nil(t, r) + } + return false + }, + expErr: errors.New(`failed creating key "key2": already exists`), + }, + { + name: "fail/ak2-exists", + fields: fields{ + tpm: tpmWithAK, + }, + args: args{ + req: &apiv1.CreateKeyRequest{ + Name: "tpmkms:name=ak2;ak=true", + SignatureAlgorithm: apiv1.SHA256WithRSA, + Bits: 2048, + }, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &apiv1.CreateKeyResponse{}, i1) { + r, _ := i1.(*apiv1.CreateKeyResponse) + return assert.Nil(t, r) + } + return false + }, + expErr: errors.New(`failed creating AK "ak2": already exists`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &TPMKMS{ + tpm: tt.fields.tpm, + } + got, err := k.CreateKey(tt.args.req) + if tt.expErr != nil { + assert.EqualError(t, err, tt.expErr.Error()) + return + } + + assert.NoError(t, err) + assert.True(t, tt.assertFunc(t, got)) + }) + } +} + +func TestTPMKMS_CreateSigner(t *testing.T) { + tpmWithKey := newSimulatedTPM(t, withKey("key1")) + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + require.NoError(t, err) + type fields struct { + tpm *tpmp.TPM + } + type args struct { + req *apiv1.CreateSignerRequest + } + tests := []struct { + name string + fields fields + args args + expErr error + }{ + { + name: "ok/signer", + fields: fields{ + tpm: tpmWithKey, + }, + args: args{ + req: &apiv1.CreateSignerRequest{ + Signer: key, + }, + }, + }, + { + name: "ok/signing-key", + fields: fields{ + tpm: tpmWithKey, + }, + args: args{ + req: &apiv1.CreateSignerRequest{ + SigningKey: "tpmkms:name=key1", + }, + }, + }, + { + name: "fail/empty", + fields: fields{ + tpm: tpmWithKey, + }, + args: args{ + req: &apiv1.CreateSignerRequest{ + SigningKey: "", + }, + }, + expErr: errors.New("createSignerRequest 'signingKey' cannot be empty"), + }, + { + name: "fail/ak", + fields: fields{ + tpm: tpmWithKey, + }, + args: args{ + req: &apiv1.CreateSignerRequest{ + SigningKey: "tpmkms:name=ak1;ak=true", + }, + }, + expErr: errors.New("signing with an AK currently not supported"), + }, + { + name: "fail/unknown-key", + fields: fields{ + tpm: tpmWithKey, + }, + args: args{ + req: &apiv1.CreateSignerRequest{ + SigningKey: "tpmkms:name=unknown-key", + }, + }, + expErr: fmt.Errorf(`failed getting key "unknown-key": not found`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &TPMKMS{ + tpm: tt.fields.tpm, + } + got, err := k.CreateSigner(tt.args.req) + if tt.expErr != nil { + assert.EqualError(t, err, tt.expErr.Error()) + return + } + + assert.NoError(t, err) + assert.NotNil(t, got) + }) + } +} + +func TestTPMKMS_GetPublicKey(t *testing.T) { + tpmWithKey := newSimulatedTPM(t, withKey("key1")) + type fields struct { + tpm *tpmp.TPM + } + type args struct { + req *apiv1.GetPublicKeyRequest + } + tests := []struct { + name string + fields fields + args args + want crypto.PublicKey + expErr error + }{ + { + name: "ok", + fields: fields{ + tpm: tpmWithKey, + }, + args: args{ + req: &apiv1.GetPublicKeyRequest{ + Name: "tpmkms:name=key1", + }, + }, + }, + { + name: "fail/empty", + fields: fields{ + tpm: tpmWithKey, + }, + args: args{ + req: &apiv1.GetPublicKeyRequest{ + Name: "", + }, + }, + expErr: errors.New("getPublicKeyRequest 'name' cannot be empty"), + }, + { + name: "fail/ak", + fields: fields{ + tpm: tpmWithKey, + }, + args: args{ + req: &apiv1.GetPublicKeyRequest{ + Name: "tpmkms:name=ak1;ak=true", + }, + }, + expErr: errors.New("retrieving AK public key currently not supported"), + }, + { + name: "fail/unknown-key", + fields: fields{ + tpm: tpmWithKey, + }, + args: args{ + req: &apiv1.GetPublicKeyRequest{ + Name: "tpmkms:name=unknown-key", + }, + }, + expErr: fmt.Errorf(`failed getting key "unknown-key": not found`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &TPMKMS{ + tpm: tt.fields.tpm, + } + got, err := k.GetPublicKey(tt.args.req) + if tt.expErr != nil { + assert.EqualError(t, err, tt.expErr.Error()) + return + } + + assert.NoError(t, err) + assert.NotNil(t, got) + }) + } +} + +func TestTPMKMS_LoadCertificate(t *testing.T) { + ctx := context.Background() + tpm := newSimulatedTPM(t) + config := tpmp.CreateKeyConfig{ + Algorithm: "RSA", + Size: 1024, + } + key, err := tpm.CreateKey(ctx, "key1", config) + require.NoError(t, err) + ak, err := tpm.CreateAK(ctx, "ak1") + require.NoError(t, err) + _, err = tpm.CreateKey(ctx, "keyWithoutCertificate", config) + require.NoError(t, err) + _, err = tpm.CreateAK(ctx, "akWithoutCertificate") + require.NoError(t, err) + ca, err := minica.New( + minica.WithGetSignerFunc( + func() (crypto.Signer, error) { + return keyutil.GenerateSigner("RSA", "", 2048) + }, + ), + ) + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + publicKey := signer.Public() + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testkey", + }, + PublicKey: publicKey, + } + cert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, cert) + err = key.SetCertificateChain(ctx, []*x509.Certificate{cert, ca.Intermediate}) + require.NoError(t, err) + akPub := ak.Public() + require.Implements(t, (*crypto.PublicKey)(nil), akPub) + template = &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testak", + }, + PublicKey: akPub, + } + akCert, err := ca.Sign(template) + err = ak.SetCertificateChain(ctx, []*x509.Certificate{akCert, ca.Intermediate}) + require.NoError(t, err) + type fields struct { + tpm *tpmp.TPM + } + type args struct { + req *apiv1.LoadCertificateRequest + } + tests := []struct { + name string + fields fields + args args + want *x509.Certificate + expErr error + }{ + { + name: "ok/ak", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateRequest{ + Name: "tpmkms:name=ak1;ak=true", + }, + }, + want: akCert, + }, + { + name: "ok/key", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateRequest{ + Name: "tpmkms:name=key1", + }, + }, + want: cert, + }, + { + name: "fail/empty", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateRequest{ + Name: "", + }, + }, + expErr: errors.New("loadCertificateRequest 'name' cannot be empty"), + }, + { + name: "fail/unknown-ak", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateRequest{ + Name: "tpmkms:name=unknown-ak;ak=true", + }, + }, + expErr: fmt.Errorf(`failed getting AK "unknown-ak": not found`), + }, + { + name: "fail/unknown-key", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateRequest{ + Name: "tpmkms:name=unknown-key", + }, + }, + expErr: fmt.Errorf(`failed getting key "unknown-key": not found`), + }, + { + name: "fail/ak-without-certificate", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateRequest{ + Name: "tpmkms:name=akWithoutCertificate;ak=true", + }, + }, + expErr: fmt.Errorf(`failed getting certificate chain for "akWithoutCertificate": no certificate chain stored`), + }, + { + name: "fail/key-without-certificate", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateRequest{ + Name: "tpmkms:name=keyWithoutCertificate", + }, + }, + expErr: fmt.Errorf(`failed getting certificate chain for "keyWithoutCertificate": no certificate chain stored`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &TPMKMS{ + tpm: tt.fields.tpm, + } + got, err := k.LoadCertificate(tt.args.req) + if tt.expErr != nil { + assert.EqualError(t, err, tt.expErr.Error()) + return + } + + assert.NoError(t, err) + if assert.NotNil(t, got) { + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestTPMKMS_LoadCertificateChain(t *testing.T) { + ctx := context.Background() + tpm := newSimulatedTPM(t) + config := tpmp.CreateKeyConfig{ + Algorithm: "RSA", + Size: 1024, + } + key, err := tpm.CreateKey(ctx, "key1", config) + require.NoError(t, err) + ak, err := tpm.CreateAK(ctx, "ak1") + require.NoError(t, err) + _, err = tpm.CreateKey(ctx, "keyWithoutCertificate", config) + require.NoError(t, err) + _, err = tpm.CreateAK(ctx, "akWithoutCertificate") + require.NoError(t, err) + ca, err := minica.New( + minica.WithGetSignerFunc( + func() (crypto.Signer, error) { + return keyutil.GenerateSigner("RSA", "", 2048) + }, + ), + ) + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + publicKey := signer.Public() + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testkey", + }, + PublicKey: publicKey, + } + cert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, cert) + err = key.SetCertificateChain(ctx, []*x509.Certificate{cert, ca.Intermediate}) + require.NoError(t, err) + akPub := ak.Public() + require.Implements(t, (*crypto.PublicKey)(nil), akPub) + template = &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testak", + }, + PublicKey: akPub, + } + akCert, err := ca.Sign(template) + err = ak.SetCertificateChain(ctx, []*x509.Certificate{akCert, ca.Intermediate}) + require.NoError(t, err) + type fields struct { + tpm *tpmp.TPM + } + type args struct { + req *apiv1.LoadCertificateChainRequest + } + tests := []struct { + name string + fields fields + args args + want []*x509.Certificate + expErr error + }{ + { + name: "ok/ak", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateChainRequest{ + Name: "tpmkms:name=ak1;ak=true", + }, + }, + want: []*x509.Certificate{ + akCert, + ca.Intermediate, + }, + }, + { + name: "ok/key", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateChainRequest{ + Name: "tpmkms:name=key1", + }, + }, + want: []*x509.Certificate{ + cert, + ca.Intermediate, + }, + }, + { + name: "fail/empty", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateChainRequest{ + Name: "", + }, + }, + expErr: errors.New("loadCertificateChainRequest 'name' cannot be empty"), + }, + { + name: "fail/unknown-ak", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateChainRequest{ + Name: "tpmkms:name=unknown-ak;ak=true", + }, + }, + expErr: fmt.Errorf(`failed getting AK "unknown-ak": not found`), + }, + { + name: "fail/unknown-key", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateChainRequest{ + Name: "tpmkms:name=unknown-key", + }, + }, + expErr: fmt.Errorf(`failed getting key "unknown-key": not found`), + }, + { + name: "fail/ak-without-certificate", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateChainRequest{ + Name: "tpmkms:name=akWithoutCertificate;ak=true", + }, + }, + expErr: fmt.Errorf(`failed getting certificate chain for "akWithoutCertificate": no certificate chain stored`), + }, + { + name: "fail/key-without-certificate", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.LoadCertificateChainRequest{ + Name: "tpmkms:name=keyWithoutCertificate", + }, + }, + expErr: fmt.Errorf(`failed getting certificate chain for "keyWithoutCertificate": no certificate chain stored`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &TPMKMS{ + tpm: tt.fields.tpm, + } + got, err := k.LoadCertificateChain(tt.args.req) + if tt.expErr != nil { + assert.EqualError(t, err, tt.expErr.Error()) + return + } + + assert.NoError(t, err) + if assert.NotNil(t, got) { + assert.Equal(t, tt.want, got) + } + }) + } +} + +func TestTPMKMS_StoreCertificate(t *testing.T) { + ctx := context.Background() + tpm := newSimulatedTPM(t) + config := tpmp.CreateKeyConfig{ + Algorithm: "RSA", + Size: 1024, + } + key, err := tpm.CreateKey(ctx, "key1", config) + require.NoError(t, err) + ak, err := tpm.CreateAK(ctx, "ak1") + require.NoError(t, err) + ca, err := minica.New( + minica.WithGetSignerFunc( + func() (crypto.Signer, error) { + return keyutil.GenerateSigner("RSA", "", 2048) + }, + ), + ) + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + publicKey := signer.Public() + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testkey", + }, + PublicKey: publicKey, + } + cert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, cert) + anotherPublicKey, _, err := keyutil.GenerateDefaultKeyPair() + require.NoError(t, err) + template = &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testanotherkey", + }, + PublicKey: anotherPublicKey, + } + anotherCert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, anotherCert) + akPub := ak.Public() + require.Implements(t, (*crypto.PublicKey)(nil), akPub) + template = &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testak", + }, + PublicKey: akPub, + } + akCert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, akCert) + type fields struct { + tpm *tpmp.TPM + } + type args struct { + req *apiv1.StoreCertificateRequest + } + tests := []struct { + name string + fields fields + args args + expErr error + }{ + { + name: "ok/ak", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateRequest{ + Name: "tpmkms:name=ak1;ak=true", + Certificate: akCert, + }, + }, + }, + { + name: "ok/key", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateRequest{ + Name: "tpmkms:name=key1", + Certificate: cert, + }, + }, + }, + { + name: "fail/empty", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateRequest{ + Name: "", + }, + }, + expErr: errors.New("storeCertificateRequest 'name' cannot be empty"), + }, + { + name: "fail/unknown-ak", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateRequest{ + Name: "tpmkms:name=unknown-ak;ak=true", + Certificate: akCert, + }, + }, + expErr: fmt.Errorf(`failed getting AK "unknown-ak": not found`), + }, + { + name: "fail/unknown-key", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateRequest{ + Name: "tpmkms:name=unknown-key", + Certificate: cert, + }, + }, + expErr: fmt.Errorf(`failed getting key "unknown-key": not found`), + }, + { + name: "fail/wrong-certificate-for-ak", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateRequest{ + Name: "tpmkms:name=ak1;ak=true", + Certificate: anotherCert, + }, + }, + expErr: errors.New(`failed storing certificate for AK "ak1": AK public key does not match the leaf certificate public key`), + }, + { + name: "fail/wrong-certificate-for-key", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateRequest{ + Name: "tpmkms:name=key1", + Certificate: anotherCert, + }, + }, + expErr: errors.New(`failed storing certificate for key "key1": public key does not match the leaf certificate public key`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &TPMKMS{ + tpm: tt.fields.tpm, + } + err := k.StoreCertificate(tt.args.req) + if tt.expErr != nil { + assert.EqualError(t, err, tt.expErr.Error()) + return + } + + assert.NoError(t, err) + }) + } +} + +func TestTPMKMS_StoreCertificateChain(t *testing.T) { + ctx := context.Background() + tpm := newSimulatedTPM(t) + config := tpmp.CreateKeyConfig{ + Algorithm: "RSA", + Size: 1024, + } + key, err := tpm.CreateKey(ctx, "key1", config) + require.NoError(t, err) + ak, err := tpm.CreateAK(ctx, "ak1") + require.NoError(t, err) + ca, err := minica.New( + minica.WithGetSignerFunc( + func() (crypto.Signer, error) { + return keyutil.GenerateSigner("RSA", "", 2048) + }, + ), + ) + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + publicKey := signer.Public() + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testkey", + }, + PublicKey: publicKey, + } + cert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, cert) + anotherPublicKey, _, err := keyutil.GenerateDefaultKeyPair() + require.NoError(t, err) + template = &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testanotherkey", + }, + PublicKey: anotherPublicKey, + } + anotherCert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, anotherCert) + akPub := ak.Public() + require.Implements(t, (*crypto.PublicKey)(nil), akPub) + template = &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testak", + }, + PublicKey: akPub, + } + akCert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, akCert) + type fields struct { + tpm *tpmp.TPM + } + type args struct { + req *apiv1.StoreCertificateChainRequest + } + tests := []struct { + name string + fields fields + args args + expErr error + }{ + { + name: "ok/ak", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateChainRequest{ + Name: "tpmkms:name=ak1;ak=true", + CertificateChain: []*x509.Certificate{akCert, ca.Intermediate}, + }, + }, + }, + { + name: "ok/key", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateChainRequest{ + Name: "tpmkms:name=key1", + CertificateChain: []*x509.Certificate{cert, ca.Intermediate}, + }, + }, + }, + { + name: "fail/empty", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateChainRequest{ + Name: "", + }, + }, + expErr: errors.New("storeCertificateChainRequest 'name' cannot be empty"), + }, + { + name: "fail/unknown-ak", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateChainRequest{ + Name: "tpmkms:name=unknown-ak;ak=true", + CertificateChain: []*x509.Certificate{akCert, ca.Intermediate}, + }, + }, + expErr: fmt.Errorf(`failed getting AK "unknown-ak": not found`), + }, + { + name: "fail/unknown-key", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateChainRequest{ + Name: "tpmkms:name=unknown-key", + CertificateChain: []*x509.Certificate{cert, ca.Intermediate}, + }, + }, + expErr: fmt.Errorf(`failed getting key "unknown-key": not found`), + }, + { + name: "fail/wrong-certificate-for-ak", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateChainRequest{ + Name: "tpmkms:name=ak1;ak=true", + CertificateChain: []*x509.Certificate{anotherCert, ca.Intermediate}, + }, + }, + expErr: errors.New(`failed storing certificate for AK "ak1": AK public key does not match the leaf certificate public key`), + }, + { + name: "fail/wrong-certificate-for-key", + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.StoreCertificateChainRequest{ + Name: "tpmkms:name=key1", + CertificateChain: []*x509.Certificate{anotherCert, ca.Intermediate}, + }, + }, + expErr: errors.New(`failed storing certificate for key "key1": public key does not match the leaf certificate public key`), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + k := &TPMKMS{ + tpm: tt.fields.tpm, + } + err := k.StoreCertificateChain(tt.args.req) + if tt.expErr != nil { + assert.EqualError(t, err, tt.expErr.Error()) + return + } + + assert.NoError(t, err) + }) + } +} + +// TODO(hs): dedupe these structs by creating some shared helper +// functions for running a fake attestation ca tpm. +type tpmInfo struct { + Version attest.TPMVersion `json:"version,omitempty"` + Manufacturer string `json:"manufacturer,omitempty"` + Model string `json:"model,omitempty"` + FirmwareVersion string `json:"firmwareVersion,omitempty"` +} + +type attestationParameters struct { + Public []byte `json:"public,omitempty"` + UseTCSDActivationFormat bool `json:"useTCSDActivationFormat,omitempty"` + CreateData []byte `json:"createData,omitempty"` + CreateAttestation []byte `json:"createAttestation,omitempty"` + CreateSignature []byte `json:"createSignature,omitempty"` +} + +type attestationRequest struct { + TPMInfo tpmInfo `json:"tpmInfo"` + EK []byte `json:"ek,omitempty"` + EKCerts [][]byte `json:"ekCerts,omitempty"` + AKCert []byte `json:"akCert,omitempty"` + AttestParams attestationParameters `json:"params,omitempty"` +} + +type attestationResponse struct { + Credential []byte `json:"credential"` + Secret []byte `json:"secret"` // encrypted secret +} + +type secretRequest struct { + Secret []byte `json:"secret"` // decrypted secret +} + +type secretResponse struct { + CertificateChain [][]byte `json:"chain"` +} + +type customAttestationClient struct { + chain []*x509.Certificate +} + +func (c *customAttestationClient) Attest(context.Context) ([]*x509.Certificate, error) { + return c.chain, nil +} + +func TestTPMKMS_CreateAttestation(t *testing.T) { + ctx := context.Background() + tpm := newSimulatedTPM(t) + eks, err := tpm.GetEKs(ctx) + require.NoError(t, err) + ek := getPreferredEK(eks) + ekKeyID, err := generateKeyID(ek.Public()) + require.NoError(t, err) + ekKeyURL := ekURL(ekKeyID) + config := tpmp.AttestKeyConfig{ + Algorithm: "RSA", + Size: 1024, + QualifyingData: []byte{1, 2, 3, 4}, + } + ca, err := minica.New( + minica.WithGetSignerFunc( + func() (crypto.Signer, error) { + return keyutil.GenerateSigner("RSA", "", 2048) + }, + ), + ) + type fields struct { + tpm *tpmp.TPM + attestationCABaseURL string + attestationCARootFile string + attestationCAInsecure bool + permanentIdentifier string + } + type args struct { + req *apiv1.CreateAttestationRequest + } + type test struct { + server *httptest.Server + fields fields + args args + want *apiv1.CreateAttestationResponse + expErr error + } + tests := map[string]func(t *testing.T) test{ + "fail/empty-name": func(t *testing.T) test { + return test{ + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "", + }, + }, + expErr: errors.New("createAttestationRequest 'name' cannot be empty"), + } + }, + "fail/ak-attestby-mutually-exclusive": func(t *testing.T) test { + return test{ + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=keyx;ak=true;attest-by=ak1", + }, + }, + expErr: errors.New(`failed parsing "tpmkms:name=keyx;ak=true;attest-by=ak1": "ak" and "attest-by" are mutually exclusive`), + } + }, + "fail/unknown-key": func(t *testing.T) test { + return test{ + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=keyx", + }, + }, + expErr: errors.New(`failed getting key "keyx": not found`), + } + }, + "fail/non-attested-key": func(t *testing.T) test { + createConfig := tpmp.CreateKeyConfig{Algorithm: "RSA", Size: 1024} + _, err = tpm.CreateKey(ctx, "nonAttestedKey", createConfig) + return test{ + fields: fields{ + tpm: tpm, + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=nonAttestedKey", + }, + }, + expErr: errors.New(`key "nonAttestedKey" was not attested`), + } + }, + "fail/non-matching-permanent-identifier": func(t *testing.T) test { + _, err = tpm.CreateAK(ctx, "newAKWithoutCert") + require.NoError(t, err) + _, err = tpm.AttestKey(ctx, "newAKWithoutCert", "newkey", config) + require.NoError(t, err) + return test{ + fields: fields{ + tpm: tpm, + permanentIdentifier: "wrong-provided-permanent-identifier", + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=newkey", // newkey was attested by the newAKWithoutCert at creation time + }, + }, + expErr: fmt.Errorf(`the provided permanent identifier "wrong-provided-permanent-identifier" does not match the EK URL %q`, ekKeyURL.String()), + } + }, + "fail/create-attestor-client": func(t *testing.T) test { + _, err = tpm.CreateAK(ctx, "ak2WithoutCert") + require.NoError(t, err) + _, err = tpm.AttestKey(ctx, "ak2WithoutCert", "key3", config) + require.NoError(t, err) + return test{ + fields: fields{ + tpm: tpm, + permanentIdentifier: ekKeyURL.String(), + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=key3", // key3 was attested by the ak2WithoutCert at creation time + }, + }, + expErr: fmt.Errorf(`failed creating attestor client: failed creating attestation client: attestation CA base URL must not be empty`), + } + }, + "fail/attest": func(t *testing.T) test { + _, err = tpm.CreateAK(ctx, "ak3WithoutCert") + require.NoError(t, err) + _, err = tpm.AttestKey(ctx, "ak3WithoutCert", "key4", config) + require.NoError(t, err) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/attest": + w.WriteHeader(http.StatusBadRequest) + default: + t.Errorf("unexpected %q request to %q", r.Method, r.URL) + } + }) + s := httptest.NewServer(handler) + return test{ + server: s, + fields: fields{ + tpm: tpm, + attestationCABaseURL: s.URL, + permanentIdentifier: ekKeyURL.String(), + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=key4", // key4 was attested by the ak3WithoutCert at creation time + }, + }, + expErr: fmt.Errorf(`failed performing AK attestation: failed attesting AK: POST %q failed with HTTP status "400 Bad Request"`, fmt.Sprintf("%s/attest", s.URL)), + } + }, + "fail/set-ak-certificate-chain": func(t *testing.T) test { + ak4WithoutCert, err := tpm.CreateAK(ctx, "ak4WithoutCert") + require.NoError(t, err) + _, err = tpm.AttestKey(ctx, "ak4WithoutCert", "key5", config) + require.NoError(t, err) + params, err := ak4WithoutCert.AttestationParameters(context.Background()) + require.NoError(t, err) + require.NotNil(t, params) + activation := attest.ActivationParameters{ + TPMVersion: attest.TPMVersion20, + EK: ek.Public(), + AK: params, + } + expectedSecret, encryptedCredentials, err := activation.Generate() + require.NoError(t, err) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/attest": + var ar attestationRequest + err := json.NewDecoder(r.Body).Decode(&ar) + require.NoError(t, err) + parsedEK, err := x509.ParsePKIXPublicKey(ar.EK) + require.NoError(t, err) + assert.Equal(t, ek.Public(), parsedEK) + attestParams := attest.AttestationParameters{ + Public: ar.AttestParams.Public, + UseTCSDActivationFormat: ar.AttestParams.UseTCSDActivationFormat, + CreateData: ar.AttestParams.CreateData, + CreateAttestation: ar.AttestParams.CreateAttestation, + CreateSignature: ar.AttestParams.CreateSignature, + } + activationParams := attest.ActivationParameters{ + TPMVersion: ar.TPMInfo.Version, + EK: parsedEK, + AK: attestParams, + } + assert.Equal(t, activation, activationParams) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&attestationResponse{ + Credential: encryptedCredentials.Credential, + Secret: encryptedCredentials.Secret, + }) + case "/secret": + var sr secretRequest + err := json.NewDecoder(r.Body).Decode(&sr) + require.NoError(t, err) + assert.Equal(t, expectedSecret, sr.Secret) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&secretResponse{ + CertificateChain: [][]byte{ + ca.Intermediate.Raw, // No leaf returned + }, + }) + default: + t.Errorf("unexpected %q request to %q", r.Method, r.URL) + } + }) + s := httptest.NewServer(handler) + return test{ + server: s, + fields: fields{ + tpm: tpm, + attestationCABaseURL: s.URL, + permanentIdentifier: ekKeyURL.String(), + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=key5", // key5 was attested by ak3WithoutCert at creation time + }, + }, + want: nil, + expErr: fmt.Errorf(`failed storing AK certificate chain: AK public key does not match the leaf certificate public key`), + } + }, + "fail/ak-certificate-chain-has-invalid-identity": func(t *testing.T) test { + ak5WithoutCert, err := tpm.CreateAK(ctx, "ak5WithoutCert") + require.NoError(t, err) + _, err = tpm.AttestKey(ctx, "ak5WithoutCert", "key6", config) + require.NoError(t, err) + ak5Pub := ak5WithoutCert.Public() + require.Implements(t, (*crypto.PublicKey)(nil), ak5Pub) + template := &x509.Certificate{ // NOTE: missing EK URI SAN + Subject: pkix.Name{ + CommonName: "testinvalidak", + }, + PublicKey: ak5Pub, + } + invalidAKIdentityCert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, invalidAKIdentityCert) + params, err := ak5WithoutCert.AttestationParameters(context.Background()) + require.NoError(t, err) + require.NotNil(t, params) + activation := attest.ActivationParameters{ + TPMVersion: attest.TPMVersion20, + EK: ek.Public(), + AK: params, + } + expectedSecret, encryptedCredentials, err := activation.Generate() + require.NoError(t, err) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/attest": + var ar attestationRequest + err := json.NewDecoder(r.Body).Decode(&ar) + require.NoError(t, err) + parsedEK, err := x509.ParsePKIXPublicKey(ar.EK) + require.NoError(t, err) + assert.Equal(t, ek.Public(), parsedEK) + attestParams := attest.AttestationParameters{ + Public: ar.AttestParams.Public, + UseTCSDActivationFormat: ar.AttestParams.UseTCSDActivationFormat, + CreateData: ar.AttestParams.CreateData, + CreateAttestation: ar.AttestParams.CreateAttestation, + CreateSignature: ar.AttestParams.CreateSignature, + } + activationParams := attest.ActivationParameters{ + TPMVersion: ar.TPMInfo.Version, + EK: parsedEK, + AK: attestParams, + } + assert.Equal(t, activation, activationParams) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&attestationResponse{ + Credential: encryptedCredentials.Credential, + Secret: encryptedCredentials.Secret, + }) + case "/secret": + var sr secretRequest + err := json.NewDecoder(r.Body).Decode(&sr) + require.NoError(t, err) + assert.Equal(t, expectedSecret, sr.Secret) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&secretResponse{ + CertificateChain: [][]byte{ + invalidAKIdentityCert.Raw, // AK certificate without EK URI SAN + ca.Intermediate.Raw, + }, + }) + default: + t.Errorf("unexpected %q request to %q", r.Method, r.URL) + } + }) + s := httptest.NewServer(handler) + return test{ + server: s, + fields: fields{ + tpm: tpm, + attestationCABaseURL: s.URL, + permanentIdentifier: ekKeyURL.String(), + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=key6", // key6 was attested by ak5WithoutCert at creation time + }, + }, + want: nil, + expErr: fmt.Errorf(`AK certificate (chain) not valid for EK %q`, ekKeyURL.String()), + } + }, + "ok": func(t *testing.T) test { + akWithExistingCert, err := tpm.CreateAK(ctx, "akWithExistingCert") + require.NoError(t, err) + key, err := tpm.AttestKey(ctx, "akWithExistingCert", "key1", config) + require.NoError(t, err) + keyParams, err := key.CertificationParameters(ctx) + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + akPub := akWithExistingCert.Public() + require.Implements(t, (*crypto.PublicKey)(nil), akPub) + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testak", + }, + URIs: []*url.URL{ekKeyURL}, + PublicKey: akPub, + } + validAKCert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, validAKCert) + err = akWithExistingCert.SetCertificateChain(ctx, []*x509.Certificate{validAKCert, ca.Intermediate}) + require.NoError(t, err) + return test{ + fields: fields{ + tpm: tpm, + permanentIdentifier: ekKeyURL.String(), + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=key1", // key1 was attested by the akWithExistingCert at creation time + }, + }, + want: &apiv1.CreateAttestationResponse{ + Certificate: validAKCert, + CertificateChain: []*x509.Certificate{validAKCert, ca.Intermediate}, + PublicKey: signer.Public(), + CertificationParameters: &apiv1.CertificationParameters{ + Public: keyParams.Public, + CreateData: keyParams.CreateData, + CreateAttestation: keyParams.CreateAttestation, + CreateSignature: keyParams.CreateSignature, + }, + PermanentIdentifier: ekKeyURL.String(), + }, + expErr: nil, + } + }, + "ok/new-chain": func(t *testing.T) test { + akWithoutCert, err := tpm.CreateAK(ctx, "akWithoutCert") + require.NoError(t, err) + key, err := tpm.AttestKey(ctx, "akWithoutCert", "key2", config) + require.NoError(t, err) + keyParams, err := key.CertificationParameters(ctx) + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + akPubNew := akWithoutCert.Public() + require.Implements(t, (*crypto.PublicKey)(nil), akPubNew) + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testnewak", + }, + URIs: []*url.URL{ekKeyURL}, + PublicKey: akPubNew, + } + newAKCert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, newAKCert) + params, err := akWithoutCert.AttestationParameters(context.Background()) + require.NoError(t, err) + require.NotNil(t, params) + activation := attest.ActivationParameters{ + TPMVersion: attest.TPMVersion20, + EK: ek.Public(), + AK: params, + } + expectedSecret, encryptedCredentials, err := activation.Generate() + require.NoError(t, err) + akChain := [][]byte{ + newAKCert.Raw, + ca.Intermediate.Raw, + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/attest": + var ar attestationRequest + err := json.NewDecoder(r.Body).Decode(&ar) + require.NoError(t, err) + parsedEK, err := x509.ParsePKIXPublicKey(ar.EK) + require.NoError(t, err) + assert.Equal(t, ek.Public(), parsedEK) + attestParams := attest.AttestationParameters{ + Public: ar.AttestParams.Public, + UseTCSDActivationFormat: ar.AttestParams.UseTCSDActivationFormat, + CreateData: ar.AttestParams.CreateData, + CreateAttestation: ar.AttestParams.CreateAttestation, + CreateSignature: ar.AttestParams.CreateSignature, + } + activationParams := attest.ActivationParameters{ + TPMVersion: ar.TPMInfo.Version, + EK: parsedEK, + AK: attestParams, + } + assert.Equal(t, activation, activationParams) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&attestationResponse{ + Credential: encryptedCredentials.Credential, + Secret: encryptedCredentials.Secret, + }) + case "/secret": + var sr secretRequest + err := json.NewDecoder(r.Body).Decode(&sr) + require.NoError(t, err) + assert.Equal(t, expectedSecret, sr.Secret) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&secretResponse{ + CertificateChain: akChain, + }) + default: + t.Errorf("unexpected %q request to %q", r.Method, r.URL) + } + }) + s := httptest.NewServer(handler) + return test{ + server: s, + fields: fields{ + tpm: tpm, + attestationCABaseURL: s.URL, + permanentIdentifier: ekKeyURL.String(), + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=key2", // key2 was attested by akWithoutCert at creation time + }, + }, + want: &apiv1.CreateAttestationResponse{ + Certificate: newAKCert, + CertificateChain: []*x509.Certificate{newAKCert, ca.Intermediate}, + PublicKey: signer.Public(), + CertificationParameters: &apiv1.CertificationParameters{ + Public: keyParams.Public, + CreateData: keyParams.CreateData, + CreateAttestation: keyParams.CreateAttestation, + CreateSignature: keyParams.CreateSignature, + }, + PermanentIdentifier: ekKeyURL.String(), + }, + expErr: nil, + } + }, + "ok/new-chain-with-custom-attestor-client": func(t *testing.T) test { + ak6WithoutCert, err := tpm.CreateAK(ctx, "ak6WithoutCert") + require.NoError(t, err) + key, err := tpm.AttestKey(ctx, "ak6WithoutCert", "key7", config) + require.NoError(t, err) + keyParams, err := key.CertificationParameters(ctx) + require.NoError(t, err) + signer, err := key.Signer(ctx) + require.NoError(t, err) + ak6Pub := ak6WithoutCert.Public() + require.Implements(t, (*crypto.PublicKey)(nil), ak6Pub) + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testak6", + }, + URIs: []*url.URL{ekKeyURL}, + PublicKey: ak6Pub, + } + ak6Cert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, ak6Cert) + return test{ + fields: fields{ + tpm: tpm, + permanentIdentifier: ekKeyURL.String(), + }, + args: args{ + req: &apiv1.CreateAttestationRequest{ + Name: "tpmkms:name=key7", // key7 was attested by ak6WithoutCert at creation time + AttestationClient: &customAttestationClient{ + chain: []*x509.Certificate{ak6Cert, ca.Intermediate}, + }, + }, + }, + want: &apiv1.CreateAttestationResponse{ + Certificate: ak6Cert, + CertificateChain: []*x509.Certificate{ak6Cert, ca.Intermediate}, + PublicKey: signer.Public(), + CertificationParameters: &apiv1.CertificationParameters{ + Public: keyParams.Public, + CreateData: keyParams.CreateData, + CreateAttestation: keyParams.CreateAttestation, + CreateSignature: keyParams.CreateSignature, + }, + PermanentIdentifier: ekKeyURL.String(), + }, + expErr: nil, + } + }, + } + for name, tt := range tests { + tc := tt(t) + t.Run(name, func(t *testing.T) { + k := &TPMKMS{ + tpm: tc.fields.tpm, + attestationCABaseURL: tc.fields.attestationCABaseURL, + attestationCARootFile: tc.fields.attestationCARootFile, + attestationCAInsecure: tc.fields.attestationCAInsecure, + permanentIdentifier: tc.fields.permanentIdentifier, + } + if tc.server != nil { + defer tc.server.Close() + } + got, err := k.CreateAttestation(tc.args.req) + if tc.expErr != nil { + assert.EqualError(t, err, tc.expErr.Error()) + return + } + + assert.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/kms/tpmkms/tpmkms_test.go b/kms/tpmkms/tpmkms_test.go new file mode 100644 index 00000000..f071ec03 --- /dev/null +++ b/kms/tpmkms/tpmkms_test.go @@ -0,0 +1,38 @@ +package tpmkms + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.step.sm/crypto/kms/apiv1" +) + +func TestNew(t *testing.T) { + type args struct { + opts apiv1.Options + } + tests := []struct { + name string + args args + want *TPMKMS + wantErr bool + }{ + {"ok/defaults", args{apiv1.Options{Type: "tpmkms"}}, &TPMKMS{}, false}, + {"ok/uri", args{apiv1.Options{Type: "tpmkms", URI: "tpmkms:device=/dev/tpm0;storage-directory=/tmp/tpmstorage"}}, &TPMKMS{}, false}, + {"fail/uri-scheme", args{apiv1.Options{Type: "tpmkms", URI: "tpmkmz://device=/dev/tpm0"}}, &TPMKMS{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := New(context.Background(), tt.args.opts) + if tt.wantErr { + assert.Error(t, err) + return + } + + if assert.NotNil(t, got) { + assert.NotNil(t, got.tpm) + } + }) + } +} diff --git a/kms/tpmkms/uri.go b/kms/tpmkms/uri.go new file mode 100644 index 00000000..5dabd8de --- /dev/null +++ b/kms/tpmkms/uri.go @@ -0,0 +1,63 @@ +//go:build !notpmkms +// +build !notpmkms + +package tpmkms + +import ( + "errors" + "fmt" + + "go.step.sm/crypto/kms/uri" +) + +type objectProperties struct { + name string + ak bool + attestBy string + qualifyingData []byte +} + +func parseNameURI(nameURI string) (o objectProperties, err error) { + if nameURI == "" { + return o, errors.New("empty URI not supported") + } + var u *uri.URI + var parseErr error + if u, parseErr = uri.ParseWithScheme(Scheme, nameURI); parseErr == nil { + if name := u.Get("name"); name == "" { + if len(u.Values) == 1 { + o.name = u.Opaque + } else { + for k, v := range u.Values { + if len(v) == 1 && v[0] == "" { + o.name = k + break + } + } + } + } else { + o.name = name + } + o.ak = u.GetBool("ak") + o.attestBy = u.Get("attest-by") + if qualifyingData := u.GetEncoded("qualifying-data"); qualifyingData != nil { + o.qualifyingData = qualifyingData + } + + // validation + if o.ak && o.attestBy != "" { + return o, errors.New(`"ak" and "attest-by" are mutually exclusive`) + } + + return + } + + if u, parseErr := uri.Parse(nameURI); parseErr == nil { + if u.Scheme != Scheme { + return o, fmt.Errorf("URI scheme %q is not supported", u.Scheme) + } + } + + o.name = nameURI // assumes there's no other properties encoded; just a name + return +} diff --git a/kms/tpmkms/uri_test.go b/kms/tpmkms/uri_test.go new file mode 100644 index 00000000..650653d4 --- /dev/null +++ b/kms/tpmkms/uri_test.go @@ -0,0 +1,40 @@ +package tpmkms + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func Test_parseNameURI(t *testing.T) { + type args struct { + nameURI string + } + tests := []struct { + name string + args args + wantO objectProperties + wantErr bool + }{ + {"ok/key-without-scheme", args{"key1"}, objectProperties{name: "key1"}, false}, + {"ok/key", args{"tpmkms:name=key1"}, objectProperties{name: "key1"}, false}, + {"ok/key-without-name-key", args{"tpmkms:key1"}, objectProperties{name: "key1"}, false}, + {"ok/key-without-name-key-with-other-properties", args{"tpmkms:key1;attest-by=ak1"}, objectProperties{name: "key1", attestBy: "ak1"}, false}, + {"ok/attested-key", args{"tpmkms:name=key2;attest-by=ak1;qualifying-data=61626364"}, objectProperties{name: "key2", attestBy: "ak1", qualifyingData: []byte{'a', 'b', 'c', 'd'}}, false}, + {"ok/ak", args{"tpmkms:name=ak1;ak=true"}, objectProperties{name: "ak1", ak: true}, false}, + {"fail/empty", args{""}, objectProperties{}, true}, + {"fail/wrong-scheme", args{nameURI: "tpmkmz:name=bla"}, objectProperties{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotO, err := parseNameURI(tt.args.nameURI) + if tt.wantErr { + assert.Error(t, err) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.wantO, gotO) + }) + } +} diff --git a/kms/yubikey/yubikey.go b/kms/yubikey/yubikey.go index bc0263fc..aaf61385 100644 --- a/kms/yubikey/yubikey.go +++ b/kms/yubikey/yubikey.go @@ -324,7 +324,7 @@ func (k *YubiKey) CreateAttestation(req *apiv1.CreateAttestationRequest) (*apiv1 return &apiv1.CreateAttestationResponse{ Certificate: cert, - CertificateChain: []*x509.Certificate{intermediate}, + CertificateChain: []*x509.Certificate{cert, intermediate}, PublicKey: cert.PublicKey, PermanentIdentifier: getSerialNumber(cert), }, nil @@ -490,3 +490,5 @@ func getSerialNumber(cert *x509.Certificate) string { } return "" } + +var _ apiv1.CertificateManager = (*YubiKey)(nil) diff --git a/kms/yubikey/yubikey_test.go b/kms/yubikey/yubikey_test.go index 84e1b46e..14385538 100644 --- a/kms/yubikey/yubikey_test.go +++ b/kms/yubikey/yubikey_test.go @@ -976,7 +976,7 @@ func TestYubiKey_CreateAttestation(t *testing.T) { Name: "yubikey:slot-id=9a", }}, &apiv1.CreateAttestationResponse{ Certificate: yk.attestMap[piv.SlotAuthentication], - CertificateChain: []*x509.Certificate{yk.attestCA.Intermediate}, + CertificateChain: []*x509.Certificate{yk.attestMap[piv.SlotAuthentication], yk.attestCA.Intermediate}, PublicKey: yk.attestMap[piv.SlotAuthentication].PublicKey, PermanentIdentifier: "112233", }, false}, diff --git a/tpm/ak.go b/tpm/ak.go index 7c22496e..a3d2c176 100644 --- a/tpm/ak.go +++ b/tpm/ak.go @@ -67,6 +67,20 @@ func (ak *AK) CertificateChain() []*x509.Certificate { return ak.chain } +// Public returns the AK public key. This is backed +// by a call to the TPM, so it can fail. If it fails, +// nil is returned. +// +// TODO: see improvement described in the private method +// to always return a non-nil crypto.PublicKey. +func (ak *AK) Public() crypto.PublicKey { + pub, err := ak.public(context.Background()) + if err != nil { + return nil + } + return pub +} + // public returns the AK public key. This can fail, because // retrieval relies on the TPM. // diff --git a/tpm/attestation/client.go b/tpm/attestation/client.go new file mode 100644 index 00000000..8810c2b7 --- /dev/null +++ b/tpm/attestation/client.go @@ -0,0 +1,282 @@ +package attestation + +import ( + "bytes" + "context" + "crypto/tls" + "crypto/x509" + "encoding/json" + "fmt" + "net/http" + "net/url" + "os" + "strconv" + "time" + + "go.step.sm/crypto/tpm" + + "github.com/smallstep/go-attestation/attest" +) + +type Client struct { + client *http.Client + baseURL *url.URL +} + +type Options struct { + rootCAs *x509.CertPool + insecure bool +} + +type Option func(o *Options) error + +// WithRootsFile can be used to set the trusted roots when +// setting up a TLS connection. An empty filename will +// be ignored. +func WithRootsFile(filename string) Option { + return func(o *Options) error { + if filename == "" { + return nil + } + data, err := os.ReadFile(filename) + if err != nil { + return fmt.Errorf("failed reading %q: %w", filename, err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(data) { + return fmt.Errorf("failed parsing %q: no certificates found", filename) + } + o.rootCAs = pool + return nil + } +} + +// WithInsecure disables TLS server certificate chain checking. +// In general this shouldn't be used, but it can be of use in +// during development and testing. +func WithInsecure() Option { + return func(o *Options) error { + o.insecure = true + return nil + } +} + +// NewClient creates a new Client that can be used to perform remote +// attestation. +func NewClient(tpmAttestationCABaseURL string, options ...Option) (*Client, error) { + u, err := url.Parse(tpmAttestationCABaseURL) + if err != nil { + return nil, fmt.Errorf("failed parsing attestation CA base URL: %w", err) + } + + opts := &Options{} + for _, o := range options { + if err := o(opts); err != nil { + return nil, fmt.Errorf("failed applying option to attestation client: %w", err) + } + } + + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.TLSClientConfig = &tls.Config{ + RootCAs: opts.rootCAs, + InsecureSkipVerify: opts.insecure, //nolint:gosec // intentional insecure if provided as option + } + + client := &http.Client{ + Timeout: 10 * time.Second, + Transport: transport, + } + + return &Client{ + client: client, + baseURL: u, + }, nil +} + +// Attest performs remote attestation using the AK backed by TPM t. +// +// TODO: support multiple EKs again? Currently selection of the EK is left +// to the caller. +func (ac *Client) Attest(ctx context.Context, t *tpm.TPM, ek *tpm.EK, ak *tpm.AK) ([]*x509.Certificate, error) { + // TODO(hs): what about performing attestation for an existing AK identifier and/or cert, but + // with a different Attestation CA? It seems sensible to enroll with that other Attestation CA, + // but it needs capturing some knowledge about the Attestation CA with the AK (cert). Possible to + // derive that from the intermediate and/or root CA and/or fingerprint, somehow? Or the attestation URI? + + info, err := t.Info(ctx) + if err != nil { + return nil, fmt.Errorf("failed retrieving info from TPM: %w", err) + } + + attestParams, err := ak.AttestationParameters(ctx) + if err != nil { + return nil, fmt.Errorf("failed getting AK attestation parameters: %w", err) + } + + attResp, err := ac.attest(ctx, info, ek, attestParams) + if err != nil { + return nil, fmt.Errorf("failed attesting AK: %w", err) + } + + encryptedCredentials := tpm.EncryptedCredential{ + Credential: attResp.Credential, + Secret: attResp.Secret, + } + + // activate the credential with the TPM + secret, err := ak.ActivateCredential(ctx, encryptedCredentials) + if err != nil { + return nil, fmt.Errorf("failed activating credential: %w", err) + } + + secretResp, err := ac.secret(ctx, secret) + if err != nil { + return nil, fmt.Errorf("failed validating secret: %w", err) + } + + akChain := make([]*x509.Certificate, len(secretResp.CertificateChain)) + for i, certBytes := range secretResp.CertificateChain { + cert, err := x509.ParseCertificate(certBytes) + if err != nil { + return nil, fmt.Errorf("failed parsing certificate: %w", err) + } + akChain[i] = cert + } + + return akChain, nil +} + +type tpmInfo struct { + Version attest.TPMVersion `json:"version,omitempty"` + Manufacturer string `json:"manufacturer,omitempty"` + Model string `json:"model,omitempty"` + FirmwareVersion string `json:"firmwareVersion,omitempty"` +} + +type attestationParameters struct { + Public []byte `json:"public,omitempty"` + UseTCSDActivationFormat bool `json:"useTCSDActivationFormat,omitempty"` + CreateData []byte `json:"createData,omitempty"` + CreateAttestation []byte `json:"createAttestation,omitempty"` + CreateSignature []byte `json:"createSignature,omitempty"` +} + +type attestationRequest struct { + TPMInfo tpmInfo `json:"tpmInfo"` + EKPub []byte `json:"ek,omitempty"` + EKCerts [][]byte `json:"ekCerts,omitempty"` + AKCert []byte `json:"akCert,omitempty"` + AttestParams attestationParameters `json:"params"` +} + +type attestationResponse struct { + Credential []byte `json:"credential"` + Secret []byte `json:"secret"` // encrypted secret +} + +// attest performs the HTTP POST request to the `/attest` endpoint of the +// Attestation CA. +func (ac *Client) attest(ctx context.Context, info *tpm.Info, ek *tpm.EK, attestParams attest.AttestationParameters) (*attestationResponse, error) { + var ekCerts [][]byte + var ekPub []byte + var err error + + // TODO: support multiple EKs again? Currently selection of the EK is left + // to the caller. + if ekCert := ek.Certificate(); ekCert != nil { + ekCerts = append(ekCerts, ekCert.Raw) + } + if ekPub, err = x509.MarshalPKIXPublicKey(ek.Public()); err != nil { + return nil, fmt.Errorf("failed marshaling public key: %w", err) + } + + ar := attestationRequest{ + TPMInfo: tpmInfo{ + Version: attest.TPMVersion20, + Manufacturer: strconv.FormatUint(uint64(info.Manufacturer.ID), 10), + Model: info.VendorInfo, + FirmwareVersion: info.FirmwareVersion.String(), + }, + EKCerts: ekCerts, + EKPub: ekPub, + AttestParams: attestationParameters{ + Public: attestParams.Public, + UseTCSDActivationFormat: attestParams.UseTCSDActivationFormat, + CreateData: attestParams.CreateData, + CreateAttestation: attestParams.CreateAttestation, + CreateSignature: attestParams.CreateSignature, + }, + } + + body, err := json.Marshal(ar) + if err != nil { + return nil, fmt.Errorf("failed marshaling attestation request: %w", err) + } + + attestURL := ac.baseURL.JoinPath("attest").String() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, attestURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed creating POST http request for %q: %w", attestURL, err) + } + + resp, err := ac.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed performing attestation request with Attestation CA %q: %w", attestURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("POST %q failed with HTTP status %q", attestURL, resp.Status) + } + + var attResp attestationResponse + if err := json.NewDecoder(resp.Body).Decode(&attResp); err != nil { + return nil, fmt.Errorf("failed decoding attestation response: %w", err) + } + + return &attResp, nil +} + +type secretRequest struct { + Secret []byte `json:"secret"` // decrypted secret +} + +type secretResponse struct { + CertificateChain [][]byte `json:"chain"` +} + +// secret performs the HTTP POST request to the `/secret` endpoint of the +// Attestation CA. +func (ac *Client) secret(ctx context.Context, secret []byte) (*secretResponse, error) { + sr := secretRequest{ + Secret: secret, + } + + body, err := json.Marshal(sr) + if err != nil { + return nil, fmt.Errorf("failed marshaling secret request: %w", err) + } + + secretURL := ac.baseURL.JoinPath("secret").String() + req, err := http.NewRequestWithContext(ctx, http.MethodPost, secretURL, bytes.NewReader(body)) + if err != nil { + return nil, fmt.Errorf("failed creating POST http request for %q: %w", secretURL, err) + } + + resp, err := ac.client.Do(req) + if err != nil { + return nil, fmt.Errorf("failed performing secret request with attestation CA %q: %w", secretURL, err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("POST %q failed with HTTP status %q", secretURL, resp.Status) + } + + var secretResp secretResponse + if err := json.NewDecoder(resp.Body).Decode(&secretResp); err != nil { + return nil, fmt.Errorf("failed decoding secret response: %w", err) + } + + return &secretResp, nil +} diff --git a/tpm/attestation/client_simulator_test.go b/tpm/attestation/client_simulator_test.go new file mode 100644 index 00000000..ed0cbfe5 --- /dev/null +++ b/tpm/attestation/client_simulator_test.go @@ -0,0 +1,413 @@ +//go:build tpmsimulator +// +build tpmsimulator + +package attestation + +import ( + "context" + "crypto" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/smallstep/go-attestation/attest" + + "go.step.sm/crypto/keyutil" + "go.step.sm/crypto/minica" + "go.step.sm/crypto/tpm" + "go.step.sm/crypto/tpm/simulator" + "go.step.sm/crypto/tpm/storage" +) + +func newSimulatedTPM(t *testing.T) *tpm.TPM { + t.Helper() + tmpDir := t.TempDir() + instance, err := tpm.New(withSimulator(t), tpm.WithStore(storage.NewDirstore(tmpDir))) + require.NoError(t, err) + return instance +} + +func withSimulator(t *testing.T) tpm.NewTPMOption { + t.Helper() + var sim simulator.Simulator + t.Cleanup(func() { + if sim == nil { + return + } + err := sim.Close() + require.NoError(t, err) + }) + sim = simulator.New() + err := sim.Open() + require.NoError(t, err) + return tpm.WithSimulator(sim) +} + +// getPreferredEK returns the first RSA TPM EK found. If no RSA +// EK exists, it returns the first ECDSA EK found. +func getPreferredEK(eks []*tpm.EK) (ek *tpm.EK) { + var fallback *tpm.EK + for _, ek = range eks { + if _, isRSA := ek.Public().(*rsa.PublicKey); isRSA { + return + } + if fallback == nil { + fallback = ek + } + } + return fallback +} + +func mustParseURL(t *testing.T, urlString string) *url.URL { + t.Helper() + u, err := url.Parse(urlString) + require.NoError(t, err) + return u +} + +func TestClient_Attest(t *testing.T) { + ctx := context.Background() + instance := newSimulatedTPM(t) + ak, err := instance.CreateAK(ctx, "ak1") + require.NoError(t, err) + require.NoError(t, err) + eks, err := instance.GetEKs(ctx) + require.NoError(t, err) + ek := getPreferredEK(eks) + ca, err := minica.New( + minica.WithGetSignerFunc( + func() (crypto.Signer, error) { + return keyutil.GenerateSigner("RSA", "", 2048) + }, + ), + ) + require.NoError(t, err) + akPub := ak.Public() + require.Implements(t, (*crypto.PublicKey)(nil), akPub) + template := &x509.Certificate{ + Subject: pkix.Name{ + CommonName: "testak", + }, + PublicKey: akPub, + } + validAKCert, err := ca.Sign(template) + require.NoError(t, err) + require.NotNil(t, validAKCert) + type fields struct { + client *http.Client + baseURL *url.URL + } + type args struct { + ctx context.Context + t *tpm.TPM + ek *tpm.EK + ak *tpm.AK + } + type test struct { + fields fields + server *httptest.Server + args args + want []*x509.Certificate + expErr error + } + tests := map[string]func(t *testing.T) test{ + "ok": func(t *testing.T) test { + params, err := ak.AttestationParameters(context.Background()) + require.NoError(t, err) + require.NotNil(t, params) + activation := attest.ActivationParameters{ + TPMVersion: attest.TPMVersion20, + EK: ek.Public(), + AK: params, + } + expectedSecret, encryptedCredentials, err := activation.Generate() + require.NoError(t, err) + akChain := [][]byte{ + validAKCert.Raw, + ca.Intermediate.Raw, + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/attest": + var ar attestationRequest + err := json.NewDecoder(r.Body).Decode(&ar) + require.NoError(t, err) + parsedEK, err := x509.ParsePKIXPublicKey(ar.EKPub) + require.NoError(t, err) + assert.Equal(t, ek.Public(), parsedEK) + attestParams := attest.AttestationParameters{ + Public: ar.AttestParams.Public, + UseTCSDActivationFormat: ar.AttestParams.UseTCSDActivationFormat, + CreateData: ar.AttestParams.CreateData, + CreateAttestation: ar.AttestParams.CreateAttestation, + CreateSignature: ar.AttestParams.CreateSignature, + } + activationParams := attest.ActivationParameters{ + TPMVersion: ar.TPMInfo.Version, + EK: parsedEK, + AK: attestParams, + } + assert.Equal(t, activation, activationParams) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&attestationResponse{ + Credential: encryptedCredentials.Credential, + Secret: encryptedCredentials.Secret, + }) + case "/secret": + var sr secretRequest + err := json.NewDecoder(r.Body).Decode(&sr) + require.NoError(t, err) + assert.Equal(t, expectedSecret, sr.Secret) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&secretResponse{ + CertificateChain: akChain, + }) + default: + t.Errorf("unexpected %q request to %q", r.Method, r.URL) + } + }) + s := httptest.NewServer(handler) + return test{ + server: s, + fields: fields{ + client: http.DefaultClient, + baseURL: mustParseURL(t, s.URL), + }, + args: args{ + ctx: ctx, + t: instance, + + ek: ek, + ak: ak, + }, + want: []*x509.Certificate{ + validAKCert, + ca.Intermediate, + }, + expErr: nil, + } + }, + "fail/attest": func(t *testing.T) test { + params, err := ak.AttestationParameters(context.Background()) + require.NoError(t, err) + require.NotNil(t, params) + activation := attest.ActivationParameters{ + TPMVersion: attest.TPMVersion20, + EK: ek.Public(), + AK: params, + } + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/attest": + var ar attestationRequest + err := json.NewDecoder(r.Body).Decode(&ar) + require.NoError(t, err) + parsedEK, err := x509.ParsePKIXPublicKey(ar.EKPub) + require.NoError(t, err) + assert.Equal(t, ek.Public(), parsedEK) + attestParams := attest.AttestationParameters{ + Public: ar.AttestParams.Public, + UseTCSDActivationFormat: ar.AttestParams.UseTCSDActivationFormat, + CreateData: ar.AttestParams.CreateData, + CreateAttestation: ar.AttestParams.CreateAttestation, + CreateSignature: ar.AttestParams.CreateSignature, + } + activationParams := attest.ActivationParameters{ + TPMVersion: ar.TPMInfo.Version, + EK: parsedEK, + AK: attestParams, + } + assert.Equal(t, activation, activationParams) + w.WriteHeader(http.StatusBadRequest) + case "/secret": + t.Errorf("unexpectedly requested /secret endpoint") + default: + t.Errorf("unexpected %q request to %q", r.Method, r.URL) + } + }) + s := httptest.NewServer(handler) + return test{ + server: s, + fields: fields{ + client: http.DefaultClient, + baseURL: mustParseURL(t, s.URL), + }, + args: args{ + ctx: ctx, + t: instance, + + ek: ek, + ak: ak, + }, + want: nil, + expErr: fmt.Errorf(`failed attesting AK: POST %q failed with HTTP status "400 Bad Request"`, fmt.Sprintf("%s/attest", s.URL)), + } + }, + "fail/secret": func(t *testing.T) test { + params, err := ak.AttestationParameters(context.Background()) + require.NoError(t, err) + require.NotNil(t, params) + activation := attest.ActivationParameters{ + TPMVersion: attest.TPMVersion20, + EK: ek.Public(), + AK: params, + } + expectedSecret, encryptedCredentials, err := activation.Generate() + require.NoError(t, err) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/attest": + var ar attestationRequest + err := json.NewDecoder(r.Body).Decode(&ar) + require.NoError(t, err) + parsedEK, err := x509.ParsePKIXPublicKey(ar.EKPub) + require.NoError(t, err) + assert.Equal(t, ek.Public(), parsedEK) + attestParams := attest.AttestationParameters{ + Public: ar.AttestParams.Public, + UseTCSDActivationFormat: ar.AttestParams.UseTCSDActivationFormat, + CreateData: ar.AttestParams.CreateData, + CreateAttestation: ar.AttestParams.CreateAttestation, + CreateSignature: ar.AttestParams.CreateSignature, + } + activationParams := attest.ActivationParameters{ + TPMVersion: ar.TPMInfo.Version, + EK: parsedEK, + AK: attestParams, + } + assert.Equal(t, activation, activationParams) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&attestationResponse{ + Credential: encryptedCredentials.Credential, + Secret: encryptedCredentials.Secret, + }) + case "/secret": + var sr secretRequest + err := json.NewDecoder(r.Body).Decode(&sr) + require.NoError(t, err) + assert.Equal(t, expectedSecret, sr.Secret) + w.WriteHeader(http.StatusForbidden) + default: + t.Errorf("unexpected %q request to %q", r.Method, r.URL) + } + }) + s := httptest.NewServer(handler) + return test{ + server: s, + fields: fields{ + client: http.DefaultClient, + baseURL: mustParseURL(t, s.URL), + }, + args: args{ + ctx: ctx, + t: instance, + + ek: ek, + ak: ak, + }, + want: nil, + expErr: fmt.Errorf(`failed validating secret: POST %q failed with HTTP status "403 Forbidden"`, fmt.Sprintf("%s/secret", s.URL)), + } + }, + "fail/pars-ak-certificate-chain": func(t *testing.T) test { + params, err := ak.AttestationParameters(context.Background()) + require.NoError(t, err) + require.NotNil(t, params) + activation := attest.ActivationParameters{ + TPMVersion: attest.TPMVersion20, + EK: ek.Public(), + AK: params, + } + expectedSecret, encryptedCredentials, err := activation.Generate() + require.NoError(t, err) + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/attest": + var ar attestationRequest + err := json.NewDecoder(r.Body).Decode(&ar) + require.NoError(t, err) + parsedEK, err := x509.ParsePKIXPublicKey(ar.EKPub) + require.NoError(t, err) + assert.Equal(t, ek.Public(), parsedEK) + attestParams := attest.AttestationParameters{ + Public: ar.AttestParams.Public, + UseTCSDActivationFormat: ar.AttestParams.UseTCSDActivationFormat, + CreateData: ar.AttestParams.CreateData, + CreateAttestation: ar.AttestParams.CreateAttestation, + CreateSignature: ar.AttestParams.CreateSignature, + } + activationParams := attest.ActivationParameters{ + TPMVersion: ar.TPMInfo.Version, + EK: parsedEK, + AK: attestParams, + } + assert.Equal(t, activation, activationParams) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&attestationResponse{ + Credential: encryptedCredentials.Credential, + Secret: encryptedCredentials.Secret, + }) + case "/secret": + var sr secretRequest + err := json.NewDecoder(r.Body).Decode(&sr) + require.NoError(t, err) + assert.Equal(t, expectedSecret, sr.Secret) + w.WriteHeader(http.StatusOK) + _ = json.NewEncoder(w).Encode(&secretResponse{ + CertificateChain: [][]byte{[]byte("this-is-no-certificate")}, + }) + default: + t.Errorf("unexpected %q request to %q", r.Method, r.URL) + } + }) + s := httptest.NewServer(handler) + return test{ + server: s, + fields: fields{ + client: http.DefaultClient, + baseURL: mustParseURL(t, s.URL), + }, + args: args{ + ctx: ctx, + t: instance, + + ek: ek, + ak: ak, + }, + want: nil, + expErr: errors.New(`failed parsing certificate: x509: malformed certificate`), + } + }, + } + for name, tt := range tests { + tc := tt(t) + t.Run(name, func(t *testing.T) { + ac := &Client{ + client: tc.fields.client, + baseURL: tc.fields.baseURL, + } + if tc.server != nil { + defer tc.server.Close() + } + got, err := ac.Attest(tc.args.ctx, tc.args.t, tc.args.ek, tc.args.ak) + if tc.expErr != nil { + assert.EqualError(t, err, tc.expErr.Error()) + assert.Nil(t, got) + return + } + + assert.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} diff --git a/tpm/attestation/client_test.go b/tpm/attestation/client_test.go new file mode 100644 index 00000000..9396b17b --- /dev/null +++ b/tpm/attestation/client_test.go @@ -0,0 +1,90 @@ +package attestation + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestNewClient(t *testing.T) { + baseURL := "http://localhost:1337" + type args struct { + tpmAttestationCABaseURL string + options []Option + } + tests := []struct { + name string + args args + assertFunc assert.ValueAssertionFunc + wantErr bool + }{ + { + name: "ok/no-options", + args: args{ + tpmAttestationCABaseURL: baseURL, + options: nil, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &Client{}, i1) { + c, _ := i1.(*Client) + if assert.NotNil(t, c) { + if assert.NotNil(t, c.baseURL) { + assert.Equal(t, baseURL, c.baseURL.String()) + } + if assert.NotNil(t, c.client) { + assert.Equal(t, 10*time.Second, c.client.Timeout) + } + return true + } + } + return false + }, + wantErr: false, + }, + { + name: "ok/with-options", + args: args{ + tpmAttestationCABaseURL: baseURL, + options: []Option{WithInsecure(), WithRootsFile("testdata/roots.pem")}, + }, + assertFunc: func(tt assert.TestingT, i1 interface{}, i2 ...interface{}) bool { + if assert.IsType(t, &Client{}, i1) { + c, _ := i1.(*Client) + if assert.NotNil(t, c) { + if assert.NotNil(t, c.baseURL) { + assert.Equal(t, baseURL, c.baseURL.String()) + } + if assert.NotNil(t, c.client) { + assert.Equal(t, 10*time.Second, c.client.Timeout) + } + return true + } + } + return false + }, + wantErr: false, + }, + { + name: "fail/non-existing-roots", + args: args{ + tpmAttestationCABaseURL: baseURL, + options: []Option{WithInsecure(), WithRootsFile("testdata/non-existing-roots.pem")}, + }, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewClient(tt.args.tpmAttestationCABaseURL, tt.args.options...) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + return + } + + assert.NoError(t, err) + assert.True(t, tt.assertFunc(t, got)) + }) + } +} diff --git a/tpm/attestation/testdata/roots.pem b/tpm/attestation/testdata/roots.pem new file mode 100644 index 00000000..c802b420 --- /dev/null +++ b/tpm/attestation/testdata/roots.pem @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBhzCCASygAwIBAgIRANJiwPnM38wWznkJGOcIyIYwCgYIKoZIzj0EAwIwITEf +MB0GA1UEAxMWU21hbGxzdGVwIFRlc3QgUm9vdCBDQTAeFw0xODA5MjcxODE4MDla +Fw0yODA5MjQxODE4MDlaMCExHzAdBgNVBAMTFlNtYWxsc3RlcCBUZXN0IFJvb3Qg +Q0EwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAAS15w7dx9zPjCnQ7+RlRkvUXQJN +Fjk5Hg5K9nCoiiNQQhcQMw63/pXQxHNsugiMshcN59XJC8195KJPm25nXN8co0Uw +QzAOBgNVHQ8BAf8EBAMCAaYwEgYDVR0TAQH/BAgwBgEB/wIBATAdBgNVHQ4EFgQU +B2BAXUSPZbFjnY6VzbApV48Tn3owCgYIKoZIzj0EAwIDSQAwRgIhAJRTVmc2xW8c +ESx4oIp2d/OX9KBZzpcNi9fHnnJCS0FXAiEA7OpFb2+b8KBzg1c02x21PS7pHoET +/A8LXNH4M06A7vE= +-----END CERTIFICATE----- diff --git a/tpm/rand/rand.go b/tpm/rand/rand.go new file mode 100644 index 00000000..fceadc3a --- /dev/null +++ b/tpm/rand/rand.go @@ -0,0 +1,16 @@ +package rand + +import ( + "fmt" + "io" + + "go.step.sm/crypto/tpm" +) + +func New(opts ...tpm.NewTPMOption) (io.Reader, error) { + t, err := tpm.New(opts...) + if err != nil { + return nil, fmt.Errorf("failed creating TPM instance: %w", err) + } + return t.RandomReader() +} diff --git a/tpm/rand/rand_simulator_test.go b/tpm/rand/rand_simulator_test.go new file mode 100644 index 00000000..91101403 --- /dev/null +++ b/tpm/rand/rand_simulator_test.go @@ -0,0 +1,61 @@ +//go:build tpmsimulator +// +build tpmsimulator + +package rand + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.step.sm/crypto/tpm" + "go.step.sm/crypto/tpm/simulator" +) + +func withSimulator(t *testing.T) tpm.NewTPMOption { + t.Helper() + var sim simulator.Simulator + t.Cleanup(func() { + if sim == nil { + return + } + err := sim.Close() + require.NoError(t, err) + }) + sim = simulator.New() + err := sim.Open() + require.NoError(t, err) + return tpm.WithSimulator(sim) +} + +func withNewErrorSimulator(t *testing.T) tpm.NewTPMOption { + return func(t *tpm.TPM) error { + return errors.New("forced new error") + } +} + +func TestNew(t *testing.T) { + r, err := New(withSimulator(t)) + require.NoError(t, err) + require.NotNil(t, r) + + ecdsaKey, err := ecdsa.GenerateKey(elliptic.P256(), r) + require.NoError(t, err) + if assert.NotNil(t, ecdsaKey) { + size := (ecdsaKey.D.BitLen() + 7) / 8 + require.Equal(t, 32, size) + } + + rsaKey, err := rsa.GenerateKey(r, 2048) + require.NoError(t, err) + if assert.NotNil(t, rsaKey) { + require.Equal(t, 256, rsaKey.Size()) // 2048 bits; 256 bytes expected to have been read + } + + _, err = New(withNewErrorSimulator(t)) + require.Error(t, err) +} diff --git a/tpm/random.go b/tpm/random.go index 61bdc49c..6edabf7e 100644 --- a/tpm/random.go +++ b/tpm/random.go @@ -2,11 +2,23 @@ package tpm import ( "context" + "errors" "fmt" + "io" + "math" "github.com/google/go-tpm/tpm2" ) +type ShortRandomReadError struct { + Requested int + Generated int +} + +func (s ShortRandomReadError) Error() string { + return fmt.Sprintf("generated %d random bytes instead of the requested %d", s.Generated, s.Requested) +} + // GenerateRandom returns `size` number of random bytes generated by the TPM. func (t *TPM) GenerateRandom(ctx context.Context, size uint16) (random []byte, err error) { if err = t.open(goTPMCall(ctx)); err != nil { @@ -14,14 +26,70 @@ func (t *TPM) GenerateRandom(ctx context.Context, size uint16) (random []byte, e } defer closeTPM(ctx, t, &err) + return t.generateRandom(ctx, size) +} + +func (t *TPM) generateRandom(ctx context.Context, size uint16) (random []byte, err error) { random, err = tpm2.GetRandom(t.rwc, size) if err != nil { return nil, fmt.Errorf("failed generating random data: %w", err) } if len(random) != int(size) { - return nil, fmt.Errorf("generated %d random bytes instead of the requested %d", len(random), size) + return nil, ShortRandomReadError{Requested: int(size), Generated: len(random)} } return } + +type generator struct { + t *TPM + readError error +} + +func (t *TPM) RandomReader() (io.Reader, error) { + return &generator{ + t: t, + }, nil +} + +func (g *generator) Read(p []byte) (n int, err error) { + if g.readError != nil { + errMsg := g.readError.Error() // multiple wrapped errors not (yet) allowed + return 0, fmt.Errorf("failed generating random bytes in previous call to Read: %s: %w", errMsg, io.EOF) + } + if len(p) > math.MaxUint16 { + return 0, fmt.Errorf("number of random bytes to read cannot exceed %d", math.MaxUint16) + } + + ctx := context.Background() + if err = g.t.open(goTPMCall(ctx)); err != nil { + return 0, fmt.Errorf("failed opening TPM: %w", err) + } + defer closeTPM(ctx, g.t, &err) + + var result []byte + requestedLength := len(p) + singleRequestLength := uint16(requestedLength) + for len(result) < requestedLength { + if r, err := g.t.generateRandom(ctx, singleRequestLength); err == nil { + result = append(result, r...) + } else { + var s ShortRandomReadError + if errors.As(err, &s) && s.Generated > 0 { + // adjust number of bytes to request if at least some data was read and continue loop + singleRequestLength = uint16(s.Generated) + result = append(result, r...) + } else { + g.readError = err // store the error to be returned for future calls to Read + n = copy(p, result) + return n, nil // return the result recorded so far and no error + } + } + } + + n = copy(p, result) + return +} + +var _ io.Reader = (*generator)(nil) diff --git a/tpm/tpm.go b/tpm/tpm.go index 1d0ac96d..ea6fb391 100644 --- a/tpm/tpm.go +++ b/tpm/tpm.go @@ -40,7 +40,9 @@ type NewTPMOption func(t *TPM) error // device. func WithDeviceName(name string) NewTPMOption { return func(t *TPM) error { - t.deviceName = name + if name != "" { + t.deviceName = name + } return nil } } diff --git a/tpm/tpm_simulator_test.go b/tpm/tpm_simulator_test.go index 40b9e7c6..99013886 100644 --- a/tpm/tpm_simulator_test.go +++ b/tpm/tpm_simulator_test.go @@ -12,10 +12,14 @@ import ( "crypto/x509/pkix" "encoding/base64" "encoding/binary" + "errors" + "io" + "math" "strings" "testing" "github.com/smallstep/go-attestation/attest" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.step.sm/crypto/keyutil" @@ -80,6 +84,113 @@ func TestTPM_GenerateRandom(t *testing.T) { require.Len(t, b, 10) } +func newErrorTPM(t *testing.T) *TPM { + t.Helper() + tmpDir := t.TempDir() + tpm, err := New(withWriteErrorSimulator(t), WithStore(storage.NewDirstore(tmpDir))) // TODO: provide in-memory storage implementation instead + require.NoError(t, err) + return tpm +} + +func withWriteErrorSimulator(t *testing.T) NewTPMOption { + t.Helper() + var sim simulator.Simulator + t.Cleanup(func() { + if sim == nil { + return + } + err := sim.Close() + require.NoError(t, err) + }) + sim = &writeErrorSimulator{} + err := sim.Open() + require.NoError(t, err) + return WithSimulator(sim) +} + +type writeErrorSimulator struct { +} + +func (s *writeErrorSimulator) Open() error { + return nil +} + +func (s *writeErrorSimulator) Close() error { + return nil +} + +func (s *writeErrorSimulator) Read([]byte) (int, error) { + return -1, nil +} + +func (s *writeErrorSimulator) Write([]byte) (int, error) { + return 0, errors.New("forced write error") // writing command fails +} + +func (s *writeErrorSimulator) MeasurementLog() ([]byte, error) { + return nil, nil +} + +var _ io.ReadWriteCloser = (*writeErrorSimulator)(nil) + +func Test_generator_Read(t *testing.T) { + tpm := newSimulatedTPM(t) + errorTPM := newErrorTPM(t) + type fields struct { + t *TPM + } + type args struct { + data []byte + } + short := make([]byte, 8) + long := make([]byte, 32) + tooLongForSimulator := make([]byte, 256) // I've observed the simulator to return 64 at most in one go; we loop through it, so we can get more than 64 random bytes + maximumLength := make([]byte, math.MaxUint16) + longerThanMax := make([]byte, math.MaxUint16+1) + readError := make([]byte, 32) + tests := []struct { + name string + fields fields + args args + want int + expErr error + }{ + {"ok/short", fields{tpm}, args{data: short}, 8, nil}, + {"ok/long", fields{tpm}, args{data: long}, 32, nil}, + {"ok/tooLongForSimulator", fields{tpm}, args{data: tooLongForSimulator}, 256, nil}, + {"ok/max", fields{tpm}, args{data: maximumLength}, math.MaxUint16, nil}, + {"ok/readError", fields{errorTPM}, args{data: readError}, 0, nil}, + {"fail/longerThanMax", fields{tpm}, args{data: longerThanMax}, 0, errors.New("number of random bytes to read cannot exceed 65535")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g, err := tt.fields.t.RandomReader() + require.NoError(t, err) + + got, err := g.Read(tt.args.data) + if tt.expErr != nil { + assert.EqualError(t, err, tt.expErr.Error()) + assert.Equal(t, 0, got) + return + } + + assert.NoError(t, err) + assert.Equal(t, tt.want, got) + + // for the test cases that use the errorTPM, check that trying + // to read (again) from the same generator fails with the previous + // error. + if tt.fields.t == errorTPM { + newShort := make([]byte, 8) + n, err := g.Read(newShort) + assert.Zero(t, n) + assert.EqualError(t, err, "failed generating random bytes in previous call to Read: failed generating random data: forced write error: EOF") + assert.ErrorIs(t, err, io.EOF) + } + }) + } +} + func TestTPM_GetEKs(t *testing.T) { tpm := newSimulatedTPM(t) eks, err := tpm.GetEKs(context.Background()) @@ -252,6 +363,28 @@ func TestAK_Blobs(t *testing.T) { require.Len(t, public, int(size)+2) } +func TestAK_Public(t *testing.T) { + tpm := newSimulatedTPM(t) + ak, err := tpm.CreateAK(context.Background(), "first-ak") + require.NoError(t, err) + require.NotNil(t, ak) + require.Same(t, tpm, ak.tpm) + + akPub := ak.Public() + require.NoError(t, err) + require.NotNil(t, akPub) + require.Implements(t, (*crypto.PublicKey)(nil), ak) + _, ok := akPub.(crypto.Signer) + require.False(t, ok) + + newAK := &AK{ + tpm: tpm, + name: "second-ak", // non-existent AK; results in error + } + newAKPub := newAK.Public() + require.Nil(t, newAKPub) +} + func TestAK_CertificateOperations(t *testing.T) { tpm := newSimulatedTPM(t) @@ -260,7 +393,7 @@ func TestAK_CertificateOperations(t *testing.T) { require.NotNil(t, ak) require.Same(t, tpm, ak.tpm) - akPub, err := ak.public(context.Background()) + akPub := ak.Public() require.NoError(t, err) require.NotNil(t, akPub)