1
2
3 package jws
4
5 import (
6 "bytes"
7 "context"
8 "sort"
9 "sync"
10
11 "github.com/lestrrat-go/jwx/internal/base64"
12 "github.com/lestrrat-go/jwx/internal/json"
13 "github.com/lestrrat-go/jwx/internal/pool"
14 "github.com/lestrrat-go/jwx/jwa"
15 "github.com/lestrrat-go/jwx/jwk"
16 "github.com/pkg/errors"
17 )
18
19 const (
20 AlgorithmKey = "alg"
21 ContentTypeKey = "cty"
22 CriticalKey = "crit"
23 JWKKey = "jwk"
24 JWKSetURLKey = "jku"
25 KeyIDKey = "kid"
26 TypeKey = "typ"
27 X509CertChainKey = "x5c"
28 X509CertThumbprintKey = "x5t"
29 X509CertThumbprintS256Key = "x5t#S256"
30 X509URLKey = "x5u"
31 )
32
33
34 type Headers interface {
35 json.Marshaler
36 json.Unmarshaler
37 Algorithm() jwa.SignatureAlgorithm
38 ContentType() string
39 Critical() []string
40 JWK() jwk.Key
41 JWKSetURL() string
42 KeyID() string
43 Type() string
44 X509CertChain() []string
45 X509CertThumbprint() string
46 X509CertThumbprintS256() string
47 X509URL() string
48 Iterate(ctx context.Context) Iterator
49 Walk(context.Context, Visitor) error
50 AsMap(context.Context) (map[string]interface{}, error)
51 Copy(context.Context, Headers) error
52 Merge(context.Context, Headers) (Headers, error)
53 Get(string) (interface{}, bool)
54 Set(string, interface{}) error
55 Remove(string) error
56
57
58
59
60 PrivateParams() map[string]interface{}
61 }
62
63 type stdHeaders struct {
64 algorithm *jwa.SignatureAlgorithm
65 contentType *string
66 critical []string
67 jwk jwk.Key
68 jwkSetURL *string
69 keyID *string
70 typ *string
71 x509CertChain []string
72 x509CertThumbprint *string
73 x509CertThumbprintS256 *string
74 x509URL *string
75 privateParams map[string]interface{}
76 mu *sync.RWMutex
77 dc DecodeCtx
78 raw []byte
79 }
80
81 func NewHeaders() Headers {
82 return &stdHeaders{
83 mu: &sync.RWMutex{},
84 }
85 }
86
87 func (h *stdHeaders) Algorithm() jwa.SignatureAlgorithm {
88 h.mu.RLock()
89 defer h.mu.RUnlock()
90 if h.algorithm == nil {
91 return ""
92 }
93 return *(h.algorithm)
94 }
95
96 func (h *stdHeaders) ContentType() string {
97 h.mu.RLock()
98 defer h.mu.RUnlock()
99 if h.contentType == nil {
100 return ""
101 }
102 return *(h.contentType)
103 }
104
105 func (h *stdHeaders) Critical() []string {
106 h.mu.RLock()
107 defer h.mu.RUnlock()
108 return h.critical
109 }
110
111 func (h *stdHeaders) JWK() jwk.Key {
112 h.mu.RLock()
113 defer h.mu.RUnlock()
114 return h.jwk
115 }
116
117 func (h *stdHeaders) JWKSetURL() string {
118 h.mu.RLock()
119 defer h.mu.RUnlock()
120 if h.jwkSetURL == nil {
121 return ""
122 }
123 return *(h.jwkSetURL)
124 }
125
126 func (h *stdHeaders) KeyID() string {
127 h.mu.RLock()
128 defer h.mu.RUnlock()
129 if h.keyID == nil {
130 return ""
131 }
132 return *(h.keyID)
133 }
134
135 func (h *stdHeaders) Type() string {
136 h.mu.RLock()
137 defer h.mu.RUnlock()
138 if h.typ == nil {
139 return ""
140 }
141 return *(h.typ)
142 }
143
144 func (h *stdHeaders) X509CertChain() []string {
145 h.mu.RLock()
146 defer h.mu.RUnlock()
147 return h.x509CertChain
148 }
149
150 func (h *stdHeaders) X509CertThumbprint() string {
151 h.mu.RLock()
152 defer h.mu.RUnlock()
153 if h.x509CertThumbprint == nil {
154 return ""
155 }
156 return *(h.x509CertThumbprint)
157 }
158
159 func (h *stdHeaders) X509CertThumbprintS256() string {
160 h.mu.RLock()
161 defer h.mu.RUnlock()
162 if h.x509CertThumbprintS256 == nil {
163 return ""
164 }
165 return *(h.x509CertThumbprintS256)
166 }
167
168 func (h *stdHeaders) X509URL() string {
169 h.mu.RLock()
170 defer h.mu.RUnlock()
171 if h.x509URL == nil {
172 return ""
173 }
174 return *(h.x509URL)
175 }
176
177 func (h *stdHeaders) DecodeCtx() DecodeCtx {
178 h.mu.RLock()
179 defer h.mu.RUnlock()
180 return h.dc
181 }
182
183 func (h *stdHeaders) SetDecodeCtx(dc DecodeCtx) {
184 h.mu.Lock()
185 defer h.mu.Unlock()
186 h.dc = dc
187 }
188
189 func (h *stdHeaders) rawBuffer() []byte {
190 return h.raw
191 }
192
193 func (h *stdHeaders) makePairs() []*HeaderPair {
194 h.mu.RLock()
195 defer h.mu.RUnlock()
196 var pairs []*HeaderPair
197 if h.algorithm != nil {
198 pairs = append(pairs, &HeaderPair{Key: AlgorithmKey, Value: *(h.algorithm)})
199 }
200 if h.contentType != nil {
201 pairs = append(pairs, &HeaderPair{Key: ContentTypeKey, Value: *(h.contentType)})
202 }
203 if h.critical != nil {
204 pairs = append(pairs, &HeaderPair{Key: CriticalKey, Value: h.critical})
205 }
206 if h.jwk != nil {
207 pairs = append(pairs, &HeaderPair{Key: JWKKey, Value: h.jwk})
208 }
209 if h.jwkSetURL != nil {
210 pairs = append(pairs, &HeaderPair{Key: JWKSetURLKey, Value: *(h.jwkSetURL)})
211 }
212 if h.keyID != nil {
213 pairs = append(pairs, &HeaderPair{Key: KeyIDKey, Value: *(h.keyID)})
214 }
215 if h.typ != nil {
216 pairs = append(pairs, &HeaderPair{Key: TypeKey, Value: *(h.typ)})
217 }
218 if h.x509CertChain != nil {
219 pairs = append(pairs, &HeaderPair{Key: X509CertChainKey, Value: h.x509CertChain})
220 }
221 if h.x509CertThumbprint != nil {
222 pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintKey, Value: *(h.x509CertThumbprint)})
223 }
224 if h.x509CertThumbprintS256 != nil {
225 pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintS256Key, Value: *(h.x509CertThumbprintS256)})
226 }
227 if h.x509URL != nil {
228 pairs = append(pairs, &HeaderPair{Key: X509URLKey, Value: *(h.x509URL)})
229 }
230 for k, v := range h.privateParams {
231 pairs = append(pairs, &HeaderPair{Key: k, Value: v})
232 }
233 sort.Slice(pairs, func(i, j int) bool {
234 return pairs[i].Key.(string) < pairs[j].Key.(string)
235 })
236 return pairs
237 }
238
239 func (h *stdHeaders) PrivateParams() map[string]interface{} {
240 h.mu.RLock()
241 defer h.mu.RUnlock()
242 return h.privateParams
243 }
244
245 func (h *stdHeaders) Get(name string) (interface{}, bool) {
246 h.mu.RLock()
247 defer h.mu.RUnlock()
248 switch name {
249 case AlgorithmKey:
250 if h.algorithm == nil {
251 return nil, false
252 }
253 return *(h.algorithm), true
254 case ContentTypeKey:
255 if h.contentType == nil {
256 return nil, false
257 }
258 return *(h.contentType), true
259 case CriticalKey:
260 if h.critical == nil {
261 return nil, false
262 }
263 return h.critical, true
264 case JWKKey:
265 if h.jwk == nil {
266 return nil, false
267 }
268 return h.jwk, true
269 case JWKSetURLKey:
270 if h.jwkSetURL == nil {
271 return nil, false
272 }
273 return *(h.jwkSetURL), true
274 case KeyIDKey:
275 if h.keyID == nil {
276 return nil, false
277 }
278 return *(h.keyID), true
279 case TypeKey:
280 if h.typ == nil {
281 return nil, false
282 }
283 return *(h.typ), true
284 case X509CertChainKey:
285 if h.x509CertChain == nil {
286 return nil, false
287 }
288 return h.x509CertChain, true
289 case X509CertThumbprintKey:
290 if h.x509CertThumbprint == nil {
291 return nil, false
292 }
293 return *(h.x509CertThumbprint), true
294 case X509CertThumbprintS256Key:
295 if h.x509CertThumbprintS256 == nil {
296 return nil, false
297 }
298 return *(h.x509CertThumbprintS256), true
299 case X509URLKey:
300 if h.x509URL == nil {
301 return nil, false
302 }
303 return *(h.x509URL), true
304 default:
305 v, ok := h.privateParams[name]
306 return v, ok
307 }
308 }
309
310 func (h *stdHeaders) Set(name string, value interface{}) error {
311 h.mu.Lock()
312 defer h.mu.Unlock()
313 return h.setNoLock(name, value)
314 }
315
316 func (h *stdHeaders) setNoLock(name string, value interface{}) error {
317 switch name {
318 case AlgorithmKey:
319 var acceptor jwa.SignatureAlgorithm
320 if err := acceptor.Accept(value); err != nil {
321 return errors.Wrapf(err, `invalid value for %s key`, AlgorithmKey)
322 }
323 h.algorithm = &acceptor
324 return nil
325 case ContentTypeKey:
326 if v, ok := value.(string); ok {
327 h.contentType = &v
328 return nil
329 }
330 return errors.Errorf(`invalid value for %s key: %T`, ContentTypeKey, value)
331 case CriticalKey:
332 if v, ok := value.([]string); ok {
333 h.critical = v
334 return nil
335 }
336 return errors.Errorf(`invalid value for %s key: %T`, CriticalKey, value)
337 case JWKKey:
338 if v, ok := value.(jwk.Key); ok {
339 h.jwk = v
340 return nil
341 }
342 return errors.Errorf(`invalid value for %s key: %T`, JWKKey, value)
343 case JWKSetURLKey:
344 if v, ok := value.(string); ok {
345 h.jwkSetURL = &v
346 return nil
347 }
348 return errors.Errorf(`invalid value for %s key: %T`, JWKSetURLKey, value)
349 case KeyIDKey:
350 if v, ok := value.(string); ok {
351 h.keyID = &v
352 return nil
353 }
354 return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
355 case TypeKey:
356 if v, ok := value.(string); ok {
357 h.typ = &v
358 return nil
359 }
360 return errors.Errorf(`invalid value for %s key: %T`, TypeKey, value)
361 case X509CertChainKey:
362 if v, ok := value.([]string); ok {
363 h.x509CertChain = v
364 return nil
365 }
366 return errors.Errorf(`invalid value for %s key: %T`, X509CertChainKey, value)
367 case X509CertThumbprintKey:
368 if v, ok := value.(string); ok {
369 h.x509CertThumbprint = &v
370 return nil
371 }
372 return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintKey, value)
373 case X509CertThumbprintS256Key:
374 if v, ok := value.(string); ok {
375 h.x509CertThumbprintS256 = &v
376 return nil
377 }
378 return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintS256Key, value)
379 case X509URLKey:
380 if v, ok := value.(string); ok {
381 h.x509URL = &v
382 return nil
383 }
384 return errors.Errorf(`invalid value for %s key: %T`, X509URLKey, value)
385 default:
386 if h.privateParams == nil {
387 h.privateParams = map[string]interface{}{}
388 }
389 h.privateParams[name] = value
390 }
391 return nil
392 }
393
394 func (h *stdHeaders) Remove(key string) error {
395 h.mu.Lock()
396 defer h.mu.Unlock()
397 switch key {
398 case AlgorithmKey:
399 h.algorithm = nil
400 case ContentTypeKey:
401 h.contentType = nil
402 case CriticalKey:
403 h.critical = nil
404 case JWKKey:
405 h.jwk = nil
406 case JWKSetURLKey:
407 h.jwkSetURL = nil
408 case KeyIDKey:
409 h.keyID = nil
410 case TypeKey:
411 h.typ = nil
412 case X509CertChainKey:
413 h.x509CertChain = nil
414 case X509CertThumbprintKey:
415 h.x509CertThumbprint = nil
416 case X509CertThumbprintS256Key:
417 h.x509CertThumbprintS256 = nil
418 case X509URLKey:
419 h.x509URL = nil
420 default:
421 delete(h.privateParams, key)
422 }
423 return nil
424 }
425
426 func (h *stdHeaders) UnmarshalJSON(buf []byte) error {
427 h.algorithm = nil
428 h.contentType = nil
429 h.critical = nil
430 h.jwk = nil
431 h.jwkSetURL = nil
432 h.keyID = nil
433 h.typ = nil
434 h.x509CertChain = nil
435 h.x509CertThumbprint = nil
436 h.x509CertThumbprintS256 = nil
437 h.x509URL = nil
438 dec := json.NewDecoder(bytes.NewReader(buf))
439 LOOP:
440 for {
441 tok, err := dec.Token()
442 if err != nil {
443 return errors.Wrap(err, `error reading token`)
444 }
445 switch tok := tok.(type) {
446 case json.Delim:
447
448
449 if tok == '}' {
450 break LOOP
451 } else if tok != '{' {
452 return errors.Errorf(`expected '{', but got '%c'`, tok)
453 }
454 case string:
455 switch tok {
456 case AlgorithmKey:
457 var decoded jwa.SignatureAlgorithm
458 if err := dec.Decode(&decoded); err != nil {
459 return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey)
460 }
461 h.algorithm = &decoded
462 case ContentTypeKey:
463 if err := json.AssignNextStringToken(&h.contentType, dec); err != nil {
464 return errors.Wrapf(err, `failed to decode value for key %s`, ContentTypeKey)
465 }
466 case CriticalKey:
467 var decoded []string
468 if err := dec.Decode(&decoded); err != nil {
469 return errors.Wrapf(err, `failed to decode value for key %s`, CriticalKey)
470 }
471 h.critical = decoded
472 case JWKKey:
473 var buf json.RawMessage
474 if err := dec.Decode(&buf); err != nil {
475 return errors.Wrapf(err, `failed to decode value for key %s`, JWKKey)
476 }
477 key, err := jwk.ParseKey(buf)
478 if err != nil {
479 return errors.Wrapf(err, `failed to parse JWK for key %s`, JWKKey)
480 }
481 h.jwk = key
482 case JWKSetURLKey:
483 if err := json.AssignNextStringToken(&h.jwkSetURL, dec); err != nil {
484 return errors.Wrapf(err, `failed to decode value for key %s`, JWKSetURLKey)
485 }
486 case KeyIDKey:
487 if err := json.AssignNextStringToken(&h.keyID, dec); err != nil {
488 return errors.Wrapf(err, `failed to decode value for key %s`, KeyIDKey)
489 }
490 case TypeKey:
491 if err := json.AssignNextStringToken(&h.typ, dec); err != nil {
492 return errors.Wrapf(err, `failed to decode value for key %s`, TypeKey)
493 }
494 case X509CertChainKey:
495 var decoded []string
496 if err := dec.Decode(&decoded); err != nil {
497 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertChainKey)
498 }
499 h.x509CertChain = decoded
500 case X509CertThumbprintKey:
501 if err := json.AssignNextStringToken(&h.x509CertThumbprint, dec); err != nil {
502 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertThumbprintKey)
503 }
504 case X509CertThumbprintS256Key:
505 if err := json.AssignNextStringToken(&h.x509CertThumbprintS256, dec); err != nil {
506 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertThumbprintS256Key)
507 }
508 case X509URLKey:
509 if err := json.AssignNextStringToken(&h.x509URL, dec); err != nil {
510 return errors.Wrapf(err, `failed to decode value for key %s`, X509URLKey)
511 }
512 default:
513 decoded, err := registry.Decode(dec, tok)
514 if err != nil {
515 return err
516 }
517 h.setNoLock(tok, decoded)
518 }
519 default:
520 return errors.Errorf(`invalid token %T`, tok)
521 }
522 }
523
524 if dc := h.dc; dc != nil {
525 if dc.CollectRaw() {
526 h.raw = buf
527 }
528 }
529 return nil
530 }
531
532 func (h stdHeaders) MarshalJSON() ([]byte, error) {
533 buf := pool.GetBytesBuffer()
534 defer pool.ReleaseBytesBuffer(buf)
535 buf.WriteByte('{')
536 enc := json.NewEncoder(buf)
537 for i, p := range h.makePairs() {
538 if i > 0 {
539 buf.WriteRune(',')
540 }
541 buf.WriteRune('"')
542 buf.WriteString(p.Key.(string))
543 buf.WriteString(`":`)
544 v := p.Value
545 switch v := v.(type) {
546 case []byte:
547 buf.WriteRune('"')
548 buf.WriteString(base64.EncodeToString(v))
549 buf.WriteRune('"')
550 default:
551 if err := enc.Encode(v); err != nil {
552 errors.Errorf(`failed to encode value for field %s`, p.Key)
553 }
554 buf.Truncate(buf.Len() - 1)
555 }
556 }
557 buf.WriteByte('}')
558 ret := make([]byte, buf.Len())
559 copy(ret, buf.Bytes())
560 return ret, nil
561 }
562
View as plain text