...

Source file src/sigs.k8s.io/controller-runtime/pkg/certwatcher/certwatcher.go

Documentation: sigs.k8s.io/controller-runtime/pkg/certwatcher

     1  /*
     2  Copyright 2021 The Kubernetes Authors.
     3  
     4  Licensed under the Apache License, Version 2.0 (the "License");
     5  you may not use this file except in compliance with the License.
     6  You may obtain a copy of the License at
     7  
     8      http://www.apache.org/licenses/LICENSE-2.0
     9  
    10  Unless required by applicable law or agreed to in writing, software
    11  distributed under the License is distributed on an "AS IS" BASIS,
    12  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    13  See the License for the specific language governing permissions and
    14  limitations under the License.
    15  */
    16  
    17  package certwatcher
    18  
    19  import (
    20  	"context"
    21  	"crypto/tls"
    22  	"fmt"
    23  	"sync"
    24  	"time"
    25  
    26  	"github.com/fsnotify/fsnotify"
    27  	kerrors "k8s.io/apimachinery/pkg/util/errors"
    28  	"k8s.io/apimachinery/pkg/util/sets"
    29  	"k8s.io/apimachinery/pkg/util/wait"
    30  	"sigs.k8s.io/controller-runtime/pkg/certwatcher/metrics"
    31  	logf "sigs.k8s.io/controller-runtime/pkg/internal/log"
    32  )
    33  
    34  var log = logf.RuntimeLog.WithName("certwatcher")
    35  
    36  // CertWatcher watches certificate and key files for changes.  When either file
    37  // changes, it reads and parses both and calls an optional callback with the new
    38  // certificate.
    39  type CertWatcher struct {
    40  	sync.RWMutex
    41  
    42  	currentCert *tls.Certificate
    43  	watcher     *fsnotify.Watcher
    44  
    45  	certPath string
    46  	keyPath  string
    47  
    48  	// callback is a function to be invoked when the certificate changes.
    49  	callback func(tls.Certificate)
    50  }
    51  
    52  // New returns a new CertWatcher watching the given certificate and key.
    53  func New(certPath, keyPath string) (*CertWatcher, error) {
    54  	var err error
    55  
    56  	cw := &CertWatcher{
    57  		certPath: certPath,
    58  		keyPath:  keyPath,
    59  	}
    60  
    61  	// Initial read of certificate and key.
    62  	if err := cw.ReadCertificate(); err != nil {
    63  		return nil, err
    64  	}
    65  
    66  	cw.watcher, err = fsnotify.NewWatcher()
    67  	if err != nil {
    68  		return nil, err
    69  	}
    70  
    71  	return cw, nil
    72  }
    73  
    74  // RegisterCallback registers a callback to be invoked when the certificate changes.
    75  func (cw *CertWatcher) RegisterCallback(callback func(tls.Certificate)) {
    76  	cw.Lock()
    77  	defer cw.Unlock()
    78  	// If the current certificate is not nil, invoke the callback immediately.
    79  	if cw.currentCert != nil {
    80  		callback(*cw.currentCert)
    81  	}
    82  	cw.callback = callback
    83  }
    84  
    85  // GetCertificate fetches the currently loaded certificate, which may be nil.
    86  func (cw *CertWatcher) GetCertificate(_ *tls.ClientHelloInfo) (*tls.Certificate, error) {
    87  	cw.RLock()
    88  	defer cw.RUnlock()
    89  	return cw.currentCert, nil
    90  }
    91  
    92  // Start starts the watch on the certificate and key files.
    93  func (cw *CertWatcher) Start(ctx context.Context) error {
    94  	files := sets.New(cw.certPath, cw.keyPath)
    95  
    96  	{
    97  		var watchErr error
    98  		if err := wait.PollUntilContextTimeout(ctx, 1*time.Second, 10*time.Second, true, func(ctx context.Context) (done bool, err error) {
    99  			for _, f := range files.UnsortedList() {
   100  				if err := cw.watcher.Add(f); err != nil {
   101  					watchErr = err
   102  					return false, nil //nolint:nilerr // We want to keep trying.
   103  				}
   104  				// We've added the watch, remove it from the set.
   105  				files.Delete(f)
   106  			}
   107  			return true, nil
   108  		}); err != nil {
   109  			return fmt.Errorf("failed to add watches: %w", kerrors.NewAggregate([]error{err, watchErr}))
   110  		}
   111  	}
   112  
   113  	go cw.Watch()
   114  
   115  	log.Info("Starting certificate watcher")
   116  
   117  	// Block until the context is done.
   118  	<-ctx.Done()
   119  
   120  	return cw.watcher.Close()
   121  }
   122  
   123  // Watch reads events from the watcher's channel and reacts to changes.
   124  func (cw *CertWatcher) Watch() {
   125  	for {
   126  		select {
   127  		case event, ok := <-cw.watcher.Events:
   128  			// Channel is closed.
   129  			if !ok {
   130  				return
   131  			}
   132  
   133  			cw.handleEvent(event)
   134  
   135  		case err, ok := <-cw.watcher.Errors:
   136  			// Channel is closed.
   137  			if !ok {
   138  				return
   139  			}
   140  
   141  			log.Error(err, "certificate watch error")
   142  		}
   143  	}
   144  }
   145  
   146  // ReadCertificate reads the certificate and key files from disk, parses them,
   147  // and updates the current certificate on the watcher.  If a callback is set, it
   148  // is invoked with the new certificate.
   149  func (cw *CertWatcher) ReadCertificate() error {
   150  	metrics.ReadCertificateTotal.Inc()
   151  	cert, err := tls.LoadX509KeyPair(cw.certPath, cw.keyPath)
   152  	if err != nil {
   153  		metrics.ReadCertificateErrors.Inc()
   154  		return err
   155  	}
   156  
   157  	cw.Lock()
   158  	cw.currentCert = &cert
   159  	cw.Unlock()
   160  
   161  	log.Info("Updated current TLS certificate")
   162  
   163  	// If a callback is registered, invoke it with the new certificate.
   164  	cw.RLock()
   165  	defer cw.RUnlock()
   166  	if cw.callback != nil {
   167  		go func() {
   168  			cw.callback(cert)
   169  		}()
   170  	}
   171  	return nil
   172  }
   173  
   174  func (cw *CertWatcher) handleEvent(event fsnotify.Event) {
   175  	// Only care about events which may modify the contents of the file.
   176  	if !(isWrite(event) || isRemove(event) || isCreate(event)) {
   177  		return
   178  	}
   179  
   180  	log.V(1).Info("certificate event", "event", event)
   181  
   182  	// If the file was removed, re-add the watch.
   183  	if isRemove(event) {
   184  		if err := cw.watcher.Add(event.Name); err != nil {
   185  			log.Error(err, "error re-watching file")
   186  		}
   187  	}
   188  
   189  	if err := cw.ReadCertificate(); err != nil {
   190  		log.Error(err, "error re-reading certificate")
   191  	}
   192  }
   193  
   194  func isWrite(event fsnotify.Event) bool {
   195  	return event.Op.Has(fsnotify.Write)
   196  }
   197  
   198  func isCreate(event fsnotify.Event) bool {
   199  	return event.Op.Has(fsnotify.Create)
   200  }
   201  
   202  func isRemove(event fsnotify.Event) bool {
   203  	return event.Op.Has(fsnotify.Remove)
   204  }
   205  

View as plain text