From 75c304a1976b1751bbea501d79f9e102cd61adde Mon Sep 17 00:00:00 2001 From: Alay Patel Date: Wed, 17 Nov 2021 23:52:23 -0500 Subject: [PATCH] update unit tests to check for valid stunnel secret Signed-off-by: Alay Patel --- transport/stunnel/server.go | 11 ++++---- transport/stunnel/server_test.go | 19 +++++++++----- transport/stunnel/stunnel.go | 36 ++++++++++++++++---------- transport/stunnel/stunnel_test.go | 38 ++++++++++++++++++---------- transport/tls/certs/generate.go | 30 ++++++++++++++++++++++ transport/tls/certs/generate_test.go | 38 +++------------------------- 6 files changed, 100 insertions(+), 72 deletions(-) diff --git a/transport/stunnel/server.go b/transport/stunnel/server.go index f301801..58ddeb1 100644 --- a/transport/stunnel/server.go +++ b/transport/stunnel/server.go @@ -194,16 +194,17 @@ func (s *server) prefixedName(name string) string { } func (s *server) reconcileSecret(ctx context.Context, c ctrlclient.Client) error { - _, _, found, err := getExistingCert(ctx, c, s.logger, s.namespacedName, serverSecretNameSuffix()) - if found { - return nil - } - + secretValid, err := isSecretValid(ctx, c, s.logger, s.namespacedName, serverSecretNameSuffix()) if err != nil { s.logger.Error(err, "error getting existing ssl certs from secret") return err } + if secretValid { + s.logger.V(4).Info("found secret with valid certs") + return nil + } + s.logger.Info("generating new certificate bundle") crtBundle, err := certs.New() if err != nil { s.logger.Error(err, "error generating ssl certs for stunnel server") diff --git a/transport/stunnel/server_test.go b/transport/stunnel/server_test.go index 6ce9b7a..07736d9 100644 --- a/transport/stunnel/server_test.go +++ b/transport/stunnel/server_test.go @@ -21,7 +21,7 @@ import ( func fakeClientWithObjects(objs ...ctrlclient.Object) ctrlclient.WithWatch { scheme := runtime.NewScheme() - AddToScheme(scheme) + _ = AddToScheme(scheme) return fake.NewClientBuilder().WithScheme(scheme).WithObjects(objs...).Build() } @@ -57,11 +57,11 @@ func (f fakeEndpoint) IngressPort() int32 { return 1234 } -func (f fakeEndpoint) IsHealthy(ctx context.Context, c ctrlclient.Client) (bool, error) { +func (f fakeEndpoint) IsHealthy(_ context.Context, _ ctrlclient.Client) (bool, error) { return true, nil } -func (f fakeEndpoint) MarkForCleanup(ctx context.Context, c ctrlclient.Client, key, value string) error { +func (f fakeEndpoint) MarkForCleanup(_ context.Context, _ ctrlclient.Client, _, _ string) error { return nil } @@ -164,7 +164,7 @@ func TestNewServer(t *testing.T) { t.Run(tt.name, func(t *testing.T) { fakeClient := fakeClientWithObjects(tt.objects...) ctx := context.WithValue(context.Background(), "test", tt.name) - fakeLogger := logrtesting.TestLogger{t} + fakeLogger := logrtesting.TestLogger{T: t} stunnelServer, err := NewServer(ctx, fakeClient, fakeLogger, tt.namespacedName, tt.endpoint, &transport.Options{Labels: tt.labels, Owners: tt.ownerReferences}) if (err != nil) != tt.wantErr { t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr) @@ -179,11 +179,11 @@ func TestNewServer(t *testing.T) { panic(fmt.Errorf("%#v should not be getting error from fake client", err)) } - configdata, ok := cm.Data["stunnel.conf"] + configData, ok := cm.Data["stunnel.conf"] if !ok { t.Error("unable to find stunnel config data in configmap") } - if !strings.Contains(configdata, "foreground = yes") { + if !strings.Contains(configData, "foreground = yes") { t.Error("configmap data does not contain the right data") } @@ -206,6 +206,11 @@ func TestNewServer(t *testing.T) { t.Error("unable to find tls.crt in stunnel secret") } + _, ok = secret.Data["ca.crt"] + if !ok { + t.Error("unable to find ca.crt in stunnel secret") + } + if len(stunnelServer.Volumes()) == 0 { t.Error("stunnel server volumes not set properly") } @@ -257,7 +262,7 @@ func Test_server_MarkForCleanup(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &server{ - logger: logrtesting.TestLogger{t}, + logger: logrtesting.TestLogger{T: t}, options: &transport.Options{ Labels: tt.labels, Owners: testOwnerReferences(), diff --git a/transport/stunnel/stunnel.go b/transport/stunnel/stunnel.go index 5bb96c5..018d7bd 100644 --- a/transport/stunnel/stunnel.go +++ b/transport/stunnel/stunnel.go @@ -5,6 +5,7 @@ import ( "context" "github.com/backube/pvc-transfer/transport" + "github.com/backube/pvc-transfer/transport/tls/certs" "github.com/go-logr/logr" corev1 "k8s.io/api/core/v1" k8serrors "k8s.io/apimachinery/pkg/api/errors" @@ -47,36 +48,45 @@ func getResourceName(obj types.NamespacedName, suffix string) string { return obj.Name + "-" + suffix } -func getExistingCert(ctx context.Context, c ctrlclient.Client, logger logr.Logger, secretName types.NamespacedName, suffix string) (*bytes.Buffer, *bytes.Buffer, bool, error) { +func isSecretValid(ctx context.Context, c ctrlclient.Client, logger logr.Logger, key types.NamespacedName, suffix string) (bool, error) { secret := &corev1.Secret{} err := c.Get(ctx, types.NamespacedName{ - Namespace: secretName.Namespace, - Name: getResourceName(secretName, suffix), + Namespace: key.Namespace, + Name: getResourceName(key, suffix), }, secret) switch { case k8serrors.IsNotFound(err): - return nil, nil, false, nil + return false, nil case err != nil: - return nil, nil, false, err + return false, err } - key, ok := secret.Data["tls.key"] + _, ok := secret.Data["tls.key"] if !ok { logger.Info("secret data missing key tls.key", "secret", types.NamespacedName{ - Namespace: secretName.Namespace, - Name: getResourceName(secretName, suffix), + Namespace: key.Namespace, + Name: getResourceName(key, suffix), }) - return nil, nil, false, nil + return false, nil } crt, ok := secret.Data["tls.crt"] if !ok { logger.Info("secret data missing key tls.crt", "secret", types.NamespacedName{ - Namespace: secretName.Namespace, - Name: getResourceName(secretName, suffix), + Namespace: key.Namespace, + Name: getResourceName(key, suffix), }) - return nil, nil, false, nil + return false, nil } - return bytes.NewBuffer(key), bytes.NewBuffer(crt), true, nil + ca, ok := secret.Data["ca.crt"] + if !ok { + logger.Info("secret data missing key ca.crt", "secret", types.NamespacedName{ + Namespace: key.Namespace, + Name: getResourceName(key, suffix), + }) + return false, nil + } + + return certs.VerifyCertificate(bytes.NewBuffer(ca), bytes.NewBuffer(crt)) } diff --git a/transport/stunnel/stunnel_test.go b/transport/stunnel/stunnel_test.go index f0676f5..f34796d 100644 --- a/transport/stunnel/stunnel_test.go +++ b/transport/stunnel/stunnel_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/backube/pvc-transfer/transport" + "github.com/backube/pvc-transfer/transport/tls/certs" logrtesting "github.com/go-logr/logr/testing" corev1 "k8s.io/api/core/v1" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" @@ -12,6 +13,8 @@ import ( ctrlclient "sigs.k8s.io/controller-runtime/pkg/client" ) +var certificateBundle, _ = certs.New() + func Test_getExistingCert(t *testing.T) { tests := []struct { name string @@ -42,7 +45,7 @@ func Test_getExistingCert(t *testing.T) { Namespace: "bar", Labels: map[string]string{"test": "me"}, }, - Data: map[string][]byte{"tls.crt": []byte(`crt`)}, + Data: map[string][]byte{"tls.crt": certificateBundle.ServerCrt.Bytes()}, }, }, }, @@ -59,7 +62,24 @@ func Test_getExistingCert(t *testing.T) { Namespace: "bar", Labels: map[string]string{"test": "me"}, }, - Data: map[string][]byte{"tls.key": []byte(`key`)}, + Data: map[string][]byte{"tls.key": certificateBundle.ServerKey.Bytes()}, + }, + }, + }, + { + name: "test with secret missing ca.crt", + namespacedName: types.NamespacedName{Namespace: "bar", Name: "foo"}, + labels: map[string]string{"test": "me"}, + wantErr: true, + wantFound: false, + objects: []ctrlclient.Object{ + &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: "foo-stunnel-credentials", + Namespace: "bar", + Labels: map[string]string{"test": "me"}, + }, + Data: map[string][]byte{"tls.key": certificateBundle.ServerKey.Bytes(), "tls.crt": certificateBundle.ServerKey.Bytes()}, }, }, }, @@ -67,7 +87,7 @@ func Test_getExistingCert(t *testing.T) { name: "test with valid secret", namespacedName: types.NamespacedName{Namespace: "bar", Name: "foo"}, labels: map[string]string{"test": "me"}, - wantErr: false, + wantErr: true, wantFound: true, objects: []ctrlclient.Object{ &corev1.Secret{ @@ -76,7 +96,7 @@ func Test_getExistingCert(t *testing.T) { Namespace: "bar", Labels: map[string]string{"test": "me"}, }, - Data: map[string][]byte{"tls.key": []byte(`key`), "tls.crt": []byte(`crt`)}, + Data: map[string][]byte{"tls.key": certificateBundle.ServerKey.Bytes(), "tls.crt": certificateBundle.ServerCrt.Bytes(), "ca.crt": certificateBundle.CACrt.Bytes()}, }, }, }, @@ -92,7 +112,7 @@ func Test_getExistingCert(t *testing.T) { }, } ctx := context.WithValue(context.Background(), "test", tt.name) - key, crt, found, err := getExistingCert(ctx, fakeClientWithObjects(tt.objects...), s.logger, s.namespacedName, stunnelSecret) + found, err := isSecretValid(ctx, fakeClientWithObjects(tt.objects...), s.logger, s.namespacedName, stunnelSecret) if err != nil { t.Error("found unexpected error", err) } @@ -102,14 +122,6 @@ func Test_getExistingCert(t *testing.T) { if tt.wantFound && !found { t.Error("not found unexpected") } - - if tt.wantFound && found && key == nil { - t.Error("secret found but empty key, unexpected") - } - - if tt.wantFound && found && crt == nil { - t.Error("secret found but empty crt, unexpected") - } }) } } diff --git a/transport/tls/certs/generate.go b/transport/tls/certs/generate.go index c69acb3..2c6ac58 100644 --- a/transport/tls/certs/generate.go +++ b/transport/tls/certs/generate.go @@ -7,6 +7,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "encoding/pem" + "fmt" "math/big" "time" ) @@ -121,6 +122,35 @@ func Generate(subject *pkix.Name, caCrtTemplate x509.Certificate, caKey rsa.Priv return } +// VerifyCertificate returns true if the crt is signed by the caCrt as the root CA +// with no intermediate DCAs in the chain +func VerifyCertificate(caCrt *bytes.Buffer, crt *bytes.Buffer) (bool, error) { + roots := x509.NewCertPool() + ok := roots.AppendCertsFromPEM(caCrt.Bytes()) + if !ok { + panic("failed to parse root certificate") + } + + block, _ := pem.Decode(crt.Bytes()) + if block == nil { + return false, fmt.Errorf("unable to decode certificate") + } + cert, err := x509.ParseCertificate(block.Bytes) + if err != nil { + return false, fmt.Errorf("failed to parse certificate: %#v", err) + } + + opts := x509.VerifyOptions{ + Roots: roots, + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, + } + + if _, err := cert.Verify(opts); err != nil { + return false, nil + } + return true, nil +} + func createCrtKeyPair(crtTemplate, parent *x509.Certificate, signer *rsa.PrivateKey) (crt *bytes.Buffer, key *rsa.PrivateKey, err error) { key, err = rsa.GenerateKey(rand.Reader, keySize) if err != nil { diff --git a/transport/tls/certs/generate_test.go b/transport/tls/certs/generate_test.go index 4d07d23..b72cbea 100644 --- a/transport/tls/certs/generate_test.go +++ b/transport/tls/certs/generate_test.go @@ -1,9 +1,6 @@ package certs import ( - "bytes" - "crypto/x509" - "encoding/pem" "testing" ) @@ -54,10 +51,10 @@ func TestNew(t *testing.T) { // t.Error("client cert is not verified with root CA") //} - if !verifySingedCA(got.CACrt, got.ClientCrt) { + if ok, _ := VerifyCertificate(got.CACrt, got.ClientCrt); !ok { t.Error("client cert is not verified with root CA") } - if !verifySingedCA(got.CACrt, got.ServerCrt) { + if ok, _ := VerifyCertificate(got.CACrt, got.ServerCrt); !ok { t.Error("server cert is not verified with root CA") } @@ -66,39 +63,12 @@ func TestNew(t *testing.T) { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return } - if verifySingedCA(got.CACrt, got2.ClientCrt) { + if ok, _ := VerifyCertificate(got.CACrt, got2.ClientCrt); ok { t.Error("client cert is verified with different root CA") } - if verifySingedCA(got.CACrt, got2.ServerCrt) { + if ok, _ := VerifyCertificate(got.CACrt, got2.ServerCrt); ok { t.Error("server cert is not verified with different root CA") } }) } } - -func verifySingedCA(caCrt *bytes.Buffer, crt *bytes.Buffer) bool { - roots := x509.NewCertPool() - ok := roots.AppendCertsFromPEM(caCrt.Bytes()) - if !ok { - panic("failed to parse root certificate") - } - - block, _ := pem.Decode(crt.Bytes()) - if block == nil { - panic("failed to parse certificate") - } - cert, err := x509.ParseCertificate(block.Bytes) - if err != nil { - panic("failed to parse certificate: " + err.Error()) - } - - opts := x509.VerifyOptions{ - Roots: roots, - KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny}, - } - - if _, err := cert.Verify(opts); err != nil { - return false - } - return true -}