Skip to content

Commit

Permalink
update unit tests to check for valid stunnel secret
Browse files Browse the repository at this point in the history
Signed-off-by: Alay Patel <[email protected]>
  • Loading branch information
alaypatel07 committed Nov 18, 2021
1 parent e10d983 commit 75c304a
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 72 deletions.
11 changes: 6 additions & 5 deletions transport/stunnel/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,17 @@ func (s *server) prefixedName(name string) string {
}

func (s *server) reconcileSecret(ctx context.Context, c ctrlclient.Client) error {
_, _, found, err := getExistingCert(ctx, c, s.logger, s.namespacedName, serverSecretNameSuffix())
if found {
return nil
}

secretValid, err := isSecretValid(ctx, c, s.logger, s.namespacedName, serverSecretNameSuffix())
if err != nil {
s.logger.Error(err, "error getting existing ssl certs from secret")
return err
}
if secretValid {
s.logger.V(4).Info("found secret with valid certs")
return nil
}

s.logger.Info("generating new certificate bundle")
crtBundle, err := certs.New()
if err != nil {
s.logger.Error(err, "error generating ssl certs for stunnel server")
Expand Down
19 changes: 12 additions & 7 deletions transport/stunnel/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (

func fakeClientWithObjects(objs ...ctrlclient.Object) ctrlclient.WithWatch {
scheme := runtime.NewScheme()
AddToScheme(scheme)
_ = AddToScheme(scheme)
return fake.NewClientBuilder().WithScheme(scheme).WithObjects(objs...).Build()
}

Expand Down Expand Up @@ -57,11 +57,11 @@ func (f fakeEndpoint) IngressPort() int32 {
return 1234
}

func (f fakeEndpoint) IsHealthy(ctx context.Context, c ctrlclient.Client) (bool, error) {
func (f fakeEndpoint) IsHealthy(_ context.Context, _ ctrlclient.Client) (bool, error) {
return true, nil
}

func (f fakeEndpoint) MarkForCleanup(ctx context.Context, c ctrlclient.Client, key, value string) error {
func (f fakeEndpoint) MarkForCleanup(_ context.Context, _ ctrlclient.Client, _, _ string) error {
return nil
}

Expand Down Expand Up @@ -164,7 +164,7 @@ func TestNewServer(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
fakeClient := fakeClientWithObjects(tt.objects...)
ctx := context.WithValue(context.Background(), "test", tt.name)
fakeLogger := logrtesting.TestLogger{t}
fakeLogger := logrtesting.TestLogger{T: t}
stunnelServer, err := NewServer(ctx, fakeClient, fakeLogger, tt.namespacedName, tt.endpoint, &transport.Options{Labels: tt.labels, Owners: tt.ownerReferences})
if (err != nil) != tt.wantErr {
t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr)
Expand All @@ -179,11 +179,11 @@ func TestNewServer(t *testing.T) {
panic(fmt.Errorf("%#v should not be getting error from fake client", err))
}

configdata, ok := cm.Data["stunnel.conf"]
configData, ok := cm.Data["stunnel.conf"]
if !ok {
t.Error("unable to find stunnel config data in configmap")
}
if !strings.Contains(configdata, "foreground = yes") {
if !strings.Contains(configData, "foreground = yes") {
t.Error("configmap data does not contain the right data")
}

Expand All @@ -206,6 +206,11 @@ func TestNewServer(t *testing.T) {
t.Error("unable to find tls.crt in stunnel secret")
}

_, ok = secret.Data["ca.crt"]
if !ok {
t.Error("unable to find ca.crt in stunnel secret")
}

if len(stunnelServer.Volumes()) == 0 {
t.Error("stunnel server volumes not set properly")
}
Expand Down Expand Up @@ -257,7 +262,7 @@ func Test_server_MarkForCleanup(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &server{
logger: logrtesting.TestLogger{t},
logger: logrtesting.TestLogger{T: t},
options: &transport.Options{
Labels: tt.labels,
Owners: testOwnerReferences(),
Expand Down
36 changes: 23 additions & 13 deletions transport/stunnel/stunnel.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"context"

"github.com/backube/pvc-transfer/transport"
"github.com/backube/pvc-transfer/transport/tls/certs"
"github.com/go-logr/logr"
corev1 "k8s.io/api/core/v1"
k8serrors "k8s.io/apimachinery/pkg/api/errors"
Expand Down Expand Up @@ -47,36 +48,45 @@ func getResourceName(obj types.NamespacedName, suffix string) string {
return obj.Name + "-" + suffix
}

func getExistingCert(ctx context.Context, c ctrlclient.Client, logger logr.Logger, secretName types.NamespacedName, suffix string) (*bytes.Buffer, *bytes.Buffer, bool, error) {
func isSecretValid(ctx context.Context, c ctrlclient.Client, logger logr.Logger, key types.NamespacedName, suffix string) (bool, error) {
secret := &corev1.Secret{}
err := c.Get(ctx, types.NamespacedName{
Namespace: secretName.Namespace,
Name: getResourceName(secretName, suffix),
Namespace: key.Namespace,
Name: getResourceName(key, suffix),
}, secret)
switch {
case k8serrors.IsNotFound(err):
return nil, nil, false, nil
return false, nil
case err != nil:
return nil, nil, false, err
return false, err
}

key, ok := secret.Data["tls.key"]
_, ok := secret.Data["tls.key"]
if !ok {
logger.Info("secret data missing key tls.key", "secret", types.NamespacedName{
Namespace: secretName.Namespace,
Name: getResourceName(secretName, suffix),
Namespace: key.Namespace,
Name: getResourceName(key, suffix),
})
return nil, nil, false, nil
return false, nil
}

crt, ok := secret.Data["tls.crt"]
if !ok {
logger.Info("secret data missing key tls.crt", "secret", types.NamespacedName{
Namespace: secretName.Namespace,
Name: getResourceName(secretName, suffix),
Namespace: key.Namespace,
Name: getResourceName(key, suffix),
})
return nil, nil, false, nil
return false, nil
}

return bytes.NewBuffer(key), bytes.NewBuffer(crt), true, nil
ca, ok := secret.Data["ca.crt"]
if !ok {
logger.Info("secret data missing key ca.crt", "secret", types.NamespacedName{
Namespace: key.Namespace,
Name: getResourceName(key, suffix),
})
return false, nil
}

return certs.VerifyCertificate(bytes.NewBuffer(ca), bytes.NewBuffer(crt))
}
38 changes: 25 additions & 13 deletions transport/stunnel/stunnel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@ import (
"testing"

"github.com/backube/pvc-transfer/transport"
"github.com/backube/pvc-transfer/transport/tls/certs"
logrtesting "github.com/go-logr/logr/testing"
corev1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/types"
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
)

var certificateBundle, _ = certs.New()

func Test_getExistingCert(t *testing.T) {
tests := []struct {
name string
Expand Down Expand Up @@ -42,7 +45,7 @@ func Test_getExistingCert(t *testing.T) {
Namespace: "bar",
Labels: map[string]string{"test": "me"},
},
Data: map[string][]byte{"tls.crt": []byte(`crt`)},
Data: map[string][]byte{"tls.crt": certificateBundle.ServerCrt.Bytes()},
},
},
},
Expand All @@ -59,15 +62,32 @@ func Test_getExistingCert(t *testing.T) {
Namespace: "bar",
Labels: map[string]string{"test": "me"},
},
Data: map[string][]byte{"tls.key": []byte(`key`)},
Data: map[string][]byte{"tls.key": certificateBundle.ServerKey.Bytes()},
},
},
},
{
name: "test with secret missing ca.crt",
namespacedName: types.NamespacedName{Namespace: "bar", Name: "foo"},
labels: map[string]string{"test": "me"},
wantErr: true,
wantFound: false,
objects: []ctrlclient.Object{
&corev1.Secret{
ObjectMeta: metav1.ObjectMeta{
Name: "foo-stunnel-credentials",
Namespace: "bar",
Labels: map[string]string{"test": "me"},
},
Data: map[string][]byte{"tls.key": certificateBundle.ServerKey.Bytes(), "tls.crt": certificateBundle.ServerKey.Bytes()},
},
},
},
{
name: "test with valid secret",
namespacedName: types.NamespacedName{Namespace: "bar", Name: "foo"},
labels: map[string]string{"test": "me"},
wantErr: false,
wantErr: true,
wantFound: true,
objects: []ctrlclient.Object{
&corev1.Secret{
Expand All @@ -76,7 +96,7 @@ func Test_getExistingCert(t *testing.T) {
Namespace: "bar",
Labels: map[string]string{"test": "me"},
},
Data: map[string][]byte{"tls.key": []byte(`key`), "tls.crt": []byte(`crt`)},
Data: map[string][]byte{"tls.key": certificateBundle.ServerKey.Bytes(), "tls.crt": certificateBundle.ServerCrt.Bytes(), "ca.crt": certificateBundle.CACrt.Bytes()},
},
},
},
Expand All @@ -92,7 +112,7 @@ func Test_getExistingCert(t *testing.T) {
},
}
ctx := context.WithValue(context.Background(), "test", tt.name)
key, crt, found, err := getExistingCert(ctx, fakeClientWithObjects(tt.objects...), s.logger, s.namespacedName, stunnelSecret)
found, err := isSecretValid(ctx, fakeClientWithObjects(tt.objects...), s.logger, s.namespacedName, stunnelSecret)
if err != nil {
t.Error("found unexpected error", err)
}
Expand All @@ -102,14 +122,6 @@ func Test_getExistingCert(t *testing.T) {
if tt.wantFound && !found {
t.Error("not found unexpected")
}

if tt.wantFound && found && key == nil {
t.Error("secret found but empty key, unexpected")
}

if tt.wantFound && found && crt == nil {
t.Error("secret found but empty crt, unexpected")
}
})
}
}
30 changes: 30 additions & 0 deletions transport/tls/certs/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"time"
)
Expand Down Expand Up @@ -121,6 +122,35 @@ func Generate(subject *pkix.Name, caCrtTemplate x509.Certificate, caKey rsa.Priv
return
}

// VerifyCertificate returns true if the crt is signed by the caCrt as the root CA
// with no intermediate DCAs in the chain
func VerifyCertificate(caCrt *bytes.Buffer, crt *bytes.Buffer) (bool, error) {
roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM(caCrt.Bytes())
if !ok {
panic("failed to parse root certificate")
}

block, _ := pem.Decode(crt.Bytes())
if block == nil {
return false, fmt.Errorf("unable to decode certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return false, fmt.Errorf("failed to parse certificate: %#v", err)
}

opts := x509.VerifyOptions{
Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
}

if _, err := cert.Verify(opts); err != nil {
return false, nil
}
return true, nil
}

func createCrtKeyPair(crtTemplate, parent *x509.Certificate, signer *rsa.PrivateKey) (crt *bytes.Buffer, key *rsa.PrivateKey, err error) {
key, err = rsa.GenerateKey(rand.Reader, keySize)
if err != nil {
Expand Down
38 changes: 4 additions & 34 deletions transport/tls/certs/generate_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
package certs

import (
"bytes"
"crypto/x509"
"encoding/pem"
"testing"
)

Expand Down Expand Up @@ -54,10 +51,10 @@ func TestNew(t *testing.T) {
// t.Error("client cert is not verified with root CA")
//}

if !verifySingedCA(got.CACrt, got.ClientCrt) {
if ok, _ := VerifyCertificate(got.CACrt, got.ClientCrt); !ok {
t.Error("client cert is not verified with root CA")
}
if !verifySingedCA(got.CACrt, got.ServerCrt) {
if ok, _ := VerifyCertificate(got.CACrt, got.ServerCrt); !ok {
t.Error("server cert is not verified with root CA")
}

Expand All @@ -66,39 +63,12 @@ func TestNew(t *testing.T) {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
}
if verifySingedCA(got.CACrt, got2.ClientCrt) {
if ok, _ := VerifyCertificate(got.CACrt, got2.ClientCrt); ok {
t.Error("client cert is verified with different root CA")
}
if verifySingedCA(got.CACrt, got2.ServerCrt) {
if ok, _ := VerifyCertificate(got.CACrt, got2.ServerCrt); ok {
t.Error("server cert is not verified with different root CA")
}
})
}
}

func verifySingedCA(caCrt *bytes.Buffer, crt *bytes.Buffer) bool {
roots := x509.NewCertPool()
ok := roots.AppendCertsFromPEM(caCrt.Bytes())
if !ok {
panic("failed to parse root certificate")
}

block, _ := pem.Decode(crt.Bytes())
if block == nil {
panic("failed to parse certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
panic("failed to parse certificate: " + err.Error())
}

opts := x509.VerifyOptions{
Roots: roots,
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
}

if _, err := cert.Verify(opts); err != nil {
return false
}
return true
}

0 comments on commit 75c304a

Please sign in to comment.