1
16
17 package certwatcher_test
18
19 import (
20 "context"
21 "crypto/rand"
22 "crypto/rsa"
23 "crypto/tls"
24 "crypto/x509"
25 "crypto/x509/pkix"
26 "encoding/pem"
27 "fmt"
28 "math/big"
29 "net"
30 "os"
31 "sync/atomic"
32 "time"
33
34 . "github.com/onsi/ginkgo/v2"
35 . "github.com/onsi/gomega"
36 "github.com/prometheus/client_golang/prometheus/testutil"
37 "sigs.k8s.io/controller-runtime/pkg/certwatcher"
38 "sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics"
39 )
40
41 var _ = Describe("CertWatcher", func() {
42 var _ = Describe("certwatcher New", func() {
43 It("should errors without cert/key", func() {
44 _, err := certwatcher.New("", "")
45 Expect(err).To(HaveOccurred())
46 })
47 })
48
49 var _ = Describe("certwatcher Start", func() {
50 var (
51 ctx context.Context
52 ctxCancel context.CancelFunc
53 watcher *certwatcher.CertWatcher
54 )
55
56 BeforeEach(func() {
57 ctx, ctxCancel = context.WithCancel(context.Background())
58
59 err := writeCerts(certPath, keyPath, "127.0.0.1")
60 Expect(err).ToNot(HaveOccurred())
61
62 Eventually(func() error {
63 for _, file := range []string{certPath, keyPath} {
64 _, err := os.ReadFile(file)
65 if err != nil {
66 return err
67 }
68 continue
69 }
70
71 return nil
72 }).Should(Succeed())
73
74 watcher, err = certwatcher.New(certPath, keyPath)
75 Expect(err).ToNot(HaveOccurred())
76 })
77
78 startWatcher := func() (done <-chan struct{}) {
79 doneCh := make(chan struct{})
80 go func() {
81 defer GinkgoRecover()
82 defer close(doneCh)
83 Expect(watcher.Start(ctx)).To(Succeed())
84 }()
85
86 Eventually(func() error {
87 err := watcher.ReadCertificate()
88 return err
89 }).Should(Succeed())
90 return doneCh
91 }
92
93 It("should read the initial cert/key", func() {
94 doneCh := startWatcher()
95
96 ctxCancel()
97 Eventually(doneCh, "4s").Should(BeClosed())
98 })
99
100 It("should reload currentCert when changed", func() {
101 doneCh := startWatcher()
102 called := atomic.Int64{}
103 watcher.RegisterCallback(func(crt tls.Certificate) {
104 called.Add(1)
105 Expect(crt.Certificate).ToNot(BeEmpty())
106 })
107
108 firstcert, _ := watcher.GetCertificate(nil)
109
110 err := writeCerts(certPath, keyPath, "192.168.0.1")
111 Expect(err).ToNot(HaveOccurred())
112
113 Eventually(func() bool {
114 secondcert, _ := watcher.GetCertificate(nil)
115 first := firstcert.PrivateKey.(*rsa.PrivateKey)
116 return first.Equal(secondcert.PrivateKey)
117 }).ShouldNot(BeTrue())
118
119 ctxCancel()
120 Eventually(doneCh, "4s").Should(BeClosed())
121 Expect(called.Load()).To(BeNumerically(">=", 1))
122 })
123
124 Context("prometheus metric read_certificate_total", func() {
125 var readCertificateTotalBefore float64
126 var readCertificateErrorsBefore float64
127
128 BeforeEach(func() {
129 readCertificateTotalBefore = testutil.ToFloat64(metrics.ReadCertificateTotal)
130 readCertificateErrorsBefore = testutil.ToFloat64(metrics.ReadCertificateErrors)
131 })
132
133 It("should get updated on successful certificate read", func() {
134 doneCh := startWatcher()
135
136 Eventually(func() error {
137 readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)
138 if readCertificateTotalAfter != readCertificateTotalBefore+1.0 {
139 return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
140 }
141 return nil
142 }, "4s").Should(Succeed())
143
144 ctxCancel()
145 Eventually(doneCh, "4s").Should(BeClosed())
146 })
147
148 It("should get updated on read certificate errors", func() {
149 doneCh := startWatcher()
150
151 Eventually(func() error {
152 readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)
153 if readCertificateTotalAfter != readCertificateTotalBefore+1.0 {
154 return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
155 }
156 readCertificateTotalBefore = readCertificateTotalAfter
157 return nil
158 }, "4s").Should(Succeed())
159
160 Expect(os.Remove(keyPath)).To(Succeed())
161
162 Eventually(func() error {
163 readCertificateTotalAfter := testutil.ToFloat64(metrics.ReadCertificateTotal)
164 if readCertificateTotalAfter != readCertificateTotalBefore+1.0 {
165 return fmt.Errorf("metric read certificate total expected: %v and got: %v", readCertificateTotalBefore+1.0, readCertificateTotalAfter)
166 }
167 return nil
168 }, "4s").Should(Succeed())
169 Eventually(func() error {
170 readCertificateErrorsAfter := testutil.ToFloat64(metrics.ReadCertificateErrors)
171 if readCertificateErrorsAfter != readCertificateErrorsBefore+1.0 {
172 return fmt.Errorf("metric read certificate errors expected: %v and got: %v", readCertificateErrorsBefore+1.0, readCertificateErrorsAfter)
173 }
174 return nil
175 }, "4s").Should(Succeed())
176
177 ctxCancel()
178 Eventually(doneCh, "4s").Should(BeClosed())
179 })
180 })
181 })
182 })
183
184 func writeCerts(certPath, keyPath, ip string) error {
185 var priv interface{}
186 var err error
187 priv, err = rsa.GenerateKey(rand.Reader, 2048)
188 if err != nil {
189 return err
190 }
191
192 keyUsage := x509.KeyUsageDigitalSignature
193 if _, isRSA := priv.(*rsa.PrivateKey); isRSA {
194 keyUsage |= x509.KeyUsageKeyEncipherment
195 }
196
197 notBefore := time.Now()
198 notAfter := notBefore.Add(1 * time.Hour)
199
200 serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
201 serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
202 if err != nil {
203 return err
204 }
205
206 template := x509.Certificate{
207 SerialNumber: serialNumber,
208 Subject: pkix.Name{
209 Organization: []string{"Kubernetes"},
210 },
211 NotBefore: notBefore,
212 NotAfter: notAfter,
213
214 KeyUsage: keyUsage,
215 ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
216 BasicConstraintsValid: true,
217 }
218
219 template.IPAddresses = append(template.IPAddresses, net.ParseIP(ip))
220
221 privkey := priv.(*rsa.PrivateKey)
222
223 derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privkey.PublicKey, priv)
224 if err != nil {
225 return err
226 }
227
228 certOut, err := os.Create(certPath)
229 if err != nil {
230 return err
231 }
232 if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
233 return err
234 }
235 if err := certOut.Close(); err != nil {
236 return err
237 }
238
239 keyOut, err := os.OpenFile(keyPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
240 if err != nil {
241 return err
242 }
243 privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
244 if err != nil {
245 return err
246 }
247 if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
248 return err
249 }
250 return keyOut.Close()
251 }
252
View as plain text