1
2
3
4 package jwk
5
6 import (
7 "bytes"
8 "context"
9 "crypto"
10 "crypto/ecdsa"
11 "crypto/ed25519"
12 "crypto/rsa"
13 "crypto/x509"
14 "encoding/pem"
15 "io"
16 "io/ioutil"
17 "math/big"
18 "net/http"
19
20 "github.com/lestrrat-go/backoff/v2"
21 "github.com/lestrrat-go/jwx/internal/base64"
22 "github.com/lestrrat-go/jwx/internal/json"
23 "github.com/lestrrat-go/jwx/jwa"
24 "github.com/lestrrat-go/jwx/x25519"
25 "github.com/pkg/errors"
26 )
27
28 var registry = json.NewRegistry()
29
30 func bigIntToBytes(n *big.Int) ([]byte, error) {
31 if n == nil {
32 return nil, errors.New(`invalid *big.Int value`)
33 }
34 return n.Bytes(), nil
35 }
36
37
38
39
40
41
42
43
44
45
46 func New(key interface{}) (Key, error) {
47 if key == nil {
48 return nil, errors.New(`jwk.New requires a non-nil key`)
49 }
50
51 var ptr interface{}
52 switch v := key.(type) {
53 case rsa.PrivateKey:
54 ptr = &v
55 case rsa.PublicKey:
56 ptr = &v
57 case ecdsa.PrivateKey:
58 ptr = &v
59 case ecdsa.PublicKey:
60 ptr = &v
61 default:
62 ptr = v
63 }
64
65 switch rawKey := ptr.(type) {
66 case *rsa.PrivateKey:
67 k := NewRSAPrivateKey()
68 if err := k.FromRaw(rawKey); err != nil {
69 return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
70 }
71 return k, nil
72 case *rsa.PublicKey:
73 k := NewRSAPublicKey()
74 if err := k.FromRaw(rawKey); err != nil {
75 return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
76 }
77 return k, nil
78 case *ecdsa.PrivateKey:
79 k := NewECDSAPrivateKey()
80 if err := k.FromRaw(rawKey); err != nil {
81 return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
82 }
83 return k, nil
84 case *ecdsa.PublicKey:
85 k := NewECDSAPublicKey()
86 if err := k.FromRaw(rawKey); err != nil {
87 return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
88 }
89 return k, nil
90 case ed25519.PrivateKey:
91 k := NewOKPPrivateKey()
92 if err := k.FromRaw(rawKey); err != nil {
93 return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
94 }
95 return k, nil
96 case ed25519.PublicKey:
97 k := NewOKPPublicKey()
98 if err := k.FromRaw(rawKey); err != nil {
99 return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
100 }
101 return k, nil
102 case x25519.PrivateKey:
103 k := NewOKPPrivateKey()
104 if err := k.FromRaw(rawKey); err != nil {
105 return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
106 }
107 return k, nil
108 case x25519.PublicKey:
109 k := NewOKPPublicKey()
110 if err := k.FromRaw(rawKey); err != nil {
111 return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
112 }
113 return k, nil
114 case []byte:
115 k := NewSymmetricKey()
116 if err := k.FromRaw(rawKey); err != nil {
117 return nil, errors.Wrapf(err, `failed to initialize %T from %T`, k, rawKey)
118 }
119 return k, nil
120 default:
121 return nil, errors.Errorf(`invalid key type '%T' for jwk.New`, key)
122 }
123 }
124
125
126
127
128
129
130
131
132
133
134 func PublicSetOf(v Set) (Set, error) {
135 newSet := NewSet()
136
137 n := v.Len()
138 for i := 0; i < n; i++ {
139 k, ok := v.Get(i)
140 if !ok {
141 return nil, errors.New("key not found")
142 }
143 pubKey, err := PublicKeyOf(k)
144 if err != nil {
145 return nil, errors.Wrapf(err, `failed to get public key of %T`, k)
146 }
147 newSet.Add(pubKey)
148 }
149
150 return newSet, nil
151 }
152
153
154
155
156
157
158
159
160
161
162 func PublicKeyOf(v interface{}) (Key, error) {
163 if pk, ok := v.(PublicKeyer); ok {
164 return pk.PublicKey()
165 }
166
167 jk, err := New(v)
168 if err != nil {
169 return nil, errors.Wrapf(err, `failed to convert key into JWK`)
170 }
171
172 return jk.PublicKey()
173 }
174
175
176
177
178
179
180
181
182 func PublicRawKeyOf(v interface{}) (interface{}, error) {
183 if pk, ok := v.(PublicKeyer); ok {
184 pubk, err := pk.PublicKey()
185 if err != nil {
186 return nil, errors.Wrapf(err, `failed to obtain public key from %T`, v)
187 }
188
189 var raw interface{}
190 if err := pubk.Raw(&raw); err != nil {
191 return nil, errors.Wrapf(err, `failed to obtain raw key from %T`, pubk)
192 }
193 return raw, nil
194 }
195
196
197 var ptr interface{}
198 switch v := v.(type) {
199 case rsa.PrivateKey:
200 ptr = &v
201 case rsa.PublicKey:
202 ptr = &v
203 case ecdsa.PrivateKey:
204 ptr = &v
205 case ecdsa.PublicKey:
206 ptr = &v
207 default:
208 ptr = v
209 }
210
211 switch x := ptr.(type) {
212 case *rsa.PrivateKey:
213 return &x.PublicKey, nil
214 case *rsa.PublicKey:
215 return x, nil
216 case *ecdsa.PrivateKey:
217 return &x.PublicKey, nil
218 case *ecdsa.PublicKey:
219 return x, nil
220 case ed25519.PrivateKey:
221 return x.Public(), nil
222 case ed25519.PublicKey:
223 return x, nil
224 case x25519.PrivateKey:
225 return x.Public(), nil
226 case x25519.PublicKey:
227 return x, nil
228 case []byte:
229 return x, nil
230 default:
231 return nil, errors.Errorf(`invalid key type passed to PublicKeyOf (%T)`, v)
232 }
233 }
234
235
236
237
238
239
240
241
242
243
244
245
246
247 func Fetch(ctx context.Context, urlstring string, options ...FetchOption) (Set, error) {
248 res, err := fetch(ctx, urlstring, options...)
249 if err != nil {
250 return nil, err
251 }
252
253 defer res.Body.Close()
254 keyset, err := ParseReader(res.Body)
255 if err != nil {
256 return nil, errors.Wrap(err, `failed to parse JWK set`)
257 }
258 return keyset, nil
259 }
260
261 func fetch(ctx context.Context, urlstring string, options ...FetchOption) (*http.Response, error) {
262 var wl Whitelist
263 var httpcl HTTPClient = http.DefaultClient
264 bo := backoff.Null()
265 for _, option := range options {
266
267 switch option.Ident() {
268 case identHTTPClient{}:
269 httpcl = option.Value().(HTTPClient)
270 case identFetchBackoff{}:
271 bo = option.Value().(backoff.Policy)
272 case identFetchWhitelist{}:
273 wl = option.Value().(Whitelist)
274 }
275 }
276
277 if wl != nil {
278 if !wl.IsAllowed(urlstring) {
279 return nil, errors.New(`url rejected by whitelist`)
280 }
281 }
282
283 req, err := http.NewRequestWithContext(ctx, http.MethodGet, urlstring, nil)
284 if err != nil {
285 return nil, errors.Wrap(err, "failed to new request to remote JWK")
286 }
287
288 b := bo.Start(ctx)
289 var lastError error
290 for backoff.Continue(b) {
291 res, err := httpcl.Do(req)
292 if err != nil {
293 lastError = errors.Wrap(err, "failed to fetch remote JWK")
294 continue
295 }
296
297 if res.StatusCode != http.StatusOK {
298 lastError = errors.Errorf("failed to fetch remote JWK (status = %d)", res.StatusCode)
299 continue
300 }
301 return res, nil
302 }
303
304
305
306
307 if lastError == nil {
308 lastError = errors.New(`fetching remote JWK did not complete`)
309 }
310 return nil, lastError
311 }
312
313
314
315
316
317 func ParseRawKey(data []byte, rawkey interface{}) error {
318 key, err := ParseKey(data)
319 if err != nil {
320 return errors.Wrap(err, `failed to parse key`)
321 }
322
323 if err := key.Raw(rawkey); err != nil {
324 return errors.Wrap(err, `failed to assign to raw key variable`)
325 }
326
327 return nil
328 }
329
330
331
332
333 func parsePEMEncodedRawKey(src []byte) (interface{}, []byte, error) {
334 block, rest := pem.Decode(src)
335 if block == nil {
336 return nil, nil, errors.New(`failed to decode PEM data`)
337 }
338
339 switch block.Type {
340
341 case "RSA PRIVATE KEY":
342 key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
343 if err != nil {
344 return nil, nil, errors.Wrap(err, `failed to parse PKCS1 private key`)
345 }
346 return key, rest, nil
347 case "RSA PUBLIC KEY":
348 key, err := x509.ParsePKCS1PublicKey(block.Bytes)
349 if err != nil {
350 return nil, nil, errors.Wrap(err, `failed to parse PKCS1 public key`)
351 }
352 return key, rest, nil
353 case "EC PRIVATE KEY":
354 key, err := x509.ParseECPrivateKey(block.Bytes)
355 if err != nil {
356 return nil, nil, errors.Wrap(err, `failed to parse EC private key`)
357 }
358 return key, rest, nil
359 case "PUBLIC KEY":
360
361 key, err := x509.ParsePKIXPublicKey(block.Bytes)
362 if err != nil {
363 return nil, nil, errors.Wrap(err, `failed to parse PKIX public key`)
364 }
365 return key, rest, nil
366 case "PRIVATE KEY":
367 key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
368 if err != nil {
369 return nil, nil, errors.Wrap(err, `failed to parse PKCS8 private key`)
370 }
371 return key, rest, nil
372 case "CERTIFICATE":
373 cert, err := x509.ParseCertificate(block.Bytes)
374 if err != nil {
375 return nil, nil, errors.Wrap(err, `failed to parse certificate`)
376 }
377 return cert.PublicKey, rest, nil
378 default:
379 return nil, nil, errors.Errorf(`invalid PEM block type %s`, block.Type)
380 }
381 }
382
383 type setDecodeCtx struct {
384 json.DecodeCtx
385 ignoreParseError bool
386 }
387
388 func (ctx *setDecodeCtx) IgnoreParseError() bool {
389 return ctx.ignoreParseError
390 }
391
392
393
394
395
396
397
398
399
400
401
402
403 func ParseKey(data []byte, options ...ParseOption) (Key, error) {
404 var parsePEM bool
405 var localReg *json.Registry
406 for _, option := range options {
407
408 switch option.Ident() {
409 case identPEM{}:
410 parsePEM = option.Value().(bool)
411 case identLocalRegistry{}:
412
413
414
415 localReg = option.Value().(*json.Registry)
416 case identTypedField{}:
417 pair := option.Value().(typedFieldPair)
418 if localReg == nil {
419 localReg = json.NewRegistry()
420 }
421 localReg.Register(pair.Name, pair.Value)
422 case identIgnoreParseError{}:
423 return nil, errors.Errorf(`jwk.WithIgnoreParseError() cannot be used for ParseKey()`)
424 }
425 }
426
427 if parsePEM {
428 raw, _, err := parsePEMEncodedRawKey(data)
429 if err != nil {
430 return nil, errors.Wrap(err, `failed to parse PEM encoded key`)
431 }
432 return New(raw)
433 }
434
435 var hint struct {
436 Kty string `json:"kty"`
437 D json.RawMessage `json:"d"`
438 }
439
440 if err := json.Unmarshal(data, &hint); err != nil {
441 return nil, errors.Wrap(err, `failed to unmarshal JSON into key hint`)
442 }
443
444 var key Key
445 switch jwa.KeyType(hint.Kty) {
446 case jwa.RSA:
447 if len(hint.D) > 0 {
448 key = newRSAPrivateKey()
449 } else {
450 key = newRSAPublicKey()
451 }
452 case jwa.EC:
453 if len(hint.D) > 0 {
454 key = newECDSAPrivateKey()
455 } else {
456 key = newECDSAPublicKey()
457 }
458 case jwa.OctetSeq:
459 key = newSymmetricKey()
460 case jwa.OKP:
461 if len(hint.D) > 0 {
462 key = newOKPPrivateKey()
463 } else {
464 key = newOKPPublicKey()
465 }
466 default:
467 return nil, errors.Errorf(`invalid key type from JSON (%s)`, hint.Kty)
468 }
469
470 if localReg != nil {
471 dcKey, ok := key.(json.DecodeCtxContainer)
472 if !ok {
473 return nil, errors.Errorf(`typed field was requested, but the key (%T) does not support DecodeCtx`, key)
474 }
475 dc := json.NewDecodeCtx(localReg)
476 dcKey.SetDecodeCtx(dc)
477 defer func() { dcKey.SetDecodeCtx(nil) }()
478 }
479
480 if err := json.Unmarshal(data, key); err != nil {
481 return nil, errors.Wrapf(err, `failed to unmarshal JSON into key (%T)`, key)
482 }
483
484 return key, nil
485 }
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501 func Parse(src []byte, options ...ParseOption) (Set, error) {
502 var parsePEM bool
503 var localReg *json.Registry
504 var ignoreParseError bool
505 for _, option := range options {
506
507 switch option.Ident() {
508 case identPEM{}:
509 parsePEM = option.Value().(bool)
510 case identIgnoreParseError{}:
511 ignoreParseError = option.Value().(bool)
512 case identTypedField{}:
513 pair := option.Value().(typedFieldPair)
514 if localReg == nil {
515 localReg = json.NewRegistry()
516 }
517 localReg.Register(pair.Name, pair.Value)
518 }
519 }
520
521 s := NewSet()
522
523 if parsePEM {
524 src = bytes.TrimSpace(src)
525 for len(src) > 0 {
526 raw, rest, err := parsePEMEncodedRawKey(src)
527 if err != nil {
528 return nil, errors.Wrap(err, `failed to parse PEM encoded key`)
529 }
530 key, err := New(raw)
531 if err != nil {
532 return nil, errors.Wrapf(err, `failed to create jwk.Key from %T`, raw)
533 }
534 s.Add(key)
535 src = bytes.TrimSpace(rest)
536 }
537 return s, nil
538 }
539
540 if localReg != nil || ignoreParseError {
541 dcKs, ok := s.(KeyWithDecodeCtx)
542 if !ok {
543 return nil, errors.Errorf(`typed field was requested, but the key set (%T) does not support DecodeCtx`, s)
544 }
545 dc := &setDecodeCtx{
546 DecodeCtx: json.NewDecodeCtx(localReg),
547 ignoreParseError: ignoreParseError,
548 }
549 dcKs.SetDecodeCtx(dc)
550 defer func() { dcKs.SetDecodeCtx(nil) }()
551 }
552
553 if err := json.Unmarshal(src, s); err != nil {
554 return nil, errors.Wrap(err, "failed to unmarshal JWK set")
555 }
556 return s, nil
557 }
558
559
560 func ParseReader(src io.Reader, options ...ParseOption) (Set, error) {
561
562
563 buf, err := ioutil.ReadAll(src)
564 if err != nil {
565 return nil, errors.Wrap(err, `failed to read from io.Reader`)
566 }
567
568 return Parse(buf, options...)
569 }
570
571
572 func ParseString(s string, options ...ParseOption) (Set, error) {
573 return Parse([]byte(s), options...)
574 }
575
576
577
578
579 func AssignKeyID(key Key, options ...Option) error {
580 if _, ok := key.Get(KeyIDKey); ok {
581 return nil
582 }
583
584 hash := crypto.SHA256
585 for _, option := range options {
586
587 switch option.Ident() {
588 case identThumbprintHash{}:
589 hash = option.Value().(crypto.Hash)
590 }
591 }
592
593 h, err := key.Thumbprint(hash)
594 if err != nil {
595 return errors.Wrap(err, `failed to generate thumbprint`)
596 }
597
598 if err := key.Set(KeyIDKey, base64.EncodeToString(h)); err != nil {
599 return errors.Wrap(err, `failed to set "kid"`)
600 }
601
602 return nil
603 }
604
605 func cloneKey(src Key) (Key, error) {
606 var dst Key
607 switch src.(type) {
608 case RSAPrivateKey:
609 dst = NewRSAPrivateKey()
610 case RSAPublicKey:
611 dst = NewRSAPublicKey()
612 case ECDSAPrivateKey:
613 dst = NewECDSAPrivateKey()
614 case ECDSAPublicKey:
615 dst = NewECDSAPublicKey()
616 case OKPPrivateKey:
617 dst = NewOKPPrivateKey()
618 case OKPPublicKey:
619 dst = NewOKPPublicKey()
620 case SymmetricKey:
621 dst = NewSymmetricKey()
622 default:
623 return nil, errors.Errorf(`unknown key type %T`, src)
624 }
625
626 for _, pair := range src.makePairs() {
627
628 key := pair.Key.(string)
629 if err := dst.Set(key, pair.Value); err != nil {
630 return nil, errors.Wrapf(err, `failed to set %q`, key)
631 }
632 }
633 return dst, nil
634 }
635
636
637
638
639
640
641
642
643
644 func Pem(v interface{}) ([]byte, error) {
645 var set Set
646 switch v := v.(type) {
647 case Key:
648 set = NewSet()
649 set.Add(v)
650 case Set:
651 set = v
652 default:
653 return nil, errors.Errorf(`argument to Pem must be either jwk.Key or jwk.Set: %T`, v)
654 }
655
656 var ret []byte
657 for i := 0; i < set.Len(); i++ {
658 key, _ := set.Get(i)
659 typ, buf, err := asnEncode(key)
660 if err != nil {
661 return nil, errors.Wrapf(err, `failed to encode content for key #%d`, i)
662 }
663
664 var block pem.Block
665 block.Type = typ
666 block.Bytes = buf
667 ret = append(ret, pem.EncodeToMemory(&block)...)
668 }
669 return ret, nil
670 }
671
672 func asnEncode(key Key) (string, []byte, error) {
673 switch key := key.(type) {
674 case RSAPrivateKey, ECDSAPrivateKey, OKPPrivateKey:
675 var rawkey interface{}
676 if err := key.Raw(&rawkey); err != nil {
677 return "", nil, errors.Wrap(err, `failed to get raw key from jwk.Key`)
678 }
679 buf, err := x509.MarshalPKCS8PrivateKey(rawkey)
680 if err != nil {
681 return "", nil, errors.Wrap(err, `failed to marshal PKCS8`)
682 }
683 return "PRIVATE KEY", buf, nil
684 case RSAPublicKey, ECDSAPublicKey, OKPPublicKey:
685 var rawkey interface{}
686 if err := key.Raw(&rawkey); err != nil {
687 return "", nil, errors.Wrap(err, `failed to get raw key from jwk.Key`)
688 }
689 buf, err := x509.MarshalPKIXPublicKey(rawkey)
690 if err != nil {
691 return "", nil, errors.Wrap(err, `failed to marshal PKIX`)
692 }
693 return "PUBLIC KEY", buf, nil
694 default:
695 return "", nil, errors.Errorf(`unsupported key type %T`, key)
696 }
697 }
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717 func RegisterCustomField(name string, object interface{}) {
718 registry.Register(name, object)
719 }
720
View as plain text