1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22 package jws
23
24 import (
25 "bufio"
26 "bytes"
27 "context"
28 "crypto/ecdsa"
29 "crypto/ed25519"
30 "crypto/rsa"
31 "fmt"
32 "io"
33 "io/ioutil"
34 "net/http"
35 "net/url"
36 "reflect"
37 "strings"
38 "sync"
39 "unicode"
40 "unicode/utf8"
41
42 "github.com/lestrrat-go/backoff/v2"
43 "github.com/lestrrat-go/jwx/internal/base64"
44 "github.com/lestrrat-go/jwx/internal/json"
45 "github.com/lestrrat-go/jwx/internal/pool"
46 "github.com/lestrrat-go/jwx/jwa"
47 "github.com/lestrrat-go/jwx/jwk"
48 "github.com/lestrrat-go/jwx/x25519"
49 "github.com/pkg/errors"
50 )
51
52 var registry = json.NewRegistry()
53
54 type payloadSigner struct {
55 signer Signer
56 key interface{}
57 protected Headers
58 public Headers
59 }
60
61 func (s *payloadSigner) Sign(payload []byte) ([]byte, error) {
62 return s.signer.Sign(payload, s.key)
63 }
64
65 func (s *payloadSigner) Algorithm() jwa.SignatureAlgorithm {
66 return s.signer.Algorithm()
67 }
68
69 func (s *payloadSigner) ProtectedHeader() Headers {
70 return s.protected
71 }
72
73 func (s *payloadSigner) PublicHeader() Headers {
74 return s.public
75 }
76
77 var signers = make(map[jwa.SignatureAlgorithm]Signer)
78 var muSigner = &sync.Mutex{}
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113 func Sign(payload []byte, alg jwa.SignatureAlgorithm, key interface{}, options ...SignOption) ([]byte, error) {
114 var hdrs Headers
115 var detached bool
116 for _, o := range options {
117
118 switch o.Ident() {
119 case identHeaders{}:
120 hdrs = o.Value().(Headers)
121 case identDetachedPayload{}:
122 detached = true
123 if payload != nil {
124 return nil, errors.New(`jws.Sign: payload must be nil when jws.WithDetachedPayload() is specified`)
125 }
126 payload = o.Value().([]byte)
127 }
128 }
129
130 muSigner.Lock()
131 signer, ok := signers[alg]
132 if !ok {
133 v, err := NewSigner(alg)
134 if err != nil {
135 muSigner.Unlock()
136 return nil, errors.Wrap(err, `failed to create signer`)
137 }
138 signers[alg] = v
139 signer = v
140 }
141 muSigner.Unlock()
142
143
144
145
146
147 sig := &Signature{
148 protected: hdrs,
149 detached: detached,
150 }
151 _, signature, err := sig.Sign(payload, signer, key)
152 if err != nil {
153 return nil, errors.Wrap(err, `failed sign payload`)
154 }
155
156 return signature, nil
157 }
158
159
160
161
162
163
164
165 func SignMulti(payload []byte, options ...Option) ([]byte, error) {
166 var signers []*payloadSigner
167 for _, o := range options {
168
169 switch o.Ident() {
170 case identPayloadSigner{}:
171 signers = append(signers, o.Value().(*payloadSigner))
172 }
173 }
174
175 if len(signers) == 0 {
176 return nil, errors.New(`no signers provided`)
177 }
178
179 var result Message
180
181 result.payload = payload
182
183 result.signatures = make([]*Signature, 0, len(signers))
184 for i, signer := range signers {
185 protected := signer.ProtectedHeader()
186 if protected == nil {
187 protected = NewHeaders()
188 }
189
190 if err := protected.Set(AlgorithmKey, signer.Algorithm()); err != nil {
191 return nil, errors.Wrap(err, `failed to set "alg" header`)
192 }
193
194 if key, ok := signer.key.(jwk.Key); ok {
195 if kid := key.KeyID(); kid != "" {
196 if err := protected.Set(KeyIDKey, kid); err != nil {
197 return nil, errors.Wrap(err, `failed to set "kid" header`)
198 }
199 }
200 }
201 sig := &Signature{
202 headers: signer.PublicHeader(),
203 protected: protected,
204 }
205 _, _, err := sig.Sign(payload, signer.signer, signer.key)
206 if err != nil {
207 return nil, errors.Wrapf(err, `failed to generate signature for signer #%d (alg=%s)`, i, signer.Algorithm())
208 }
209
210 result.signatures = append(result.signatures, sig)
211 }
212
213 return json.Marshal(result)
214 }
215
216 type verifyCtx struct {
217 dst *Message
218 detachedPayload []byte
219 alg jwa.SignatureAlgorithm
220 key interface{}
221 useJKU bool
222 jwksFetcher JWKSetFetcher
223
224
225 isJSON bool
226 }
227
228 var allowNoneWhitelist = jwk.WhitelistFunc(func(string) bool {
229 return false
230 })
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260 func VerifyAuto(buf []byte, options ...VerifyOption) ([]byte, error) {
261 var ctx verifyCtx
262
263 ctx.useJKU = true
264
265 var fetchOptions []jwk.FetchOption
266
267
268 for _, option := range options {
269 switch option.Ident() {
270 case identMessage{}:
271 ctx.dst = option.Value().(*Message)
272 case identDetachedPayload{}:
273 ctx.detachedPayload = option.Value().([]byte)
274 case identJWKSetFetcher{}:
275 ctx.jwksFetcher = option.Value().(JWKSetFetcher)
276 case identFetchWhitelist{}:
277 fetchOptions = append(fetchOptions, jwk.WithFetchWhitelist(option.Value().(jwk.Whitelist)))
278 case identFetchBackoff{}:
279 fetchOptions = append(fetchOptions, jwk.WithFetchBackoff(option.Value().(backoff.Policy)))
280 case identHTTPClient{}:
281 fetchOptions = append(fetchOptions, jwk.WithHTTPClient(option.Value().(*http.Client)))
282 }
283 }
284
285
286
287 if ctx.jwksFetcher == nil {
288 fetchOptions = append([]jwk.FetchOption{jwk.WithFetchWhitelist(allowNoneWhitelist)}, fetchOptions...)
289 ctx.jwksFetcher = NewJWKSetFetcher(fetchOptions...)
290 }
291
292 return ctx.verify(buf)
293 }
294
295
296
297
298
299
300
301
302
303
304 func Verify(buf []byte, alg jwa.SignatureAlgorithm, key interface{}, options ...VerifyOption) ([]byte, error) {
305 var ctx verifyCtx
306 ctx.alg = alg
307 ctx.key = key
308
309 for _, option := range options {
310 switch option.Ident() {
311 case identMessage{}:
312 ctx.dst = option.Value().(*Message)
313 case identDetachedPayload{}:
314 ctx.detachedPayload = option.Value().([]byte)
315 default:
316 return nil, errors.Errorf(`invalid jws.VerifyOption %q passed`, `With`+strings.TrimPrefix(fmt.Sprintf(`%T`, option.Ident()), `jws.ident`))
317 }
318 }
319
320 return ctx.verify(buf)
321 }
322
323 func (ctx *verifyCtx) verify(buf []byte) ([]byte, error) {
324 buf = bytes.TrimSpace(buf)
325 if len(buf) == 0 {
326 return nil, errors.New(`attempt to verify empty buffer`)
327 }
328
329 if buf[0] == '{' {
330 return ctx.verifyJSON(buf)
331 }
332 return ctx.verifyCompact(buf)
333 }
334
335
336
337
338
339
340
341
342
343 func VerifySet(buf []byte, set jwk.Set) ([]byte, error) {
344 n := set.Len()
345 for i := 0; i < n; i++ {
346 key, ok := set.Get(i)
347 if !ok {
348 continue
349 }
350 if key.Algorithm() == "" {
351 continue
352 }
353
354 if usage := key.KeyUsage(); usage != "" && usage != jwk.ForSignature.String() {
355 continue
356 }
357
358 buf, err := Verify(buf, jwa.SignatureAlgorithm(key.Algorithm()), key)
359 if err != nil {
360 continue
361 }
362
363 return buf, nil
364 }
365
366 return nil, errors.New(`failed to verify message with any of the keys in the jwk.Set object`)
367 }
368
369 func (ctx *verifyCtx) verifyJSON(signed []byte) ([]byte, error) {
370 ctx.isJSON = true
371
372 var m Message
373 m.SetDecodeCtx(collectRawCtx{})
374 defer m.clearRaw()
375 if err := json.Unmarshal(signed, &m); err != nil {
376 return nil, errors.Wrap(err, `failed to unmarshal JSON message`)
377 }
378 m.SetDecodeCtx(nil)
379
380 if len(m.payload) != 0 && ctx.detachedPayload != nil {
381 return nil, errors.New(`can't specify detached payload for JWS with payload`)
382 }
383
384 if ctx.detachedPayload != nil {
385 m.payload = ctx.detachedPayload
386 }
387
388
389 var payload string
390 if m.b64 {
391 payload = base64.EncodeToString(m.payload)
392 } else {
393 payload = string(m.payload)
394 }
395
396 buf := pool.GetBytesBuffer()
397 defer pool.ReleaseBytesBuffer(buf)
398
399 for i, sig := range m.signatures {
400 buf.Reset()
401
402 var encodedProtectedHeader string
403 if rbp, ok := sig.protected.(interface{ rawBuffer() []byte }); ok {
404 if raw := rbp.rawBuffer(); raw != nil {
405 encodedProtectedHeader = base64.EncodeToString(raw)
406 }
407 }
408
409 if encodedProtectedHeader == "" {
410 protected, err := json.Marshal(sig.protected)
411 if err != nil {
412 return nil, errors.Wrapf(err, `failed to marshal "protected" for signature #%d`, i+1)
413 }
414
415 encodedProtectedHeader = base64.EncodeToString(protected)
416 }
417
418 buf.WriteString(encodedProtectedHeader)
419 buf.WriteByte('.')
420 buf.WriteString(payload)
421
422 if !ctx.useJKU {
423 if hdr := sig.protected; hdr != nil && hdr.KeyID() != "" {
424 if jwkKey, ok := ctx.key.(jwk.Key); ok {
425 if jwkKey.KeyID() != hdr.KeyID() {
426 continue
427 }
428 }
429 }
430
431 verifier, err := NewVerifier(ctx.alg)
432 if err != nil {
433 return nil, errors.Wrap(err, "failed to create verifier")
434 }
435
436 if _, err := ctx.tryVerify(verifier, sig.protected, buf.Bytes(), sig.signature, m.payload); err == nil {
437 if ctx.dst != nil {
438 *(ctx.dst) = m
439 }
440 return m.payload, nil
441 }
442
443 continue
444 }
445
446 if _, err := ctx.verifyJKU(sig.protected, buf.Bytes(), sig.signature, m.payload); err == nil {
447 if ctx.dst != nil {
448 *(ctx.dst) = m
449 }
450 return m.payload, nil
451 }
452
453 }
454 return nil, errors.New(`could not verify with any of the signatures`)
455 }
456
457
458
459
460 func getB64Value(hdr Headers) bool {
461 b64raw, ok := hdr.Get("b64")
462 if !ok {
463 return true
464 }
465
466 b64, ok := b64raw.(bool)
467 if !ok {
468 return false
469 }
470 return b64
471 }
472
473 func (ctx *verifyCtx) verifyCompact(signed []byte) ([]byte, error) {
474 protected, payload, signature, err := SplitCompact(signed)
475 if err != nil {
476 return nil, errors.Wrap(err, `failed extract from compact serialization format`)
477 }
478
479 decodedSignature, err := base64.Decode(signature)
480 if err != nil {
481 return nil, errors.Wrap(err, `failed to decode signature`)
482 }
483
484 hdr := NewHeaders()
485 decodedProtected, err := base64.Decode(protected)
486 if err != nil {
487 return nil, errors.Wrap(err, `failed to decode headers`)
488 }
489
490 if err := json.Unmarshal(decodedProtected, hdr); err != nil {
491 return nil, errors.Wrap(err, `failed to decode headers`)
492 }
493
494 verifyBuf := pool.GetBytesBuffer()
495 defer pool.ReleaseBytesBuffer(verifyBuf)
496
497 verifyBuf.Write(protected)
498 verifyBuf.WriteByte('.')
499 if len(payload) == 0 && ctx.detachedPayload != nil {
500 if getB64Value(hdr) {
501 payload = base64.Encode(ctx.detachedPayload)
502 } else {
503 payload = ctx.detachedPayload
504 }
505 }
506 verifyBuf.Write(payload)
507
508 if !ctx.useJKU {
509 if hdr.KeyID() != "" {
510 if jwkKey, ok := ctx.key.(jwk.Key); ok {
511 if jwkKey.KeyID() != hdr.KeyID() {
512 return nil, errors.New(`"kid" fields do not match`)
513 }
514 }
515 }
516
517 verifier, err := NewVerifier(ctx.alg)
518 if err != nil {
519 return nil, errors.Wrap(err, "failed to create verifier")
520 }
521
522 return ctx.tryVerify(verifier, hdr, verifyBuf.Bytes(), decodedSignature, payload)
523 }
524
525 return ctx.verifyJKU(hdr, verifyBuf.Bytes(), decodedSignature, payload)
526 }
527
528
529 type JWKSetFetcher interface {
530 Fetch(string) (jwk.Set, error)
531 }
532
533
534
535
536
537
538 type SimpleJWKSetFetcher struct {
539 options []jwk.FetchOption
540 }
541
542 func NewJWKSetFetcher(options ...jwk.FetchOption) *SimpleJWKSetFetcher {
543 return &SimpleJWKSetFetcher{options: options}
544 }
545
546 func (f *SimpleJWKSetFetcher) Fetch(u string) (jwk.Set, error) {
547 return jwk.Fetch(context.TODO(), u, f.options...)
548 }
549
550 type JWKSetFetchFunc func(string) (jwk.Set, error)
551
552 func (f JWKSetFetchFunc) Fetch(u string) (jwk.Set, error) {
553 return f(u)
554 }
555
556 func (ctx *verifyCtx) verifyJKU(hdr Headers, verifyBuf, decodedSignature, payload []byte) ([]byte, error) {
557 u := hdr.JWKSetURL()
558 if u == "" {
559 return nil, errors.New(`use of "jku" field specified, but the field is empty`)
560 }
561 uo, err := url.Parse(u)
562 if err != nil {
563 return nil, errors.Wrap(err, `failed to parse "jku"`)
564 }
565 if uo.Scheme != "https" {
566 return nil, errors.New(`url in "jku" must be HTTPS`)
567 }
568
569 set, err := ctx.jwksFetcher.Fetch(u)
570 if err != nil {
571 return nil, errors.Wrapf(err, `failed to fetch "jku"`)
572 }
573
574
575
576 if hdr.KeyID() == "" {
577 return nil, errors.Errorf(`"kid" is required on the JWS message to use "jku"`)
578 }
579
580 key, ok := set.LookupKeyID(hdr.KeyID())
581 if !ok {
582 return nil, errors.New(`key specified via "kid" is not present in the JWK set specified by "jku"`)
583 }
584
585
586 algs, err := AlgorithmsForKey(key)
587 if err != nil {
588 return nil, errors.Wrapf(err, `failed to get a list of signature methods for key type %s`, key.KeyType())
589 }
590
591
592 ctx.key = key
593 hdrAlg := hdr.Algorithm()
594 for _, alg := range algs {
595
596
597 if hdrAlg != "" && hdrAlg != alg {
598 continue
599 }
600
601 verifier, err := NewVerifier(alg)
602 if err != nil {
603 return nil, errors.Wrap(err, "failed to create verifier")
604 }
605
606 if decoded, err := ctx.tryVerify(verifier, hdr, verifyBuf, decodedSignature, payload); err == nil {
607 return decoded, nil
608 }
609 }
610 return nil, errors.New(`failed to verify payload using key in "jku"`)
611 }
612
613 func (ctx *verifyCtx) tryVerify(verifier Verifier, hdr Headers, buf, decodedSignature, payload []byte) ([]byte, error) {
614 if err := verifier.Verify(buf, decodedSignature, ctx.key); err != nil {
615 return nil, errors.Wrap(err, `failed to verify message`)
616 }
617
618 var decodedPayload []byte
619
620
621
622 if !ctx.isJSON {
623
624 if !getB64Value(hdr) {
625 decodedPayload = payload
626 }
627
628 if decodedPayload == nil {
629 v, err := base64.Decode(payload)
630 if err != nil {
631 return nil, errors.Wrap(err, `message verified, failed to decode payload`)
632 }
633 decodedPayload = v
634 }
635
636
637
638 if ctx.dst != nil {
639
640 m := NewMessage()
641 m.SetPayload(decodedPayload)
642 sig := NewSignature()
643 sig.SetProtectedHeaders(hdr)
644 sig.SetSignature(decodedSignature)
645 m.AppendSignature(sig)
646
647 *(ctx.dst) = *m
648 }
649 }
650 return decodedPayload, nil
651 }
652
653
654
655
656 func readAll(rdr io.Reader) ([]byte, bool) {
657 switch rdr.(type) {
658 case *bytes.Reader, *bytes.Buffer, *strings.Reader:
659 data, err := ioutil.ReadAll(rdr)
660 if err != nil {
661 return nil, false
662 }
663 return data, true
664 default:
665 return nil, false
666 }
667 }
668
669
670
671 func Parse(src []byte) (*Message, error) {
672 for i := 0; i < len(src); i++ {
673 r := rune(src[i])
674 if r >= utf8.RuneSelf {
675 r, _ = utf8.DecodeRune(src)
676 }
677 if !unicode.IsSpace(r) {
678 if r == '{' {
679 return parseJSON(src)
680 }
681 return parseCompact(src)
682 }
683 }
684 return nil, errors.New("invalid byte sequence")
685 }
686
687
688
689 func ParseString(src string) (*Message, error) {
690 return Parse([]byte(src))
691 }
692
693
694
695 func ParseReader(src io.Reader) (*Message, error) {
696 if data, ok := readAll(src); ok {
697 return Parse(data)
698 }
699
700 rdr := bufio.NewReader(src)
701 var first rune
702 for {
703 r, _, err := rdr.ReadRune()
704 if err != nil {
705 return nil, errors.Wrap(err, `failed to read rune`)
706 }
707 if !unicode.IsSpace(r) {
708 first = r
709 if err := rdr.UnreadRune(); err != nil {
710 return nil, errors.Wrap(err, `failed to unread rune`)
711 }
712
713 break
714 }
715 }
716
717 var parser func(io.Reader) (*Message, error)
718 if first == '{' {
719 parser = parseJSONReader
720 } else {
721 parser = parseCompactReader
722 }
723
724 m, err := parser(rdr)
725 if err != nil {
726 return nil, errors.Wrap(err, `failed to parse jws message`)
727 }
728
729 return m, nil
730 }
731
732 func parseJSONReader(src io.Reader) (result *Message, err error) {
733 var m Message
734 if err := json.NewDecoder(src).Decode(&m); err != nil {
735 return nil, errors.Wrap(err, `failed to unmarshal jws message`)
736 }
737 return &m, nil
738 }
739
740 func parseJSON(data []byte) (result *Message, err error) {
741 var m Message
742 if err := json.Unmarshal(data, &m); err != nil {
743 return nil, errors.Wrap(err, `failed to unmarshal jws message`)
744 }
745 return &m, nil
746 }
747
748
749
750 func SplitCompact(src []byte) ([]byte, []byte, []byte, error) {
751 parts := bytes.Split(src, []byte("."))
752 if len(parts) < 3 {
753 return nil, nil, nil, errors.New(`invalid number of segments`)
754 }
755 return parts[0], parts[1], parts[2], nil
756 }
757
758
759
760 func SplitCompactString(src string) ([]byte, []byte, []byte, error) {
761 parts := strings.Split(src, ".")
762 if len(parts) < 3 {
763 return nil, nil, nil, errors.New(`invalid number of segments`)
764 }
765 return []byte(parts[0]), []byte(parts[1]), []byte(parts[2]), nil
766 }
767
768
769
770 func SplitCompactReader(rdr io.Reader) ([]byte, []byte, []byte, error) {
771 if data, ok := readAll(rdr); ok {
772 return SplitCompact(data)
773 }
774
775 var protected []byte
776 var payload []byte
777 var signature []byte
778 var periods int
779 var state int
780
781 buf := make([]byte, 4096)
782 var sofar []byte
783
784 for {
785
786 n, err := rdr.Read(buf)
787
788 if err != nil && err != io.EOF {
789 return nil, nil, nil, errors.Wrap(err, `unexpected end of input`)
790 }
791
792
793 sofar = append(sofar, buf[:n]...)
794
795 for loop := true; loop; {
796 var i = bytes.IndexByte(sofar, '.')
797 if i == -1 && err != io.EOF {
798
799 loop = false
800 continue
801 } else if i == -1 && err == io.EOF {
802
803 i = len(sofar)
804 loop = false
805 } else {
806
807 periods++
808 }
809
810
811 switch state {
812 case 0:
813 protected = sofar[:i]
814 state++
815 case 1:
816 payload = sofar[:i]
817 state++
818 case 2:
819 signature = sofar[:i]
820 }
821
822 if len(sofar) > i {
823 sofar = sofar[i+1:]
824 }
825 }
826
827 if err == io.EOF {
828 break
829 }
830 }
831 if periods != 2 {
832 return nil, nil, nil, errors.New(`invalid number of segments`)
833 }
834
835 return protected, payload, signature, nil
836 }
837
838
839 func parseCompactReader(rdr io.Reader) (m *Message, err error) {
840 protected, payload, signature, err := SplitCompactReader(rdr)
841 if err != nil {
842 return nil, errors.Wrap(err, `invalid compact serialization format`)
843 }
844 return parse(protected, payload, signature)
845 }
846
847 func parseCompact(data []byte) (m *Message, err error) {
848 protected, payload, signature, err := SplitCompact(data)
849 if err != nil {
850 return nil, errors.Wrap(err, `invalid compact serialization format`)
851 }
852 return parse(protected, payload, signature)
853 }
854
855 func parse(protected, payload, signature []byte) (*Message, error) {
856 decodedHeader, err := base64.Decode(protected)
857 if err != nil {
858 return nil, errors.Wrap(err, `failed to decode protected headers`)
859 }
860
861 hdr := NewHeaders()
862 if err := json.Unmarshal(decodedHeader, hdr); err != nil {
863 return nil, errors.Wrap(err, `failed to parse JOSE headers`)
864 }
865
866 decodedPayload, err := base64.Decode(payload)
867 if err != nil {
868 return nil, errors.Wrap(err, `failed to decode payload`)
869 }
870
871 decodedSignature, err := base64.Decode(signature)
872 if err != nil {
873 return nil, errors.Wrap(err, `failed to decode signature`)
874 }
875
876 var msg Message
877 msg.payload = decodedPayload
878 msg.signatures = append(msg.signatures, &Signature{
879 protected: hdr,
880 signature: decodedSignature,
881 })
882 return &msg, nil
883 }
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903 func RegisterCustomField(name string, object interface{}) {
904 registry.Register(name, object)
905 }
906
907
908 var rawKeyToKeyType = make(map[reflect.Type]jwa.KeyType)
909 var keyTypeToAlgorithms = make(map[jwa.KeyType][]jwa.SignatureAlgorithm)
910
911 func init() {
912 rawKeyToKeyType[reflect.TypeOf([]byte(nil))] = jwa.OctetSeq
913 rawKeyToKeyType[reflect.TypeOf(ed25519.PublicKey(nil))] = jwa.OKP
914 rawKeyToKeyType[reflect.TypeOf(rsa.PublicKey{})] = jwa.RSA
915 rawKeyToKeyType[reflect.TypeOf((*rsa.PublicKey)(nil))] = jwa.RSA
916 rawKeyToKeyType[reflect.TypeOf(ecdsa.PublicKey{})] = jwa.EC
917 rawKeyToKeyType[reflect.TypeOf((*ecdsa.PublicKey)(nil))] = jwa.EC
918
919 addAlgorithmForKeyType(jwa.OKP, jwa.EdDSA)
920 for _, alg := range []jwa.SignatureAlgorithm{jwa.HS256, jwa.HS384, jwa.HS512} {
921 addAlgorithmForKeyType(jwa.OctetSeq, alg)
922 }
923 for _, alg := range []jwa.SignatureAlgorithm{jwa.RS256, jwa.RS384, jwa.RS512, jwa.PS256, jwa.PS384, jwa.PS512} {
924 addAlgorithmForKeyType(jwa.RSA, alg)
925 }
926 for _, alg := range []jwa.SignatureAlgorithm{jwa.ES256, jwa.ES384, jwa.ES512} {
927 addAlgorithmForKeyType(jwa.EC, alg)
928 }
929 }
930
931 func addAlgorithmForKeyType(kty jwa.KeyType, alg jwa.SignatureAlgorithm) {
932 keyTypeToAlgorithms[kty] = append(keyTypeToAlgorithms[kty], alg)
933 }
934
935
936
937
938
939 func AlgorithmsForKey(key interface{}) ([]jwa.SignatureAlgorithm, error) {
940 var kty jwa.KeyType
941 switch key := key.(type) {
942 case jwk.Key:
943 kty = key.KeyType()
944 case rsa.PublicKey, *rsa.PublicKey, rsa.PrivateKey, *rsa.PrivateKey:
945 kty = jwa.RSA
946 case ecdsa.PublicKey, *ecdsa.PublicKey, ecdsa.PrivateKey, *ecdsa.PrivateKey:
947 kty = jwa.EC
948 case ed25519.PublicKey, ed25519.PrivateKey, x25519.PublicKey, x25519.PrivateKey:
949 kty = jwa.OKP
950 case []byte:
951 kty = jwa.OctetSeq
952 default:
953 return nil, errors.Errorf(`invalid key %T`, key)
954 }
955
956 algs, ok := keyTypeToAlgorithms[kty]
957 if !ok {
958 return nil, errors.Errorf(`invalid key type %q`, kty)
959 }
960 return algs, nil
961 }
962
View as plain text