...

Source file src/github.com/MicahParks/keyfunc/v2/get.go

Documentation: github.com/MicahParks/keyfunc/v2

     1  package keyfunc
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"net/http"
     9  	"sync"
    10  	"time"
    11  )
    12  
    13  var (
    14  	// ErrRefreshImpossible is returned when a refresh is attempted on a JWKS that was not created from a remote
    15  	// resource.
    16  	ErrRefreshImpossible = errors.New("refresh impossible: JWKS was not created from a remote resource")
    17  
    18  	// defaultRefreshTimeout is the default duration for the context used to create the HTTP request for a refresh of
    19  	// the JWKS.
    20  	defaultRefreshTimeout = time.Minute
    21  )
    22  
    23  // Get loads the JWKS at the given URL.
    24  func Get(jwksURL string, options Options) (jwks *JWKS, err error) {
    25  	jwks = &JWKS{
    26  		jwksURL: jwksURL,
    27  	}
    28  
    29  	applyOptions(jwks, options)
    30  
    31  	if jwks.client == nil {
    32  		jwks.client = http.DefaultClient
    33  	}
    34  	if jwks.requestFactory == nil {
    35  		jwks.requestFactory = defaultRequestFactory
    36  	}
    37  	if jwks.responseExtractor == nil {
    38  		jwks.responseExtractor = ResponseExtractorStatusOK
    39  	}
    40  	if jwks.refreshTimeout == 0 {
    41  		jwks.refreshTimeout = defaultRefreshTimeout
    42  	}
    43  	if !options.JWKUseNoWhitelist && len(jwks.jwkUseWhitelist) == 0 {
    44  		jwks.jwkUseWhitelist = map[JWKUse]struct{}{
    45  			UseOmitted:   {},
    46  			UseSignature: {},
    47  		}
    48  	}
    49  
    50  	err = jwks.refresh()
    51  	if err != nil {
    52  		if options.TolerateInitialJWKHTTPError {
    53  			if jwks.refreshErrorHandler != nil {
    54  				jwks.refreshErrorHandler(err)
    55  			}
    56  			jwks.keys = make(map[string]parsedJWK)
    57  		} else {
    58  			return nil, err
    59  		}
    60  	}
    61  
    62  	if jwks.refreshInterval != 0 || jwks.refreshUnknownKID {
    63  		if jwks.ctx == nil {
    64  			jwks.ctx = context.Background()
    65  		}
    66  		jwks.ctx, jwks.cancel = context.WithCancel(jwks.ctx)
    67  		jwks.refreshRequests = make(chan refreshRequest, 1)
    68  		go jwks.backgroundRefresh()
    69  	}
    70  
    71  	return jwks, nil
    72  }
    73  
    74  // Refresh manually refreshes the JWKS with the remote resource. It can bypass the rate limit if configured to do so.
    75  // This function will return an ErrRefreshImpossible if the JWKS was created from a static source like given keys or raw
    76  // JSON, because there is no remote resource to refresh from.
    77  //
    78  // This function will block until the refresh is finished or an error occurs.
    79  func (j *JWKS) Refresh(ctx context.Context, options RefreshOptions) error {
    80  	if j.jwksURL == "" {
    81  		return ErrRefreshImpossible
    82  	}
    83  
    84  	// Check if the background goroutine was launched.
    85  	if j.refreshInterval != 0 || j.refreshUnknownKID {
    86  		ctx, cancel := context.WithCancel(ctx)
    87  
    88  		req := refreshRequest{
    89  			cancel:          cancel,
    90  			ignoreRateLimit: options.IgnoreRateLimit,
    91  		}
    92  
    93  		select {
    94  		case <-ctx.Done():
    95  			return fmt.Errorf("failed to send request refresh to background goroutine: %w", j.ctx.Err())
    96  		case j.refreshRequests <- req:
    97  		}
    98  
    99  		<-ctx.Done()
   100  
   101  		if !errors.Is(ctx.Err(), context.Canceled) {
   102  			return fmt.Errorf("unexpected keyfunc background refresh context error: %w", ctx.Err())
   103  		}
   104  	} else {
   105  		err := j.refresh()
   106  		if err != nil {
   107  			return fmt.Errorf("failed to refresh JWKS: %w", err)
   108  		}
   109  	}
   110  
   111  	return nil
   112  }
   113  
   114  // backgroundRefresh is meant to be a separate goroutine that will update the keys in a JWKS over a given interval of
   115  // time.
   116  func (j *JWKS) backgroundRefresh() {
   117  	var lastRefresh time.Time
   118  	var queueOnce sync.Once
   119  	var refreshMux sync.Mutex
   120  	if j.refreshRateLimit != 0 {
   121  		lastRefresh = time.Now().Add(-j.refreshRateLimit)
   122  	}
   123  
   124  	// Create a channel that will never send anything unless there is a refresh interval.
   125  	refreshInterval := make(<-chan time.Time)
   126  
   127  	refresh := func() {
   128  		err := j.refresh()
   129  		if err != nil && j.refreshErrorHandler != nil {
   130  			j.refreshErrorHandler(err)
   131  		}
   132  		lastRefresh = time.Now()
   133  	}
   134  
   135  	// Enter an infinite loop that ends when the background ends.
   136  	for {
   137  		if j.refreshInterval != 0 {
   138  			refreshInterval = time.After(j.refreshInterval)
   139  		}
   140  
   141  		select {
   142  		case <-refreshInterval:
   143  			select {
   144  			case <-j.ctx.Done():
   145  				return
   146  			case j.refreshRequests <- refreshRequest{}:
   147  			default: // If the j.refreshRequests channel is full, don't send another request.
   148  			}
   149  
   150  		case req := <-j.refreshRequests:
   151  			refreshMux.Lock()
   152  			if req.ignoreRateLimit {
   153  				refresh()
   154  			} else if j.refreshRateLimit != 0 && lastRefresh.Add(j.refreshRateLimit).After(time.Now()) {
   155  				// Launch a goroutine that will get a reservation for a JWKS refresh or fail to and immediately return.
   156  				queueOnce.Do(func() {
   157  					go func() {
   158  						refreshMux.Lock()
   159  						wait := time.Until(lastRefresh.Add(j.refreshRateLimit))
   160  						refreshMux.Unlock()
   161  						select {
   162  						case <-j.ctx.Done():
   163  							return
   164  						case <-time.After(wait):
   165  						}
   166  
   167  						refreshMux.Lock()
   168  						defer refreshMux.Unlock()
   169  						refresh()
   170  						queueOnce = sync.Once{}
   171  					}()
   172  				})
   173  			} else {
   174  				refresh()
   175  			}
   176  			if req.cancel != nil {
   177  				req.cancel()
   178  			}
   179  			refreshMux.Unlock()
   180  
   181  		// Clean up this goroutine when its context expires.
   182  		case <-j.ctx.Done():
   183  			return
   184  		}
   185  	}
   186  }
   187  
   188  func defaultRequestFactory(ctx context.Context, url string) (*http.Request, error) {
   189  	return http.NewRequestWithContext(ctx, http.MethodGet, url, bytes.NewReader(nil))
   190  }
   191  
   192  // refresh does an HTTP GET on the JWKS URL to rebuild the JWKS.
   193  func (j *JWKS) refresh() (err error) {
   194  	var ctx context.Context
   195  	var cancel context.CancelFunc
   196  	if j.ctx != nil {
   197  		ctx, cancel = context.WithTimeout(j.ctx, j.refreshTimeout)
   198  	} else {
   199  		ctx, cancel = context.WithTimeout(context.Background(), j.refreshTimeout)
   200  	}
   201  	defer cancel()
   202  
   203  	req, err := j.requestFactory(ctx, j.jwksURL)
   204  	if err != nil {
   205  		return fmt.Errorf("failed to create request via factory function: %w", err)
   206  	}
   207  
   208  	resp, err := j.client.Do(req)
   209  	if err != nil {
   210  		return err
   211  	}
   212  
   213  	jwksBytes, err := j.responseExtractor(ctx, resp)
   214  	if err != nil {
   215  		return fmt.Errorf("failed to extract response via extractor function: %w", err)
   216  	}
   217  
   218  	// Only reprocess if the JWKS has changed.
   219  	if len(jwksBytes) != 0 && bytes.Equal(jwksBytes, j.raw) {
   220  		return nil
   221  	}
   222  	j.raw = jwksBytes
   223  
   224  	updated, err := NewJSON(jwksBytes)
   225  	if err != nil {
   226  		return err
   227  	}
   228  
   229  	j.mux.Lock()
   230  	defer j.mux.Unlock()
   231  	j.keys = updated.keys
   232  
   233  	if j.givenKeys != nil {
   234  		for kid, key := range j.givenKeys {
   235  			// Only overwrite the key if configured to do so.
   236  			if !j.givenKIDOverride {
   237  				if _, ok := j.keys[kid]; ok {
   238  					continue
   239  				}
   240  			}
   241  
   242  			j.keys[kid] = parsedJWK{public: key.inter}
   243  		}
   244  	}
   245  
   246  	return nil
   247  }
   248  

View as plain text