1
2
3
4 package jwt
5
6 import (
7 "bytes"
8 "io"
9 "io/ioutil"
10 "net/http"
11 "strings"
12 "sync/atomic"
13
14 "github.com/lestrrat-go/backoff/v2"
15 "github.com/lestrrat-go/jwx"
16 "github.com/lestrrat-go/jwx/internal/json"
17 "github.com/lestrrat-go/jwx/jwe"
18
19 "github.com/lestrrat-go/jwx/jwa"
20 "github.com/lestrrat-go/jwx/jwk"
21 "github.com/lestrrat-go/jwx/jws"
22 "github.com/pkg/errors"
23 )
24
25 const _jwt = `jwt`
26
27
28 func Settings(options ...GlobalOption) {
29 var flattenAudienceBool bool
30
31
32 for _, option := range options {
33 switch option.Ident() {
34 case identFlattenAudience{}:
35 flattenAudienceBool = option.Value().(bool)
36 }
37 }
38
39 v := atomic.LoadUint32(&json.FlattenAudience)
40 if (v == 1) != flattenAudienceBool {
41 var newVal uint32
42 if flattenAudienceBool {
43 newVal = 1
44 }
45 atomic.CompareAndSwapUint32(&json.FlattenAudience, v, newVal)
46 }
47 }
48
49 var registry = json.NewRegistry()
50
51
52 func ParseString(s string, options ...ParseOption) (Token, error) {
53 return parseBytes([]byte(s), options...)
54 }
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82 func Parse(s []byte, options ...ParseOption) (Token, error) {
83 return parseBytes(s, options...)
84 }
85
86
87 func ParseReader(src io.Reader, options ...ParseOption) (Token, error) {
88
89 data, err := ioutil.ReadAll(src)
90 if err != nil {
91 return nil, errors.Wrap(err, `failed to read from token data source`)
92 }
93 return parseBytes(data, options...)
94 }
95
96 type parseCtx struct {
97 decryptParams DecryptParameters
98 verifyParams VerifyParameters
99 keySet jwk.Set
100 keySetProvider KeySetProvider
101 token Token
102 validateOpts []ValidateOption
103 verifyAutoOpts []jws.VerifyOption
104 localReg *json.Registry
105 inferAlgorithm bool
106 pedantic bool
107 skipVerification bool
108 useDefault bool
109 validate bool
110 verifyAuto bool
111 }
112
113 func parseBytes(data []byte, options ...ParseOption) (Token, error) {
114 var ctx parseCtx
115 for _, o := range options {
116 if v, ok := o.(ValidateOption); ok {
117 ctx.validateOpts = append(ctx.validateOpts, v)
118 continue
119 }
120
121
122 switch o.Ident() {
123 case identVerifyAuto{}:
124 ctx.verifyAuto = o.Value().(bool)
125 case identFetchWhitelist{}:
126 ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithFetchWhitelist(o.Value().(jwk.Whitelist)))
127 case identHTTPClient{}:
128 ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithHTTPClient(o.Value().(*http.Client)))
129 case identFetchBackoff{}:
130 ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithFetchBackoff(o.Value().(backoff.Policy)))
131 case identJWKSetFetcher{}:
132 ctx.verifyAutoOpts = append(ctx.verifyAutoOpts, jws.WithJWKSetFetcher(o.Value().(jws.JWKSetFetcher)))
133 case identVerify{}:
134 ctx.verifyParams = o.Value().(VerifyParameters)
135 case identDecrypt{}:
136 ctx.decryptParams = o.Value().(DecryptParameters)
137 case identKeySet{}:
138 ks, ok := o.Value().(jwk.Set)
139 if !ok {
140 return nil, errors.Errorf(`invalid JWK set passed via WithKeySet() option (%T)`, o.Value())
141 }
142 ctx.keySet = ks
143 case identToken{}:
144 token, ok := o.Value().(Token)
145 if !ok {
146 return nil, errors.Errorf(`invalid token passed via WithToken() option (%T)`, o.Value())
147 }
148 ctx.token = token
149 case identPedantic{}:
150 ctx.pedantic = o.Value().(bool)
151 case identDefault{}:
152 ctx.useDefault = o.Value().(bool)
153 case identValidate{}:
154 ctx.validate = o.Value().(bool)
155 case identTypedClaim{}:
156 pair := o.Value().(claimPair)
157 if ctx.localReg == nil {
158 ctx.localReg = json.NewRegistry()
159 }
160 ctx.localReg.Register(pair.Name, pair.Value)
161 case identInferAlgorithmFromKey{}:
162 ctx.inferAlgorithm = o.Value().(bool)
163 case identKeySetProvider{}:
164 ctx.keySetProvider = o.Value().(KeySetProvider)
165 }
166 }
167
168 data = bytes.TrimSpace(data)
169 return parse(&ctx, data)
170 }
171
172 const (
173 _JwsVerifyInvalid = iota
174 _JwsVerifyDone
175 _JwsVerifyExpectNested
176 _JwsVerifySkipped
177 )
178
179 func verifyJWS(ctx *parseCtx, payload []byte) ([]byte, int, error) {
180 if ctx.verifyAuto {
181 options := ctx.verifyAutoOpts
182 verified, err := jws.VerifyAuto(payload, options...)
183 return verified, _JwsVerifyDone, err
184 }
185
186
187 ks := ctx.keySet
188 p := ctx.keySetProvider
189 if ks != nil || p != nil {
190 return verifyJWSWithKeySet(ctx, payload)
191 }
192
193
194 vp := ctx.verifyParams
195 if vp == nil {
196 return nil, _JwsVerifySkipped, nil
197 }
198
199 return verifyJWSWithParams(ctx, payload, vp.Algorithm(), vp.Key())
200 }
201
202 func verifyJWSWithKeySet(ctx *parseCtx, payload []byte) ([]byte, int, error) {
203
204 msg, err := jws.Parse(payload)
205 if err != nil {
206 return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to parse token data as JWS message`)
207 }
208 ks := ctx.keySet
209 if ks == nil {
210 if p := ctx.keySetProvider; p != nil {
211
212 ctx.skipVerification = true
213 tok, err := parse(ctx, msg.Payload())
214 if err != nil {
215 return nil, _JwsVerifyInvalid, err
216 }
217 ctx.skipVerification = false
218
219 v, err := p.KeySetFrom(tok)
220 if err != nil {
221 return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to obtain jwk.Set from KeySetProvider`)
222 }
223 ks = v
224 }
225 }
226
227
228 if ks.Len() == 0 {
229 return nil, _JwsVerifyInvalid, errors.New(`empty keyset provided`)
230 }
231
232 var key jwk.Key
233
234
235
236 headers := msg.Signatures()[0].ProtectedHeaders()
237 kid := headers.KeyID()
238 if kid == "" {
239
240
241 if !ctx.useDefault {
242 return nil, _JwsVerifyInvalid, errors.New(`failed to find matching key: no key ID ("kid") specified in token`)
243 } else if ctx.useDefault && ks.Len() > 1 {
244 return nil, _JwsVerifyInvalid, errors.New(`failed to find matching key: no key ID ("kid") specified in token but multiple keys available in key set`)
245 }
246
247
248
249 key, _ = ks.Get(0)
250 } else {
251
252 v, ok := ks.LookupKeyID(kid)
253 if !ok {
254 return nil, _JwsVerifyInvalid, errors.Errorf(`failed to find key with key ID %q in key set`, kid)
255 }
256 key = v
257 }
258
259
260
261 if v := key.Algorithm(); v != "" {
262 var alg jwa.SignatureAlgorithm
263 if err := alg.Accept(v); err != nil {
264 return nil, _JwsVerifyInvalid, errors.Wrapf(err, `invalid signature algorithm %s`, key.Algorithm())
265 }
266
267
268 return verifyJWSWithParams(ctx, payload, alg, key)
269 }
270
271 if ctx.inferAlgorithm {
272
273
274 algs, err := jws.AlgorithmsForKey(key)
275 if err != nil {
276 return nil, _JwsVerifyInvalid, errors.Wrapf(err, `failed to get a list of signature methods for key type %s`, key.KeyType())
277 }
278
279 for _, alg := range algs {
280
281 if tokAlg := headers.Algorithm(); tokAlg != "" {
282 if tokAlg != alg {
283 continue
284 }
285 }
286
287 return verifyJWSWithParams(ctx, payload, alg, key)
288 }
289 }
290
291 return nil, _JwsVerifyInvalid, errors.New(`failed to match any of the keys`)
292 }
293
294 func verifyJWSWithParams(ctx *parseCtx, payload []byte, alg jwa.SignatureAlgorithm, key interface{}) ([]byte, int, error) {
295 var m *jws.Message
296 var verifyOpts []jws.VerifyOption
297 if ctx.pedantic {
298 m = jws.NewMessage()
299 verifyOpts = []jws.VerifyOption{jws.WithMessage(m)}
300 }
301 v, err := jws.Verify(payload, alg, key, verifyOpts...)
302 if err != nil {
303 return nil, _JwsVerifyInvalid, errors.Wrap(err, `failed to verify jws signature`)
304 }
305
306 if !ctx.pedantic {
307 return v, _JwsVerifyDone, nil
308 }
309
310
311 for _, sig := range m.Signatures() {
312 hdrs := sig.ProtectedHeaders()
313 if strings.ToLower(hdrs.Type()) == _jwt {
314 return v, _JwsVerifyDone, nil
315 }
316
317 if strings.ToLower(hdrs.ContentType()) == _jwt {
318 return v, _JwsVerifyExpectNested, nil
319 }
320 }
321
322
323 return nil, _JwsVerifyInvalid, errors.Errorf(`expected "typ" or "cty" fields, neither could be found`)
324 }
325
326
327
328 func parse(ctx *parseCtx, data []byte) (Token, error) {
329 payload := data
330 const maxDecodeLevels = 2
331
332
333 var expectNested bool
334
335 OUTER:
336 for i := 0; i < maxDecodeLevels; i++ {
337 switch kind := jwx.GuessFormat(payload); kind {
338 case jwx.JWT:
339 if ctx.pedantic {
340 if expectNested {
341 return nil, errors.Errorf(`expected nested encrypted/signed payload, got raw JWT`)
342 }
343 }
344
345 if i == 0 {
346
347 if !ctx.skipVerification {
348 if _, _, err := verifyJWS(ctx, payload); err != nil {
349 return nil, err
350 }
351 }
352 }
353
354 break OUTER
355 case jwx.UnknownFormat:
356
357
358 if ctx.pedantic {
359 return nil, errors.Errorf(`invalid JWT`)
360 }
361
362 if i == 0 {
363
364 if !ctx.skipVerification {
365 if _, _, err := verifyJWS(ctx, payload); err != nil {
366 return nil, err
367 }
368 }
369 }
370 break OUTER
371 case jwx.JWS:
372
373
374
375
376
377
378 if !ctx.skipVerification {
379
380
381
382 v, state, err := verifyJWS(ctx, payload)
383 if err != nil {
384 return nil, err
385 }
386
387 if state != _JwsVerifySkipped {
388 payload = v
389
390
391 if !ctx.pedantic {
392 continue
393 }
394
395 if state == _JwsVerifyExpectNested {
396 expectNested = true
397 continue OUTER
398 }
399
400
401 break OUTER
402 }
403 }
404
405
406 m, err := jws.Parse(data)
407 if err != nil {
408 return nil, errors.Wrap(err, `invalid jws message`)
409 }
410 payload = m.Payload()
411 case jwx.JWE:
412 dp := ctx.decryptParams
413 if dp == nil {
414 return nil, errors.Errorf(`jwt.Parse: cannot proceed with JWE encrypted payload without decryption parameters`)
415 }
416
417 var m *jwe.Message
418 var decryptOpts []jwe.DecryptOption
419 if ctx.pedantic {
420 m = jwe.NewMessage()
421 decryptOpts = []jwe.DecryptOption{jwe.WithMessage(m)}
422 }
423
424 v, err := jwe.Decrypt(data, dp.Algorithm(), dp.Key(), decryptOpts...)
425 if err != nil {
426 return nil, errors.Wrap(err, `failed to decrypt payload`)
427 }
428
429 if !ctx.pedantic {
430 payload = v
431 continue
432 }
433
434 if strings.ToLower(m.ProtectedHeaders().Type()) == _jwt {
435 payload = v
436 break OUTER
437 }
438
439 if strings.ToLower(m.ProtectedHeaders().ContentType()) == _jwt {
440 expectNested = true
441 payload = v
442 continue OUTER
443 }
444 default:
445 return nil, errors.Errorf(`unsupported format (layer: #%d)`, i+1)
446 }
447 expectNested = false
448 }
449
450 if ctx.token == nil {
451 ctx.token = New()
452 }
453
454 if ctx.localReg != nil {
455 dcToken, ok := ctx.token.(TokenWithDecodeCtx)
456 if !ok {
457 return nil, errors.Errorf(`typed claim was requested, but the token (%T) does not support DecodeCtx`, ctx.token)
458 }
459 dc := json.NewDecodeCtx(ctx.localReg)
460 dcToken.SetDecodeCtx(dc)
461 defer func() { dcToken.SetDecodeCtx(nil) }()
462 }
463
464 if err := json.Unmarshal(payload, ctx.token); err != nil {
465 return nil, errors.Wrap(err, `failed to parse token`)
466 }
467
468 if ctx.validate {
469 if err := Validate(ctx.token, ctx.validateOpts...); err != nil {
470 return nil, err
471 }
472 }
473 return ctx.token, nil
474 }
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492 func Sign(t Token, alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) ([]byte, error) {
493 return NewSerializer().Sign(alg, key, options...).Serialize(t)
494 }
495
496
497
498
499
500
501
502
503
504
505 func Equal(t1, t2 Token) bool {
506 if t1 == nil && t2 == nil {
507 return true
508 }
509
510
511 if t1 == nil || t2 == nil {
512 return false
513 }
514
515 j1, err := json.Marshal(t1)
516 if err != nil {
517 return false
518 }
519
520 j2, err := json.Marshal(t2)
521 if err != nil {
522 return false
523 }
524
525 return bytes.Equal(j1, j2)
526 }
527
528 func (t *stdToken) Clone() (Token, error) {
529 dst := New()
530
531 for _, pair := range t.makePairs() {
532
533 key := pair.Key.(string)
534 if err := dst.Set(key, pair.Value); err != nil {
535 return nil, errors.Wrapf(err, `failed to set %s`, key)
536 }
537 }
538 return dst, nil
539 }
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559 func RegisterCustomField(name string, object interface{}) {
560 registry.Register(name, object)
561 }
562
View as plain text