...

Source file src/github.com/MicahParks/keyfunc/jwks.go

Documentation: github.com/MicahParks/keyfunc

     1  package keyfunc
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"errors"
     7  	"fmt"
     8  	"net/http"
     9  	"sync"
    10  	"time"
    11  )
    12  
    13  var (
    14  	// ErrJWKUseWhitelist indicates that the given JWK was found, but its "use" parameter's value was not whitelisted.
    15  	ErrJWKUseWhitelist = errors.New(`the given JWK was found, but its "use" parameter's value was not whitelisted`)
    16  
    17  	// ErrKIDNotFound indicates that the given key ID was not found in the JWKS.
    18  	ErrKIDNotFound = errors.New("the given key ID was not found in the JWKS")
    19  
    20  	// ErrMissingAssets indicates there are required assets are missing to create a public key.
    21  	ErrMissingAssets = errors.New("required assets are missing to create a public key")
    22  )
    23  
    24  // ErrorHandler is a function signature that consumes an error.
    25  type ErrorHandler func(err error)
    26  
    27  const (
    28  	// UseEncryption is a JWK "use" parameter value indicating the JSON Web Key is to be used for encryption.
    29  	UseEncryption JWKUse = "enc"
    30  	// UseOmitted is a JWK "use" parameter value that was not specified or was empty.
    31  	UseOmitted JWKUse = ""
    32  	// UseSignature is a JWK "use" parameter value indicating the JSON Web Key is to be used for signatures.
    33  	UseSignature JWKUse = "sig"
    34  )
    35  
    36  // JWKUse is a set of values for the "use" parameter of a JWK.
    37  // See https://tools.ietf.org/html/rfc7517#section-4.2.
    38  type JWKUse string
    39  
    40  // jsonWebKey represents a JSON Web Key inside a JWKS.
    41  type jsonWebKey struct {
    42  	Curve    string `json:"crv"`
    43  	Exponent string `json:"e"`
    44  	K        string `json:"k"`
    45  	ID       string `json:"kid"`
    46  	Modulus  string `json:"n"`
    47  	Type     string `json:"kty"`
    48  	Use      string `json:"use"`
    49  	X        string `json:"x"`
    50  	Y        string `json:"y"`
    51  }
    52  
    53  // parsedJWK represents a JSON Web Key parsed with fields as the correct Go types.
    54  type parsedJWK struct {
    55  	use    JWKUse
    56  	public interface{}
    57  }
    58  
    59  // JWKS represents a JSON Web Key Set (JWK Set).
    60  type JWKS struct {
    61  	jwkUseWhitelist     map[JWKUse]struct{}
    62  	cancel              context.CancelFunc
    63  	client              *http.Client
    64  	ctx                 context.Context
    65  	raw                 []byte
    66  	givenKeys           map[string]GivenKey
    67  	givenKIDOverride    bool
    68  	jwksURL             string
    69  	keys                map[string]parsedJWK
    70  	mux                 sync.RWMutex
    71  	refreshErrorHandler ErrorHandler
    72  	refreshInterval     time.Duration
    73  	refreshRateLimit    time.Duration
    74  	refreshRequests     chan context.CancelFunc
    75  	refreshTimeout      time.Duration
    76  	refreshUnknownKID   bool
    77  	requestFactory      func(ctx context.Context, url string) (*http.Request, error)
    78  	responseExtractor   func(ctx context.Context, resp *http.Response) (json.RawMessage, error)
    79  }
    80  
    81  // rawJWKS represents a JWKS in JSON format.
    82  type rawJWKS struct {
    83  	Keys []*jsonWebKey `json:"keys"`
    84  }
    85  
    86  // NewJSON creates a new JWKS from a raw JSON message.
    87  func NewJSON(jwksBytes json.RawMessage) (jwks *JWKS, err error) {
    88  	var rawKS rawJWKS
    89  	err = json.Unmarshal(jwksBytes, &rawKS)
    90  	if err != nil {
    91  		return nil, err
    92  	}
    93  
    94  	// Iterate through the keys in the raw JWKS. Add them to the JWKS.
    95  	jwks = &JWKS{
    96  		keys: make(map[string]parsedJWK, len(rawKS.Keys)),
    97  	}
    98  	for _, key := range rawKS.Keys {
    99  		var keyInter interface{}
   100  		switch keyType := key.Type; keyType {
   101  		case ktyEC:
   102  			keyInter, err = key.ECDSA()
   103  			if err != nil {
   104  				continue
   105  			}
   106  		case ktyOKP:
   107  			keyInter, err = key.EdDSA()
   108  			if err != nil {
   109  				continue
   110  			}
   111  		case ktyOct:
   112  			keyInter, err = key.Oct()
   113  			if err != nil {
   114  				continue
   115  			}
   116  		case ktyRSA:
   117  			keyInter, err = key.RSA()
   118  			if err != nil {
   119  				continue
   120  			}
   121  		default:
   122  			// Ignore unknown key types silently.
   123  			continue
   124  		}
   125  
   126  		jwks.keys[key.ID] = parsedJWK{
   127  			use:    JWKUse(key.Use),
   128  			public: keyInter,
   129  		}
   130  	}
   131  
   132  	return jwks, nil
   133  }
   134  
   135  // EndBackground ends the background goroutine to update the JWKS. It can only happen once and is only effective if the
   136  // JWKS has a background goroutine refreshing the JWKS keys.
   137  func (j *JWKS) EndBackground() {
   138  	if j.cancel != nil {
   139  		j.cancel()
   140  	}
   141  }
   142  
   143  // KIDs returns the key IDs (`kid`) for all keys in the JWKS.
   144  func (j *JWKS) KIDs() (kids []string) {
   145  	j.mux.RLock()
   146  	defer j.mux.RUnlock()
   147  	kids = make([]string, len(j.keys))
   148  	index := 0
   149  	for kid := range j.keys {
   150  		kids[index] = kid
   151  		index++
   152  	}
   153  	return kids
   154  }
   155  
   156  // Len returns the number of keys in the JWKS.
   157  func (j *JWKS) Len() int {
   158  	j.mux.RLock()
   159  	defer j.mux.RUnlock()
   160  	return len(j.keys)
   161  }
   162  
   163  // RawJWKS returns a copy of the raw JWKS received from the given JWKS URL.
   164  func (j *JWKS) RawJWKS() []byte {
   165  	j.mux.RLock()
   166  	defer j.mux.RUnlock()
   167  	raw := make([]byte, len(j.raw))
   168  	copy(raw, j.raw)
   169  	return raw
   170  }
   171  
   172  // ReadOnlyKeys returns a read-only copy of the mapping of key IDs (`kid`) to cryptographic keys.
   173  func (j *JWKS) ReadOnlyKeys() map[string]interface{} {
   174  	keys := make(map[string]interface{})
   175  	j.mux.Lock()
   176  	for kid, cryptoKey := range j.keys {
   177  		keys[kid] = cryptoKey.public
   178  	}
   179  	j.mux.Unlock()
   180  	return keys
   181  }
   182  
   183  // getKey gets the jsonWebKey from the given KID from the JWKS. It may refresh the JWKS if configured to.
   184  func (j *JWKS) getKey(kid string) (jsonKey interface{}, err error) {
   185  	j.mux.RLock()
   186  	pubKey, ok := j.keys[kid]
   187  	j.mux.RUnlock()
   188  
   189  	if !ok {
   190  		if !j.refreshUnknownKID {
   191  			return nil, ErrKIDNotFound
   192  		}
   193  
   194  		ctx, cancel := context.WithCancel(j.ctx)
   195  
   196  		// Refresh the JWKS.
   197  		select {
   198  		case <-j.ctx.Done():
   199  			return
   200  		case j.refreshRequests <- cancel:
   201  		default:
   202  			// If the j.refreshRequests channel is full, return the error early.
   203  			return nil, ErrKIDNotFound
   204  		}
   205  
   206  		// Wait for the JWKS refresh to finish.
   207  		<-ctx.Done()
   208  
   209  		j.mux.RLock()
   210  		defer j.mux.RUnlock()
   211  		if pubKey, ok = j.keys[kid]; !ok {
   212  			return nil, ErrKIDNotFound
   213  		}
   214  	}
   215  
   216  	// jwkUseWhitelist might be empty if the jwks was from keyfunc.NewJSON() or if JWKUseNoWhitelist option was true.
   217  	if len(j.jwkUseWhitelist) > 0 {
   218  		_, ok = j.jwkUseWhitelist[pubKey.use]
   219  		if !ok {
   220  			return nil, fmt.Errorf(`%w: JWK "use" parameter value %q is not whitelisted`, ErrJWKUseWhitelist, pubKey.use)
   221  		}
   222  	}
   223  
   224  	return pubKey.public, nil
   225  }
   226  

View as plain text