...

Source file src/github.com/lestrrat-go/jwx/jwt/jwt_test.go

Documentation: github.com/lestrrat-go/jwx/jwt

     1  package jwt_test
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/ecdsa"
     7  	"crypto/ed25519"
     8  	"crypto/rand"
     9  	"crypto/rsa"
    10  	"encoding/base64"
    11  	"fmt"
    12  	"io/ioutil"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"net/url"
    16  	"strconv"
    17  	"strings"
    18  	"sync"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/lestrrat-go/backoff/v2"
    23  	"github.com/lestrrat-go/jwx/internal/ecutil"
    24  	"github.com/lestrrat-go/jwx/internal/json"
    25  	"github.com/lestrrat-go/jwx/internal/jwxtest"
    26  	"github.com/lestrrat-go/jwx/jwe"
    27  	"github.com/pkg/errors"
    28  
    29  	"github.com/lestrrat-go/jwx/jwa"
    30  	"github.com/lestrrat-go/jwx/jwk"
    31  	"github.com/lestrrat-go/jwx/jws"
    32  	"github.com/lestrrat-go/jwx/jwt"
    33  	"github.com/stretchr/testify/assert"
    34  )
    35  
    36  /* This is commented out, because it is intended to cause compilation errors */
    37  /*
    38  func TestOption(t *testing.T) {
    39  	var p jwt.ParseOption
    40  	var v jwt.ValidateOption
    41  	var o jwt.Option
    42  	p = o // should be error
    43  	v = o // should be error
    44  	_ = p
    45  	_ = v
    46  }
    47  */
    48  
    49  func TestJWTParse(t *testing.T) {
    50  	t.Parallel()
    51  
    52  	alg := jwa.RS256
    53  
    54  	key, err := jwxtest.GenerateRsaKey()
    55  	if !assert.NoError(t, err, `jwxtest.GenerateRsaKey should succeed`) {
    56  		return
    57  	}
    58  	t1 := jwt.New()
    59  	signed, err := jwt.Sign(t1, alg, key)
    60  	if !assert.NoError(t, err, `jwt.Sign should succeed`) {
    61  		return
    62  	}
    63  
    64  	t.Logf("%s", signed)
    65  
    66  	t.Run("Parse (no signature verification)", func(t *testing.T) {
    67  		t.Parallel()
    68  		t2, err := jwt.Parse(signed)
    69  		if !assert.NoError(t, err, `jwt.Parse should succeed`) {
    70  			return
    71  		}
    72  		if !assert.True(t, jwt.Equal(t1, t2), `t1 == t2`) {
    73  			return
    74  		}
    75  	})
    76  	t.Run("ParseString (no signature verification)", func(t *testing.T) {
    77  		t.Parallel()
    78  		t2, err := jwt.ParseString(string(signed))
    79  		if !assert.NoError(t, err, `jwt.ParseString should succeed`) {
    80  			return
    81  		}
    82  		if !assert.True(t, jwt.Equal(t1, t2), `t1 == t2`) {
    83  			return
    84  		}
    85  	})
    86  	t.Run("ParseReader (no signature verification)", func(t *testing.T) {
    87  		t.Parallel()
    88  		t2, err := jwt.ParseReader(bytes.NewReader(signed))
    89  		if !assert.NoError(t, err, `jwt.ParseReader should succeed`) {
    90  			return
    91  		}
    92  		if !assert.True(t, jwt.Equal(t1, t2), `t1 == t2`) {
    93  			return
    94  		}
    95  	})
    96  	t.Run("Parse (correct signature key)", func(t *testing.T) {
    97  		t.Parallel()
    98  		t2, err := jwt.Parse(signed, jwt.WithVerify(alg, &key.PublicKey))
    99  		if !assert.NoError(t, err, `jwt.Parse should succeed`) {
   100  			return
   101  		}
   102  		if !assert.True(t, jwt.Equal(t1, t2), `t1 == t2`) {
   103  			return
   104  		}
   105  	})
   106  	t.Run("parse (wrong signature algorithm)", func(t *testing.T) {
   107  		t.Parallel()
   108  		_, err := jwt.Parse(signed, jwt.WithVerify(jwa.RS512, &key.PublicKey))
   109  		if !assert.Error(t, err, `jwt.Parse should fail`) {
   110  			return
   111  		}
   112  	})
   113  	t.Run("parse (wrong signature key)", func(t *testing.T) {
   114  		t.Parallel()
   115  		pubkey := key.PublicKey
   116  		pubkey.E = 0 // bogus value
   117  		_, err := jwt.Parse(signed, jwt.WithVerify(alg, &pubkey))
   118  		if !assert.Error(t, err, `jwt.Parse should fail`) {
   119  			return
   120  		}
   121  	})
   122  }
   123  
   124  func TestJWTParseVerify(t *testing.T) {
   125  	t.Parallel()
   126  
   127  	keys := make([]interface{}, 0, 6)
   128  
   129  	keys = append(keys, []byte("abra cadabra"))
   130  
   131  	rsaPrivKey, err := jwxtest.GenerateRsaKey()
   132  	if !assert.NoError(t, err, "RSA key generated") {
   133  		return
   134  	}
   135  	keys = append(keys, rsaPrivKey)
   136  
   137  	for _, alg := range []jwa.EllipticCurveAlgorithm{jwa.P256, jwa.P384, jwa.P521} {
   138  		ecdsaPrivKey, err := jwxtest.GenerateEcdsaKey(alg)
   139  		if !assert.NoError(t, err, "jwxtest.GenerateEcdsaKey should succeed for %s", alg) {
   140  			return
   141  		}
   142  		keys = append(keys, ecdsaPrivKey)
   143  	}
   144  
   145  	ed25519PrivKey, err := jwxtest.GenerateEd25519Key()
   146  	if !assert.NoError(t, err, `jwxtest.GenerateEd25519Key should succeed`) {
   147  		return
   148  	}
   149  	keys = append(keys, ed25519PrivKey)
   150  
   151  	for _, key := range keys {
   152  		key := key
   153  		t.Run(fmt.Sprintf("Key=%T", key), func(t *testing.T) {
   154  			t.Parallel()
   155  			algs, err := jws.AlgorithmsForKey(key)
   156  			if !assert.NoError(t, err, `jwas.AlgorithmsForKey should succeed`) {
   157  				return
   158  			}
   159  
   160  			var dummyRawKey interface{}
   161  			switch pk := key.(type) {
   162  			case *rsa.PrivateKey:
   163  				dummyRawKey, err = jwxtest.GenerateRsaKey()
   164  				if !assert.NoError(t, err, `jwxtest.GenerateRsaKey should succeed`) {
   165  					return
   166  				}
   167  			case *ecdsa.PrivateKey:
   168  				curveAlg, ok := ecutil.AlgorithmForCurve(pk.Curve)
   169  				if !assert.True(t, ok, `ecutil.AlgorithmForCurve should succeed`) {
   170  					return
   171  				}
   172  				dummyRawKey, err = jwxtest.GenerateEcdsaKey(curveAlg)
   173  				if !assert.NoError(t, err, `jwxtest.GenerateEcdsaKey should succeed`) {
   174  					return
   175  				}
   176  			case ed25519.PrivateKey:
   177  				dummyRawKey, err = jwxtest.GenerateEd25519Key()
   178  				if !assert.NoError(t, err, `jwxtest.GenerateEd25519Key should succeed`) {
   179  					return
   180  				}
   181  			case []byte:
   182  				dummyRawKey = jwxtest.GenerateSymmetricKey()
   183  			default:
   184  				assert.Fail(t, fmt.Sprintf("Unhandled key type %T", key))
   185  				return
   186  			}
   187  
   188  			testcases := []struct {
   189  				SetAlgorithm   bool
   190  				SetKid         bool
   191  				InferAlgorithm bool
   192  				Error          bool
   193  			}{
   194  				{
   195  					SetAlgorithm:   true,
   196  					SetKid:         true,
   197  					InferAlgorithm: true,
   198  				},
   199  				{
   200  					SetAlgorithm:   true,
   201  					SetKid:         true,
   202  					InferAlgorithm: false,
   203  				},
   204  				{
   205  					SetAlgorithm:   true,
   206  					SetKid:         false,
   207  					InferAlgorithm: true,
   208  					Error:          true,
   209  				},
   210  				{
   211  					SetAlgorithm:   false,
   212  					SetKid:         true,
   213  					InferAlgorithm: true,
   214  				},
   215  				{
   216  					SetAlgorithm:   false,
   217  					SetKid:         true,
   218  					InferAlgorithm: false,
   219  					Error:          true,
   220  				},
   221  				{
   222  					SetAlgorithm:   false,
   223  					SetKid:         false,
   224  					InferAlgorithm: true,
   225  					Error:          true,
   226  				},
   227  				{
   228  					SetAlgorithm:   true,
   229  					SetKid:         false,
   230  					InferAlgorithm: false,
   231  					Error:          true,
   232  				},
   233  				{
   234  					SetAlgorithm:   false,
   235  					SetKid:         false,
   236  					InferAlgorithm: false,
   237  					Error:          true,
   238  				},
   239  			}
   240  			for _, alg := range algs {
   241  				alg := alg
   242  				for _, tc := range testcases {
   243  					tc := tc
   244  					t.Run(fmt.Sprintf("Algorithm=%s, SetAlgorithm=%t, SetKid=%t, InferAlgorithm=%t, Expect Error=%t", alg, tc.SetAlgorithm, tc.SetKid, tc.InferAlgorithm, tc.Error), func(t *testing.T) {
   245  						t.Parallel()
   246  
   247  						const kid = "test-jwt-parse-verify-kid"
   248  						const dummyKid = "test-jwt-parse-verify-dummy-kid"
   249  						hdrs := jws.NewHeaders()
   250  						hdrs.Set(jws.KeyIDKey, kid)
   251  
   252  						t1 := jwt.New()
   253  						signed, err := jwt.Sign(t1, alg, key, jwt.WithHeaders(hdrs))
   254  						if !assert.NoError(t, err, "token.Sign should succeed") {
   255  							return
   256  						}
   257  
   258  						pubkey, err := jwk.PublicKeyOf(key)
   259  						if !assert.NoError(t, err, `jwk.PublicKeyOf should succeed`) {
   260  							return
   261  						}
   262  
   263  						if tc.SetAlgorithm {
   264  							pubkey.Set(jwk.AlgorithmKey, alg)
   265  						}
   266  
   267  						dummyKey, err := jwk.PublicKeyOf(dummyRawKey)
   268  						if !assert.NoError(t, err, `jwk.PublicKeyOf should succeed`) {
   269  							return
   270  						}
   271  
   272  						if tc.SetKid {
   273  							pubkey.Set(jwk.KeyIDKey, kid)
   274  							dummyKey.Set(jwk.KeyIDKey, dummyKid)
   275  						}
   276  
   277  						// Permute on the location of the correct key, to check for possible
   278  						// cases where we loop too little or too much.
   279  						for i := 0; i < 6; i++ {
   280  							var name string
   281  							set := jwk.NewSet()
   282  							switch i {
   283  							case 0:
   284  								name = "Lone key"
   285  								set.Add(pubkey)
   286  							case 1:
   287  								name = "Two keys, correct one at the end"
   288  								set.Add(dummyKey)
   289  								set.Add(pubkey)
   290  							case 2:
   291  								name = "Two keys, correct one at the beginning"
   292  								set.Add(pubkey)
   293  								set.Add(dummyKey)
   294  							case 3:
   295  								name = "Three keys, correct one at the end"
   296  								set.Add(dummyKey)
   297  								set.Add(dummyKey)
   298  								set.Add(pubkey)
   299  							case 4:
   300  								name = "Three keys, correct one at the middle"
   301  								set.Add(dummyKey)
   302  								set.Add(pubkey)
   303  								set.Add(dummyKey)
   304  							case 5:
   305  								name = "Three keys, correct one at the beginning"
   306  								set.Add(pubkey)
   307  								set.Add(dummyKey)
   308  								set.Add(dummyKey)
   309  							}
   310  
   311  							t.Run(name, func(t *testing.T) {
   312  								options := []jwt.ParseOption{jwt.WithKeySet(set)}
   313  								if tc.InferAlgorithm {
   314  									options = append(options, jwt.InferAlgorithmFromKey(true))
   315  								}
   316  								t2, err := jwt.Parse(signed, options...)
   317  
   318  								if tc.Error {
   319  									assert.Error(t, err, `jwt.Parse should fail`)
   320  									return
   321  								}
   322  
   323  								if !assert.NoError(t, err, `jwt.Parse should succeed`) {
   324  									return
   325  								}
   326  
   327  								if !assert.True(t, jwt.Equal(t1, t2), `t1 == t2`) {
   328  									return
   329  								}
   330  							})
   331  						}
   332  					})
   333  				}
   334  			}
   335  		})
   336  	}
   337  	t.Run("Miscellaneous", func(t *testing.T) {
   338  		key, err := jwxtest.GenerateRsaKey()
   339  		if !assert.NoError(t, err, "RSA key generated") {
   340  			return
   341  		}
   342  		const alg = jwa.RS256
   343  		const kid = "my-very-special-key"
   344  		hdrs := jws.NewHeaders()
   345  		hdrs.Set(jws.KeyIDKey, kid)
   346  		t1 := jwt.New()
   347  		signed, err := jwt.Sign(t1, alg, key, jwt.WithHeaders(hdrs))
   348  		if !assert.NoError(t, err, "token.Sign should succeed") {
   349  			return
   350  		}
   351  
   352  		t.Run("Use KeySetProvider", func(t *testing.T) {
   353  			t.Parallel()
   354  			pubkey, _ := jwk.New(key.PublicKey)
   355  
   356  			pubkey.Set(jwk.AlgorithmKey, alg)
   357  			pubkey.Set(jwk.KeyIDKey, kid)
   358  
   359  			set := jwk.NewSet()
   360  			set.Add(pubkey)
   361  
   362  			t2, err := t1.Clone()
   363  			if !assert.NoError(t, err) {
   364  				return
   365  			}
   366  			if !assert.NoError(t, t2.Set(jwt.IssuerKey, "http://www.example.com")) {
   367  				return
   368  			}
   369  			signed, err := jwt.Sign(t2, alg, key, jwt.WithHeaders(hdrs))
   370  			if !assert.NoError(t, err) {
   371  				return
   372  			}
   373  
   374  			t3, err := jwt.Parse(signed, jwt.WithKeySetProvider(jwt.KeySetProviderFunc(func(tok jwt.Token) (jwk.Set, error) {
   375  				switch tok.Issuer() {
   376  				case "http://www.example.com":
   377  					return set, nil
   378  				}
   379  				return nil, fmt.Errorf("unknown issuer")
   380  			})))
   381  			if !assert.NoError(t, err, `jwt.Parse with key set func should succeed`) {
   382  				return
   383  			}
   384  
   385  			if !assert.True(t, jwt.Equal(t2, t3), `t2 == t3`) {
   386  				return
   387  			}
   388  
   389  			_, err = jwt.Parse(signed, jwt.WithKeySetProvider(jwt.KeySetProviderFunc(func(tok jwt.Token) (jwk.Set, error) {
   390  				return nil, errors.New(`dummy`)
   391  			})))
   392  			if !assert.Error(t, err, `jwt.Parse should fail`) {
   393  				return
   394  			}
   395  		})
   396  		t.Run("Alg does not match", func(t *testing.T) {
   397  			t.Parallel()
   398  			pubkey, err := jwk.PublicKeyOf(key)
   399  			if !assert.NoError(t, err) {
   400  				return
   401  			}
   402  
   403  			pubkey.Set(jwk.AlgorithmKey, jwa.HS256)
   404  			pubkey.Set(jwk.KeyIDKey, kid)
   405  			set := jwk.NewSet()
   406  			set.Add(pubkey)
   407  
   408  			_, err = jwt.Parse(signed, jwt.WithKeySet(set), jwt.InferAlgorithmFromKey(true), jwt.UseDefaultKey(true))
   409  			if !assert.Error(t, err, `jwt.Parse should fail`) {
   410  				return
   411  			}
   412  		})
   413  		t.Run("UseDefault with a key set with 1 key", func(t *testing.T) {
   414  			t.Parallel()
   415  			pubkey, err := jwk.PublicKeyOf(key)
   416  			if !assert.NoError(t, err) {
   417  				return
   418  			}
   419  
   420  			pubkey.Set(jwk.AlgorithmKey, alg)
   421  			pubkey.Set(jwk.KeyIDKey, kid)
   422  			signedNoKid, err := jwt.Sign(t1, alg, key)
   423  			if err != nil {
   424  				t.Fatal("Failed to sign JWT")
   425  			}
   426  			set := jwk.NewSet()
   427  			set.Add(pubkey)
   428  			t2, err := jwt.Parse(signedNoKid, jwt.WithKeySet(set), jwt.UseDefaultKey(true))
   429  			if !assert.NoError(t, err, `jwt.Parse with key set should succeed`) {
   430  				return
   431  			}
   432  			if !assert.True(t, jwt.Equal(t1, t2), `t1 == t2`) {
   433  				return
   434  			}
   435  		})
   436  		t.Run("UseDefault with multiple keys should fail", func(t *testing.T) {
   437  			t.Parallel()
   438  			pubkey1 := jwk.NewRSAPublicKey()
   439  			if !assert.NoError(t, pubkey1.FromRaw(&key.PublicKey)) {
   440  				return
   441  			}
   442  			pubkey2 := jwk.NewRSAPublicKey()
   443  			if !assert.NoError(t, pubkey2.FromRaw(&key.PublicKey)) {
   444  				return
   445  			}
   446  
   447  			pubkey1.Set(jwk.KeyIDKey, kid)
   448  			pubkey2.Set(jwk.KeyIDKey, "test-jwt-parse-verify-kid-2")
   449  			signedNoKid, err := jwt.Sign(t1, alg, key)
   450  			if err != nil {
   451  				t.Fatal("Failed to sign JWT")
   452  			}
   453  			set := jwk.NewSet()
   454  			set.Add(pubkey1)
   455  			set.Add(pubkey2)
   456  			_, err = jwt.Parse(signedNoKid, jwt.WithKeySet(set), jwt.UseDefaultKey(true))
   457  			if !assert.Error(t, err, `jwt.Parse should fail`) {
   458  				return
   459  			}
   460  		})
   461  		// This is a test to check if we allow alg: none in the protected header section.
   462  		// But in truth, since we delegate everything to jws.Verify anyways, it's really
   463  		// a test to see if jws.Verify returns an error if alg: none is specified in the
   464  		// header section. Move this test to jws if need be.
   465  		t.Run("Check alg=none", func(t *testing.T) {
   466  			t.Parallel()
   467  			// Create a signed payload, but use alg=none
   468  			_, payload, signature, err := jws.SplitCompact(signed)
   469  			if !assert.NoError(t, err, `jws.SplitCompact should succeed`) {
   470  				return
   471  			}
   472  
   473  			dummyHeader := jws.NewHeaders()
   474  			ctx, cancel := context.WithCancel(context.Background())
   475  			defer cancel()
   476  			for iter := hdrs.Iterate(ctx); iter.Next(ctx); {
   477  				pair := iter.Pair()
   478  				dummyHeader.Set(pair.Key.(string), pair.Value)
   479  			}
   480  			dummyHeader.Set(jws.AlgorithmKey, jwa.NoSignature)
   481  
   482  			dummyMarshaled, err := json.Marshal(dummyHeader)
   483  			if !assert.NoError(t, err, `json.Marshal should succeed`) {
   484  				return
   485  			}
   486  			dummyEncoded := make([]byte, base64.RawURLEncoding.EncodedLen(len(dummyMarshaled)))
   487  			base64.RawURLEncoding.Encode(dummyEncoded, dummyMarshaled)
   488  
   489  			signedButNot := bytes.Join([][]byte{dummyEncoded, payload, signature}, []byte{'.'})
   490  
   491  			pubkey := jwk.NewRSAPublicKey()
   492  			if !assert.NoError(t, pubkey.FromRaw(&key.PublicKey)) {
   493  				return
   494  			}
   495  
   496  			pubkey.Set(jwk.KeyIDKey, kid)
   497  
   498  			set := jwk.NewSet()
   499  			set.Add(pubkey)
   500  			_, err = jwt.Parse(signedButNot, jwt.WithKeySet(set))
   501  			// This should fail
   502  			if !assert.Error(t, err, `jwt.Parse with key set + alg=none should fail`) {
   503  				return
   504  			}
   505  		})
   506  	})
   507  }
   508  
   509  func TestValidateClaims(t *testing.T) {
   510  	t.Parallel()
   511  	// GitHub issue #37: tokens are invalid in the second they are created (because Now() is not after IssuedAt())
   512  	t.Run("Empty fields", func(t *testing.T) {
   513  		token := jwt.New()
   514  
   515  		if !assert.Error(t, jwt.Validate(token, jwt.WithIssuer("foo")), `token.Validate should fail`) {
   516  			return
   517  		}
   518  		if !assert.Error(t, jwt.Validate(token, jwt.WithJwtID("foo")), `token.Validate should fail`) {
   519  			return
   520  		}
   521  		if !assert.Error(t, jwt.Validate(token, jwt.WithSubject("foo")), `token.Validate should fail`) {
   522  			return
   523  		}
   524  	})
   525  	t.Run(jwt.IssuedAtKey+"+skew", func(t *testing.T) {
   526  		t.Parallel()
   527  		token := jwt.New()
   528  		now := time.Now().UTC()
   529  		token.Set(jwt.IssuedAtKey, now)
   530  
   531  		const DefaultSkew = 0
   532  
   533  		args := []jwt.ValidateOption{
   534  			jwt.WithClock(jwt.ClockFunc(func() time.Time { return now })),
   535  			jwt.WithAcceptableSkew(DefaultSkew),
   536  		}
   537  
   538  		if !assert.NoError(t, jwt.Validate(token, args...), "token.Validate should validate tokens in the same second they are created") {
   539  			if now.Equal(token.IssuedAt()) {
   540  				t.Errorf("iat claim failed: iat == now")
   541  			}
   542  			return
   543  		}
   544  	})
   545  }
   546  
   547  const aLongLongTimeAgo = 233431200
   548  const aLongLongTimeAgoString = "233431200"
   549  
   550  func TestUnmarshal(t *testing.T) {
   551  	t.Parallel()
   552  	testcases := []struct {
   553  		Title        string
   554  		Source       string
   555  		Expected     func() jwt.Token
   556  		ExpectedJSON string
   557  	}{
   558  		{
   559  			Title:  "single aud",
   560  			Source: `{"aud":"foo"}`,
   561  			Expected: func() jwt.Token {
   562  				t := jwt.New()
   563  				t.Set("aud", "foo")
   564  				return t
   565  			},
   566  			ExpectedJSON: `{"aud":["foo"]}`,
   567  		},
   568  		{
   569  			Title:  "multiple aud's",
   570  			Source: `{"aud":["foo","bar"]}`,
   571  			Expected: func() jwt.Token {
   572  				t := jwt.New()
   573  				t.Set("aud", []string{"foo", "bar"})
   574  				return t
   575  			},
   576  			ExpectedJSON: `{"aud":["foo","bar"]}`,
   577  		},
   578  		{
   579  			Title:  "issuedAt",
   580  			Source: `{"` + jwt.IssuedAtKey + `":` + aLongLongTimeAgoString + `}`,
   581  			Expected: func() jwt.Token {
   582  				t := jwt.New()
   583  				t.Set(jwt.IssuedAtKey, aLongLongTimeAgo)
   584  				return t
   585  			},
   586  			ExpectedJSON: `{"` + jwt.IssuedAtKey + `":` + aLongLongTimeAgoString + `}`,
   587  		},
   588  	}
   589  
   590  	for _, tc := range testcases {
   591  		tc := tc
   592  		t.Run(tc.Title, func(t *testing.T) {
   593  			t.Parallel()
   594  			token := jwt.New()
   595  			if !assert.NoError(t, json.Unmarshal([]byte(tc.Source), &token), `json.Unmarshal should succeed`) {
   596  				return
   597  			}
   598  			if !assert.Equal(t, tc.Expected(), token, `token should match expected value`) {
   599  				return
   600  			}
   601  
   602  			var buf bytes.Buffer
   603  			if !assert.NoError(t, json.NewEncoder(&buf).Encode(token), `json.Marshal should succeed`) {
   604  				return
   605  			}
   606  			if !assert.Equal(t, tc.ExpectedJSON, strings.TrimSpace(buf.String()), `json should match`) {
   607  				return
   608  			}
   609  		})
   610  	}
   611  }
   612  
   613  func TestGH52(t *testing.T) {
   614  	if testing.Short() {
   615  		t.SkipNow()
   616  	}
   617  
   618  	t.Parallel()
   619  	priv, err := jwxtest.GenerateEcdsaKey(jwa.P521)
   620  	if !assert.NoError(t, err) {
   621  		return
   622  	}
   623  
   624  	pub := &priv.PublicKey
   625  	if !assert.NoError(t, err) {
   626  		return
   627  	}
   628  	const max = 100
   629  	var wg sync.WaitGroup
   630  	wg.Add(max)
   631  	for i := 0; i < max; i++ {
   632  		// Do not use t.Run here as it will clutter up the outpuA
   633  		go func(t *testing.T, priv *ecdsa.PrivateKey, i int) {
   634  			defer wg.Done()
   635  			tok := jwt.New()
   636  
   637  			s, err := jwt.Sign(tok, jwa.ES256, priv)
   638  			if !assert.NoError(t, err) {
   639  				return
   640  			}
   641  
   642  			if _, err = jws.Verify(s, jwa.ES256, pub); !assert.NoError(t, err, `test should pass (run %d)`, i) {
   643  				return
   644  			}
   645  		}(t, priv, i)
   646  	}
   647  	wg.Wait()
   648  }
   649  
   650  func TestUnmarshalJSON(t *testing.T) {
   651  	t.Parallel()
   652  	t.Run("Unmarshal audience with multiple values", func(t *testing.T) {
   653  		t.Parallel()
   654  		t1 := jwt.New()
   655  		if !assert.NoError(t, json.Unmarshal([]byte(`{"aud":["foo", "bar", "baz"]}`), &t1), `jwt.Parse should succeed`) {
   656  			return
   657  		}
   658  		aud, ok := t1.Get(jwt.AudienceKey)
   659  		if !assert.True(t, ok, `jwt.Get(jwt.AudienceKey) should succeed`) {
   660  			t.Logf("%#v", t1)
   661  			return
   662  		}
   663  
   664  		if !assert.Equal(t, aud.([]string), []string{"foo", "bar", "baz"}, "audience should match. got %v", aud) {
   665  			return
   666  		}
   667  	})
   668  }
   669  
   670  func TestSignErrors(t *testing.T) {
   671  	t.Parallel()
   672  	priv, err := jwxtest.GenerateEcdsaKey(jwa.P521)
   673  	if !assert.NoError(t, err, `jwxtest.GenerateEcdsaKey should succeed`) {
   674  		return
   675  	}
   676  
   677  	tok := jwt.New()
   678  	_, err = jwt.Sign(tok, jwa.SignatureAlgorithm("BOGUS"), priv)
   679  	if !assert.Error(t, err) {
   680  		return
   681  	}
   682  
   683  	if !assert.Contains(t, err.Error(), `unsupported signature algorithm "BOGUS"`) {
   684  		return
   685  	}
   686  
   687  	_, err = jwt.Sign(tok, jwa.ES256, nil)
   688  	if !assert.Error(t, err) {
   689  		return
   690  	}
   691  
   692  	if !assert.Contains(t, err.Error(), "missing private key") {
   693  		return
   694  	}
   695  }
   696  
   697  func TestSignJWK(t *testing.T) {
   698  	t.Parallel()
   699  	priv, err := jwxtest.GenerateRsaKey()
   700  	assert.Nil(t, err)
   701  
   702  	key := jwk.NewRSAPrivateKey()
   703  	err = key.FromRaw(priv)
   704  	assert.Nil(t, err)
   705  
   706  	key.Set(jwk.KeyIDKey, "test")
   707  	key.Set(jwk.AlgorithmKey, jwa.RS256)
   708  
   709  	tok := jwt.New()
   710  	signed, err := jwt.Sign(tok, jwa.SignatureAlgorithm(key.Algorithm()), key)
   711  	assert.Nil(t, err)
   712  
   713  	header, err := jws.ParseString(string(signed))
   714  	assert.Nil(t, err)
   715  
   716  	signatures := header.LookupSignature("test")
   717  	assert.Len(t, signatures, 1)
   718  }
   719  
   720  func getJWTHeaders(jwt []byte) (jws.Headers, error) {
   721  	msg, err := jws.Parse(jwt)
   722  	if err != nil {
   723  		return nil, err
   724  	}
   725  	return msg.Signatures()[0].ProtectedHeaders(), nil
   726  }
   727  
   728  func TestSignTyp(t *testing.T) {
   729  	t.Parallel()
   730  	key, err := jwxtest.GenerateRsaKey()
   731  	if !assert.NoError(t, err) {
   732  		return
   733  	}
   734  
   735  	t.Run(`"typ" header parameter should be set to JWT by default`, func(t *testing.T) {
   736  		t.Parallel()
   737  		t1 := jwt.New()
   738  		signed, err := jwt.Sign(t1, jwa.RS256, key)
   739  		if !assert.NoError(t, err) {
   740  			return
   741  		}
   742  		got, err := getJWTHeaders(signed)
   743  		if !assert.NoError(t, err) {
   744  			return
   745  		}
   746  		if !assert.Equal(t, `JWT`, got.Type(), `"typ" header parameter should be set to JWT`) {
   747  			return
   748  		}
   749  	})
   750  
   751  	t.Run(`"typ" header parameter should be customizable by WithHeaders`, func(t *testing.T) {
   752  		t.Parallel()
   753  		t1 := jwt.New()
   754  		hdrs := jws.NewHeaders()
   755  		hdrs.Set(`typ`, `custom-typ`)
   756  		signed, err := jwt.Sign(t1, jwa.RS256, key, jwt.WithHeaders(hdrs))
   757  		if !assert.NoError(t, err) {
   758  			return
   759  		}
   760  		got, err := getJWTHeaders(signed)
   761  		if !assert.NoError(t, err) {
   762  			return
   763  		}
   764  		if !assert.Equal(t, `custom-typ`, got.Type(), `"typ" header parameter should be set to the custom value`) {
   765  			return
   766  		}
   767  	})
   768  }
   769  
   770  func TestReadFile(t *testing.T) {
   771  	t.Parallel()
   772  
   773  	f, err := ioutil.TempFile("", "test-read-file-*.jwt")
   774  	if !assert.NoError(t, err, `ioutil.TempFile should succeed`) {
   775  		return
   776  	}
   777  	defer f.Close()
   778  
   779  	token := jwt.New()
   780  	token.Set(jwt.IssuerKey, `lestrrat`)
   781  	if !assert.NoError(t, json.NewEncoder(f).Encode(token), `json.NewEncoder.Encode should succeed`) {
   782  		return
   783  	}
   784  
   785  	if _, err := jwt.ReadFile(f.Name(), jwt.WithValidate(true), jwt.WithIssuer("lestrrat")); !assert.NoError(t, err, `jwt.ReadFile should succeed`) {
   786  		return
   787  	}
   788  	if _, err := jwt.ReadFile(f.Name(), jwt.WithValidate(true), jwt.WithIssuer("lestrrrrrat")); !assert.Error(t, err, `jwt.ReadFile should fail`) {
   789  		return
   790  	}
   791  }
   792  
   793  func TestCustomField(t *testing.T) {
   794  	// XXX has global effect!!!
   795  	jwt.RegisterCustomField(`x-birthday`, time.Time{})
   796  	defer jwt.RegisterCustomField(`x-birthday`, nil)
   797  
   798  	expected := time.Date(2015, 11, 4, 5, 12, 52, 0, time.UTC)
   799  	bdaybytes, _ := expected.MarshalText() // RFC3339
   800  
   801  	var b strings.Builder
   802  	b.WriteString(`{"iss": "github.com/lesstrrat-go/jwx", "x-birthday": "`)
   803  	b.Write(bdaybytes)
   804  	b.WriteString(`"}`)
   805  	src := b.String()
   806  
   807  	t.Run("jwt.Parse", func(t *testing.T) {
   808  		token, err := jwt.Parse([]byte(src))
   809  		if !assert.NoError(t, err, `jwt.Parse should succeed`) {
   810  			return
   811  		}
   812  
   813  		v, ok := token.Get(`x-birthday`)
   814  		if !assert.True(t, ok, `token.Get("x-birthday") should succeed`) {
   815  			return
   816  		}
   817  
   818  		if !assert.Equal(t, expected, v, `values should match`) {
   819  			return
   820  		}
   821  	})
   822  	t.Run("json.Unmarshal", func(t *testing.T) {
   823  		token := jwt.New()
   824  		if !assert.NoError(t, json.Unmarshal([]byte(src), token), `json.Unmarshal should succeed`) {
   825  			return
   826  		}
   827  
   828  		v, ok := token.Get(`x-birthday`)
   829  		if !assert.True(t, ok, `token.Get("x-birthday") should succeed`) {
   830  			return
   831  		}
   832  
   833  		if !assert.Equal(t, expected, v, `values should match`) {
   834  			return
   835  		}
   836  	})
   837  }
   838  
   839  func TestParseRequest(t *testing.T) {
   840  	const u = "https://github.com/lestrrat-gow/jwx/jwt"
   841  
   842  	privkey, _ := jwxtest.GenerateEcdsaJwk()
   843  	privkey.Set(jwk.AlgorithmKey, jwa.ES256)
   844  	privkey.Set(jwk.KeyIDKey, `my-awesome-key`)
   845  	pubkey, _ := jwk.PublicKeyOf(privkey)
   846  	pubkey.Set(jwk.AlgorithmKey, jwa.ES256)
   847  
   848  	tok := jwt.New()
   849  	tok.Set(jwt.IssuerKey, u)
   850  	tok.Set(jwt.IssuedAtKey, time.Now().Round(0))
   851  
   852  	signed, _ := jwt.Sign(tok, jwa.ES256, privkey)
   853  
   854  	testcases := []struct {
   855  		Request func() *http.Request
   856  		Parse   func(*http.Request) (jwt.Token, error)
   857  		Name    string
   858  		Error   bool
   859  	}{
   860  		{
   861  			Name: "Token not present (w/ multiple options)",
   862  			Request: func() *http.Request {
   863  				return httptest.NewRequest(http.MethodGet, u, nil)
   864  			},
   865  			Parse: func(req *http.Request) (jwt.Token, error) {
   866  				return jwt.ParseRequest(req,
   867  					jwt.WithHeaderKey("Authorization"),
   868  					jwt.WithHeaderKey("x-authorization"),
   869  					jwt.WithFormKey("access_token"),
   870  					jwt.WithFormKey("token"),
   871  					jwt.WithVerify(jwa.ES256, pubkey))
   872  			},
   873  			Error: true,
   874  		},
   875  		{
   876  			Name: "Token not present (w/o options)",
   877  			Request: func() *http.Request {
   878  				return httptest.NewRequest(http.MethodGet, u, nil)
   879  			},
   880  			Parse: func(req *http.Request) (jwt.Token, error) {
   881  				return jwt.ParseRequest(req, jwt.WithVerify(jwa.ES256, pubkey))
   882  			},
   883  			Error: true,
   884  		},
   885  		{
   886  			Name: "Token in Authorization header (w/o extra options)",
   887  			Request: func() *http.Request {
   888  				req := httptest.NewRequest(http.MethodGet, u, nil)
   889  				req.Header.Add("Authorization", "Bearer "+string(signed))
   890  				return req
   891  			},
   892  			Parse: func(req *http.Request) (jwt.Token, error) {
   893  				return jwt.ParseRequest(req, jwt.WithVerify(jwa.ES256, pubkey))
   894  			},
   895  		},
   896  		{
   897  			Name: "Token in Authorization header (w/o extra options, using jwk.Set)",
   898  			Request: func() *http.Request {
   899  				req := httptest.NewRequest(http.MethodGet, u, nil)
   900  				req.Header.Add("Authorization", "Bearer "+string(signed))
   901  				return req
   902  			},
   903  			Parse: func(req *http.Request) (jwt.Token, error) {
   904  				set := jwk.NewSet()
   905  				set.Add(pubkey)
   906  				return jwt.ParseRequest(req, jwt.WithKeySet(set))
   907  			},
   908  		},
   909  		{
   910  			Name: "Token in Authorization header but we specified another header key",
   911  			Request: func() *http.Request {
   912  				req := httptest.NewRequest(http.MethodGet, u, nil)
   913  				req.Header.Add("Authorization", "Bearer "+string(signed))
   914  				return req
   915  			},
   916  			Parse: func(req *http.Request) (jwt.Token, error) {
   917  				return jwt.ParseRequest(req, jwt.WithHeaderKey("x-authorization"), jwt.WithVerify(jwa.ES256, pubkey))
   918  			},
   919  			Error: true,
   920  		},
   921  		{
   922  			Name: "Token in x-authorization header (w/ option)",
   923  			Request: func() *http.Request {
   924  				req := httptest.NewRequest(http.MethodGet, u, nil)
   925  				req.Header.Add("x-authorization", string(signed))
   926  				return req
   927  			},
   928  			Parse: func(req *http.Request) (jwt.Token, error) {
   929  				return jwt.ParseRequest(req, jwt.WithHeaderKey("x-authorization"), jwt.WithVerify(jwa.ES256, pubkey))
   930  			},
   931  		},
   932  		{
   933  			Name: "Invalid token in x-authorization header",
   934  			Request: func() *http.Request {
   935  				req := httptest.NewRequest(http.MethodGet, u, nil)
   936  				req.Header.Add("x-authorization", string(signed)+"foobarbaz")
   937  				return req
   938  			},
   939  			Parse: func(req *http.Request) (jwt.Token, error) {
   940  				return jwt.ParseRequest(req, jwt.WithHeaderKey("x-authorization"), jwt.WithVerify(jwa.ES256, pubkey))
   941  			},
   942  			Error: true,
   943  		},
   944  		{
   945  			Name: "Token in access_token form field (w/ option)",
   946  			Request: func() *http.Request {
   947  				req := httptest.NewRequest(http.MethodPost, u, nil)
   948  				// for whatever reason, I can't populate req.Body and get this to work
   949  				// so populating req.Form directly instead
   950  				req.Form = url.Values{}
   951  				req.Form.Add("access_token", string(signed))
   952  				return req
   953  			},
   954  			Parse: func(req *http.Request) (jwt.Token, error) {
   955  				return jwt.ParseRequest(req, jwt.WithFormKey("access_token"), jwt.WithVerify(jwa.ES256, pubkey))
   956  			},
   957  		},
   958  		{
   959  			Name: "Token in access_token form field (w/o option)",
   960  			Request: func() *http.Request {
   961  				req := httptest.NewRequest(http.MethodPost, u, nil)
   962  				// for whatever reason, I can't populate req.Body and get this to work
   963  				// so populating req.Form directly instead
   964  				req.Form = url.Values{}
   965  				req.Form.Add("access_token", string(signed))
   966  				return req
   967  			},
   968  			Parse: func(req *http.Request) (jwt.Token, error) {
   969  				return jwt.ParseRequest(req, jwt.WithVerify(jwa.ES256, pubkey))
   970  			},
   971  			Error: true,
   972  		},
   973  		{
   974  			Name: "Invalid token in access_token form field",
   975  			Request: func() *http.Request {
   976  				req := httptest.NewRequest(http.MethodPost, u, nil)
   977  				// for whatever reason, I can't populate req.Body and get this to work
   978  				// so populating req.Form directly instead
   979  				req.Form = url.Values{}
   980  				req.Form.Add("access_token", string(signed)+"foobarbarz")
   981  				return req
   982  			},
   983  			Parse: func(req *http.Request) (jwt.Token, error) {
   984  				return jwt.ParseRequest(req, jwt.WithVerify(jwa.ES256, pubkey), jwt.WithFormKey("access_token"))
   985  			},
   986  			Error: true,
   987  		},
   988  	}
   989  
   990  	for _, tc := range testcases {
   991  		tc := tc
   992  		t.Run(tc.Name, func(t *testing.T) {
   993  			got, err := tc.Parse(tc.Request())
   994  			if tc.Error {
   995  				t.Logf("%s", err)
   996  				assert.Error(t, err, `tc.Parse should fail`)
   997  				return
   998  			}
   999  
  1000  			if !assert.NoError(t, err, `tc.Parse should succeed`) {
  1001  				return
  1002  			}
  1003  
  1004  			if !assert.True(t, jwt.Equal(tok, got), `tokens should match`) {
  1005  				{
  1006  					buf, _ := json.MarshalIndent(tok, "", "  ")
  1007  					t.Logf("expected: %s", buf)
  1008  				}
  1009  				{
  1010  					buf, _ := json.MarshalIndent(got, "", "  ")
  1011  					t.Logf("got: %s", buf)
  1012  				}
  1013  				return
  1014  			}
  1015  		})
  1016  	}
  1017  }
  1018  
  1019  func TestGHIssue368(t *testing.T) {
  1020  	// DO NOT RUN THIS IN PARALLEL
  1021  	for _, flatten := range []bool{true, false} {
  1022  		flatten := flatten
  1023  		t.Run(fmt.Sprintf("Test serialization (WithFlattenAudience(%t))", flatten), func(t *testing.T) {
  1024  			jwt.Settings(jwt.WithFlattenAudience(flatten))
  1025  
  1026  			t.Run("Single Key", func(t *testing.T) {
  1027  				tok := jwt.New()
  1028  				_ = tok.Set(jwt.AudienceKey, "hello")
  1029  
  1030  				buf, err := json.MarshalIndent(tok, "", "  ")
  1031  				if !assert.NoError(t, err, `json.MarshalIndent should succeed`) {
  1032  					return
  1033  				}
  1034  
  1035  				var expected string
  1036  				if flatten {
  1037  					expected = `{
  1038    "aud": "hello"
  1039  }`
  1040  				} else {
  1041  					expected = `{
  1042    "aud": [
  1043      "hello"
  1044    ]
  1045  }`
  1046  				}
  1047  
  1048  				if !assert.Equal(t, expected, string(buf), `output should match`) {
  1049  					return
  1050  				}
  1051  			})
  1052  			t.Run("Multiple Keys", func(t *testing.T) {
  1053  				tok, err := jwt.NewBuilder().
  1054  					Audience([]string{"hello", "world"}).
  1055  					Build()
  1056  				if !assert.NoError(t, err, `jwt.Builder should succeed`) {
  1057  					return
  1058  				}
  1059  
  1060  				buf, err := json.MarshalIndent(tok, "", "  ")
  1061  				if !assert.NoError(t, err, `json.MarshalIndent should succeed`) {
  1062  					return
  1063  				}
  1064  
  1065  				const expected = `{
  1066    "aud": [
  1067      "hello",
  1068      "world"
  1069    ]
  1070  }`
  1071  
  1072  				if !assert.Equal(t, expected, string(buf), `output should match`) {
  1073  					return
  1074  				}
  1075  			})
  1076  		})
  1077  	}
  1078  }
  1079  
  1080  func TestGH375(t *testing.T) {
  1081  	key, err := jwxtest.GenerateRsaJwk()
  1082  	if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) {
  1083  		return
  1084  	}
  1085  	key.Set(jwk.KeyIDKey, `test`)
  1086  
  1087  	token, err := jwt.NewBuilder().
  1088  		Issuer(`foobar`).
  1089  		Build()
  1090  	if !assert.NoError(t, err, `jwt.Builder should succeed`) {
  1091  		return
  1092  	}
  1093  
  1094  	signAlg := jwa.RS512
  1095  	signed, err := jwt.Sign(token, signAlg, key)
  1096  	if !assert.NoError(t, err, `jwt.Sign should succeed`) {
  1097  		return
  1098  	}
  1099  
  1100  	verifyKey, err := jwk.PublicKeyOf(key)
  1101  	if !assert.NoError(t, err, `jwk.PublicKeyOf should succeed`) {
  1102  		return
  1103  	}
  1104  
  1105  	verifyKey.Set(jwk.KeyIDKey, `test`)
  1106  	verifyKey.Set(jwk.AlgorithmKey, jwa.RS256) // != jwa.RS512
  1107  
  1108  	ks := jwk.NewSet()
  1109  	ks.Add(verifyKey)
  1110  
  1111  	_, err = jwt.Parse(signed, jwt.WithKeySet(ks))
  1112  	if !assert.Error(t, err, `jwt.Parse should fail`) {
  1113  		return
  1114  	}
  1115  }
  1116  
  1117  type Claim struct {
  1118  	Foo string
  1119  	Bar int
  1120  }
  1121  
  1122  func TestJWTParseWithTypedClaim(t *testing.T) {
  1123  	testcases := []struct {
  1124  		Name        string
  1125  		Options     []jwt.ParseOption
  1126  		PostProcess func(*testing.T, interface{}) (*Claim, error)
  1127  	}{
  1128  		{
  1129  			Name:    "Basic",
  1130  			Options: []jwt.ParseOption{jwt.WithTypedClaim("typed-claim", Claim{})},
  1131  			PostProcess: func(t *testing.T, claim interface{}) (*Claim, error) {
  1132  				t.Helper()
  1133  				v, ok := claim.(Claim)
  1134  				if !ok {
  1135  					return nil, errors.Errorf(`claim value should be of type "Claim", but got %T`, claim)
  1136  				}
  1137  				return &v, nil
  1138  			},
  1139  		},
  1140  		{
  1141  			Name:    "json.RawMessage",
  1142  			Options: []jwt.ParseOption{jwt.WithTypedClaim("typed-claim", json.RawMessage{})},
  1143  			PostProcess: func(t *testing.T, claim interface{}) (*Claim, error) {
  1144  				t.Helper()
  1145  				v, ok := claim.(json.RawMessage)
  1146  				if !ok {
  1147  					return nil, errors.Errorf(`claim value should be of type "json.RawMessage", but got %T`, claim)
  1148  				}
  1149  
  1150  				var c Claim
  1151  				if err := json.Unmarshal(v, &c); err != nil {
  1152  					return nil, errors.Wrap(err, `json.Unmarshal failed`)
  1153  				}
  1154  
  1155  				return &c, nil
  1156  			},
  1157  		},
  1158  	}
  1159  
  1160  	expected := &Claim{Foo: "Foo", Bar: 0xdeadbeef}
  1161  	key, err := jwxtest.GenerateRsaKey()
  1162  	if !assert.NoError(t, err, `jwxtest.GenerateRsaKey should succeed`) {
  1163  		return
  1164  	}
  1165  
  1166  	var signed []byte
  1167  	{
  1168  		token := jwt.New()
  1169  		if !assert.NoError(t, token.Set("typed-claim", expected), `expected.Set should succeed`) {
  1170  			return
  1171  		}
  1172  		v, err := jwt.Sign(token, jwa.RS256, key)
  1173  		if !assert.NoError(t, err, `jws.Sign should succeed`) {
  1174  			return
  1175  		}
  1176  		signed = v
  1177  	}
  1178  
  1179  	for _, tc := range testcases {
  1180  		tc := tc
  1181  		t.Run(tc.Name, func(t *testing.T) {
  1182  			got, err := jwt.Parse(signed, tc.Options...)
  1183  			if !assert.NoError(t, err, `jwt.Parse should succeed`) {
  1184  				return
  1185  			}
  1186  
  1187  			v, ok := got.Get("typed-claim")
  1188  			if !assert.True(t, ok, `got.Get() should succeed`) {
  1189  				return
  1190  			}
  1191  			claim, err := tc.PostProcess(t, v)
  1192  			if !assert.NoError(t, err, `tc.PostProcess should succeed`) {
  1193  				return
  1194  			}
  1195  
  1196  			if !assert.Equal(t, claim, expected, `claim should match expected value`) {
  1197  				return
  1198  			}
  1199  		})
  1200  	}
  1201  }
  1202  
  1203  func TestGH393(t *testing.T) {
  1204  	t.Run("Non-existent required claims", func(t *testing.T) {
  1205  		tok := jwt.New()
  1206  		if !assert.Error(t, jwt.Validate(tok, jwt.WithRequiredClaim(jwt.IssuedAtKey)), `jwt.Validate should fail`) {
  1207  			return
  1208  		}
  1209  	})
  1210  	t.Run("exp - iat < WithMaxDelta(10 secs)", func(t *testing.T) {
  1211  		now := time.Now()
  1212  		tok, err := jwt.NewBuilder().
  1213  			IssuedAt(now).
  1214  			Expiration(now.Add(5 * time.Second)).
  1215  			Build()
  1216  		if !assert.NoError(t, err, `jwt.Builder should succeed`) {
  1217  			return
  1218  		}
  1219  
  1220  		if !assert.Error(t, jwt.Validate(tok, jwt.WithMaxDelta(2*time.Second, jwt.ExpirationKey, jwt.IssuedAtKey)), `jwt.Validate should fail`) {
  1221  			return
  1222  		}
  1223  
  1224  		if !assert.NoError(t, jwt.Validate(tok, jwt.WithMaxDelta(10*time.Second, jwt.ExpirationKey, jwt.IssuedAtKey)), `jwt.Validate should succeed`) {
  1225  			return
  1226  		}
  1227  	})
  1228  	t.Run("iat - exp (5 secs) < WithMinDelta(10 secs)", func(t *testing.T) {
  1229  		now := time.Now()
  1230  		tok, err := jwt.NewBuilder().
  1231  			IssuedAt(now).
  1232  			Expiration(now.Add(5 * time.Second)).
  1233  			Build()
  1234  		if !assert.NoError(t, err, `jwt.Builder should succeed`) {
  1235  			return
  1236  		}
  1237  
  1238  		if !assert.Error(t, jwt.Validate(tok, jwt.WithMinDelta(10*time.Second, jwt.ExpirationKey, jwt.IssuedAtKey)), `jwt.Validate should fail`) {
  1239  			return
  1240  		}
  1241  	})
  1242  	t.Run("iat - exp (5 secs) > WithMinDelta(10 secs)", func(t *testing.T) {
  1243  		now := time.Now()
  1244  		tok, err := jwt.NewBuilder().
  1245  			IssuedAt(now).
  1246  			Expiration(now.Add(5 * time.Second)).
  1247  			Build()
  1248  		if !assert.NoError(t, err, `jwt.Builder should succeed`) {
  1249  			return
  1250  		}
  1251  
  1252  		if !assert.NoError(t, jwt.Validate(tok, jwt.WithMinDelta(10*time.Second, jwt.ExpirationKey, jwt.IssuedAtKey), jwt.WithAcceptableSkew(5*time.Second)), `jwt.Validate should succeed`) {
  1253  			return
  1254  		}
  1255  	})
  1256  	t.Run("now - iat < WithMaxDelta(10 secs)", func(t *testing.T) {
  1257  		now := time.Now()
  1258  		tok, err := jwt.NewBuilder().
  1259  			IssuedAt(now).
  1260  			Build()
  1261  		if !assert.NoError(t, err, `jwt.Builder should succeed`) {
  1262  			return
  1263  		}
  1264  
  1265  		if !assert.NoError(t, jwt.Validate(tok, jwt.WithMaxDelta(10*time.Second, "", jwt.IssuedAtKey), jwt.WithClock(jwt.ClockFunc(func() time.Time { return now.Add(5 * time.Second) }))), `jwt.Validate should succeed`) {
  1266  			return
  1267  		}
  1268  	})
  1269  	t.Run("invalid claim name (c1)", func(t *testing.T) {
  1270  		now := time.Now()
  1271  		tok, err := jwt.NewBuilder().
  1272  			Claim("foo", now).
  1273  			Expiration(now.Add(5 * time.Second)).
  1274  			Build()
  1275  		if !assert.NoError(t, err, `jwt.Builder should succeed`) {
  1276  			return
  1277  		}
  1278  
  1279  		if !assert.Error(t, jwt.Validate(tok, jwt.WithMinDelta(10*time.Second, jwt.ExpirationKey, "foo"), jwt.WithAcceptableSkew(5*time.Second)), `jwt.Validate should fail`) {
  1280  			return
  1281  		}
  1282  	})
  1283  	t.Run("invalid claim name (c2)", func(t *testing.T) {
  1284  		now := time.Now()
  1285  		tok, err := jwt.NewBuilder().
  1286  			Claim("foo", now.Add(5*time.Second)).
  1287  			IssuedAt(now).
  1288  			Build()
  1289  		if !assert.NoError(t, err, `jwt.Builder should succeed`) {
  1290  			return
  1291  		}
  1292  
  1293  		if !assert.Error(t, jwt.Validate(tok, jwt.WithMinDelta(10*time.Second, "foo", jwt.IssuedAtKey), jwt.WithAcceptableSkew(5*time.Second)), `jwt.Validate should fail`) {
  1294  			return
  1295  		}
  1296  	})
  1297  
  1298  	// Following tests deviate a little from the original issue, but
  1299  	// since they were added for the same issue, we just bundle the
  1300  	// tests together
  1301  	t.Run(`WithRequiredClaim fails for non-existent claim`, func(t *testing.T) {
  1302  		tok := jwt.New()
  1303  		if !assert.Error(t, jwt.Validate(tok, jwt.WithRequiredClaim("foo")), `jwt.Validate should fail`) {
  1304  			return
  1305  		}
  1306  	})
  1307  	t.Run(`WithRequiredClaim succeeds for existing claim`, func(t *testing.T) {
  1308  		tok, err := jwt.NewBuilder().
  1309  			Claim(`foo`, 1).
  1310  			Build()
  1311  		if !assert.NoError(t, err, `jwt.Builder should succeed`) {
  1312  			return
  1313  		}
  1314  		if !assert.NoError(t, jwt.Validate(tok, jwt.WithRequiredClaim("foo")), `jwt.Validate should fail`) {
  1315  			return
  1316  		}
  1317  	})
  1318  }
  1319  
  1320  func TestNested(t *testing.T) {
  1321  	key, err := jwxtest.GenerateRsaKey()
  1322  	if !assert.NoError(t, err, `jwxtest.GenerateRsaKey should succeed`) {
  1323  		return
  1324  	}
  1325  
  1326  	token, err := jwt.NewBuilder().
  1327  		Issuer(`https://github.com/lestrrat-go/jwx`).
  1328  		Build()
  1329  	if !assert.NoError(t, err, `jwt.Builder should succeed`) {
  1330  		return
  1331  	}
  1332  
  1333  	serialized, err := jwt.NewSerializer().
  1334  		Sign(jwa.RS256, key).
  1335  		Encrypt(jwa.RSA_OAEP, key.PublicKey, jwa.A256GCM, jwa.NoCompress).
  1336  		Serialize(token)
  1337  
  1338  	if !assert.NoError(t, err, `jwt.NewSerializer should succeed`) {
  1339  		return
  1340  	}
  1341  
  1342  	// First layer should be JWE
  1343  	jweMessage := jwe.NewMessage()
  1344  	decrypted, err := jwe.Decrypt(serialized, jwa.RSA_OAEP, key, jwe.WithMessage(jweMessage))
  1345  	if !assert.NoError(t, err, `jwe.Decrypt should succeed`) {
  1346  		return
  1347  	}
  1348  
  1349  	// The message should have cty = JWT
  1350  	cty := jweMessage.ProtectedHeaders().ContentType()
  1351  	if !assert.Equal(t, cty, `JWT`, `cty should be JWT`) {
  1352  		return
  1353  	}
  1354  
  1355  	// Second layer should JWS.
  1356  	jwsMessage := jws.NewMessage()
  1357  	verified, err := jws.Verify(decrypted, jwa.RS256, key.PublicKey, jws.WithMessage(jwsMessage))
  1358  	if !assert.NoError(t, err, `jws.Verify should succeed`) {
  1359  		return
  1360  	}
  1361  
  1362  	typ := jwsMessage.Signatures()[0].ProtectedHeaders().Type()
  1363  	if !assert.Equal(t, typ, `JWT`, `cty should be JWT`) {
  1364  		return
  1365  	}
  1366  
  1367  	t.Logf("%s", verified)
  1368  
  1369  	parsed, err := jwt.Parse(serialized,
  1370  		jwt.WithPedantic(true),
  1371  		jwt.WithVerify(jwa.RS256, key.PublicKey),
  1372  		jwt.WithDecrypt(jwa.RSA_OAEP, key),
  1373  	)
  1374  	if !assert.NoError(t, err, `jwt.Parse with both decryption and verification should succeed`) {
  1375  		return
  1376  	}
  1377  	_ = parsed
  1378  }
  1379  
  1380  func TestRFC7797(t *testing.T) {
  1381  	key, err := jwxtest.GenerateRsaKey()
  1382  	if !assert.NoError(t, err, `jwxtest.GenerateRsaKey should succeed`) {
  1383  		return
  1384  	}
  1385  
  1386  	hdrs := jws.NewHeaders()
  1387  	hdrs.Set("b64", false)
  1388  	hdrs.Set("crit", "b64")
  1389  
  1390  	token := jwt.New()
  1391  	token.Set(jwt.AudienceKey, `foo`)
  1392  
  1393  	_, err = jwt.Sign(token, jwa.RS256, key, jwt.WithJwsHeaders(hdrs))
  1394  	if !assert.Error(t, err, `jwt.Sign should fail`) {
  1395  		return
  1396  	}
  1397  }
  1398  
  1399  func TestGH430(t *testing.T) {
  1400  	t1 := jwt.New()
  1401  	err := t1.Set("payload", map[string]interface{}{
  1402  		"name": "someone",
  1403  	})
  1404  	if !assert.NoError(t, err, `t1.Set should succeed`) {
  1405  		return
  1406  	}
  1407  
  1408  	key := []byte("secret")
  1409  	signed, err := jwt.Sign(t1, jwa.HS256, key)
  1410  	if !assert.NoError(t, err, `jwt.Sign should succeed`) {
  1411  		return
  1412  	}
  1413  
  1414  	if _, err = jwt.Parse(signed, jwt.WithVerify(jwa.HS256, key)); !assert.NoError(t, err, `jwt.Parse should succeed`) {
  1415  		return
  1416  	}
  1417  }
  1418  
  1419  func TestBenHigginsByPassRegression(t *testing.T) {
  1420  	key, err := rsa.GenerateKey(rand.Reader, 2048)
  1421  	if err != nil {
  1422  		panic(err)
  1423  	}
  1424  	// Test if an access token JSON payload parses when provided directly
  1425  	//
  1426  	// The JSON below is slightly modified example payload from:
  1427  	// https://docs.aws.amazon.com/cognito/latest/developerguide/amazon-cognito-user-pools-using-the-access-token.html
  1428  
  1429  	// Case 1: add "aud", and adjust exp to be valid
  1430  	// Case 2: do not add "aud", adjust exp
  1431  
  1432  	exp := strconv.Itoa(int(time.Now().Unix()) + 1000)
  1433  	const tmpl = `{%s
  1434      "sub": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
  1435      "device_key": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
  1436      "cognito:groups": ["admin"],
  1437      "token_use": "access",
  1438      "scope": "aws.cognito.signin.user.admin",
  1439      "auth_time": 1562190524,
  1440      "iss": "https://cognito-idp.us-west-2.amazonaws.com/us-west-2_example",
  1441      "exp": %s,
  1442      "iat": 1562190524,
  1443      "origin_jti": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
  1444      "jti": "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee",
  1445      "client_id": "57cbishk4j24pabc1234567890",
  1446      "username": "janedoe@example.com"
  1447    }`
  1448  
  1449  	testcases := [][]byte{
  1450  		[]byte(fmt.Sprintf(tmpl, `"aud": ["test"],`, exp)),
  1451  		[]byte(fmt.Sprintf(tmpl, ``, exp)),
  1452  	}
  1453  
  1454  	for _, tc := range testcases {
  1455  		for _, pedantic := range []bool{true, false} {
  1456  			_, err = jwt.Parse(
  1457  				tc,
  1458  				jwt.WithValidate(true),
  1459  				jwt.WithPedantic(pedantic),
  1460  				jwt.WithVerify(jwa.RS256, &key.PublicKey),
  1461  			)
  1462  			t.Logf("%s", err)
  1463  			if !assert.Error(t, err, `jwt.Parse should fail`) {
  1464  				return
  1465  			}
  1466  		}
  1467  	}
  1468  }
  1469  
  1470  func TestVerifyAuto(t *testing.T) {
  1471  	key, err := jwxtest.GenerateRsaJwk()
  1472  	if !assert.NoError(t, err, `jwxtest.GenerateRsaJwk should succeed`) {
  1473  		return
  1474  	}
  1475  
  1476  	key.Set(jwk.KeyIDKey, `my-awesome-key`)
  1477  
  1478  	pubkey, err := jwk.PublicKeyOf(key)
  1479  	if !assert.NoError(t, err, `jwk.PublicKeyOf should succeed`) {
  1480  		return
  1481  	}
  1482  	set := jwk.NewSet()
  1483  	set.Add(pubkey)
  1484  	backoffCount := 0
  1485  	srv := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  1486  		switch r.URL.Query().Get(`type`) {
  1487  		case "backoff":
  1488  			backoffCount++
  1489  			if backoffCount == 1 {
  1490  				w.WriteHeader(http.StatusInternalServerError)
  1491  				return
  1492  			}
  1493  		}
  1494  		w.WriteHeader(http.StatusOK)
  1495  		json.NewEncoder(w).Encode(set)
  1496  	}))
  1497  	defer srv.Close()
  1498  
  1499  	tok, err := jwt.NewBuilder().
  1500  		Claim(jwt.IssuerKey, `https://github.com/lestrrat-go/jwx`).
  1501  		Claim(jwt.SubjectKey, `jku-test`).
  1502  		Build()
  1503  
  1504  	if !assert.NoError(t, err, `jwt.NewBuilder.Build() should succeed`) {
  1505  		return
  1506  	}
  1507  
  1508  	hdrs := jws.NewHeaders()
  1509  	hdrs.Set(jws.JWKSetURLKey, srv.URL)
  1510  
  1511  	signed, err := jwt.Sign(tok, jwa.RS256, key, jwt.WithHeaders(hdrs))
  1512  	if !assert.NoError(t, err, `jwt.Sign() should succeed`) {
  1513  		return
  1514  	}
  1515  
  1516  	wl := jwk.NewMapWhitelist().
  1517  		Add(srv.URL)
  1518  
  1519  	parsed, err := jwt.Parse(signed, jwt.WithVerifyAuto(true), jwt.WithFetchWhitelist(wl), jwt.WithHTTPClient(srv.Client()))
  1520  	if !assert.NoError(t, err, `jwt.Parse should succeed`) {
  1521  		return
  1522  	}
  1523  
  1524  	if !assert.True(t, jwt.Equal(tok, parsed), `tokens should be equal`) {
  1525  		return
  1526  	}
  1527  
  1528  	_, err = jwt.Parse(signed, jwt.WithVerifyAuto(true))
  1529  	if !assert.Error(t, err, `jwt.Parse should fail`) {
  1530  		return
  1531  	}
  1532  	wl = jwk.NewMapWhitelist().
  1533  		Add(`https://github.com/lestrrat-go/jwx`)
  1534  	_, err = jwt.Parse(signed, jwt.WithVerifyAuto(true), jwt.WithFetchWhitelist(wl))
  1535  	if !assert.Error(t, err, `jwt.Parse should fail`) {
  1536  		return
  1537  	}
  1538  
  1539  	// now with backoff
  1540  	bo := backoff.NewConstantPolicy(backoff.WithInterval(500 * time.Millisecond))
  1541  	parsed, err = jwt.Parse(signed,
  1542  		jwt.WithVerifyAuto(true),
  1543  		jwt.WithFetchWhitelist(jwk.InsecureWhitelist{}),
  1544  		jwt.WithHTTPClient(srv.Client()),
  1545  		jwt.WithFetchBackoff(bo),
  1546  	)
  1547  	if !assert.NoError(t, err, `jwt.Parse should succeed`) {
  1548  		return
  1549  	}
  1550  
  1551  	if !assert.True(t, jwt.Equal(tok, parsed), `tokens should be equal`) {
  1552  		return
  1553  	}
  1554  
  1555  	// now with AutoRefresh
  1556  	ar := jwk.NewAutoRefresh(context.TODO())
  1557  	parsed, err = jwt.Parse(signed,
  1558  		jwt.WithVerifyAuto(true),
  1559  		jwt.WithJWKSetFetcher(jws.JWKSetFetchFunc(func(u string) (jwk.Set, error) {
  1560  			ar.Configure(u,
  1561  				jwk.WithHTTPClient(srv.Client()),
  1562  				jwk.WithFetchWhitelist(jwk.InsecureWhitelist{}),
  1563  			)
  1564  			return ar.Fetch(context.TODO(), u)
  1565  		})),
  1566  	)
  1567  	if !assert.NoError(t, err, `jwt.Parse should succeed`) {
  1568  		return
  1569  	}
  1570  
  1571  	if !assert.True(t, jwt.Equal(tok, parsed), `tokens should be equal`) {
  1572  		return
  1573  	}
  1574  }
  1575  

View as plain text