1
2
3 package jwt
4
5 import (
6 "bytes"
7 "context"
8 "sort"
9 "sync"
10 "time"
11
12 "github.com/lestrrat-go/iter/mapiter"
13 "github.com/lestrrat-go/jwx/internal/base64"
14 "github.com/lestrrat-go/jwx/internal/iter"
15 "github.com/lestrrat-go/jwx/internal/json"
16 "github.com/lestrrat-go/jwx/internal/pool"
17 "github.com/lestrrat-go/jwx/jwt/internal/types"
18 "github.com/pkg/errors"
19 )
20
21 const (
22 AudienceKey = "aud"
23 ExpirationKey = "exp"
24 IssuedAtKey = "iat"
25 IssuerKey = "iss"
26 JwtIDKey = "jti"
27 NotBeforeKey = "nbf"
28 SubjectKey = "sub"
29 )
30
31
32
33
34
35
36
37
38
39
40
41
42
43 type Token interface {
44
45
46 Audience() []string
47
48
49 Expiration() time.Time
50
51
52 IssuedAt() time.Time
53
54
55 Issuer() string
56
57
58 JwtID() string
59
60
61 NotBefore() time.Time
62
63
64 Subject() string
65
66
67
68 PrivateClaims() map[string]interface{}
69
70
71
72
73
74
75
76
77
78 Get(string) (interface{}, bool)
79
80
81
82
83
84 Set(string, interface{}) error
85 Remove(string) error
86 Clone() (Token, error)
87 Iterate(context.Context) Iterator
88 Walk(context.Context, Visitor) error
89 AsMap(context.Context) (map[string]interface{}, error)
90 }
91 type stdToken struct {
92 mu *sync.RWMutex
93 dc DecodeCtx
94 audience types.StringList
95 expiration *types.NumericDate
96 issuedAt *types.NumericDate
97 issuer *string
98 jwtID *string
99 notBefore *types.NumericDate
100 subject *string
101 privateClaims map[string]interface{}
102 }
103
104
105
106
107 func New() Token {
108 return &stdToken{
109 mu: &sync.RWMutex{},
110 privateClaims: make(map[string]interface{}),
111 }
112 }
113
114 func (t *stdToken) Get(name string) (interface{}, bool) {
115 t.mu.RLock()
116 defer t.mu.RUnlock()
117 switch name {
118 case AudienceKey:
119 if t.audience == nil {
120 return nil, false
121 }
122 v := t.audience.Get()
123 return v, true
124 case ExpirationKey:
125 if t.expiration == nil {
126 return nil, false
127 }
128 v := t.expiration.Get()
129 return v, true
130 case IssuedAtKey:
131 if t.issuedAt == nil {
132 return nil, false
133 }
134 v := t.issuedAt.Get()
135 return v, true
136 case IssuerKey:
137 if t.issuer == nil {
138 return nil, false
139 }
140 v := *(t.issuer)
141 return v, true
142 case JwtIDKey:
143 if t.jwtID == nil {
144 return nil, false
145 }
146 v := *(t.jwtID)
147 return v, true
148 case NotBeforeKey:
149 if t.notBefore == nil {
150 return nil, false
151 }
152 v := t.notBefore.Get()
153 return v, true
154 case SubjectKey:
155 if t.subject == nil {
156 return nil, false
157 }
158 v := *(t.subject)
159 return v, true
160 default:
161 v, ok := t.privateClaims[name]
162 return v, ok
163 }
164 }
165
166 func (t *stdToken) Remove(key string) error {
167 t.mu.Lock()
168 defer t.mu.Unlock()
169 switch key {
170 case AudienceKey:
171 t.audience = nil
172 case ExpirationKey:
173 t.expiration = nil
174 case IssuedAtKey:
175 t.issuedAt = nil
176 case IssuerKey:
177 t.issuer = nil
178 case JwtIDKey:
179 t.jwtID = nil
180 case NotBeforeKey:
181 t.notBefore = nil
182 case SubjectKey:
183 t.subject = nil
184 default:
185 delete(t.privateClaims, key)
186 }
187 return nil
188 }
189
190 func (t *stdToken) Set(name string, value interface{}) error {
191 t.mu.Lock()
192 defer t.mu.Unlock()
193 return t.setNoLock(name, value)
194 }
195
196 func (t *stdToken) DecodeCtx() DecodeCtx {
197 t.mu.RLock()
198 defer t.mu.RUnlock()
199 return t.dc
200 }
201
202 func (t *stdToken) SetDecodeCtx(v DecodeCtx) {
203 t.mu.Lock()
204 defer t.mu.Unlock()
205 t.dc = v
206 }
207
208 func (t *stdToken) setNoLock(name string, value interface{}) error {
209 switch name {
210 case AudienceKey:
211 var acceptor types.StringList
212 if err := acceptor.Accept(value); err != nil {
213 return errors.Wrapf(err, `invalid value for %s key`, AudienceKey)
214 }
215 t.audience = acceptor
216 return nil
217 case ExpirationKey:
218 var acceptor types.NumericDate
219 if err := acceptor.Accept(value); err != nil {
220 return errors.Wrapf(err, `invalid value for %s key`, ExpirationKey)
221 }
222 t.expiration = &acceptor
223 return nil
224 case IssuedAtKey:
225 var acceptor types.NumericDate
226 if err := acceptor.Accept(value); err != nil {
227 return errors.Wrapf(err, `invalid value for %s key`, IssuedAtKey)
228 }
229 t.issuedAt = &acceptor
230 return nil
231 case IssuerKey:
232 if v, ok := value.(string); ok {
233 t.issuer = &v
234 return nil
235 }
236 return errors.Errorf(`invalid value for %s key: %T`, IssuerKey, value)
237 case JwtIDKey:
238 if v, ok := value.(string); ok {
239 t.jwtID = &v
240 return nil
241 }
242 return errors.Errorf(`invalid value for %s key: %T`, JwtIDKey, value)
243 case NotBeforeKey:
244 var acceptor types.NumericDate
245 if err := acceptor.Accept(value); err != nil {
246 return errors.Wrapf(err, `invalid value for %s key`, NotBeforeKey)
247 }
248 t.notBefore = &acceptor
249 return nil
250 case SubjectKey:
251 if v, ok := value.(string); ok {
252 t.subject = &v
253 return nil
254 }
255 return errors.Errorf(`invalid value for %s key: %T`, SubjectKey, value)
256 default:
257 if t.privateClaims == nil {
258 t.privateClaims = map[string]interface{}{}
259 }
260 t.privateClaims[name] = value
261 }
262 return nil
263 }
264
265 func (t *stdToken) Audience() []string {
266 t.mu.RLock()
267 defer t.mu.RUnlock()
268 if t.audience != nil {
269 return t.audience.Get()
270 }
271 return nil
272 }
273
274 func (t *stdToken) Expiration() time.Time {
275 t.mu.RLock()
276 defer t.mu.RUnlock()
277 if t.expiration != nil {
278 return t.expiration.Get()
279 }
280 return time.Time{}
281 }
282
283 func (t *stdToken) IssuedAt() time.Time {
284 t.mu.RLock()
285 defer t.mu.RUnlock()
286 if t.issuedAt != nil {
287 return t.issuedAt.Get()
288 }
289 return time.Time{}
290 }
291
292 func (t *stdToken) Issuer() string {
293 t.mu.RLock()
294 defer t.mu.RUnlock()
295 if t.issuer != nil {
296 return *(t.issuer)
297 }
298 return ""
299 }
300
301 func (t *stdToken) JwtID() string {
302 t.mu.RLock()
303 defer t.mu.RUnlock()
304 if t.jwtID != nil {
305 return *(t.jwtID)
306 }
307 return ""
308 }
309
310 func (t *stdToken) NotBefore() time.Time {
311 t.mu.RLock()
312 defer t.mu.RUnlock()
313 if t.notBefore != nil {
314 return t.notBefore.Get()
315 }
316 return time.Time{}
317 }
318
319 func (t *stdToken) Subject() string {
320 t.mu.RLock()
321 defer t.mu.RUnlock()
322 if t.subject != nil {
323 return *(t.subject)
324 }
325 return ""
326 }
327
328 func (t *stdToken) PrivateClaims() map[string]interface{} {
329 t.mu.RLock()
330 defer t.mu.RUnlock()
331 return t.privateClaims
332 }
333
334 func (t *stdToken) makePairs() []*ClaimPair {
335 t.mu.RLock()
336 defer t.mu.RUnlock()
337
338 pairs := make([]*ClaimPair, 0, 7)
339 if t.audience != nil {
340 v := t.audience.Get()
341 pairs = append(pairs, &ClaimPair{Key: AudienceKey, Value: v})
342 }
343 if t.expiration != nil {
344 v := t.expiration.Get()
345 pairs = append(pairs, &ClaimPair{Key: ExpirationKey, Value: v})
346 }
347 if t.issuedAt != nil {
348 v := t.issuedAt.Get()
349 pairs = append(pairs, &ClaimPair{Key: IssuedAtKey, Value: v})
350 }
351 if t.issuer != nil {
352 v := *(t.issuer)
353 pairs = append(pairs, &ClaimPair{Key: IssuerKey, Value: v})
354 }
355 if t.jwtID != nil {
356 v := *(t.jwtID)
357 pairs = append(pairs, &ClaimPair{Key: JwtIDKey, Value: v})
358 }
359 if t.notBefore != nil {
360 v := t.notBefore.Get()
361 pairs = append(pairs, &ClaimPair{Key: NotBeforeKey, Value: v})
362 }
363 if t.subject != nil {
364 v := *(t.subject)
365 pairs = append(pairs, &ClaimPair{Key: SubjectKey, Value: v})
366 }
367 for k, v := range t.privateClaims {
368 pairs = append(pairs, &ClaimPair{Key: k, Value: v})
369 }
370 sort.Slice(pairs, func(i, j int) bool {
371 return pairs[i].Key.(string) < pairs[j].Key.(string)
372 })
373 return pairs
374 }
375
376 func (t *stdToken) UnmarshalJSON(buf []byte) error {
377 t.mu.Lock()
378 defer t.mu.Unlock()
379 t.audience = nil
380 t.expiration = nil
381 t.issuedAt = nil
382 t.issuer = nil
383 t.jwtID = nil
384 t.notBefore = nil
385 t.subject = nil
386 dec := json.NewDecoder(bytes.NewReader(buf))
387 LOOP:
388 for {
389 tok, err := dec.Token()
390 if err != nil {
391 return errors.Wrap(err, `error reading token`)
392 }
393 switch tok := tok.(type) {
394 case json.Delim:
395
396
397 if tok == '}' {
398 break LOOP
399 } else if tok != '{' {
400 return errors.Errorf(`expected '{', but got '%c'`, tok)
401 }
402 case string:
403 switch tok {
404 case AudienceKey:
405 var decoded types.StringList
406 if err := dec.Decode(&decoded); err != nil {
407 return errors.Wrapf(err, `failed to decode value for key %s`, AudienceKey)
408 }
409 t.audience = decoded
410 case ExpirationKey:
411 var decoded types.NumericDate
412 if err := dec.Decode(&decoded); err != nil {
413 return errors.Wrapf(err, `failed to decode value for key %s`, ExpirationKey)
414 }
415 t.expiration = &decoded
416 case IssuedAtKey:
417 var decoded types.NumericDate
418 if err := dec.Decode(&decoded); err != nil {
419 return errors.Wrapf(err, `failed to decode value for key %s`, IssuedAtKey)
420 }
421 t.issuedAt = &decoded
422 case IssuerKey:
423 if err := json.AssignNextStringToken(&t.issuer, dec); err != nil {
424 return errors.Wrapf(err, `failed to decode value for key %s`, IssuerKey)
425 }
426 case JwtIDKey:
427 if err := json.AssignNextStringToken(&t.jwtID, dec); err != nil {
428 return errors.Wrapf(err, `failed to decode value for key %s`, JwtIDKey)
429 }
430 case NotBeforeKey:
431 var decoded types.NumericDate
432 if err := dec.Decode(&decoded); err != nil {
433 return errors.Wrapf(err, `failed to decode value for key %s`, NotBeforeKey)
434 }
435 t.notBefore = &decoded
436 case SubjectKey:
437 if err := json.AssignNextStringToken(&t.subject, dec); err != nil {
438 return errors.Wrapf(err, `failed to decode value for key %s`, SubjectKey)
439 }
440 default:
441 if dc := t.dc; dc != nil {
442 if localReg := dc.Registry(); localReg != nil {
443 decoded, err := localReg.Decode(dec, tok)
444 if err == nil {
445 t.setNoLock(tok, decoded)
446 continue
447 }
448 }
449 }
450 decoded, err := registry.Decode(dec, tok)
451 if err == nil {
452 t.setNoLock(tok, decoded)
453 continue
454 }
455 return errors.Wrapf(err, `could not decode field %s`, tok)
456 }
457 default:
458 return errors.Errorf(`invalid token %T`, tok)
459 }
460 }
461 return nil
462 }
463
464 func (t stdToken) MarshalJSON() ([]byte, error) {
465 t.mu.RLock()
466 defer t.mu.RUnlock()
467 buf := pool.GetBytesBuffer()
468 defer pool.ReleaseBytesBuffer(buf)
469 buf.WriteByte('{')
470 enc := json.NewEncoder(buf)
471 for i, pair := range t.makePairs() {
472 f := pair.Key.(string)
473 if i > 0 {
474 buf.WriteByte(',')
475 }
476 buf.WriteRune('"')
477 buf.WriteString(f)
478 buf.WriteString(`":`)
479 switch f {
480 case AudienceKey:
481 if err := json.EncodeAudience(enc, pair.Value.([]string)); err != nil {
482 return nil, errors.Wrap(err, `failed to encode "aud"`)
483 }
484 continue
485 case ExpirationKey, IssuedAtKey, NotBeforeKey:
486 enc.Encode(pair.Value.(time.Time).Unix())
487 continue
488 }
489 switch v := pair.Value.(type) {
490 case []byte:
491 buf.WriteRune('"')
492 buf.WriteString(base64.EncodeToString(v))
493 buf.WriteRune('"')
494 default:
495 if err := enc.Encode(v); err != nil {
496 return nil, errors.Wrapf(err, `failed to marshal field %s`, f)
497 }
498 buf.Truncate(buf.Len() - 1)
499 }
500 }
501 buf.WriteByte('}')
502 ret := make([]byte, buf.Len())
503 copy(ret, buf.Bytes())
504 return ret, nil
505 }
506
507 func (t *stdToken) Iterate(ctx context.Context) Iterator {
508 pairs := t.makePairs()
509 ch := make(chan *ClaimPair, len(pairs))
510 go func(ctx context.Context, ch chan *ClaimPair, pairs []*ClaimPair) {
511 defer close(ch)
512 for _, pair := range pairs {
513 select {
514 case <-ctx.Done():
515 return
516 case ch <- pair:
517 }
518 }
519 }(ctx, ch, pairs)
520 return mapiter.New(ch)
521 }
522
523 func (t *stdToken) Walk(ctx context.Context, visitor Visitor) error {
524 return iter.WalkMap(ctx, t, visitor)
525 }
526
527 func (t *stdToken) AsMap(ctx context.Context) (map[string]interface{}, error) {
528 return iter.AsMap(ctx, t)
529 }
530
View as plain text