...

Source file src/github.com/MicahParks/keyfunc/get.go

Documentation: github.com/MicahParks/keyfunc

     1  package keyfunc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"net/http"
     8  	"sync"
     9  	"time"
    10  )
    11  
    12  var (
    13  	// defaultRefreshTimeout is the default duration for the context used to create the HTTP request for a refresh of
    14  	// the JWKS.
    15  	defaultRefreshTimeout = time.Minute
    16  )
    17  
    18  // Get loads the JWKS at the given URL.
    19  func Get(jwksURL string, options Options) (jwks *JWKS, err error) {
    20  	jwks = &JWKS{
    21  		jwksURL: jwksURL,
    22  	}
    23  
    24  	applyOptions(jwks, options)
    25  
    26  	if jwks.client == nil {
    27  		jwks.client = http.DefaultClient
    28  	}
    29  	if jwks.requestFactory == nil {
    30  		jwks.requestFactory = defaultRequestFactory
    31  	}
    32  	if jwks.responseExtractor == nil {
    33  		jwks.responseExtractor = ResponseExtractorStatusOK
    34  	}
    35  	if jwks.refreshTimeout == 0 {
    36  		jwks.refreshTimeout = defaultRefreshTimeout
    37  	}
    38  	if !options.JWKUseNoWhitelist && len(jwks.jwkUseWhitelist) == 0 {
    39  		jwks.jwkUseWhitelist = map[JWKUse]struct{}{
    40  			UseOmitted:   {},
    41  			UseSignature: {},
    42  		}
    43  	}
    44  
    45  	err = jwks.refresh()
    46  	if err != nil {
    47  		return nil, err
    48  	}
    49  
    50  	if jwks.refreshInterval != 0 || jwks.refreshUnknownKID {
    51  		jwks.ctx, jwks.cancel = context.WithCancel(context.Background())
    52  		jwks.refreshRequests = make(chan context.CancelFunc, 1)
    53  		go jwks.backgroundRefresh()
    54  	}
    55  
    56  	return jwks, nil
    57  }
    58  
    59  // backgroundRefresh is meant to be a separate goroutine that will update the keys in a JWKS over a given interval of
    60  // time.
    61  func (j *JWKS) backgroundRefresh() {
    62  	var lastRefresh time.Time
    63  	var queueOnce sync.Once
    64  	var refreshMux sync.Mutex
    65  	if j.refreshRateLimit != 0 {
    66  		lastRefresh = time.Now().Add(-j.refreshRateLimit)
    67  	}
    68  
    69  	// Create a channel that will never send anything unless there is a refresh interval.
    70  	refreshInterval := make(<-chan time.Time)
    71  
    72  	// Enter an infinite loop that ends when the background ends.
    73  	for {
    74  		if j.refreshInterval != 0 {
    75  			refreshInterval = time.After(j.refreshInterval)
    76  		}
    77  
    78  		select {
    79  		case <-refreshInterval:
    80  			select {
    81  			case <-j.ctx.Done():
    82  				return
    83  			case j.refreshRequests <- func() {}:
    84  			default: // If the j.refreshRequests channel is full, don't send another request.
    85  			}
    86  
    87  		case cancel := <-j.refreshRequests:
    88  			refreshMux.Lock()
    89  			if j.refreshRateLimit != 0 && lastRefresh.Add(j.refreshRateLimit).After(time.Now()) {
    90  				// Don't make the JWT parsing goroutine wait for the JWKS to refresh.
    91  				cancel()
    92  
    93  				// Launch a goroutine that will get a reservation for a JWKS refresh or fail to and immediately return.
    94  				queueOnce.Do(func() {
    95  					go func() {
    96  						refreshMux.Lock()
    97  						wait := time.Until(lastRefresh.Add(j.refreshRateLimit))
    98  						refreshMux.Unlock()
    99  						select {
   100  						case <-j.ctx.Done():
   101  							return
   102  						case <-time.After(wait):
   103  						}
   104  
   105  						refreshMux.Lock()
   106  						defer refreshMux.Unlock()
   107  						err := j.refresh()
   108  						if err != nil && j.refreshErrorHandler != nil {
   109  							j.refreshErrorHandler(err)
   110  						}
   111  
   112  						lastRefresh = time.Now()
   113  						queueOnce = sync.Once{}
   114  					}()
   115  				})
   116  			} else {
   117  				err := j.refresh()
   118  				if err != nil && j.refreshErrorHandler != nil {
   119  					j.refreshErrorHandler(err)
   120  				}
   121  
   122  				lastRefresh = time.Now()
   123  
   124  				// Allow the JWT parsing goroutine to continue with the refreshed JWKS.
   125  				cancel()
   126  			}
   127  			refreshMux.Unlock()
   128  
   129  		// Clean up this goroutine when its context expires.
   130  		case <-j.ctx.Done():
   131  			return
   132  		}
   133  	}
   134  }
   135  
   136  func defaultRequestFactory(ctx context.Context, url string) (*http.Request, error) {
   137  	return http.NewRequestWithContext(ctx, http.MethodGet, url, bytes.NewReader(nil))
   138  }
   139  
   140  // refresh does an HTTP GET on the JWKS URL to rebuild the JWKS.
   141  func (j *JWKS) refresh() (err error) {
   142  	var ctx context.Context
   143  	var cancel context.CancelFunc
   144  	if j.ctx != nil {
   145  		ctx, cancel = context.WithTimeout(j.ctx, j.refreshTimeout)
   146  	} else {
   147  		ctx, cancel = context.WithTimeout(context.Background(), j.refreshTimeout)
   148  	}
   149  	defer cancel()
   150  
   151  	req, err := j.requestFactory(ctx, j.jwksURL)
   152  	if err != nil {
   153  		return fmt.Errorf("failed to create request via factory function: %w", err)
   154  	}
   155  
   156  	resp, err := j.client.Do(req)
   157  	if err != nil {
   158  		return err
   159  	}
   160  
   161  	jwksBytes, err := j.responseExtractor(ctx, resp)
   162  	if err != nil {
   163  		return fmt.Errorf("failed to extract response via extractor function: %w", err)
   164  	}
   165  
   166  	// Only reprocess if the JWKS has changed.
   167  	if len(jwksBytes) != 0 && bytes.Equal(jwksBytes, j.raw) {
   168  		return nil
   169  	}
   170  	j.raw = jwksBytes
   171  
   172  	updated, err := NewJSON(jwksBytes)
   173  	if err != nil {
   174  		return err
   175  	}
   176  
   177  	j.mux.Lock()
   178  	defer j.mux.Unlock()
   179  	j.keys = updated.keys
   180  
   181  	if j.givenKeys != nil {
   182  		for kid, key := range j.givenKeys {
   183  			// Only overwrite the key if configured to do so.
   184  			if !j.givenKIDOverride {
   185  				if _, ok := j.keys[kid]; ok {
   186  					continue
   187  				}
   188  			}
   189  
   190  			j.keys[kid] = parsedJWK{public: key.inter}
   191  		}
   192  	}
   193  
   194  	return nil
   195  }
   196  

View as plain text