1
2
3 package jwk
4
5 import (
6 "bytes"
7 "context"
8 "crypto/x509"
9 "fmt"
10 "sort"
11 "sync"
12
13 "github.com/lestrrat-go/iter/mapiter"
14 "github.com/lestrrat-go/jwx/internal/base64"
15 "github.com/lestrrat-go/jwx/internal/iter"
16 "github.com/lestrrat-go/jwx/internal/json"
17 "github.com/lestrrat-go/jwx/internal/pool"
18 "github.com/lestrrat-go/jwx/jwa"
19 "github.com/pkg/errors"
20 )
21
22 const (
23 OKPCrvKey = "crv"
24 OKPDKey = "d"
25 OKPXKey = "x"
26 )
27
28 type OKPPublicKey interface {
29 Key
30 FromRaw(interface{}) error
31 Crv() jwa.EllipticCurveAlgorithm
32 X() []byte
33 }
34
35 type okpPublicKey struct {
36 algorithm *string
37 crv *jwa.EllipticCurveAlgorithm
38 keyID *string
39 keyOps *KeyOperationList
40 keyUsage *string
41 x []byte
42 x509CertChain *CertificateChain
43 x509CertThumbprint *string
44 x509CertThumbprintS256 *string
45 x509URL *string
46 privateParams map[string]interface{}
47 mu *sync.RWMutex
48 dc json.DecodeCtx
49 }
50
51 func NewOKPPublicKey() OKPPublicKey {
52 return newOKPPublicKey()
53 }
54
55 func newOKPPublicKey() *okpPublicKey {
56 return &okpPublicKey{
57 mu: &sync.RWMutex{},
58 privateParams: make(map[string]interface{}),
59 }
60 }
61
62 func (h okpPublicKey) KeyType() jwa.KeyType {
63 return jwa.OKP
64 }
65
66 func (h *okpPublicKey) Algorithm() string {
67 if h.algorithm != nil {
68 return *(h.algorithm)
69 }
70 return ""
71 }
72
73 func (h *okpPublicKey) Crv() jwa.EllipticCurveAlgorithm {
74 if h.crv != nil {
75 return *(h.crv)
76 }
77 return jwa.InvalidEllipticCurve
78 }
79
80 func (h *okpPublicKey) KeyID() string {
81 if h.keyID != nil {
82 return *(h.keyID)
83 }
84 return ""
85 }
86
87 func (h *okpPublicKey) KeyOps() KeyOperationList {
88 if h.keyOps != nil {
89 return *(h.keyOps)
90 }
91 return nil
92 }
93
94 func (h *okpPublicKey) KeyUsage() string {
95 if h.keyUsage != nil {
96 return *(h.keyUsage)
97 }
98 return ""
99 }
100
101 func (h *okpPublicKey) X() []byte {
102 return h.x
103 }
104
105 func (h *okpPublicKey) X509CertChain() []*x509.Certificate {
106 if h.x509CertChain != nil {
107 return h.x509CertChain.Get()
108 }
109 return nil
110 }
111
112 func (h *okpPublicKey) X509CertThumbprint() string {
113 if h.x509CertThumbprint != nil {
114 return *(h.x509CertThumbprint)
115 }
116 return ""
117 }
118
119 func (h *okpPublicKey) X509CertThumbprintS256() string {
120 if h.x509CertThumbprintS256 != nil {
121 return *(h.x509CertThumbprintS256)
122 }
123 return ""
124 }
125
126 func (h *okpPublicKey) X509URL() string {
127 if h.x509URL != nil {
128 return *(h.x509URL)
129 }
130 return ""
131 }
132
133 func (h *okpPublicKey) makePairs() []*HeaderPair {
134 h.mu.RLock()
135 defer h.mu.RUnlock()
136
137 var pairs []*HeaderPair
138 pairs = append(pairs, &HeaderPair{Key: "kty", Value: jwa.OKP})
139 if h.algorithm != nil {
140 pairs = append(pairs, &HeaderPair{Key: AlgorithmKey, Value: *(h.algorithm)})
141 }
142 if h.crv != nil {
143 pairs = append(pairs, &HeaderPair{Key: OKPCrvKey, Value: *(h.crv)})
144 }
145 if h.keyID != nil {
146 pairs = append(pairs, &HeaderPair{Key: KeyIDKey, Value: *(h.keyID)})
147 }
148 if h.keyOps != nil {
149 pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: *(h.keyOps)})
150 }
151 if h.keyUsage != nil {
152 pairs = append(pairs, &HeaderPair{Key: KeyUsageKey, Value: *(h.keyUsage)})
153 }
154 if h.x != nil {
155 pairs = append(pairs, &HeaderPair{Key: OKPXKey, Value: h.x})
156 }
157 if h.x509CertChain != nil {
158 pairs = append(pairs, &HeaderPair{Key: X509CertChainKey, Value: *(h.x509CertChain)})
159 }
160 if h.x509CertThumbprint != nil {
161 pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintKey, Value: *(h.x509CertThumbprint)})
162 }
163 if h.x509CertThumbprintS256 != nil {
164 pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintS256Key, Value: *(h.x509CertThumbprintS256)})
165 }
166 if h.x509URL != nil {
167 pairs = append(pairs, &HeaderPair{Key: X509URLKey, Value: *(h.x509URL)})
168 }
169 for k, v := range h.privateParams {
170 pairs = append(pairs, &HeaderPair{Key: k, Value: v})
171 }
172 return pairs
173 }
174
175 func (h *okpPublicKey) PrivateParams() map[string]interface{} {
176 return h.privateParams
177 }
178
179 func (h *okpPublicKey) Get(name string) (interface{}, bool) {
180 h.mu.RLock()
181 defer h.mu.RUnlock()
182 switch name {
183 case KeyTypeKey:
184 return h.KeyType(), true
185 case AlgorithmKey:
186 if h.algorithm == nil {
187 return nil, false
188 }
189 return *(h.algorithm), true
190 case OKPCrvKey:
191 if h.crv == nil {
192 return nil, false
193 }
194 return *(h.crv), true
195 case KeyIDKey:
196 if h.keyID == nil {
197 return nil, false
198 }
199 return *(h.keyID), true
200 case KeyOpsKey:
201 if h.keyOps == nil {
202 return nil, false
203 }
204 return *(h.keyOps), true
205 case KeyUsageKey:
206 if h.keyUsage == nil {
207 return nil, false
208 }
209 return *(h.keyUsage), true
210 case OKPXKey:
211 if h.x == nil {
212 return nil, false
213 }
214 return h.x, true
215 case X509CertChainKey:
216 if h.x509CertChain == nil {
217 return nil, false
218 }
219 return h.x509CertChain.Get(), true
220 case X509CertThumbprintKey:
221 if h.x509CertThumbprint == nil {
222 return nil, false
223 }
224 return *(h.x509CertThumbprint), true
225 case X509CertThumbprintS256Key:
226 if h.x509CertThumbprintS256 == nil {
227 return nil, false
228 }
229 return *(h.x509CertThumbprintS256), true
230 case X509URLKey:
231 if h.x509URL == nil {
232 return nil, false
233 }
234 return *(h.x509URL), true
235 default:
236 v, ok := h.privateParams[name]
237 return v, ok
238 }
239 }
240
241 func (h *okpPublicKey) Set(name string, value interface{}) error {
242 h.mu.Lock()
243 defer h.mu.Unlock()
244 return h.setNoLock(name, value)
245 }
246
247 func (h *okpPublicKey) setNoLock(name string, value interface{}) error {
248 switch name {
249 case "kty":
250 return nil
251 case AlgorithmKey:
252 switch v := value.(type) {
253 case string:
254 h.algorithm = &v
255 case fmt.Stringer:
256 tmp := v.String()
257 h.algorithm = &tmp
258 default:
259 return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value)
260 }
261 return nil
262 case OKPCrvKey:
263 if v, ok := value.(jwa.EllipticCurveAlgorithm); ok {
264 h.crv = &v
265 return nil
266 }
267 return errors.Errorf(`invalid value for %s key: %T`, OKPCrvKey, value)
268 case KeyIDKey:
269 if v, ok := value.(string); ok {
270 h.keyID = &v
271 return nil
272 }
273 return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
274 case KeyOpsKey:
275 var acceptor KeyOperationList
276 if err := acceptor.Accept(value); err != nil {
277 return errors.Wrapf(err, `invalid value for %s key`, KeyOpsKey)
278 }
279 h.keyOps = &acceptor
280 return nil
281 case KeyUsageKey:
282 switch v := value.(type) {
283 case KeyUsageType:
284 switch v {
285 case ForSignature, ForEncryption:
286 tmp := v.String()
287 h.keyUsage = &tmp
288 default:
289 return errors.Errorf(`invalid key usage type %s`, v)
290 }
291 case string:
292 h.keyUsage = &v
293 default:
294 return errors.Errorf(`invalid key usage type %s`, v)
295 }
296 case OKPXKey:
297 if v, ok := value.([]byte); ok {
298 h.x = v
299 return nil
300 }
301 return errors.Errorf(`invalid value for %s key: %T`, OKPXKey, value)
302 case X509CertChainKey:
303 var acceptor CertificateChain
304 if err := acceptor.Accept(value); err != nil {
305 return errors.Wrapf(err, `invalid value for %s key`, X509CertChainKey)
306 }
307 h.x509CertChain = &acceptor
308 return nil
309 case X509CertThumbprintKey:
310 if v, ok := value.(string); ok {
311 h.x509CertThumbprint = &v
312 return nil
313 }
314 return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintKey, value)
315 case X509CertThumbprintS256Key:
316 if v, ok := value.(string); ok {
317 h.x509CertThumbprintS256 = &v
318 return nil
319 }
320 return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintS256Key, value)
321 case X509URLKey:
322 if v, ok := value.(string); ok {
323 h.x509URL = &v
324 return nil
325 }
326 return errors.Errorf(`invalid value for %s key: %T`, X509URLKey, value)
327 default:
328 if h.privateParams == nil {
329 h.privateParams = map[string]interface{}{}
330 }
331 h.privateParams[name] = value
332 }
333 return nil
334 }
335
336 func (k *okpPublicKey) Remove(key string) error {
337 k.mu.Lock()
338 defer k.mu.Unlock()
339 switch key {
340 case AlgorithmKey:
341 k.algorithm = nil
342 case OKPCrvKey:
343 k.crv = nil
344 case KeyIDKey:
345 k.keyID = nil
346 case KeyOpsKey:
347 k.keyOps = nil
348 case KeyUsageKey:
349 k.keyUsage = nil
350 case OKPXKey:
351 k.x = nil
352 case X509CertChainKey:
353 k.x509CertChain = nil
354 case X509CertThumbprintKey:
355 k.x509CertThumbprint = nil
356 case X509CertThumbprintS256Key:
357 k.x509CertThumbprintS256 = nil
358 case X509URLKey:
359 k.x509URL = nil
360 default:
361 delete(k.privateParams, key)
362 }
363 return nil
364 }
365
366 func (k *okpPublicKey) Clone() (Key, error) {
367 return cloneKey(k)
368 }
369
370 func (k *okpPublicKey) DecodeCtx() json.DecodeCtx {
371 k.mu.RLock()
372 defer k.mu.RUnlock()
373 return k.dc
374 }
375
376 func (k *okpPublicKey) SetDecodeCtx(dc json.DecodeCtx) {
377 k.mu.Lock()
378 defer k.mu.Unlock()
379 k.dc = dc
380 }
381
382 func (h *okpPublicKey) UnmarshalJSON(buf []byte) error {
383 h.algorithm = nil
384 h.crv = nil
385 h.keyID = nil
386 h.keyOps = nil
387 h.keyUsage = nil
388 h.x = nil
389 h.x509CertChain = nil
390 h.x509CertThumbprint = nil
391 h.x509CertThumbprintS256 = nil
392 h.x509URL = nil
393 dec := json.NewDecoder(bytes.NewReader(buf))
394 LOOP:
395 for {
396 tok, err := dec.Token()
397 if err != nil {
398 return errors.Wrap(err, `error reading token`)
399 }
400 switch tok := tok.(type) {
401 case json.Delim:
402
403
404 if tok == '}' {
405 break LOOP
406 } else if tok != '{' {
407 return errors.Errorf(`expected '{', but got '%c'`, tok)
408 }
409 case string:
410 switch tok {
411 case KeyTypeKey:
412 val, err := json.ReadNextStringToken(dec)
413 if err != nil {
414 return errors.Wrap(err, `error reading token`)
415 }
416 if val != jwa.OKP.String() {
417 return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val)
418 }
419 case AlgorithmKey:
420 if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil {
421 return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey)
422 }
423 case OKPCrvKey:
424 var decoded jwa.EllipticCurveAlgorithm
425 if err := dec.Decode(&decoded); err != nil {
426 return errors.Wrapf(err, `failed to decode value for key %s`, OKPCrvKey)
427 }
428 h.crv = &decoded
429 case KeyIDKey:
430 if err := json.AssignNextStringToken(&h.keyID, dec); err != nil {
431 return errors.Wrapf(err, `failed to decode value for key %s`, KeyIDKey)
432 }
433 case KeyOpsKey:
434 var decoded KeyOperationList
435 if err := dec.Decode(&decoded); err != nil {
436 return errors.Wrapf(err, `failed to decode value for key %s`, KeyOpsKey)
437 }
438 h.keyOps = &decoded
439 case KeyUsageKey:
440 if err := json.AssignNextStringToken(&h.keyUsage, dec); err != nil {
441 return errors.Wrapf(err, `failed to decode value for key %s`, KeyUsageKey)
442 }
443 case OKPXKey:
444 if err := json.AssignNextBytesToken(&h.x, dec); err != nil {
445 return errors.Wrapf(err, `failed to decode value for key %s`, OKPXKey)
446 }
447 case X509CertChainKey:
448 var decoded CertificateChain
449 if err := dec.Decode(&decoded); err != nil {
450 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertChainKey)
451 }
452 h.x509CertChain = &decoded
453 case X509CertThumbprintKey:
454 if err := json.AssignNextStringToken(&h.x509CertThumbprint, dec); err != nil {
455 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertThumbprintKey)
456 }
457 case X509CertThumbprintS256Key:
458 if err := json.AssignNextStringToken(&h.x509CertThumbprintS256, dec); err != nil {
459 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertThumbprintS256Key)
460 }
461 case X509URLKey:
462 if err := json.AssignNextStringToken(&h.x509URL, dec); err != nil {
463 return errors.Wrapf(err, `failed to decode value for key %s`, X509URLKey)
464 }
465 default:
466 if dc := h.dc; dc != nil {
467 if localReg := dc.Registry(); localReg != nil {
468 decoded, err := localReg.Decode(dec, tok)
469 if err == nil {
470 h.setNoLock(tok, decoded)
471 continue
472 }
473 }
474 }
475 decoded, err := registry.Decode(dec, tok)
476 if err == nil {
477 h.setNoLock(tok, decoded)
478 continue
479 }
480 return errors.Wrapf(err, `could not decode field %s`, tok)
481 }
482 default:
483 return errors.Errorf(`invalid token %T`, tok)
484 }
485 }
486 if h.crv == nil {
487 return errors.Errorf(`required field crv is missing`)
488 }
489 if h.x == nil {
490 return errors.Errorf(`required field x is missing`)
491 }
492 return nil
493 }
494
495 func (h okpPublicKey) MarshalJSON() ([]byte, error) {
496 data := make(map[string]interface{})
497 fields := make([]string, 0, 10)
498 for _, pair := range h.makePairs() {
499 fields = append(fields, pair.Key.(string))
500 data[pair.Key.(string)] = pair.Value
501 }
502
503 sort.Strings(fields)
504 buf := pool.GetBytesBuffer()
505 defer pool.ReleaseBytesBuffer(buf)
506 buf.WriteByte('{')
507 enc := json.NewEncoder(buf)
508 for i, f := range fields {
509 if i > 0 {
510 buf.WriteRune(',')
511 }
512 buf.WriteRune('"')
513 buf.WriteString(f)
514 buf.WriteString(`":`)
515 v := data[f]
516 switch v := v.(type) {
517 case []byte:
518 buf.WriteRune('"')
519 buf.WriteString(base64.EncodeToString(v))
520 buf.WriteRune('"')
521 default:
522 if err := enc.Encode(v); err != nil {
523 return nil, errors.Wrapf(err, `failed to encode value for field %s`, f)
524 }
525 buf.Truncate(buf.Len() - 1)
526 }
527 }
528 buf.WriteByte('}')
529 ret := make([]byte, buf.Len())
530 copy(ret, buf.Bytes())
531 return ret, nil
532 }
533
534 func (h *okpPublicKey) Iterate(ctx context.Context) HeaderIterator {
535 pairs := h.makePairs()
536 ch := make(chan *HeaderPair, len(pairs))
537 go func(ctx context.Context, ch chan *HeaderPair, pairs []*HeaderPair) {
538 defer close(ch)
539 for _, pair := range pairs {
540 select {
541 case <-ctx.Done():
542 return
543 case ch <- pair:
544 }
545 }
546 }(ctx, ch, pairs)
547 return mapiter.New(ch)
548 }
549
550 func (h *okpPublicKey) Walk(ctx context.Context, visitor HeaderVisitor) error {
551 return iter.WalkMap(ctx, h, visitor)
552 }
553
554 func (h *okpPublicKey) AsMap(ctx context.Context) (map[string]interface{}, error) {
555 return iter.AsMap(ctx, h)
556 }
557
558 type OKPPrivateKey interface {
559 Key
560 FromRaw(interface{}) error
561 Crv() jwa.EllipticCurveAlgorithm
562 D() []byte
563 X() []byte
564 }
565
566 type okpPrivateKey struct {
567 algorithm *string
568 crv *jwa.EllipticCurveAlgorithm
569 d []byte
570 keyID *string
571 keyOps *KeyOperationList
572 keyUsage *string
573 x []byte
574 x509CertChain *CertificateChain
575 x509CertThumbprint *string
576 x509CertThumbprintS256 *string
577 x509URL *string
578 privateParams map[string]interface{}
579 mu *sync.RWMutex
580 dc json.DecodeCtx
581 }
582
583 func NewOKPPrivateKey() OKPPrivateKey {
584 return newOKPPrivateKey()
585 }
586
587 func newOKPPrivateKey() *okpPrivateKey {
588 return &okpPrivateKey{
589 mu: &sync.RWMutex{},
590 privateParams: make(map[string]interface{}),
591 }
592 }
593
594 func (h okpPrivateKey) KeyType() jwa.KeyType {
595 return jwa.OKP
596 }
597
598 func (h *okpPrivateKey) Algorithm() string {
599 if h.algorithm != nil {
600 return *(h.algorithm)
601 }
602 return ""
603 }
604
605 func (h *okpPrivateKey) Crv() jwa.EllipticCurveAlgorithm {
606 if h.crv != nil {
607 return *(h.crv)
608 }
609 return jwa.InvalidEllipticCurve
610 }
611
612 func (h *okpPrivateKey) D() []byte {
613 return h.d
614 }
615
616 func (h *okpPrivateKey) KeyID() string {
617 if h.keyID != nil {
618 return *(h.keyID)
619 }
620 return ""
621 }
622
623 func (h *okpPrivateKey) KeyOps() KeyOperationList {
624 if h.keyOps != nil {
625 return *(h.keyOps)
626 }
627 return nil
628 }
629
630 func (h *okpPrivateKey) KeyUsage() string {
631 if h.keyUsage != nil {
632 return *(h.keyUsage)
633 }
634 return ""
635 }
636
637 func (h *okpPrivateKey) X() []byte {
638 return h.x
639 }
640
641 func (h *okpPrivateKey) X509CertChain() []*x509.Certificate {
642 if h.x509CertChain != nil {
643 return h.x509CertChain.Get()
644 }
645 return nil
646 }
647
648 func (h *okpPrivateKey) X509CertThumbprint() string {
649 if h.x509CertThumbprint != nil {
650 return *(h.x509CertThumbprint)
651 }
652 return ""
653 }
654
655 func (h *okpPrivateKey) X509CertThumbprintS256() string {
656 if h.x509CertThumbprintS256 != nil {
657 return *(h.x509CertThumbprintS256)
658 }
659 return ""
660 }
661
662 func (h *okpPrivateKey) X509URL() string {
663 if h.x509URL != nil {
664 return *(h.x509URL)
665 }
666 return ""
667 }
668
669 func (h *okpPrivateKey) makePairs() []*HeaderPair {
670 h.mu.RLock()
671 defer h.mu.RUnlock()
672
673 var pairs []*HeaderPair
674 pairs = append(pairs, &HeaderPair{Key: "kty", Value: jwa.OKP})
675 if h.algorithm != nil {
676 pairs = append(pairs, &HeaderPair{Key: AlgorithmKey, Value: *(h.algorithm)})
677 }
678 if h.crv != nil {
679 pairs = append(pairs, &HeaderPair{Key: OKPCrvKey, Value: *(h.crv)})
680 }
681 if h.d != nil {
682 pairs = append(pairs, &HeaderPair{Key: OKPDKey, Value: h.d})
683 }
684 if h.keyID != nil {
685 pairs = append(pairs, &HeaderPair{Key: KeyIDKey, Value: *(h.keyID)})
686 }
687 if h.keyOps != nil {
688 pairs = append(pairs, &HeaderPair{Key: KeyOpsKey, Value: *(h.keyOps)})
689 }
690 if h.keyUsage != nil {
691 pairs = append(pairs, &HeaderPair{Key: KeyUsageKey, Value: *(h.keyUsage)})
692 }
693 if h.x != nil {
694 pairs = append(pairs, &HeaderPair{Key: OKPXKey, Value: h.x})
695 }
696 if h.x509CertChain != nil {
697 pairs = append(pairs, &HeaderPair{Key: X509CertChainKey, Value: *(h.x509CertChain)})
698 }
699 if h.x509CertThumbprint != nil {
700 pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintKey, Value: *(h.x509CertThumbprint)})
701 }
702 if h.x509CertThumbprintS256 != nil {
703 pairs = append(pairs, &HeaderPair{Key: X509CertThumbprintS256Key, Value: *(h.x509CertThumbprintS256)})
704 }
705 if h.x509URL != nil {
706 pairs = append(pairs, &HeaderPair{Key: X509URLKey, Value: *(h.x509URL)})
707 }
708 for k, v := range h.privateParams {
709 pairs = append(pairs, &HeaderPair{Key: k, Value: v})
710 }
711 return pairs
712 }
713
714 func (h *okpPrivateKey) PrivateParams() map[string]interface{} {
715 return h.privateParams
716 }
717
718 func (h *okpPrivateKey) Get(name string) (interface{}, bool) {
719 h.mu.RLock()
720 defer h.mu.RUnlock()
721 switch name {
722 case KeyTypeKey:
723 return h.KeyType(), true
724 case AlgorithmKey:
725 if h.algorithm == nil {
726 return nil, false
727 }
728 return *(h.algorithm), true
729 case OKPCrvKey:
730 if h.crv == nil {
731 return nil, false
732 }
733 return *(h.crv), true
734 case OKPDKey:
735 if h.d == nil {
736 return nil, false
737 }
738 return h.d, true
739 case KeyIDKey:
740 if h.keyID == nil {
741 return nil, false
742 }
743 return *(h.keyID), true
744 case KeyOpsKey:
745 if h.keyOps == nil {
746 return nil, false
747 }
748 return *(h.keyOps), true
749 case KeyUsageKey:
750 if h.keyUsage == nil {
751 return nil, false
752 }
753 return *(h.keyUsage), true
754 case OKPXKey:
755 if h.x == nil {
756 return nil, false
757 }
758 return h.x, true
759 case X509CertChainKey:
760 if h.x509CertChain == nil {
761 return nil, false
762 }
763 return h.x509CertChain.Get(), true
764 case X509CertThumbprintKey:
765 if h.x509CertThumbprint == nil {
766 return nil, false
767 }
768 return *(h.x509CertThumbprint), true
769 case X509CertThumbprintS256Key:
770 if h.x509CertThumbprintS256 == nil {
771 return nil, false
772 }
773 return *(h.x509CertThumbprintS256), true
774 case X509URLKey:
775 if h.x509URL == nil {
776 return nil, false
777 }
778 return *(h.x509URL), true
779 default:
780 v, ok := h.privateParams[name]
781 return v, ok
782 }
783 }
784
785 func (h *okpPrivateKey) Set(name string, value interface{}) error {
786 h.mu.Lock()
787 defer h.mu.Unlock()
788 return h.setNoLock(name, value)
789 }
790
791 func (h *okpPrivateKey) setNoLock(name string, value interface{}) error {
792 switch name {
793 case "kty":
794 return nil
795 case AlgorithmKey:
796 switch v := value.(type) {
797 case string:
798 h.algorithm = &v
799 case fmt.Stringer:
800 tmp := v.String()
801 h.algorithm = &tmp
802 default:
803 return errors.Errorf(`invalid type for %s key: %T`, AlgorithmKey, value)
804 }
805 return nil
806 case OKPCrvKey:
807 if v, ok := value.(jwa.EllipticCurveAlgorithm); ok {
808 h.crv = &v
809 return nil
810 }
811 return errors.Errorf(`invalid value for %s key: %T`, OKPCrvKey, value)
812 case OKPDKey:
813 if v, ok := value.([]byte); ok {
814 h.d = v
815 return nil
816 }
817 return errors.Errorf(`invalid value for %s key: %T`, OKPDKey, value)
818 case KeyIDKey:
819 if v, ok := value.(string); ok {
820 h.keyID = &v
821 return nil
822 }
823 return errors.Errorf(`invalid value for %s key: %T`, KeyIDKey, value)
824 case KeyOpsKey:
825 var acceptor KeyOperationList
826 if err := acceptor.Accept(value); err != nil {
827 return errors.Wrapf(err, `invalid value for %s key`, KeyOpsKey)
828 }
829 h.keyOps = &acceptor
830 return nil
831 case KeyUsageKey:
832 switch v := value.(type) {
833 case KeyUsageType:
834 switch v {
835 case ForSignature, ForEncryption:
836 tmp := v.String()
837 h.keyUsage = &tmp
838 default:
839 return errors.Errorf(`invalid key usage type %s`, v)
840 }
841 case string:
842 h.keyUsage = &v
843 default:
844 return errors.Errorf(`invalid key usage type %s`, v)
845 }
846 case OKPXKey:
847 if v, ok := value.([]byte); ok {
848 h.x = v
849 return nil
850 }
851 return errors.Errorf(`invalid value for %s key: %T`, OKPXKey, value)
852 case X509CertChainKey:
853 var acceptor CertificateChain
854 if err := acceptor.Accept(value); err != nil {
855 return errors.Wrapf(err, `invalid value for %s key`, X509CertChainKey)
856 }
857 h.x509CertChain = &acceptor
858 return nil
859 case X509CertThumbprintKey:
860 if v, ok := value.(string); ok {
861 h.x509CertThumbprint = &v
862 return nil
863 }
864 return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintKey, value)
865 case X509CertThumbprintS256Key:
866 if v, ok := value.(string); ok {
867 h.x509CertThumbprintS256 = &v
868 return nil
869 }
870 return errors.Errorf(`invalid value for %s key: %T`, X509CertThumbprintS256Key, value)
871 case X509URLKey:
872 if v, ok := value.(string); ok {
873 h.x509URL = &v
874 return nil
875 }
876 return errors.Errorf(`invalid value for %s key: %T`, X509URLKey, value)
877 default:
878 if h.privateParams == nil {
879 h.privateParams = map[string]interface{}{}
880 }
881 h.privateParams[name] = value
882 }
883 return nil
884 }
885
886 func (k *okpPrivateKey) Remove(key string) error {
887 k.mu.Lock()
888 defer k.mu.Unlock()
889 switch key {
890 case AlgorithmKey:
891 k.algorithm = nil
892 case OKPCrvKey:
893 k.crv = nil
894 case OKPDKey:
895 k.d = nil
896 case KeyIDKey:
897 k.keyID = nil
898 case KeyOpsKey:
899 k.keyOps = nil
900 case KeyUsageKey:
901 k.keyUsage = nil
902 case OKPXKey:
903 k.x = nil
904 case X509CertChainKey:
905 k.x509CertChain = nil
906 case X509CertThumbprintKey:
907 k.x509CertThumbprint = nil
908 case X509CertThumbprintS256Key:
909 k.x509CertThumbprintS256 = nil
910 case X509URLKey:
911 k.x509URL = nil
912 default:
913 delete(k.privateParams, key)
914 }
915 return nil
916 }
917
918 func (k *okpPrivateKey) Clone() (Key, error) {
919 return cloneKey(k)
920 }
921
922 func (k *okpPrivateKey) DecodeCtx() json.DecodeCtx {
923 k.mu.RLock()
924 defer k.mu.RUnlock()
925 return k.dc
926 }
927
928 func (k *okpPrivateKey) SetDecodeCtx(dc json.DecodeCtx) {
929 k.mu.Lock()
930 defer k.mu.Unlock()
931 k.dc = dc
932 }
933
934 func (h *okpPrivateKey) UnmarshalJSON(buf []byte) error {
935 h.algorithm = nil
936 h.crv = nil
937 h.d = nil
938 h.keyID = nil
939 h.keyOps = nil
940 h.keyUsage = nil
941 h.x = nil
942 h.x509CertChain = nil
943 h.x509CertThumbprint = nil
944 h.x509CertThumbprintS256 = nil
945 h.x509URL = nil
946 dec := json.NewDecoder(bytes.NewReader(buf))
947 LOOP:
948 for {
949 tok, err := dec.Token()
950 if err != nil {
951 return errors.Wrap(err, `error reading token`)
952 }
953 switch tok := tok.(type) {
954 case json.Delim:
955
956
957 if tok == '}' {
958 break LOOP
959 } else if tok != '{' {
960 return errors.Errorf(`expected '{', but got '%c'`, tok)
961 }
962 case string:
963 switch tok {
964 case KeyTypeKey:
965 val, err := json.ReadNextStringToken(dec)
966 if err != nil {
967 return errors.Wrap(err, `error reading token`)
968 }
969 if val != jwa.OKP.String() {
970 return errors.Errorf(`invalid kty value for RSAPublicKey (%s)`, val)
971 }
972 case AlgorithmKey:
973 if err := json.AssignNextStringToken(&h.algorithm, dec); err != nil {
974 return errors.Wrapf(err, `failed to decode value for key %s`, AlgorithmKey)
975 }
976 case OKPCrvKey:
977 var decoded jwa.EllipticCurveAlgorithm
978 if err := dec.Decode(&decoded); err != nil {
979 return errors.Wrapf(err, `failed to decode value for key %s`, OKPCrvKey)
980 }
981 h.crv = &decoded
982 case OKPDKey:
983 if err := json.AssignNextBytesToken(&h.d, dec); err != nil {
984 return errors.Wrapf(err, `failed to decode value for key %s`, OKPDKey)
985 }
986 case KeyIDKey:
987 if err := json.AssignNextStringToken(&h.keyID, dec); err != nil {
988 return errors.Wrapf(err, `failed to decode value for key %s`, KeyIDKey)
989 }
990 case KeyOpsKey:
991 var decoded KeyOperationList
992 if err := dec.Decode(&decoded); err != nil {
993 return errors.Wrapf(err, `failed to decode value for key %s`, KeyOpsKey)
994 }
995 h.keyOps = &decoded
996 case KeyUsageKey:
997 if err := json.AssignNextStringToken(&h.keyUsage, dec); err != nil {
998 return errors.Wrapf(err, `failed to decode value for key %s`, KeyUsageKey)
999 }
1000 case OKPXKey:
1001 if err := json.AssignNextBytesToken(&h.x, dec); err != nil {
1002 return errors.Wrapf(err, `failed to decode value for key %s`, OKPXKey)
1003 }
1004 case X509CertChainKey:
1005 var decoded CertificateChain
1006 if err := dec.Decode(&decoded); err != nil {
1007 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertChainKey)
1008 }
1009 h.x509CertChain = &decoded
1010 case X509CertThumbprintKey:
1011 if err := json.AssignNextStringToken(&h.x509CertThumbprint, dec); err != nil {
1012 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertThumbprintKey)
1013 }
1014 case X509CertThumbprintS256Key:
1015 if err := json.AssignNextStringToken(&h.x509CertThumbprintS256, dec); err != nil {
1016 return errors.Wrapf(err, `failed to decode value for key %s`, X509CertThumbprintS256Key)
1017 }
1018 case X509URLKey:
1019 if err := json.AssignNextStringToken(&h.x509URL, dec); err != nil {
1020 return errors.Wrapf(err, `failed to decode value for key %s`, X509URLKey)
1021 }
1022 default:
1023 if dc := h.dc; dc != nil {
1024 if localReg := dc.Registry(); localReg != nil {
1025 decoded, err := localReg.Decode(dec, tok)
1026 if err == nil {
1027 h.setNoLock(tok, decoded)
1028 continue
1029 }
1030 }
1031 }
1032 decoded, err := registry.Decode(dec, tok)
1033 if err == nil {
1034 h.setNoLock(tok, decoded)
1035 continue
1036 }
1037 return errors.Wrapf(err, `could not decode field %s`, tok)
1038 }
1039 default:
1040 return errors.Errorf(`invalid token %T`, tok)
1041 }
1042 }
1043 if h.crv == nil {
1044 return errors.Errorf(`required field crv is missing`)
1045 }
1046 if h.d == nil {
1047 return errors.Errorf(`required field d is missing`)
1048 }
1049 if h.x == nil {
1050 return errors.Errorf(`required field x is missing`)
1051 }
1052 return nil
1053 }
1054
1055 func (h okpPrivateKey) MarshalJSON() ([]byte, error) {
1056 data := make(map[string]interface{})
1057 fields := make([]string, 0, 11)
1058 for _, pair := range h.makePairs() {
1059 fields = append(fields, pair.Key.(string))
1060 data[pair.Key.(string)] = pair.Value
1061 }
1062
1063 sort.Strings(fields)
1064 buf := pool.GetBytesBuffer()
1065 defer pool.ReleaseBytesBuffer(buf)
1066 buf.WriteByte('{')
1067 enc := json.NewEncoder(buf)
1068 for i, f := range fields {
1069 if i > 0 {
1070 buf.WriteRune(',')
1071 }
1072 buf.WriteRune('"')
1073 buf.WriteString(f)
1074 buf.WriteString(`":`)
1075 v := data[f]
1076 switch v := v.(type) {
1077 case []byte:
1078 buf.WriteRune('"')
1079 buf.WriteString(base64.EncodeToString(v))
1080 buf.WriteRune('"')
1081 default:
1082 if err := enc.Encode(v); err != nil {
1083 return nil, errors.Wrapf(err, `failed to encode value for field %s`, f)
1084 }
1085 buf.Truncate(buf.Len() - 1)
1086 }
1087 }
1088 buf.WriteByte('}')
1089 ret := make([]byte, buf.Len())
1090 copy(ret, buf.Bytes())
1091 return ret, nil
1092 }
1093
1094 func (h *okpPrivateKey) Iterate(ctx context.Context) HeaderIterator {
1095 pairs := h.makePairs()
1096 ch := make(chan *HeaderPair, len(pairs))
1097 go func(ctx context.Context, ch chan *HeaderPair, pairs []*HeaderPair) {
1098 defer close(ch)
1099 for _, pair := range pairs {
1100 select {
1101 case <-ctx.Done():
1102 return
1103 case ch <- pair:
1104 }
1105 }
1106 }(ctx, ch, pairs)
1107 return mapiter.New(ch)
1108 }
1109
1110 func (h *okpPrivateKey) Walk(ctx context.Context, visitor HeaderVisitor) error {
1111 return iter.WalkMap(ctx, h, visitor)
1112 }
1113
1114 func (h *okpPrivateKey) AsMap(ctx context.Context) (map[string]interface{}, error) {
1115 return iter.AsMap(ctx, h)
1116 }
1117
View as plain text