...
1
16
17 package certwatcher
18
19 import (
20 "context"
21 "crypto/tls"
22 "fmt"
23 "sync"
24 "time"
25
26 "github.com/fsnotify/fsnotify"
27 kerrors "k8s.io/apimachinery/pkg/util/errors"
28 "k8s.io/apimachinery/pkg/util/sets"
29 "k8s.io/apimachinery/pkg/util/wait"
30 "sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics"
31 logf "sigs.k8s.io/controller-runtime/pkg/internal/log"
32 )
33
34 var log = logf.RuntimeLog.WithName("certwatcher")
35
36
37
38
39 type CertWatcher struct {
40 sync.RWMutex
41
42 currentCert *tls.Certificate
43 watcher *fsnotify.Watcher
44
45 certPath string
46 keyPath string
47
48
49 callback func(tls.Certificate)
50 }
51
52
53 func New(certPath, keyPath string) (*CertWatcher, error) {
54 var err error
55
56 cw := &CertWatcher{
57 certPath: certPath,
58 keyPath: keyPath,
59 }
60
61
62 if err := cw.ReadCertificate(); err != nil {
63 return nil, err
64 }
65
66 cw.watcher, err = fsnotify.NewWatcher()
67 if err != nil {
68 return nil, err
69 }
70
71 return cw, nil
72 }
73
74
75 func (cw *CertWatcher) RegisterCallback(callback func(tls.Certificate)) {
76 cw.Lock()
77 defer cw.Unlock()
78
79 if cw.currentCert != nil {
80 callback(*cw.currentCert)
81 }
82 cw.callback = callback
83 }
84
85
86 func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
87 cw.RLock()
88 defer cw.RUnlock()
89 return cw.currentCert, nil
90 }
91
92
93 func (cw *CertWatcher) Start(ctx context.Context) error {
94 files := sets.New(cw.certPath, cw.keyPath)
95
96 {
97 var watchErr error
98 if err := wait.PollUntilContextTimeout(ctx, 1*time.Second, 10*time.Second, true, func(ctx context.Context) (done bool, err error) {
99 for _, f := range files.UnsortedList() {
100 if err := cw.watcher.Add(f); err != nil {
101 watchErr = err
102 return false, nil
103 }
104
105 files.Delete(f)
106 }
107 return true, nil
108 }); err != nil {
109 return fmt.Errorf("failed to add watches: %w", kerrors.NewAggregate([]error{err, watchErr}))
110 }
111 }
112
113 go cw.Watch()
114
115 log.Info("Starting certificate watcher")
116
117
118 <-ctx.Done()
119
120 return cw.watcher.Close()
121 }
122
123
124 func (cw *CertWatcher) Watch() {
125 for {
126 select {
127 case event, ok := <-cw.watcher.Events:
128
129 if !ok {
130 return
131 }
132
133 cw.handleEvent(event)
134
135 case err, ok := <-cw.watcher.Errors:
136
137 if !ok {
138 return
139 }
140
141 log.Error(err, "certificate watch error")
142 }
143 }
144 }
145
146
147
148
149 func (cw *CertWatcher) ReadCertificate() error {
150 metrics.ReadCertificateTotal.Inc()
151 cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath)
152 if err != nil {
153 metrics.ReadCertificateErrors.Inc()
154 return err
155 }
156
157 cw.Lock()
158 cw.currentCert = &cert
159 cw.Unlock()
160
161 log.Info("Updated current TLS certificate")
162
163
164 cw.RLock()
165 defer cw.RUnlock()
166 if cw.callback != nil {
167 go func() {
168 cw.callback(cert)
169 }()
170 }
171 return nil
172 }
173
174 func (cw *CertWatcher) handleEvent(event fsnotify.Event) {
175
176 if !(isWrite(event) || isRemove(event) || isCreate(event)) {
177 return
178 }
179
180 log.V(1).Info("certificate event", "event", event)
181
182
183 if isRemove(event) {
184 if err := cw.watcher.Add(event.Name); err != nil {
185 log.Error(err, "error re-watching file")
186 }
187 }
188
189 if err := cw.ReadCertificate(); err != nil {
190 log.Error(err, "error re-reading certificate")
191 }
192 }
193
194 func isWrite(event fsnotify.Event) bool {
195 return event.Op.Has(fsnotify.Write)
196 }
197
198 func isCreate(event fsnotify.Event) bool {
199 return event.Op.Has(fsnotify.Create)
200 }
201
202 func isRemove(event fsnotify.Event) bool {
203 return event.Op.Has(fsnotify.Remove)
204 }
205
View as plain text