...

Source file src/github.com/MicahParks/keyfunc/v2/jwks.go

Documentation: github.com/MicahParks/keyfunc/v2

     1  package keyfunc
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net/http"
     9  	"sync"
    10  	"time"
    11  )
    12  
    13  var (
    14  	// ErrJWKAlgMismatch indicates that the given JWK was found, but its "alg" parameter's value did not match that of
    15  	// the JWT.
    16  	ErrJWKAlgMismatch = errors.New(`the given JWK was found, but its "alg" parameter's value did not match the expected algorithm`)
    17  
    18  	// ErrJWKUseWhitelist indicates that the given JWK was found, but its "use" parameter's value was not whitelisted.
    19  	ErrJWKUseWhitelist = errors.New(`the given JWK was found, but its "use" parameter's value was not whitelisted`)
    20  
    21  	// ErrKIDNotFound indicates that the given key ID was not found in the JWKS.
    22  	ErrKIDNotFound = errors.New("the given key ID was not found in the JWKS")
    23  
    24  	// ErrMissingAssets indicates there are required assets are missing to create a public key.
    25  	ErrMissingAssets = errors.New("required assets are missing to create a public key")
    26  )
    27  
    28  // ErrorHandler is a function signature that consumes an error.
    29  type ErrorHandler func(err error)
    30  
    31  const (
    32  	// UseEncryption is a JWK "use" parameter value indicating the JSON Web Key is to be used for encryption.
    33  	UseEncryption JWKUse = "enc"
    34  	// UseOmitted is a JWK "use" parameter value that was not specified or was empty.
    35  	UseOmitted JWKUse = ""
    36  	// UseSignature is a JWK "use" parameter value indicating the JSON Web Key is to be used for signatures.
    37  	UseSignature JWKUse = "sig"
    38  )
    39  
    40  // JWKUse is a set of values for the "use" parameter of a JWK.
    41  // See https://tools.ietf.org/html/rfc7517#section-4.2.
    42  type JWKUse string
    43  
    44  // jsonWebKey represents a JSON Web Key inside a JWKS.
    45  type jsonWebKey struct {
    46  	Algorithm string `json:"alg"`
    47  	Curve     string `json:"crv"`
    48  	Exponent  string `json:"e"`
    49  	K         string `json:"k"`
    50  	ID        string `json:"kid"`
    51  	Modulus   string `json:"n"`
    52  	Type      string `json:"kty"`
    53  	Use       string `json:"use"`
    54  	X         string `json:"x"`
    55  	Y         string `json:"y"`
    56  }
    57  
    58  // parsedJWK represents a JSON Web Key parsed with fields as the correct Go types.
    59  type parsedJWK struct {
    60  	algorithm string
    61  	public    interface{}
    62  	use       JWKUse
    63  }
    64  
    65  // JWKS represents a JSON Web Key Set (JWK Set).
    66  type JWKS struct {
    67  	jwkUseWhitelist     map[JWKUse]struct{}
    68  	cancel              context.CancelFunc
    69  	client              *http.Client
    70  	ctx                 context.Context
    71  	raw                 []byte
    72  	givenKeys           map[string]GivenKey
    73  	givenKIDOverride    bool
    74  	jwksURL             string
    75  	keys                map[string]parsedJWK
    76  	mux                 sync.RWMutex
    77  	refreshErrorHandler ErrorHandler
    78  	refreshInterval     time.Duration
    79  	refreshRateLimit    time.Duration
    80  	refreshRequests     chan refreshRequest
    81  	refreshTimeout      time.Duration
    82  	refreshUnknownKID   bool
    83  	requestFactory      func(ctx context.Context, url string) (*http.Request, error)
    84  	responseExtractor   func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
    85  }
    86  
    87  // rawJWKS represents a JWKS in JSON format.
    88  type rawJWKS struct {
    89  	Keys []*jsonWebKey `json:"keys"`
    90  }
    91  
    92  // NewJSON creates a new JWKS from a raw JSON message.
    93  func NewJSON(jwksBytes json.RawMessage) (jwks *JWKS, err error) {
    94  	var rawKS rawJWKS
    95  	err = json.Unmarshal(jwksBytes, &rawKS)
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	// Iterate through the keys in the raw JWKS. Add them to the JWKS.
   101  	jwks = &JWKS{
   102  		keys: make(map[string]parsedJWK, len(rawKS.Keys)),
   103  	}
   104  	for _, key := range rawKS.Keys {
   105  		var keyInter interface{}
   106  		switch keyType := key.Type; keyType {
   107  		case ktyEC:
   108  			keyInter, err = key.ECDSA()
   109  			if err != nil {
   110  				continue
   111  			}
   112  		case ktyOKP:
   113  			keyInter, err = key.EdDSA()
   114  			if err != nil {
   115  				continue
   116  			}
   117  		case ktyOct:
   118  			keyInter, err = key.Oct()
   119  			if err != nil {
   120  				continue
   121  			}
   122  		case ktyRSA:
   123  			keyInter, err = key.RSA()
   124  			if err != nil {
   125  				continue
   126  			}
   127  		default:
   128  			// Ignore unknown key types silently.
   129  			continue
   130  		}
   131  
   132  		jwks.keys[key.ID] = parsedJWK{
   133  			algorithm: key.Algorithm,
   134  			use:       JWKUse(key.Use),
   135  			public:    keyInter,
   136  		}
   137  	}
   138  
   139  	return jwks, nil
   140  }
   141  
   142  // EndBackground ends the background goroutine to update the JWKS. It can only happen once and is only effective if the
   143  // JWKS has a background goroutine refreshing the JWKS keys.
   144  func (j *JWKS) EndBackground() {
   145  	if j.cancel != nil {
   146  		j.cancel()
   147  	}
   148  }
   149  
   150  // KIDs returns the key IDs (`kid`) for all keys in the JWKS.
   151  func (j *JWKS) KIDs() (kids []string) {
   152  	j.mux.RLock()
   153  	defer j.mux.RUnlock()
   154  	kids = make([]string, len(j.keys))
   155  	index := 0
   156  	for kid := range j.keys {
   157  		kids[index] = kid
   158  		index++
   159  	}
   160  	return kids
   161  }
   162  
   163  // Len returns the number of keys in the JWKS.
   164  func (j *JWKS) Len() int {
   165  	j.mux.RLock()
   166  	defer j.mux.RUnlock()
   167  	return len(j.keys)
   168  }
   169  
   170  // RawJWKS returns a copy of the raw JWKS received from the given JWKS URL.
   171  func (j *JWKS) RawJWKS() []byte {
   172  	j.mux.RLock()
   173  	defer j.mux.RUnlock()
   174  	raw := make([]byte, len(j.raw))
   175  	copy(raw, j.raw)
   176  	return raw
   177  }
   178  
   179  // ReadOnlyKeys returns a read-only copy of the mapping of key IDs (`kid`) to cryptographic keys.
   180  func (j *JWKS) ReadOnlyKeys() map[string]interface{} {
   181  	keys := make(map[string]interface{})
   182  	j.mux.Lock()
   183  	for kid, cryptoKey := range j.keys {
   184  		keys[kid] = cryptoKey.public
   185  	}
   186  	j.mux.Unlock()
   187  	return keys
   188  }
   189  
   190  // getKey gets the jsonWebKey from the given KID from the JWKS. It may refresh the JWKS if configured to.
   191  func (j *JWKS) getKey(alg, kid string) (jsonKey interface{}, err error) {
   192  	j.mux.RLock()
   193  	pubKey, ok := j.keys[kid]
   194  	j.mux.RUnlock()
   195  
   196  	if !ok {
   197  		if !j.refreshUnknownKID {
   198  			return nil, ErrKIDNotFound
   199  		}
   200  
   201  		ctx, cancel := context.WithCancel(j.ctx)
   202  		req := refreshRequest{
   203  			cancel: cancel,
   204  		}
   205  
   206  		// Refresh the JWKS.
   207  		select {
   208  		case <-j.ctx.Done():
   209  			return
   210  		case j.refreshRequests <- req:
   211  		default:
   212  			// If the j.refreshRequests channel is full, return the error early.
   213  			return nil, ErrKIDNotFound
   214  		}
   215  
   216  		// Wait for the JWKS refresh to finish.
   217  		<-ctx.Done()
   218  
   219  		j.mux.RLock()
   220  		defer j.mux.RUnlock()
   221  		if pubKey, ok = j.keys[kid]; !ok {
   222  			return nil, ErrKIDNotFound
   223  		}
   224  	}
   225  
   226  	// jwkUseWhitelist might be empty if the jwks was from keyfunc.NewJSON() or if JWKUseNoWhitelist option was true.
   227  	if len(j.jwkUseWhitelist) > 0 {
   228  		_, ok = j.jwkUseWhitelist[pubKey.use]
   229  		if !ok {
   230  			return nil, fmt.Errorf(`%w: JWK "use" parameter value %q is not whitelisted`, ErrJWKUseWhitelist, pubKey.use)
   231  		}
   232  	}
   233  
   234  	if pubKey.algorithm != "" && pubKey.algorithm != alg {
   235  		return nil, fmt.Errorf(`%w: JWK "alg" parameter value %q does not match token "alg" parameter value %q`, ErrJWKAlgMismatch, pubKey.algorithm, alg)
   236  	}
   237  
   238  	return pubKey.public, nil
   239  }
   240  

View as plain text