Skip to content

Commit 75c304a

Browse files
committedNov 18, 2021
update unit tests to check for valid stunnel secret
Signed-off-by: Alay Patel <[email protected]>
1 parent e10d983 commit 75c304a

File tree

6 files changed

+100
-72
lines changed

6 files changed

+100
-72
lines changed
 

‎transport/stunnel/server.go

+6-5
Original file line numberDiff line numberDiff line change
@@ -194,16 +194,17 @@ func (s *server) prefixedName(name string) string {
194194
}
195195

196196
func (s *server) reconcileSecret(ctx context.Context, c ctrlclient.Client) error {
197-
_, _, found, err := getExistingCert(ctx, c, s.logger, s.namespacedName, serverSecretNameSuffix())
198-
if found {
199-
return nil
200-
}
201-
197+
secretValid, err := isSecretValid(ctx, c, s.logger, s.namespacedName, serverSecretNameSuffix())
202198
if err != nil {
203199
s.logger.Error(err, "error getting existing ssl certs from secret")
204200
return err
205201
}
202+
if secretValid {
203+
s.logger.V(4).Info("found secret with valid certs")
204+
return nil
205+
}
206206

207+
s.logger.Info("generating new certificate bundle")
207208
crtBundle, err := certs.New()
208209
if err != nil {
209210
s.logger.Error(err, "error generating ssl certs for stunnel server")

‎transport/stunnel/server_test.go

+12-7
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import (
2121

2222
func fakeClientWithObjects(objs ...ctrlclient.Object) ctrlclient.WithWatch {
2323
scheme := runtime.NewScheme()
24-
AddToScheme(scheme)
24+
_ = AddToScheme(scheme)
2525
return fake.NewClientBuilder().WithScheme(scheme).WithObjects(objs...).Build()
2626
}
2727

@@ -57,11 +57,11 @@ func (f fakeEndpoint) IngressPort() int32 {
5757
return 1234
5858
}
5959

60-
func (f fakeEndpoint) IsHealthy(ctx context.Context, c ctrlclient.Client) (bool, error) {
60+
func (f fakeEndpoint) IsHealthy(_ context.Context, _ ctrlclient.Client) (bool, error) {
6161
return true, nil
6262
}
6363

64-
func (f fakeEndpoint) MarkForCleanup(ctx context.Context, c ctrlclient.Client, key, value string) error {
64+
func (f fakeEndpoint) MarkForCleanup(_ context.Context, _ ctrlclient.Client, _, _ string) error {
6565
return nil
6666
}
6767

@@ -164,7 +164,7 @@ func TestNewServer(t *testing.T) {
164164
t.Run(tt.name, func(t *testing.T) {
165165
fakeClient := fakeClientWithObjects(tt.objects...)
166166
ctx := context.WithValue(context.Background(), "test", tt.name)
167-
fakeLogger := logrtesting.TestLogger{t}
167+
fakeLogger := logrtesting.TestLogger{T: t}
168168
stunnelServer, err := NewServer(ctx, fakeClient, fakeLogger, tt.namespacedName, tt.endpoint, &transport.Options{Labels: tt.labels, Owners: tt.ownerReferences})
169169
if (err != nil) != tt.wantErr {
170170
t.Errorf("NewServer() error = %v, wantErr %v", err, tt.wantErr)
@@ -179,11 +179,11 @@ func TestNewServer(t *testing.T) {
179179
panic(fmt.Errorf("%#v should not be getting error from fake client", err))
180180
}
181181

182-
configdata, ok := cm.Data["stunnel.conf"]
182+
configData, ok := cm.Data["stunnel.conf"]
183183
if !ok {
184184
t.Error("unable to find stunnel config data in configmap")
185185
}
186-
if !strings.Contains(configdata, "foreground = yes") {
186+
if !strings.Contains(configData, "foreground = yes") {
187187
t.Error("configmap data does not contain the right data")
188188
}
189189

@@ -206,6 +206,11 @@ func TestNewServer(t *testing.T) {
206206
t.Error("unable to find tls.crt in stunnel secret")
207207
}
208208

209+
_, ok = secret.Data["ca.crt"]
210+
if !ok {
211+
t.Error("unable to find ca.crt in stunnel secret")
212+
}
213+
209214
if len(stunnelServer.Volumes()) == 0 {
210215
t.Error("stunnel server volumes not set properly")
211216
}
@@ -257,7 +262,7 @@ func Test_server_MarkForCleanup(t *testing.T) {
257262
for _, tt := range tests {
258263
t.Run(tt.name, func(t *testing.T) {
259264
s := &server{
260-
logger: logrtesting.TestLogger{t},
265+
logger: logrtesting.TestLogger{T: t},
261266
options: &transport.Options{
262267
Labels: tt.labels,
263268
Owners: testOwnerReferences(),

‎transport/stunnel/stunnel.go

+23-13
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import (
55
"context"
66

77
"github.com/backube/pvc-transfer/transport"
8+
"github.com/backube/pvc-transfer/transport/tls/certs"
89
"github.com/go-logr/logr"
910
corev1 "k8s.io/api/core/v1"
1011
k8serrors "k8s.io/apimachinery/pkg/api/errors"
@@ -47,36 +48,45 @@ func getResourceName(obj types.NamespacedName, suffix string) string {
4748
return obj.Name + "-" + suffix
4849
}
4950

50-
func getExistingCert(ctx context.Context, c ctrlclient.Client, logger logr.Logger, secretName types.NamespacedName, suffix string) (*bytes.Buffer, *bytes.Buffer, bool, error) {
51+
func isSecretValid(ctx context.Context, c ctrlclient.Client, logger logr.Logger, key types.NamespacedName, suffix string) (bool, error) {
5152
secret := &corev1.Secret{}
5253
err := c.Get(ctx, types.NamespacedName{
53-
Namespace: secretName.Namespace,
54-
Name: getResourceName(secretName, suffix),
54+
Namespace: key.Namespace,
55+
Name: getResourceName(key, suffix),
5556
}, secret)
5657
switch {
5758
case k8serrors.IsNotFound(err):
58-
return nil, nil, false, nil
59+
return false, nil
5960
case err != nil:
60-
return nil, nil, false, err
61+
return false, err
6162
}
6263

63-
key, ok := secret.Data["tls.key"]
64+
_, ok := secret.Data["tls.key"]
6465
if !ok {
6566
logger.Info("secret data missing key tls.key", "secret", types.NamespacedName{
66-
Namespace: secretName.Namespace,
67-
Name: getResourceName(secretName, suffix),
67+
Namespace: key.Namespace,
68+
Name: getResourceName(key, suffix),
6869
})
69-
return nil, nil, false, nil
70+
return false, nil
7071
}
7172

7273
crt, ok := secret.Data["tls.crt"]
7374
if !ok {
7475
logger.Info("secret data missing key tls.crt", "secret", types.NamespacedName{
75-
Namespace: secretName.Namespace,
76-
Name: getResourceName(secretName, suffix),
76+
Namespace: key.Namespace,
77+
Name: getResourceName(key, suffix),
7778
})
78-
return nil, nil, false, nil
79+
return false, nil
7980
}
8081

81-
return bytes.NewBuffer(key), bytes.NewBuffer(crt), true, nil
82+
ca, ok := secret.Data["ca.crt"]
83+
if !ok {
84+
logger.Info("secret data missing key ca.crt", "secret", types.NamespacedName{
85+
Namespace: key.Namespace,
86+
Name: getResourceName(key, suffix),
87+
})
88+
return false, nil
89+
}
90+
91+
return certs.VerifyCertificate(bytes.NewBuffer(ca), bytes.NewBuffer(crt))
8292
}

‎transport/stunnel/stunnel_test.go

+25-13
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@ import (
55
"testing"
66

77
"github.com/backube/pvc-transfer/transport"
8+
"github.com/backube/pvc-transfer/transport/tls/certs"
89
logrtesting "github.com/go-logr/logr/testing"
910
corev1 "k8s.io/api/core/v1"
1011
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
1112
"k8s.io/apimachinery/pkg/types"
1213
ctrlclient "sigs.k8s.io/controller-runtime/pkg/client"
1314
)
1415

16+
var certificateBundle, _ = certs.New()
17+
1518
func Test_getExistingCert(t *testing.T) {
1619
tests := []struct {
1720
name string
@@ -42,7 +45,7 @@ func Test_getExistingCert(t *testing.T) {
4245
Namespace: "bar",
4346
Labels: map[string]string{"test": "me"},
4447
},
45-
Data: map[string][]byte{"tls.crt": []byte(`crt`)},
48+
Data: map[string][]byte{"tls.crt": certificateBundle.ServerCrt.Bytes()},
4649
},
4750
},
4851
},
@@ -59,15 +62,32 @@ func Test_getExistingCert(t *testing.T) {
5962
Namespace: "bar",
6063
Labels: map[string]string{"test": "me"},
6164
},
62-
Data: map[string][]byte{"tls.key": []byte(`key`)},
65+
Data: map[string][]byte{"tls.key": certificateBundle.ServerKey.Bytes()},
66+
},
67+
},
68+
},
69+
{
70+
name: "test with secret missing ca.crt",
71+
namespacedName: types.NamespacedName{Namespace: "bar", Name: "foo"},
72+
labels: map[string]string{"test": "me"},
73+
wantErr: true,
74+
wantFound: false,
75+
objects: []ctrlclient.Object{
76+
&corev1.Secret{
77+
ObjectMeta: metav1.ObjectMeta{
78+
Name: "foo-stunnel-credentials",
79+
Namespace: "bar",
80+
Labels: map[string]string{"test": "me"},
81+
},
82+
Data: map[string][]byte{"tls.key": certificateBundle.ServerKey.Bytes(), "tls.crt": certificateBundle.ServerKey.Bytes()},
6383
},
6484
},
6585
},
6686
{
6787
name: "test with valid secret",
6888
namespacedName: types.NamespacedName{Namespace: "bar", Name: "foo"},
6989
labels: map[string]string{"test": "me"},
70-
wantErr: false,
90+
wantErr: true,
7191
wantFound: true,
7292
objects: []ctrlclient.Object{
7393
&corev1.Secret{
@@ -76,7 +96,7 @@ func Test_getExistingCert(t *testing.T) {
7696
Namespace: "bar",
7797
Labels: map[string]string{"test": "me"},
7898
},
79-
Data: map[string][]byte{"tls.key": []byte(`key`), "tls.crt": []byte(`crt`)},
99+
Data: map[string][]byte{"tls.key": certificateBundle.ServerKey.Bytes(), "tls.crt": certificateBundle.ServerCrt.Bytes(), "ca.crt": certificateBundle.CACrt.Bytes()},
80100
},
81101
},
82102
},
@@ -92,7 +112,7 @@ func Test_getExistingCert(t *testing.T) {
92112
},
93113
}
94114
ctx := context.WithValue(context.Background(), "test", tt.name)
95-
key, crt, found, err := getExistingCert(ctx, fakeClientWithObjects(tt.objects...), s.logger, s.namespacedName, stunnelSecret)
115+
found, err := isSecretValid(ctx, fakeClientWithObjects(tt.objects...), s.logger, s.namespacedName, stunnelSecret)
96116
if err != nil {
97117
t.Error("found unexpected error", err)
98118
}
@@ -102,14 +122,6 @@ func Test_getExistingCert(t *testing.T) {
102122
if tt.wantFound && !found {
103123
t.Error("not found unexpected")
104124
}
105-
106-
if tt.wantFound && found && key == nil {
107-
t.Error("secret found but empty key, unexpected")
108-
}
109-
110-
if tt.wantFound && found && crt == nil {
111-
t.Error("secret found but empty crt, unexpected")
112-
}
113125
})
114126
}
115127
}

‎transport/tls/certs/generate.go

+30
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"crypto/x509"
88
"crypto/x509/pkix"
99
"encoding/pem"
10+
"fmt"
1011
"math/big"
1112
"time"
1213
)
@@ -121,6 +122,35 @@ func Generate(subject *pkix.Name, caCrtTemplate x509.Certificate, caKey rsa.Priv
121122
return
122123
}
123124

125+
// VerifyCertificate returns true if the crt is signed by the caCrt as the root CA
126+
// with no intermediate DCAs in the chain
127+
func VerifyCertificate(caCrt *bytes.Buffer, crt *bytes.Buffer) (bool, error) {
128+
roots := x509.NewCertPool()
129+
ok := roots.AppendCertsFromPEM(caCrt.Bytes())
130+
if !ok {
131+
panic("failed to parse root certificate")
132+
}
133+
134+
block, _ := pem.Decode(crt.Bytes())
135+
if block == nil {
136+
return false, fmt.Errorf("unable to decode certificate")
137+
}
138+
cert, err := x509.ParseCertificate(block.Bytes)
139+
if err != nil {
140+
return false, fmt.Errorf("failed to parse certificate: %#v", err)
141+
}
142+
143+
opts := x509.VerifyOptions{
144+
Roots: roots,
145+
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
146+
}
147+
148+
if _, err := cert.Verify(opts); err != nil {
149+
return false, nil
150+
}
151+
return true, nil
152+
}
153+
124154
func createCrtKeyPair(crtTemplate, parent *x509.Certificate, signer *rsa.PrivateKey) (crt *bytes.Buffer, key *rsa.PrivateKey, err error) {
125155
key, err = rsa.GenerateKey(rand.Reader, keySize)
126156
if err != nil {

‎transport/tls/certs/generate_test.go

+4-34
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
package certs
22

33
import (
4-
"bytes"
5-
"crypto/x509"
6-
"encoding/pem"
74
"testing"
85
)
96

@@ -54,10 +51,10 @@ func TestNew(t *testing.T) {
5451
// t.Error("client cert is not verified with root CA")
5552
//}
5653

57-
if !verifySingedCA(got.CACrt, got.ClientCrt) {
54+
if ok, _ := VerifyCertificate(got.CACrt, got.ClientCrt); !ok {
5855
t.Error("client cert is not verified with root CA")
5956
}
60-
if !verifySingedCA(got.CACrt, got.ServerCrt) {
57+
if ok, _ := VerifyCertificate(got.CACrt, got.ServerCrt); !ok {
6158
t.Error("server cert is not verified with root CA")
6259
}
6360

@@ -66,39 +63,12 @@ func TestNew(t *testing.T) {
6663
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
6764
return
6865
}
69-
if verifySingedCA(got.CACrt, got2.ClientCrt) {
66+
if ok, _ := VerifyCertificate(got.CACrt, got2.ClientCrt); ok {
7067
t.Error("client cert is verified with different root CA")
7168
}
72-
if verifySingedCA(got.CACrt, got2.ServerCrt) {
69+
if ok, _ := VerifyCertificate(got.CACrt, got2.ServerCrt); ok {
7370
t.Error("server cert is not verified with different root CA")
7471
}
7572
})
7673
}
7774
}
78-
79-
func verifySingedCA(caCrt *bytes.Buffer, crt *bytes.Buffer) bool {
80-
roots := x509.NewCertPool()
81-
ok := roots.AppendCertsFromPEM(caCrt.Bytes())
82-
if !ok {
83-
panic("failed to parse root certificate")
84-
}
85-
86-
block, _ := pem.Decode(crt.Bytes())
87-
if block == nil {
88-
panic("failed to parse certificate")
89-
}
90-
cert, err := x509.ParseCertificate(block.Bytes)
91-
if err != nil {
92-
panic("failed to parse certificate: " + err.Error())
93-
}
94-
95-
opts := x509.VerifyOptions{
96-
Roots: roots,
97-
KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageAny},
98-
}
99-
100-
if _, err := cert.Verify(opts); err != nil {
101-
return false
102-
}
103-
return true
104-
}

0 commit comments

Comments
 (0)
Please sign in to comment.