...
1
16
17 package transport
18
19 import (
20 "bytes"
21 "crypto/tls"
22 "fmt"
23 "reflect"
24 "sync"
25 "time"
26
27 utilnet "k8s.io/apimachinery/pkg/util/net"
28 utilruntime "k8s.io/apimachinery/pkg/util/runtime"
29 "k8s.io/apimachinery/pkg/util/wait"
30 "k8s.io/client-go/util/connrotation"
31 "k8s.io/client-go/util/workqueue"
32 "k8s.io/klog/v2"
33 )
34
35 const workItemKey = "key"
36
37
38 var CertCallbackRefreshDuration = 5 * time.Minute
39
40 type reloadFunc func(*tls.CertificateRequestInfo) (*tls.Certificate, error)
41
42 type dynamicClientCert struct {
43 clientCert *tls.Certificate
44 certMtx sync.RWMutex
45
46 reload reloadFunc
47 connDialer *connrotation.Dialer
48
49
50 queue workqueue.RateLimitingInterface
51 }
52
53 func certRotatingDialer(reload reloadFunc, dial utilnet.DialFunc) *dynamicClientCert {
54 d := &dynamicClientCert{
55 reload: reload,
56 connDialer: connrotation.NewDialer(connrotation.DialFunc(dial)),
57 queue: workqueue.NewNamedRateLimitingQueue(workqueue.DefaultControllerRateLimiter(), "DynamicClientCertificate"),
58 }
59
60 return d
61 }
62
63
64 func (c *dynamicClientCert) loadClientCert() (*tls.Certificate, error) {
65 cert, err := c.reload(nil)
66 if err != nil {
67 return nil, err
68 }
69
70
71 c.certMtx.RLock()
72 haveCert := c.clientCert != nil
73 if certsEqual(c.clientCert, cert) {
74 c.certMtx.RUnlock()
75 return c.clientCert, nil
76 }
77 c.certMtx.RUnlock()
78
79 c.certMtx.Lock()
80 c.clientCert = cert
81 c.certMtx.Unlock()
82
83
84 if !haveCert {
85 return cert, nil
86 }
87
88 klog.V(1).Infof("certificate rotation detected, shutting down client connections to start using new credentials")
89 c.connDialer.CloseAll()
90
91 return cert, nil
92 }
93
94
95 func certsEqual(left, right *tls.Certificate) bool {
96 if left == nil || right == nil {
97 return left == right
98 }
99
100 if !byteMatrixEqual(left.Certificate, right.Certificate) {
101 return false
102 }
103
104 if !reflect.DeepEqual(left.PrivateKey, right.PrivateKey) {
105 return false
106 }
107
108 if !byteMatrixEqual(left.SignedCertificateTimestamps, right.SignedCertificateTimestamps) {
109 return false
110 }
111
112 if !bytes.Equal(left.OCSPStaple, right.OCSPStaple) {
113 return false
114 }
115
116 return true
117 }
118
119 func byteMatrixEqual(left, right [][]byte) bool {
120 if len(left) != len(right) {
121 return false
122 }
123
124 for i := range left {
125 if !bytes.Equal(left[i], right[i]) {
126 return false
127 }
128 }
129 return true
130 }
131
132
133 func (c *dynamicClientCert) Run(stopCh <-chan struct{}) {
134 defer utilruntime.HandleCrash()
135 defer c.queue.ShutDown()
136
137 klog.V(3).Infof("Starting client certificate rotation controller")
138 defer klog.V(3).Infof("Shutting down client certificate rotation controller")
139
140 go wait.Until(c.runWorker, time.Second, stopCh)
141
142 go wait.PollImmediateUntil(CertCallbackRefreshDuration, func() (bool, error) {
143 c.queue.Add(workItemKey)
144 return false, nil
145 }, stopCh)
146
147 <-stopCh
148 }
149
150 func (c *dynamicClientCert) runWorker() {
151 for c.processNextWorkItem() {
152 }
153 }
154
155 func (c *dynamicClientCert) processNextWorkItem() bool {
156 dsKey, quit := c.queue.Get()
157 if quit {
158 return false
159 }
160 defer c.queue.Done(dsKey)
161
162 _, err := c.loadClientCert()
163 if err == nil {
164 c.queue.Forget(dsKey)
165 return true
166 }
167
168 utilruntime.HandleError(fmt.Errorf("%v failed with : %v", dsKey, err))
169 c.queue.AddRateLimited(dsKey)
170
171 return true
172 }
173
174 func (c *dynamicClientCert) GetClientCertificate(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
175 return c.loadClientCert()
176 }
177
View as plain text