1 package jws
2
3 import (
4 "crypto"
5 "crypto/rand"
6 "crypto/rsa"
7
8 "github.com/lestrrat-go/jwx/internal/keyconv"
9 "github.com/lestrrat-go/jwx/jwa"
10 "github.com/pkg/errors"
11 )
12
13 var rsaSigners map[jwa.SignatureAlgorithm]*rsaSigner
14 var rsaVerifiers map[jwa.SignatureAlgorithm]*rsaVerifier
15
16 func init() {
17 algs := map[jwa.SignatureAlgorithm]struct {
18 Hash crypto.Hash
19 PSS bool
20 }{
21 jwa.RS256: {
22 Hash: crypto.SHA256,
23 },
24 jwa.RS384: {
25 Hash: crypto.SHA384,
26 },
27 jwa.RS512: {
28 Hash: crypto.SHA512,
29 },
30 jwa.PS256: {
31 Hash: crypto.SHA256,
32 PSS: true,
33 },
34 jwa.PS384: {
35 Hash: crypto.SHA384,
36 PSS: true,
37 },
38 jwa.PS512: {
39 Hash: crypto.SHA512,
40 PSS: true,
41 },
42 }
43
44 rsaSigners = make(map[jwa.SignatureAlgorithm]*rsaSigner)
45 rsaVerifiers = make(map[jwa.SignatureAlgorithm]*rsaVerifier)
46 for alg, item := range algs {
47 rsaSigners[alg] = &rsaSigner{
48 alg: alg,
49 hash: item.Hash,
50 pss: item.PSS,
51 }
52 rsaVerifiers[alg] = &rsaVerifier{
53 alg: alg,
54 hash: item.Hash,
55 pss: item.PSS,
56 }
57 }
58 }
59
60 type rsaSigner struct {
61 alg jwa.SignatureAlgorithm
62 hash crypto.Hash
63 pss bool
64 }
65
66 func newRSASigner(alg jwa.SignatureAlgorithm) Signer {
67 return rsaSigners[alg]
68 }
69
70 func (rs *rsaSigner) Algorithm() jwa.SignatureAlgorithm {
71 return rs.alg
72 }
73
74 func (rs *rsaSigner) Sign(payload []byte, key interface{}) ([]byte, error) {
75 if key == nil {
76 return nil, errors.New(`missing private key while signing payload`)
77 }
78
79 signer, ok := key.(crypto.Signer)
80 if !ok {
81 var privkey rsa.PrivateKey
82 if err := keyconv.RSAPrivateKey(&privkey, key); err != nil {
83 return nil, errors.Wrapf(err, `failed to retrieve rsa.PrivateKey out of %T`, key)
84 }
85 signer = &privkey
86 }
87
88 h := rs.hash.New()
89 if _, err := h.Write(payload); err != nil {
90 return nil, errors.Wrap(err, "failed to write payload to hash")
91 }
92 if rs.pss {
93 return signer.Sign(rand.Reader, h.Sum(nil), &rsa.PSSOptions{
94 Hash: rs.hash,
95 SaltLength: rsa.PSSSaltLengthEqualsHash,
96 })
97 }
98 return signer.Sign(rand.Reader, h.Sum(nil), rs.hash)
99 }
100
101 type rsaVerifier struct {
102 alg jwa.SignatureAlgorithm
103 hash crypto.Hash
104 pss bool
105 }
106
107 func newRSAVerifier(alg jwa.SignatureAlgorithm) Verifier {
108 return rsaVerifiers[alg]
109 }
110
111 func (rv *rsaVerifier) Verify(payload, signature []byte, key interface{}) error {
112 if key == nil {
113 return errors.New(`missing public key while verifying payload`)
114 }
115
116 var pubkey rsa.PublicKey
117 if cs, ok := key.(crypto.Signer); ok {
118 cpub := cs.Public()
119 switch cpub := cpub.(type) {
120 case rsa.PublicKey:
121 pubkey = cpub
122 case *rsa.PublicKey:
123 pubkey = *cpub
124 default:
125 return errors.Errorf(`failed to retrieve rsa.PublicKey out of crypto.Signer %T`, key)
126 }
127 } else {
128 if err := keyconv.RSAPublicKey(&pubkey, key); err != nil {
129 return errors.Wrapf(err, `failed to retrieve rsa.PublicKey out of %T`, key)
130 }
131 }
132
133 h := rv.hash.New()
134 if _, err := h.Write(payload); err != nil {
135 return errors.Wrap(err, "failed to write payload to hash")
136 }
137
138 if rv.pss {
139 return rsa.VerifyPSS(&pubkey, rv.hash, h.Sum(nil), signature, nil)
140 }
141 return rsa.VerifyPKCS1v15(&pubkey, rv.hash, h.Sum(nil), signature)
142 }
143
View as plain text