1
2
3 package jwe
4
5 import (
6 "bytes"
7 "context"
8 "sort"
9 "sync"
10
11 "github.com/lestrrat-go/jwx/internal/base64"
12 "github.com/lestrrat-go/jwx/internal/json"
13 "github.com/lestrrat-go/jwx/internal/pool"
14 "github.com/lestrrat-go/jwx/jwa"
15 "github.com/lestrrat-go/jwx/jwk"
16 "github.com/pkg/errors"
17 )
18
19 const (
20 AgreementPartyUInfoKey = "apu"
21 AgreementPartyVInfoKey = "apv"
22 AlgorithmKey = "alg"
23 CompressionKey = "zip"
24 ContentEncryptionKey = "enc"
25 ContentTypeKey = "cty"
26 CriticalKey = "crit"
27 EphemeralPublicKeyKey = "epk"
28 JWKKey = "jwk"
29 JWKSetURLKey = "jku"
30 KeyIDKey = "kid"
31 TypeKey = "typ"
32 X509CertChainKey = "x5c"
33 X509CertThumbprintKey = "x5t"
34 X509CertThumbprintS256Key = "x5t#S256"
35 X509URLKey = "x5u"
36 )
37
38
39 type Headers interface {
40 json.Marshaler
41 json.Unmarshaler
42 AgreementPartyUInfo() []byte
43 AgreementPartyVInfo() []byte
44 Algorithm() jwa.KeyEncryptionAlgorithm
45 Compression() jwa.CompressionAlgorithm
46 ContentEncryption() jwa.ContentEncryptionAlgorithm
47 ContentType() string
48 Critical() []string
49 EphemeralPublicKey() jwk.Key
50 JWK() jwk.Key
51 JWKSetURL() string
52 KeyID() string
53 Type() string
54 X509CertChain() []string
55 X509CertThumbprint() string
56 X509CertThumbprintS256() string
57 X509URL() string
58 Iterate(ctx context.Context) Iterator
59 Walk(ctx context.Context, v Visitor) error
60 AsMap(ctx context.Context) (map[string]interface{}, error)
61 Get(string) (interface{}, bool)
62 Set(string, interface{}) error
63 Remove(string) error
64 Encode() ([]byte, error)
65 Decode([]byte) error
66
67
68
69
70 PrivateParams() map[string]interface{}
71 Clone(context.Context) (Headers, error)
72 Copy(context.Context, Headers) error
73 Merge(context.Context, Headers) (Headers, error)
74 }
75
76 type stdHeaders struct {
77 agreementPartyUInfo []byte
78 agreementPartyVInfo []byte
79 algorithm *jwa.KeyEncryptionAlgorithm
80 compression *jwa.CompressionAlgorithm
81 contentEncryption *jwa.ContentEncryptionAlgorithm
82 contentType *string
83 critical []string
84 ephemeralPublicKey jwk.Key
85 jwk jwk.Key
86 jwkSetURL *string
87 keyID *string
88 typ *string
89 x509CertChain []string
90 x509CertThumbprint *string
91 x509CertThumbprintS256 *string
92 x509URL *string
93 privateParams map[string]interface{}
94 mu *sync.RWMutex
95 }
96
97 func NewHeaders() Headers {
98 return &stdHeaders{
99 mu: &sync.RWMutex{},
100 privateParams: map[string]interface{}{},
101 }
102 }
103
104 func (h *stdHeaders) AgreementPartyUInfo() []byte {
105 h.mu.RLock()
106 defer h.mu.RUnlock()
107 return h.agreementPartyUInfo
108 }
109
110 func (h *stdHeaders) AgreementPartyVInfo() []byte {
111 h.mu.RLock()
112 defer h.mu.RUnlock()
113 return h.agreementPartyVInfo
114 }
115
116 func (h *stdHeaders) Algorithm() jwa.KeyEncryptionAlgorithm {
117 h.mu.RLock()
118 defer h.mu.RUnlock()
119 if h.algorithm == nil {
120 return ""
121 }
122 return *(h.algorithm)
123 }
124
125 func (h *stdHeaders) Compression() jwa.CompressionAlgorithm {
126 h.mu.RLock()
127 defer h.mu.RUnlock()
128 if h.compression == nil {
129 return jwa.NoCompress
130 }
131 return *(h.compression)
132 }
133
134 func (h *stdHeaders) ContentEncryption() jwa.ContentEncryptionAlgorithm {
135 h.mu.RLock()
136 defer h.mu.RUnlock()
137 if h.contentEncryption == nil {
138 return ""
139 }
140 return *(h.contentEncryption)
141 }
142
143 func (h *stdHeaders) ContentType() string {
144 h.mu.RLock()
145 defer h.mu.RUnlock()
146 if h.contentType == nil {
147 return ""
148 }
149 return *(h.contentType)
150 }
151
152 func (h *stdHeaders) Critical() []string {
153 h.mu.RLock()
154 defer h.mu.RUnlock()
155 return h.critical
156 }
157
158 func (h *stdHeaders) EphemeralPublicKey() jwk.Key {
159 h.mu.RLock()
160 defer h.mu.RUnlock()
161 return h.ephemeralPublicKey
162 }
163
164 func (h *stdHeaders) JWK() jwk.Key {
165 h.mu.RLock()
166 defer h.mu.RUnlock()
167 return h.jwk
168 }
169
170 func (h *stdHeaders) JWKSetURL() string {
171 h.mu.RLock()
172 defer h.mu.RUnlock()
173 if h.jwkSetURL == nil {
174 return ""
175 }
176 return *(h.jwkSetURL)
177 }
178
179 func (h *stdHeaders) KeyID() string {
180 h.mu.RLock()
181 defer h.mu.RUnlock()
182 if h.keyID == nil {
183 return ""
184 }
185 return *(h.keyID)
186 }
187
188 func (h *stdHeaders) Type() string {
189 h.mu.RLock()
190 defer h.mu.RUnlock()
191 if h.typ == nil {
192 return ""
193 }
194 return *(h.typ)
195 }
196
197 func (h *stdHeaders) X509CertChain() []string {
198 h.mu.RLock()
199 defer h.mu.RUnlock()
200 return h.x509CertChain
201 }
202
203 func (h *stdHeaders) X509CertThumbprint() string {
204 h.mu.RLock()
205 defer h.mu.RUnlock()
206 if h.x509CertThumbprint == nil {
207 return ""
208 }
209 return *(h.x509CertThumbprint)
210 }
211
212 func (h *stdHeaders) X509CertThumbprintS256() string {
213 h.mu.RLock()
214 defer h.mu.RUnlock()
215 if h.x509CertThumbprintS256 == nil {
216 return ""
217 }
218 return *(h.x509CertThumbprintS256)
219 }
220
221 func (h *stdHeaders) X509URL() string {
222 h.mu.RLock()
223 defer h.mu.RUnlock()
224 if h.x509URL == nil {
225 return ""
226 }
227 return *(h.x509URL)
228 }
229
230 func (h *stdHeaders) makePairs() []*HeaderPair {
231 h.mu.RLock()
232 defer h.mu.RUnlock()
233 var pairs []*HeaderPair
234 if h.agreementPartyUInfo != nil {
235 pairs = append(pairs, &HeaderPair{Key: AgreementPartyUInfoKey, Value: h.agreementPartyUInfo})
236 }
237 if h.agreementPartyVInfo != nil {
238 pairs = append(pairs, &HeaderPair{Key: AgreementPartyVInfoKey, Value: h.agreementPartyVInfo})
239 }
240 if h.algorithm != nil {
241 pairs = append(pairs, &HeaderPair{Key: AlgorithmKey, Value: *(h.algorithm)})
242 }
243 if h.compression != nil {
244 pairs = append(pairs, &HeaderPair{Key: CompressionKey, Value: *(h.compression)})
245 }
246 if h.contentEncryption != nil {
247 pairs = append(pairs, &HeaderPair{Key: ContentEncryptionKey, Value: *(h.contentEncryption)})
248 }
249 if h.contentType != nil {
250 pairs = append(pairs, &HeaderPair{Key: ContentTypeKey, Value: *(h.contentType)})
251 }
252 if h.critical != nil {
253 pairs = append(pairs, &HeaderPair{Key: CriticalKey, Value: h.critical})
254 }
255 if h.ephemeralPublicKey != nil {
256 pairs = append(pairs, &HeaderPair{Key: EphemeralPublicKeyKey, Value: h.ephemeralPublicKey})
257 }
258 if h.jwk != nil {
259 pairs = append(pairs, &HeaderPair{Key: JWKKey, Value: h.jwk})
260 }
261 if h.jwkSetURL != nil {
262 pairs = append(pairs, &HeaderPair{Key: JWKSetURLKey, Value: *(h.jwkSetURL)})
263 }
264 if h.keyID != nil {
265 pairs = append(pairs, &HeaderPair{Key: KeyIDKey, Value: *(h.keyID)})
266 }
267 if h.typ != nil {
268 pairs = append(pairs, &HeaderPair{Key: TypeKey, Value: *(h.typ)})
269 }
270 if h.x509CertChain != nil {
271 pairs = append(pairs, &HeaderPair{Key: X509CertChainKey, Value: h.x509CertChain})
272 }
273 if h.x509CertThumbprint != nil {
274 pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintKey, Value: *(h.x509CertThumbprint)})
275 }
276 if h.x509CertThumbprintS256 != nil {
277 pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintS256Key, Value: *(h.x509CertThumbprintS256)})
278 }
279 if h.x509URL != nil {
280 pairs = append(pairs, &HeaderPair{Key: X509URLKey, Value: *(h.x509URL)})
281 }
282 for k, v := range h.privateParams {
283 pairs = append(pairs, &HeaderPair{Key: k, Value: v})
284 }
285 return pairs
286 }
287
288 func (h *stdHeaders) PrivateParams() map[string]interface{} {
289 h.mu.RLock()
290 defer h.mu.RUnlock()
291 return h.privateParams
292 }
293
294 func (h *stdHeaders) Get(name string) (interface{}, bool) {
295 h.mu.RLock()
296 defer h.mu.RUnlock()
297 switch name {
298 case AgreementPartyUInfoKey:
299 if h.agreementPartyUInfo == nil {
300 return nil, false
301 }
302 return h.agreementPartyUInfo, true
303 case AgreementPartyVInfoKey:
304 if h.agreementPartyVInfo == nil {
305 return nil, false
306 }
307 return h.agreementPartyVInfo, true
308 case AlgorithmKey:
309 if h.algorithm == nil {
310 return nil, false
311 }
312 return *(h.algorithm), true
313 case CompressionKey:
314 if h.compression == nil {
315 return nil, false
316 }
317 return *(h.compression), true
318 case ContentEncryptionKey:
319 if h.contentEncryption == nil {
320 return nil, false
321 }
322 return *(h.contentEncryption), true
323 case ContentTypeKey:
324 if h.contentType == nil {
325 return nil, false
326 }
327 return *(h.contentType), true
328 case CriticalKey:
329 if h.critical == nil {
330 return nil, false
331 }
332 return h.critical, true
333 case EphemeralPublicKeyKey:
334 if h.ephemeralPublicKey == nil {
335 return nil, false
336 }
337 return h.ephemeralPublicKey, true
338 case JWKKey:
339 if h.jwk == nil {
340 return nil, false
341 }
342 return h.jwk, true
343 case JWKSetURLKey:
344 if h.jwkSetURL == nil {
345 return nil, false
346 }
347 return *(h.jwkSetURL), true
348 case KeyIDKey:
349 if h.keyID == nil {
350 return nil, false
351 }
352 return *(h.keyID), true
353 case TypeKey:
354 if h.typ == nil {
355 return nil, false
356 }
357 return *(h.typ), true
358 case X509CertChainKey:
359 if h.x509CertChain == nil {
360 return nil, false
361 }
362 return h.x509CertChain, true
363 case X509CertThumbprintKey:
364 if h.x509CertThumbprint == nil {
365 return nil, false
366 }
367 return *(h.x509CertThumbprint), true
368 case X509CertThumbprintS256Key:
369 if h.x509CertThumbprintS256 == nil {
370 return nil, false
371 }
372 return *(h.x509CertThumbprintS256), true
373 case X509URLKey:
374 if h.x509URL == nil {
375 return nil, false
376 }
377 return *(h.x509URL), true
378 default:
379 v, ok := h.privateParams[name]
380 return v, ok
381 }
382 }
383
384 func (h *stdHeaders) Set(name string, value interface{}) error {
385 h.mu.Lock()
386 defer h.mu.Unlock()
387 return h.setNoLock(name, value)
388 }
389
390 func (h *stdHeaders) setNoLock(name string, value interface{}) error {
391 switch name {
392 case AgreementPartyUInfoKey:
393 if v, ok := value.([]byte); ok {
394 h.agreementPartyUInfo = v
395 return nil
396 }
397 return errors.Errorf(`invalid value for %s key: %T`, AgreementPartyUInfoKey, value)
398 case AgreementPartyVInfoKey:
399 if v, ok := value.([]byte); ok {
400 h.agreementPartyVInfo = v
401 return nil
402 }
403 return errors.Errorf(`invalid value for %s key: %T`, AgreementPartyVInfoKey, value)
404 case AlgorithmKey:
405 if v, ok := value.(jwa.KeyEncryptionAlgorithm); ok {
406 h.algorithm = &v
407 return nil
408 }
409 return errors.Errorf(`invalid value for %s key: %T`, AlgorithmKey, value)
410 case CompressionKey:
411 if v, ok := value.(jwa.CompressionAlgorithm); ok {
412 h.compression = &v
413 return nil
414 }
415 return errors.Errorf(`invalid value for %s key: %T`, CompressionKey, value)
416 case ContentEncryptionKey:
417 if v, ok := value.(jwa.ContentEncryptionAlgorithm); ok {
418 if v == "" {
419 return errors.New(`"enc" field cannot be an empty string`)
420 }
421 h.contentEncryption = &v
422 return nil
423 }
424 return errors.Errorf(`invalid value for %s key: %T`, ContentEncryptionKey, value)
425 case ContentTypeKey:
426 if v, ok := value.(string); ok {
427 h.contentType = &v
428 return nil
429 }
430 return errors.Errorf(`invalid value for %s key: %T`, ContentTypeKey, value)
431 case CriticalKey:
432 if v, ok := value.([]string); ok {
433 h.critical = v
434 return nil
435 }
436 return errors.Errorf(`invalid value for %s key: %T`, CriticalKey, value)
437 case EphemeralPublicKeyKey:
438 if v, ok := value.(jwk.Key); ok {
439 h.ephemeralPublicKey = v
440 return nil
441 }
442 return errors.Errorf(`invalid value for %s key: %T`, EphemeralPublicKeyKey, value)
443 case JWKKey:
444 if v, ok := value.(jwk.Key); ok {
445 h.jwk = v
446 return nil
447 }
448 return errors.Errorf(`invalid value for %s key: %T`, JWKKey, value)
449 case JWKSetURLKey:
450 if v, ok := value.(string); ok {
451 h.jwkSetURL = &v
452 return nil
453 }
454 return errors.Errorf(`invalid value for %s key: %T`, JWKSetURLKey, value)
455 case KeyIDKey:
456 if v, ok := value.(string); ok {
457 h.keyID = &v
458 return nil
459 }
460 return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
461 case TypeKey:
462 if v, ok := value.(string); ok {
463 h.typ = &v
464 return nil
465 }
466 return errors.Errorf(`invalid value for %s key: %T`, TypeKey, value)
467 case X509CertChainKey:
468 if v, ok := value.([]string); ok {
469 h.x509CertChain = v
470 return nil
471 }
472 return errors.Errorf(`invalid value for %s key: %T`, X509CertChainKey, value)
473 case X509CertThumbprintKey:
474 if v, ok := value.(string); ok {
475 h.x509CertThumbprint = &v
476 return nil
477 }
478 return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintKey, value)
479 case X509CertThumbprintS256Key:
480 if v, ok := value.(string); ok {
481 h.x509CertThumbprintS256 = &v
482 return nil
483 }
484 return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintS256Key, value)
485 case X509URLKey:
486 if v, ok := value.(string); ok {
487 h.x509URL = &v
488 return nil
489 }
490 return errors.Errorf(`invalid value for %s key: %T`, X509URLKey, value)
491 default:
492 if h.privateParams == nil {
493 h.privateParams = map[string]interface{}{}
494 }
495 h.privateParams[name] = value
496 }
497 return nil
498 }
499
500 func (h *stdHeaders) Remove(key string) error {
501 h.mu.Lock()
502 defer h.mu.Unlock()
503 switch key {
504 case AgreementPartyUInfoKey:
505 h.agreementPartyUInfo = nil
506 case AgreementPartyVInfoKey:
507 h.agreementPartyVInfo = nil
508 case AlgorithmKey:
509 h.algorithm = nil
510 case CompressionKey:
511 h.compression = nil
512 case ContentEncryptionKey:
513 h.contentEncryption = nil
514 case ContentTypeKey:
515 h.contentType = nil
516 case CriticalKey:
517 h.critical = nil
518 case EphemeralPublicKeyKey:
519 h.ephemeralPublicKey = nil
520 case JWKKey:
521 h.jwk = nil
522 case JWKSetURLKey:
523 h.jwkSetURL = nil
524 case KeyIDKey:
525 h.keyID = nil
526 case TypeKey:
527 h.typ = nil
528 case X509CertChainKey:
529 h.x509CertChain = nil
530 case X509CertThumbprintKey:
531 h.x509CertThumbprint = nil
532 case X509CertThumbprintS256Key:
533 h.x509CertThumbprintS256 = nil
534 case X509URLKey:
535 h.x509URL = nil
536 default:
537 delete(h.privateParams, key)
538 }
539 return nil
540 }
541
542 func (h *stdHeaders) UnmarshalJSON(buf []byte) error {
543 h.agreementPartyUInfo = nil
544 h.agreementPartyVInfo = nil
545 h.algorithm = nil
546 h.compression = nil
547 h.contentEncryption = nil
548 h.contentType = nil
549 h.critical = nil
550 h.ephemeralPublicKey = nil
551 h.jwk = nil
552 h.jwkSetURL = nil
553 h.keyID = nil
554 h.typ = nil
555 h.x509CertChain = nil
556 h.x509CertThumbprint = nil
557 h.x509CertThumbprintS256 = nil
558 h.x509URL = nil
559 dec := json.NewDecoder(bytes.NewReader(buf))
560 LOOP:
561 for {
562 tok, err := dec.Token()
563 if err != nil {
564 return errors.Wrap(err, `error reading token`)
565 }
566 switch tok := tok.(type) {
567 case json.Delim:
568
569
570 if tok == '}' {
571 break LOOP
572 } else if tok != '{' {
573 return errors.Errorf(`expected '{', but got '%c'`, tok)
574 }
575 case string:
576 switch tok {
577 case AgreementPartyUInfoKey:
578 if err := json.AssignNextBytesToken(&h.agreementPartyUInfo, dec); err != nil {
579 return errors.Wrapf(err, `failed to decode value for key %s`, AgreementPartyUInfoKey)
580 }
581 case AgreementPartyVInfoKey:
582 if err := json.AssignNextBytesToken(&h.agreementPartyVInfo, dec); err != nil {
583 return errors.Wrapf(err, `failed to decode value for key %s`, AgreementPartyVInfoKey)
584 }
585 case AlgorithmKey:
586 var decoded jwa.KeyEncryptionAlgorithm
587 if err := dec.Decode(&decoded); err != nil {
588 return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey)
589 }
590 h.algorithm = &decoded
591 case CompressionKey:
592 var decoded jwa.CompressionAlgorithm
593 if err := dec.Decode(&decoded); err != nil {
594 return errors.Wrapf(err, `failed to decode value for key %s`, CompressionKey)
595 }
596 h.compression = &decoded
597 case ContentEncryptionKey:
598 var decoded jwa.ContentEncryptionAlgorithm
599 if err := dec.Decode(&decoded); err != nil {
600 return errors.Wrapf(err, `failed to decode value for key %s`, ContentEncryptionKey)
601 }
602 h.contentEncryption = &decoded
603 case ContentTypeKey:
604 if err := json.AssignNextStringToken(&h.contentType, dec); err != nil {
605 return errors.Wrapf(err, `failed to decode value for key %s`, ContentTypeKey)
606 }
607 case CriticalKey:
608 var decoded []string
609 if err := dec.Decode(&decoded); err != nil {
610 return errors.Wrapf(err, `failed to decode value for key %s`, CriticalKey)
611 }
612 h.critical = decoded
613 case EphemeralPublicKeyKey:
614 var buf json.RawMessage
615 if err := dec.Decode(&buf); err != nil {
616 return errors.Wrapf(err, `failed to decode value for key %s`, EphemeralPublicKeyKey)
617 }
618 key, err := jwk.ParseKey(buf)
619 if err != nil {
620 return errors.Wrapf(err, `failed to parse JWK for key %s`, EphemeralPublicKeyKey)
621 }
622 h.ephemeralPublicKey = key
623 case JWKKey:
624 var buf json.RawMessage
625 if err := dec.Decode(&buf); err != nil {
626 return errors.Wrapf(err, `failed to decode value for key %s`, JWKKey)
627 }
628 key, err := jwk.ParseKey(buf)
629 if err != nil {
630 return errors.Wrapf(err, `failed to parse JWK for key %s`, JWKKey)
631 }
632 h.jwk = key
633 case JWKSetURLKey:
634 if err := json.AssignNextStringToken(&h.jwkSetURL, dec); err != nil {
635 return errors.Wrapf(err, `failed to decode value for key %s`, JWKSetURLKey)
636 }
637 case KeyIDKey:
638 if err := json.AssignNextStringToken(&h.keyID, dec); err != nil {
639 return errors.Wrapf(err, `failed to decode value for key %s`, KeyIDKey)
640 }
641 case TypeKey:
642 if err := json.AssignNextStringToken(&h.typ, dec); err != nil {
643 return errors.Wrapf(err, `failed to decode value for key %s`, TypeKey)
644 }
645 case X509CertChainKey:
646 var decoded []string
647 if err := dec.Decode(&decoded); err != nil {
648 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertChainKey)
649 }
650 h.x509CertChain = decoded
651 case X509CertThumbprintKey:
652 if err := json.AssignNextStringToken(&h.x509CertThumbprint, dec); err != nil {
653 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertThumbprintKey)
654 }
655 case X509CertThumbprintS256Key:
656 if err := json.AssignNextStringToken(&h.x509CertThumbprintS256, dec); err != nil {
657 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertThumbprintS256Key)
658 }
659 case X509URLKey:
660 if err := json.AssignNextStringToken(&h.x509URL, dec); err != nil {
661 return errors.Wrapf(err, `failed to decode value for key %s`, X509URLKey)
662 }
663 default:
664 decoded, err := registry.Decode(dec, tok)
665 if err != nil {
666 return err
667 }
668 h.setNoLock(tok, decoded)
669 }
670 default:
671 return errors.Errorf(`invalid token %T`, tok)
672 }
673 }
674 return nil
675 }
676
677 func (h stdHeaders) MarshalJSON() ([]byte, error) {
678 data := make(map[string]interface{})
679 fields := make([]string, 0, 16)
680 for _, pair := range h.makePairs() {
681 fields = append(fields, pair.Key.(string))
682 data[pair.Key.(string)] = pair.Value
683 }
684
685 sort.Strings(fields)
686 buf := pool.GetBytesBuffer()
687 defer pool.ReleaseBytesBuffer(buf)
688 buf.WriteByte('{')
689 enc := json.NewEncoder(buf)
690 for i, f := range fields {
691 if i > 0 {
692 buf.WriteRune(',')
693 }
694 buf.WriteRune('"')
695 buf.WriteString(f)
696 buf.WriteString(`":`)
697 v := data[f]
698 switch v := v.(type) {
699 case []byte:
700 buf.WriteRune('"')
701 buf.WriteString(base64.EncodeToString(v))
702 buf.WriteRune('"')
703 default:
704 if err := enc.Encode(v); err != nil {
705 errors.Errorf(`failed to encode value for field %s`, f)
706 }
707 buf.Truncate(buf.Len() - 1)
708 }
709 }
710 buf.WriteByte('}')
711 ret := make([]byte, buf.Len())
712 copy(ret, buf.Bytes())
713 return ret, nil
714 }
715
View as plain text