1
16
17 package transport
18
19 import (
20 "context"
21 "crypto/tls"
22 "crypto/x509"
23 "encoding/pem"
24 "fmt"
25 "net/http"
26 "os"
27 "sync"
28 "time"
29
30 utilnet "k8s.io/apimachinery/pkg/util/net"
31 "k8s.io/klog/v2"
32 )
33
34
35
36 func New(config *Config) (http.RoundTripper, error) {
37
38 if config.Transport != nil && (config.HasCA() || config.HasCertAuth() || config.HasCertCallback() || config.TLS.Insecure) {
39 return nil, fmt.Errorf("using a custom transport with TLS certificate options or the insecure flag is not allowed")
40 }
41
42 if !isValidHolders(config) {
43 return nil, fmt.Errorf("misconfigured holder for dialer or cert callback")
44 }
45
46 var (
47 rt http.RoundTripper
48 err error
49 )
50
51 if config.Transport != nil {
52 rt = config.Transport
53 } else {
54 rt, err = tlsCache.get(config)
55 if err != nil {
56 return nil, err
57 }
58 }
59
60 return HTTPWrappersForConfig(config, rt)
61 }
62
63 func isValidHolders(config *Config) bool {
64 if config.TLS.GetCertHolder != nil && config.TLS.GetCertHolder.GetCert == nil {
65 return false
66 }
67
68 if config.DialHolder != nil && config.DialHolder.Dial == nil {
69 return false
70 }
71
72 return true
73 }
74
75
76
77 func TLSConfigFor(c *Config) (*tls.Config, error) {
78 if !(c.HasCA() || c.HasCertAuth() || c.HasCertCallback() || c.TLS.Insecure || len(c.TLS.ServerName) > 0 || len(c.TLS.NextProtos) > 0) {
79 return nil, nil
80 }
81 if c.HasCA() && c.TLS.Insecure {
82 return nil, fmt.Errorf("specifying a root certificates file with the insecure flag is not allowed")
83 }
84 if err := loadTLSFiles(c); err != nil {
85 return nil, err
86 }
87
88 tlsConfig := &tls.Config{
89
90
91
92 MinVersion: tls.VersionTLS12,
93 InsecureSkipVerify: c.TLS.Insecure,
94 ServerName: c.TLS.ServerName,
95 NextProtos: c.TLS.NextProtos,
96 }
97
98 if c.HasCA() {
99
124
125 rootCAs, err := rootCertPool(c.TLS.CAData)
126 if err != nil {
127 return nil, fmt.Errorf("unable to load root certificates: %w", err)
128 }
129 tlsConfig.RootCAs = rootCAs
130 }
131
132 var staticCert *tls.Certificate
133
134 if c.HasCertAuth() && !c.TLS.ReloadTLSFiles {
135
136
137 cert, err := tls.X509KeyPair(c.TLS.CertData, c.TLS.KeyData)
138 if err != nil {
139 return nil, err
140 }
141 staticCert = &cert
142 }
143
144 var dynamicCertLoader func() (*tls.Certificate, error)
145 if c.TLS.ReloadTLSFiles {
146 dynamicCertLoader = cachingCertificateLoader(c.TLS.CertFile, c.TLS.KeyFile)
147 }
148
149 if c.HasCertAuth() || c.HasCertCallback() {
150
151
178
179 tlsConfig.GetClientCertificate = func(*tls.CertificateRequestInfo) (*tls.Certificate, error) {
180
181
182 if staticCert != nil {
183 return staticCert, nil
184 }
185
186 if dynamicCertLoader != nil {
187 return dynamicCertLoader()
188 }
189 if c.HasCertCallback() {
190 cert, err := c.TLS.GetCertHolder.GetCert()
191 if err != nil {
192 return nil, err
193 }
194
195 if cert != nil {
196 return cert, nil
197 }
198 }
199
200
201
202
203 return &tls.Certificate{}, nil
204 }
205 }
206
207 return tlsConfig, nil
208 }
209
210
211
212
213 func loadTLSFiles(c *Config) error {
214 var err error
215 c.TLS.CAData, err = dataFromSliceOrFile(c.TLS.CAData, c.TLS.CAFile)
216 if err != nil {
217 return err
218 }
219
220
221 if len(c.TLS.CertFile) > 0 && len(c.TLS.CertData) == 0 && len(c.TLS.KeyFile) > 0 && len(c.TLS.KeyData) == 0 {
222 c.TLS.ReloadTLSFiles = true
223 }
224
225 c.TLS.CertData, err = dataFromSliceOrFile(c.TLS.CertData, c.TLS.CertFile)
226 if err != nil {
227 return err
228 }
229
230 c.TLS.KeyData, err = dataFromSliceOrFile(c.TLS.KeyData, c.TLS.KeyFile)
231 return err
232 }
233
234
235
236 func dataFromSliceOrFile(data []byte, file string) ([]byte, error) {
237 if len(data) > 0 {
238 return data, nil
239 }
240 if len(file) > 0 {
241 fileData, err := os.ReadFile(file)
242 if err != nil {
243 return []byte{}, err
244 }
245 return fileData, nil
246 }
247 return nil, nil
248 }
249
250
251
252 func rootCertPool(caData []byte) (*x509.CertPool, error) {
253
254
255
256 if len(caData) == 0 {
257 return nil, nil
258 }
259
260
261 certPool := x509.NewCertPool()
262 if ok := certPool.AppendCertsFromPEM(caData); !ok {
263 return nil, createErrorParsingCAData(caData)
264 }
265 return certPool, nil
266 }
267
268
269
270 func createErrorParsingCAData(pemCerts []byte) error {
271 for len(pemCerts) > 0 {
272 var block *pem.Block
273 block, pemCerts = pem.Decode(pemCerts)
274 if block == nil {
275 return fmt.Errorf("unable to parse bytes as PEM block")
276 }
277
278 if block.Type != "CERTIFICATE" || len(block.Headers) != 0 {
279 continue
280 }
281
282 if _, err := x509.ParseCertificate(block.Bytes); err != nil {
283 return fmt.Errorf("failed to parse certificate: %w", err)
284 }
285 }
286 return fmt.Errorf("no valid certificate authority data seen")
287 }
288
289
290
291
292 type WrapperFunc func(rt http.RoundTripper) http.RoundTripper
293
294
295
296
297
298 func Wrappers(fns ...WrapperFunc) WrapperFunc {
299 if len(fns) == 0 {
300 return nil
301 }
302
303
304 if len(fns) == 2 && fns[0] == nil {
305 return fns[1]
306 }
307 return func(rt http.RoundTripper) http.RoundTripper {
308 base := rt
309 for _, fn := range fns {
310 if fn != nil {
311 base = fn(base)
312 }
313 }
314 return base
315 }
316 }
317
318
319
320
321 func ContextCanceller(ctx context.Context, err error) WrapperFunc {
322 return func(rt http.RoundTripper) http.RoundTripper {
323 return &contextCanceller{
324 ctx: ctx,
325 rt: rt,
326 err: err,
327 }
328 }
329 }
330
331 type contextCanceller struct {
332 ctx context.Context
333 rt http.RoundTripper
334 err error
335 }
336
337 func (b *contextCanceller) RoundTrip(req *http.Request) (*http.Response, error) {
338 select {
339 case <-b.ctx.Done():
340 return nil, b.err
341 default:
342 return b.rt.RoundTrip(req)
343 }
344 }
345
346 func tryCancelRequest(rt http.RoundTripper, req *http.Request) {
347 type canceler interface {
348 CancelRequest(*http.Request)
349 }
350 switch rt := rt.(type) {
351 case canceler:
352 rt.CancelRequest(req)
353 case utilnet.RoundTripperWrapper:
354 tryCancelRequest(rt.WrappedRoundTripper(), req)
355 default:
356 klog.Warningf("Unable to cancel request for %T", rt)
357 }
358 }
359
360 type certificateCacheEntry struct {
361 cert *tls.Certificate
362 err error
363 birth time.Time
364 }
365
366
367 func (c *certificateCacheEntry) isStale() bool {
368 return time.Since(c.birth) > time.Second
369 }
370
371 func newCertificateCacheEntry(certFile, keyFile string) certificateCacheEntry {
372 cert, err := tls.LoadX509KeyPair(certFile, keyFile)
373 return certificateCacheEntry{cert: &cert, err: err, birth: time.Now()}
374 }
375
376
377
378 func cachingCertificateLoader(certFile, keyFile string) func() (*tls.Certificate, error) {
379 current := newCertificateCacheEntry(certFile, keyFile)
380 var currentMtx sync.RWMutex
381
382 return func() (*tls.Certificate, error) {
383 currentMtx.RLock()
384 if current.isStale() {
385 currentMtx.RUnlock()
386
387 currentMtx.Lock()
388 defer currentMtx.Unlock()
389
390 if current.isStale() {
391 current = newCertificateCacheEntry(certFile, keyFile)
392 }
393 } else {
394 defer currentMtx.RUnlock()
395 }
396
397 return current.cert, current.err
398 }
399 }
400
View as plain text