...

Source file src/github.com/lestrrat-go/jwx/jwt/jwt.go

Documentation: github.com/lestrrat-go/jwx/jwt

     1  //go:generate ./gen.sh
     2  
     3  // Package jwt implements JSON Web Tokens as described in https://tools.ietf.org/html/rfc7519
     4  package jwt
     5  
     6  import (
     7  	"bytes"
     8  	"io"
     9  	"io/ioutil"
    10  	"net/http"
    11  	"strings"
    12  	"sync/atomic"
    13  
    14  	"github.com/lestrrat-go/backoff/v2"
    15  	"github.com/lestrrat-go/jwx"
    16  	"github.com/lestrrat-go/jwx/internal/json"
    17  	"github.com/lestrrat-go/jwx/jwe"
    18  
    19  	"github.com/lestrrat-go/jwx/jwa"
    20  	"github.com/lestrrat-go/jwx/jwk"
    21  	"github.com/lestrrat-go/jwx/jws"
    22  	"github.com/pkg/errors"
    23  )
    24  
    25  const _jwt = `jwt`
    26  
    27  // Settings controls global settings that are specific to JWTs.
    28  func Settings(options ...GlobalOption) {
    29  	var flattenAudienceBool bool
    30  
    31  	//nolint:forcetypeassert
    32  	for _, option := range options {
    33  		switch option.Ident() {
    34  		case identFlattenAudience{}:
    35  			flattenAudienceBool = option.Value().(bool)
    36  		}
    37  	}
    38  
    39  	v := atomic.LoadUint32(&json.FlattenAudience)
    40  	if (v == 1) != flattenAudienceBool {
    41  		var newVal uint32
    42  		if flattenAudienceBool {
    43  			newVal = 1
    44  		}
    45  		atomic.CompareAndSwapUint32(&json.FlattenAudience, v, newVal)
    46  	}
    47  }
    48  
    49  var registry = json.NewRegistry()
    50  
    51  // ParseString calls Parse against a string
    52  func ParseString(s string, options ...ParseOption) (Token, error) {
    53  	return parseBytes([]byte(s), options...)
    54  }
    55  
    56  // Parse parses the JWT token payload and creates a new `jwt.Token` object.
    57  // The token must be encoded in either JSON format or compact format.
    58  //
    59  // This function can work with encrypted and/or signed tokens. Any combination
    60  // of JWS and JWE may be applied to the token, but this function will only
    61  // attempt to verify/decrypt up to 2 levels (i.e. JWS only, JWE only, JWS then
    62  // JWE, or JWE then JWS)
    63  //
    64  // If the token is signed and you want to verify the payload matches the signature,
    65  // you must pass the jwt.WithVerify(alg, key) or jwt.WithKeySet(jwk.Set) option.
    66  // If you do not specify these parameters, no verification will be performed.
    67  //
    68  // During verification, if the JWS headers specify a key ID (`kid`), the
    69  // key used for verification must match the specified ID. If you are somehow
    70  // using a key without a `kid` (which is highly unlikely if you are working
    71  // with a JWT from a well know provider), you can workaround this by modifying
    72  // the `jwk.Key` and setting the `kid` header.
    73  //
    74  // If you also want to assert the validity of the JWT itself (i.e. expiration
    75  // and such), use the `Validate()` function on the returned token, or pass the
    76  // `WithValidate(true)` option. Validate options can also be passed to
    77  // `Parse`
    78  //
    79  // This function takes both ParseOption and ValidateOption types:
    80  // ParseOptions control the parsing behavior, and ValidateOptions are
    81  // passed to `Validate()` when `jwt.WithValidate` is specified.
    82  func Parse(s []byte, options ...ParseOption) (Token, error) {
    83  	return parseBytes(s, options...)
    84  }
    85  
    86  // ParseReader calls Parse against an io.Reader
    87  func ParseReader(src io.Reader, options ...ParseOption) (Token, error) {
    88  	// We're going to need the raw bytes regardless. Read it.
    89  	data, err := ioutil.ReadAll(src)
    90  	if err != nil {
    91  		return nil, errors.Wrap(err, `failed to read from token data source`)
    92  	}
    93  	return parseBytes(data, options...)
    94  }
    95  
    96  type parseCtx struct {
    97  	decryptParams    DecryptParameters
    98  	verifyParams     VerifyParameters
    99  	keySet           jwk.Set
   100  	keySetProvider   KeySetProvider
   101  	token            Token
   102  	validateOpts     []ValidateOption
   103  	verifyAutoOpts   []jws.VerifyOption
   104  	localReg         *json.Registry
   105  	inferAlgorithm   bool
   106  	pedantic         bool
   107  	skipVerification bool
   108  	useDefault       bool
   109  	validate         bool
   110  	verifyAuto       bool
   111  }
   112  
   113  func parseBytes(data []byte, options ...ParseOption) (Token, error) {
   114  	var ctx parseCtx
   115  	for _, o := range options {
   116  		if v, ok := o.(ValidateOption); ok {
   117  			ctx.validateOpts = append(ctx.validateOpts, v)
   118  			continue
   119  		}
   120  
   121  		//nolint:forcetypeassert
   122  		switch o.Ident() {
   123  		case identVerifyAuto{}:
   124  			ctx.verifyAuto = o.Value().(bool)
   125  		case identFetchWhitelist{}:
   126  			ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithFetchWhitelist(o.Value().(jwk.Whitelist)))
   127  		case identHTTPClient{}:
   128  			ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithHTTPClient(o.Value().(*http.Client)))
   129  		case identFetchBackoff{}:
   130  			ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithFetchBackoff(o.Value().(backoff.Policy)))
   131  		case identJWKSetFetcher{}:
   132  			ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithJWKSetFetcher(o.Value().(jws.JWKSetFetcher)))
   133  		case identVerify{}:
   134  			ctx.verifyParams = o.Value().(VerifyParameters)
   135  		case identDecrypt{}:
   136  			ctx.decryptParams = o.Value().(DecryptParameters)
   137  		case identKeySet{}:
   138  			ks, ok := o.Value().(jwk.Set)
   139  			if !ok {
   140  				return nil, errors.Errorf(`invalid JWK set passed via WithKeySet() option (%T)`, o.Value())
   141  			}
   142  			ctx.keySet = ks
   143  		case identToken{}:
   144  			token, ok := o.Value().(Token)
   145  			if !ok {
   146  				return nil, errors.Errorf(`invalid token passed via WithToken() option (%T)`, o.Value())
   147  			}
   148  			ctx.token = token
   149  		case identPedantic{}:
   150  			ctx.pedantic = o.Value().(bool)
   151  		case identDefault{}:
   152  			ctx.useDefault = o.Value().(bool)
   153  		case identValidate{}:
   154  			ctx.validate = o.Value().(bool)
   155  		case identTypedClaim{}:
   156  			pair := o.Value().(claimPair)
   157  			if ctx.localReg == nil {
   158  				ctx.localReg = json.NewRegistry()
   159  			}
   160  			ctx.localReg.Register(pair.Name, pair.Value)
   161  		case identInferAlgorithmFromKey{}:
   162  			ctx.inferAlgorithm = o.Value().(bool)
   163  		case identKeySetProvider{}:
   164  			ctx.keySetProvider = o.Value().(KeySetProvider)
   165  		}
   166  	}
   167  
   168  	data = bytes.TrimSpace(data)
   169  	return parse(&ctx, data)
   170  }
   171  
   172  const (
   173  	_JwsVerifyInvalid = iota
   174  	_JwsVerifyDone
   175  	_JwsVerifyExpectNested
   176  	_JwsVerifySkipped
   177  )
   178  
   179  func verifyJWS(ctx *parseCtx, payload []byte) ([]byte, int, error) {
   180  	if ctx.verifyAuto {
   181  		options := ctx.verifyAutoOpts
   182  		verified, err := jws.VerifyAuto(payload, options...)
   183  		return verified, _JwsVerifyDone, err
   184  	}
   185  
   186  	// if we have a key set or a provider, use that
   187  	ks := ctx.keySet
   188  	p := ctx.keySetProvider
   189  	if ks != nil || p != nil {
   190  		return verifyJWSWithKeySet(ctx, payload)
   191  	}
   192  
   193  	// We can't proceed without verification parameters
   194  	vp := ctx.verifyParams
   195  	if vp == nil {
   196  		return nil, _JwsVerifySkipped, nil
   197  	}
   198  
   199  	return verifyJWSWithParams(ctx, payload, vp.Algorithm(), vp.Key())
   200  }
   201  
   202  func verifyJWSWithKeySet(ctx *parseCtx, payload []byte) ([]byte, int, error) {
   203  	// First, get the JWS message
   204  	msg, err := jws.Parse(payload)
   205  	if err != nil {
   206  		return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to parse token data as JWS message`)
   207  	}
   208  	ks := ctx.keySet
   209  	if ks == nil { // the caller should have checked ctx.keySet || ctx.keySetProvider
   210  		if p := ctx.keySetProvider; p != nil {
   211  			// "trust" the payload, and parse it so that the provider can do its thing
   212  			ctx.skipVerification = true
   213  			tok, err := parse(ctx, msg.Payload())
   214  			if err != nil {
   215  				return nil, _JwsVerifyInvalid, err
   216  			}
   217  			ctx.skipVerification = false
   218  
   219  			v, err := p.KeySetFrom(tok)
   220  			if err != nil {
   221  				return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to obtain jwk.Set from KeySetProvider`)
   222  			}
   223  			ks = v
   224  		}
   225  	}
   226  
   227  	// Bail out early if we don't even have a key in the set
   228  	if ks.Len() == 0 {
   229  		return nil, _JwsVerifyInvalid, errors.New(`empty keyset provided`)
   230  	}
   231  
   232  	var key jwk.Key
   233  
   234  	// Find the kid. we need the kid, unless the user explicitly
   235  	// specified to use the "default" (the first and only) key in the set
   236  	headers := msg.Signatures()[0].ProtectedHeaders()
   237  	kid := headers.KeyID()
   238  	if kid == "" {
   239  		// If the kid is NOT specified... ctx.useDefault needs to be true, and the
   240  		// JWKs must have exactly one key in it
   241  		if !ctx.useDefault {
   242  			return nil, _JwsVerifyInvalid, errors.New(`failed to find matching key: no key ID ("kid") specified in token`)
   243  		} else if ctx.useDefault && ks.Len() > 1 {
   244  			return nil, _JwsVerifyInvalid, errors.New(`failed to find matching key: no key ID ("kid") specified in token but multiple keys available in key set`)
   245  		}
   246  
   247  		// if we got here, then useDefault == true AND there is exactly
   248  		// one key in the set.
   249  		key, _ = ks.Get(0)
   250  	} else {
   251  		// Otherwise we better be able to look up the key, baby.
   252  		v, ok := ks.LookupKeyID(kid)
   253  		if !ok {
   254  			return nil, _JwsVerifyInvalid, errors.Errorf(`failed to find key with key ID %q in key set`, kid)
   255  		}
   256  		key = v
   257  	}
   258  
   259  	// We found a key with matching kid. Check fo the algorithm specified in the key.
   260  	// If we find an algorithm in the key, use that.
   261  	if v := key.Algorithm(); v != "" {
   262  		var alg jwa.SignatureAlgorithm
   263  		if err := alg.Accept(v); err != nil {
   264  			return nil, _JwsVerifyInvalid, errors.Wrapf(err, `invalid signature algorithm %s`, key.Algorithm())
   265  		}
   266  
   267  		// Okay, we have a valid algorithm, go go
   268  		return verifyJWSWithParams(ctx, payload, alg, key)
   269  	}
   270  
   271  	if ctx.inferAlgorithm {
   272  		// Check whether the JWT headers specify a valid
   273  		// algorithm, use it if it's compatible.
   274  		algs, err := jws.AlgorithmsForKey(key)
   275  		if err != nil {
   276  			return nil, _JwsVerifyInvalid, errors.Wrapf(err, `failed to get a list of signature methods for key type %s`, key.KeyType())
   277  		}
   278  
   279  		for _, alg := range algs {
   280  			// bail out if the JWT has a `alg` field, and it doesn't match
   281  			if tokAlg := headers.Algorithm(); tokAlg != "" {
   282  				if tokAlg != alg {
   283  					continue
   284  				}
   285  			}
   286  
   287  			return verifyJWSWithParams(ctx, payload, alg, key)
   288  		}
   289  	}
   290  
   291  	return nil, _JwsVerifyInvalid, errors.New(`failed to match any of the keys`)
   292  }
   293  
   294  func verifyJWSWithParams(ctx *parseCtx, payload []byte, alg jwa.SignatureAlgorithm, key interface{}) ([]byte, int, error) {
   295  	var m *jws.Message
   296  	var verifyOpts []jws.VerifyOption
   297  	if ctx.pedantic {
   298  		m = jws.NewMessage()
   299  		verifyOpts = []jws.VerifyOption{jws.WithMessage(m)}
   300  	}
   301  	v, err := jws.Verify(payload, alg, key, verifyOpts...)
   302  	if err != nil {
   303  		return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to verify jws signature`)
   304  	}
   305  
   306  	if !ctx.pedantic {
   307  		return v, _JwsVerifyDone, nil
   308  	}
   309  	// This payload could be a JWT+JWS, in which case typ: JWT should be there
   310  	// If its JWT+(JWE or JWS or...)+JWS, then cty should be JWT
   311  	for _, sig := range m.Signatures() {
   312  		hdrs := sig.ProtectedHeaders()
   313  		if strings.ToLower(hdrs.Type()) == _jwt {
   314  			return v, _JwsVerifyDone, nil
   315  		}
   316  
   317  		if strings.ToLower(hdrs.ContentType()) == _jwt {
   318  			return v, _JwsVerifyExpectNested, nil
   319  		}
   320  	}
   321  
   322  	// Hmmm, it was a JWS and we got... nothing?
   323  	return nil, _JwsVerifyInvalid, errors.Errorf(`expected "typ" or "cty" fields, neither could be found`)
   324  }
   325  
   326  // verify parameter exists to make sure that we don't accidentally skip
   327  // over verification just because alg == ""  or key == nil or something.
   328  func parse(ctx *parseCtx, data []byte) (Token, error) {
   329  	payload := data
   330  	const maxDecodeLevels = 2
   331  
   332  	// If cty = `JWT`, we expect this to be a nested structure
   333  	var expectNested bool
   334  
   335  OUTER:
   336  	for i := 0; i < maxDecodeLevels; i++ {
   337  		switch kind := jwx.GuessFormat(payload); kind {
   338  		case jwx.JWT:
   339  			if ctx.pedantic {
   340  				if expectNested {
   341  					return nil, errors.Errorf(`expected nested encrypted/signed payload, got raw JWT`)
   342  				}
   343  			}
   344  
   345  			if i == 0 {
   346  				// We were NOT enveloped in other formats
   347  				if !ctx.skipVerification {
   348  					if _, _, err := verifyJWS(ctx, payload); err != nil {
   349  						return nil, err
   350  					}
   351  				}
   352  			}
   353  
   354  			break OUTER
   355  		case jwx.UnknownFormat:
   356  			// "Unknown" may include invalid JWTs, for example, those who lack "aud"
   357  			// claim. We could be pedantic and reject these
   358  			if ctx.pedantic {
   359  				return nil, errors.Errorf(`invalid JWT`)
   360  			}
   361  
   362  			if i == 0 {
   363  				// We were NOT enveloped in other formats
   364  				if !ctx.skipVerification {
   365  					if _, _, err := verifyJWS(ctx, payload); err != nil {
   366  						return nil, err
   367  					}
   368  				}
   369  			}
   370  			break OUTER
   371  		case jwx.JWS:
   372  			// Food for thought: This is going to break if you have multiple layers of
   373  			// JWS enveloping using different keys. It is highly unlikely use case,
   374  			// but it might happen.
   375  
   376  			// skipVerification should only be set to true by us. It's used
   377  			// when we just want to parse the JWT out of a payload
   378  			if !ctx.skipVerification {
   379  				// nested return value means:
   380  				// false (next envelope _may_ need to be processed)
   381  				// true (next envelope MUST be processed)
   382  				v, state, err := verifyJWS(ctx, payload)
   383  				if err != nil {
   384  					return nil, err
   385  				}
   386  
   387  				if state != _JwsVerifySkipped {
   388  					payload = v
   389  
   390  					// We only check for cty and typ if the pedantic flag is enabled
   391  					if !ctx.pedantic {
   392  						continue
   393  					}
   394  
   395  					if state == _JwsVerifyExpectNested {
   396  						expectNested = true
   397  						continue OUTER
   398  					}
   399  
   400  					// if we're not nested, we found our target. bail out of this loop
   401  					break OUTER
   402  				}
   403  			}
   404  
   405  			// No verification.
   406  			m, err := jws.Parse(data)
   407  			if err != nil {
   408  				return nil, errors.Wrap(err, `invalid jws message`)
   409  			}
   410  			payload = m.Payload()
   411  		case jwx.JWE:
   412  			dp := ctx.decryptParams
   413  			if dp == nil {
   414  				return nil, errors.Errorf(`jwt.Parse: cannot proceed with JWE encrypted payload without decryption parameters`)
   415  			}
   416  
   417  			var m *jwe.Message
   418  			var decryptOpts []jwe.DecryptOption
   419  			if ctx.pedantic {
   420  				m = jwe.NewMessage()
   421  				decryptOpts = []jwe.DecryptOption{jwe.WithMessage(m)}
   422  			}
   423  
   424  			v, err := jwe.Decrypt(data, dp.Algorithm(), dp.Key(), decryptOpts...)
   425  			if err != nil {
   426  				return nil, errors.Wrap(err, `failed to decrypt payload`)
   427  			}
   428  
   429  			if !ctx.pedantic {
   430  				payload = v
   431  				continue
   432  			}
   433  
   434  			if strings.ToLower(m.ProtectedHeaders().Type()) == _jwt {
   435  				payload = v
   436  				break OUTER
   437  			}
   438  
   439  			if strings.ToLower(m.ProtectedHeaders().ContentType()) == _jwt {
   440  				expectNested = true
   441  				payload = v
   442  				continue OUTER
   443  			}
   444  		default:
   445  			return nil, errors.Errorf(`unsupported format (layer: #%d)`, i+1)
   446  		}
   447  		expectNested = false
   448  	}
   449  
   450  	if ctx.token == nil {
   451  		ctx.token = New()
   452  	}
   453  
   454  	if ctx.localReg != nil {
   455  		dcToken, ok := ctx.token.(TokenWithDecodeCtx)
   456  		if !ok {
   457  			return nil, errors.Errorf(`typed claim was requested, but the token (%T) does not support DecodeCtx`, ctx.token)
   458  		}
   459  		dc := json.NewDecodeCtx(ctx.localReg)
   460  		dcToken.SetDecodeCtx(dc)
   461  		defer func() { dcToken.SetDecodeCtx(nil) }()
   462  	}
   463  
   464  	if err := json.Unmarshal(payload, ctx.token); err != nil {
   465  		return nil, errors.Wrap(err, `failed to parse token`)
   466  	}
   467  
   468  	if ctx.validate {
   469  		if err := Validate(ctx.token, ctx.validateOpts...); err != nil {
   470  			return nil, err
   471  		}
   472  	}
   473  	return ctx.token, nil
   474  }
   475  
   476  // Sign is a convenience function to create a signed JWT token serialized in
   477  // compact form.
   478  //
   479  // It accepts either a raw key (e.g. rsa.PrivateKey, ecdsa.PrivateKey, etc)
   480  // or a jwk.Key, and the name of the algorithm that should be used to sign
   481  // the token.
   482  //
   483  // If the key is a jwk.Key and the key contains a key ID (`kid` field),
   484  // then it is added to the protected header generated by the signature
   485  //
   486  // The algorithm specified in the `alg` parameter must be able to support
   487  // the type of key you provided, otherwise an error is returned.
   488  //
   489  // The protected header will also automatically have the `typ` field set
   490  // to the literal value `JWT`, unless you provide a custom value for it
   491  // by jwt.WithHeaders option.
   492  func Sign(t Token, alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) ([]byte, error) {
   493  	return NewSerializer().Sign(alg, key, options...).Serialize(t)
   494  }
   495  
   496  // Equal compares two JWT tokens. Do not use `reflect.Equal` or the like
   497  // to compare tokens as they will also compare extra detail such as
   498  // sync.Mutex objects used to control concurrent access.
   499  //
   500  // The comparison for values is currently done using a simple equality ("=="),
   501  // except for time.Time, which uses time.Equal after dropping the monotonic
   502  // clock and truncating the values to 1 second accuracy.
   503  //
   504  // if both t1 and t2 are nil, returns true
   505  func Equal(t1, t2 Token) bool {
   506  	if t1 == nil && t2 == nil {
   507  		return true
   508  	}
   509  
   510  	// we already checked for t1 == t2 == nil, so safe to do this
   511  	if t1 == nil || t2 == nil {
   512  		return false
   513  	}
   514  
   515  	j1, err := json.Marshal(t1)
   516  	if err != nil {
   517  		return false
   518  	}
   519  
   520  	j2, err := json.Marshal(t2)
   521  	if err != nil {
   522  		return false
   523  	}
   524  
   525  	return bytes.Equal(j1, j2)
   526  }
   527  
   528  func (t *stdToken) Clone() (Token, error) {
   529  	dst := New()
   530  
   531  	for _, pair := range t.makePairs() {
   532  		//nolint:forcetypeassert
   533  		key := pair.Key.(string)
   534  		if err := dst.Set(key, pair.Value); err != nil {
   535  			return nil, errors.Wrapf(err, `failed to set %s`, key)
   536  		}
   537  	}
   538  	return dst, nil
   539  }
   540  
   541  // RegisterCustomField allows users to specify that a private field
   542  // be decoded as an instance of the specified type. This option has
   543  // a global effect.
   544  //
   545  // For example, suppose you have a custom field `x-birthday`, which
   546  // you want to represent as a string formatted in RFC3339 in JSON,
   547  // but want it back as `time.Time`.
   548  //
   549  // In that case you would register a custom field as follows
   550  //
   551  //   jwt.RegisterCustomField(`x-birthday`, timeT)
   552  //
   553  // Then `token.Get("x-birthday")` will still return an `interface{}`,
   554  // but you can convert its type to `time.Time`
   555  //
   556  //   bdayif, _ := token.Get(`x-birthday`)
   557  //   bday := bdayif.(time.Time)
   558  //
   559  func RegisterCustomField(name string, object interface{}) {
   560  	registry.Register(name, object)
   561  }
   562  

View as plain text