1 package jwe
2
3 import (
4 "crypto/aes"
5 cryptocipher "crypto/cipher"
6 "crypto/ecdsa"
7 "crypto/rsa"
8 "crypto/sha256"
9 "crypto/sha512"
10 "hash"
11
12 "golang.org/x/crypto/pbkdf2"
13
14 "github.com/lestrrat-go/jwx/internal/keyconv"
15 "github.com/lestrrat-go/jwx/jwa"
16 "github.com/lestrrat-go/jwx/jwe/internal/cipher"
17 "github.com/lestrrat-go/jwx/jwe/internal/content_crypt"
18 "github.com/lestrrat-go/jwx/jwe/internal/keyenc"
19 "github.com/lestrrat-go/jwx/x25519"
20 "github.com/pkg/errors"
21 )
22
23
24
25
26 type Decrypter struct {
27 aad []byte
28 apu []byte
29 apv []byte
30 computedAad []byte
31 iv []byte
32 keyiv []byte
33 keysalt []byte
34 keytag []byte
35 tag []byte
36 privkey interface{}
37 pubkey interface{}
38 ctalg jwa.ContentEncryptionAlgorithm
39 keyalg jwa.KeyEncryptionAlgorithm
40 cipher content_crypt.Cipher
41 keycount int
42 }
43
44
45
46
47
48
49
50
51
52 func NewDecrypter(keyalg jwa.KeyEncryptionAlgorithm, ctalg jwa.ContentEncryptionAlgorithm, privkey interface{}) *Decrypter {
53 return &Decrypter{
54 ctalg: ctalg,
55 keyalg: keyalg,
56 privkey: privkey,
57 }
58 }
59
60 func (d *Decrypter) AgreementPartyUInfo(apu []byte) *Decrypter {
61 d.apu = apu
62 return d
63 }
64
65 func (d *Decrypter) AgreementPartyVInfo(apv []byte) *Decrypter {
66 d.apv = apv
67 return d
68 }
69
70 func (d *Decrypter) AuthenticatedData(aad []byte) *Decrypter {
71 d.aad = aad
72 return d
73 }
74
75 func (d *Decrypter) ComputedAuthenticatedData(aad []byte) *Decrypter {
76 d.computedAad = aad
77 return d
78 }
79
80 func (d *Decrypter) ContentEncryptionAlgorithm(ctalg jwa.ContentEncryptionAlgorithm) *Decrypter {
81 d.ctalg = ctalg
82 return d
83 }
84
85 func (d *Decrypter) InitializationVector(iv []byte) *Decrypter {
86 d.iv = iv
87 return d
88 }
89
90 func (d *Decrypter) KeyCount(keycount int) *Decrypter {
91 d.keycount = keycount
92 return d
93 }
94
95 func (d *Decrypter) KeyInitializationVector(keyiv []byte) *Decrypter {
96 d.keyiv = keyiv
97 return d
98 }
99
100 func (d *Decrypter) KeySalt(keysalt []byte) *Decrypter {
101 d.keysalt = keysalt
102 return d
103 }
104
105 func (d *Decrypter) KeyTag(keytag []byte) *Decrypter {
106 d.keytag = keytag
107 return d
108 }
109
110
111
112 func (d *Decrypter) PublicKey(pubkey interface{}) *Decrypter {
113 d.pubkey = pubkey
114 return d
115 }
116
117 func (d *Decrypter) Tag(tag []byte) *Decrypter {
118 d.tag = tag
119 return d
120 }
121
122 func (d *Decrypter) ContentCipher() (content_crypt.Cipher, error) {
123 if d.cipher == nil {
124 switch d.ctalg {
125 case jwa.A128GCM, jwa.A192GCM, jwa.A256GCM, jwa.A128CBC_HS256, jwa.A192CBC_HS384, jwa.A256CBC_HS512:
126 cipher, err := cipher.NewAES(d.ctalg)
127 if err != nil {
128 return nil, errors.Wrapf(err, `failed to build content cipher for %s`, d.ctalg)
129 }
130 d.cipher = cipher
131 default:
132 return nil, errors.Errorf(`invalid content cipher algorithm (%s)`, d.ctalg)
133 }
134 }
135
136 return d.cipher, nil
137 }
138
139 func (d *Decrypter) Decrypt(recipientKey, ciphertext []byte) (plaintext []byte, err error) {
140 cek, keyerr := d.DecryptKey(recipientKey)
141 if keyerr != nil {
142 err = errors.Wrap(keyerr, `failed to decrypt key`)
143 return
144 }
145
146 cipher, ciphererr := d.ContentCipher()
147 if ciphererr != nil {
148 err = errors.Wrap(ciphererr, `failed to fetch content crypt cipher`)
149 return
150 }
151
152 computedAad := d.computedAad
153 if d.aad != nil {
154 computedAad = append(append(computedAad, '.'), d.aad...)
155 }
156
157 plaintext, err = cipher.Decrypt(cek, d.iv, ciphertext, d.tag, computedAad)
158 if err != nil {
159 err = errors.Wrap(err, `failed to decrypt payload`)
160 return
161 }
162
163 return plaintext, nil
164 }
165
166 func (d *Decrypter) decryptSymmetricKey(recipientKey, cek []byte) ([]byte, error) {
167 switch d.keyalg {
168 case jwa.DIRECT:
169 return cek, nil
170 case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
171 var hashFunc func() hash.Hash
172 var keylen int
173 switch d.keyalg {
174 case jwa.PBES2_HS256_A128KW:
175 hashFunc = sha256.New
176 keylen = 16
177 case jwa.PBES2_HS384_A192KW:
178 hashFunc = sha512.New384
179 keylen = 24
180 case jwa.PBES2_HS512_A256KW:
181 hashFunc = sha512.New
182 keylen = 32
183 }
184 salt := []byte(d.keyalg)
185 salt = append(salt, byte(0))
186 salt = append(salt, d.keysalt...)
187 cek = pbkdf2.Key(cek, salt, d.keycount, keylen, hashFunc)
188 fallthrough
189 case jwa.A128KW, jwa.A192KW, jwa.A256KW:
190 block, err := aes.NewCipher(cek)
191 if err != nil {
192 return nil, errors.Wrap(err, `failed to create new AES cipher`)
193 }
194
195 jek, err := keyenc.Unwrap(block, recipientKey)
196 if err != nil {
197 return nil, errors.Wrap(err, `failed to unwrap key`)
198 }
199
200 return jek, nil
201 case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW:
202 if len(d.keyiv) != 12 {
203 return nil, errors.Errorf("GCM requires 96-bit iv, got %d", len(d.keyiv)*8)
204 }
205 if len(d.keytag) != 16 {
206 return nil, errors.Errorf("GCM requires 128-bit tag, got %d", len(d.keytag)*8)
207 }
208 block, err := aes.NewCipher(cek)
209 if err != nil {
210 return nil, errors.Wrap(err, `failed to create new AES cipher`)
211 }
212 aesgcm, err := cryptocipher.NewGCM(block)
213 if err != nil {
214 return nil, errors.Wrap(err, `failed to create new GCM wrap`)
215 }
216 ciphertext := recipientKey[:]
217 ciphertext = append(ciphertext, d.keytag...)
218 jek, err := aesgcm.Open(nil, d.keyiv, ciphertext, nil)
219 if err != nil {
220 return nil, errors.Wrap(err, `failed to decode key`)
221 }
222 return jek, nil
223 default:
224 return nil, errors.Errorf("decrypt key: unsupported algorithm %s", d.keyalg)
225 }
226 }
227
228 func (d *Decrypter) DecryptKey(recipientKey []byte) (cek []byte, err error) {
229 if d.keyalg.IsSymmetric() {
230 var ok bool
231 cek, ok = d.privkey.([]byte)
232 if !ok {
233 return nil, errors.Errorf("decrypt key: []byte is required as the key to build %s key decrypter (got %T)", d.keyalg, d.privkey)
234 }
235
236 return d.decryptSymmetricKey(recipientKey, cek)
237 }
238
239 k, err := d.BuildKeyDecrypter()
240 if err != nil {
241 return nil, errors.Wrap(err, `failed to build key decrypter`)
242 }
243
244 cek, err = k.Decrypt(recipientKey)
245 if err != nil {
246 return nil, errors.Wrap(err, `failed to decrypt key`)
247 }
248
249 return cek, nil
250 }
251
252 func (d *Decrypter) BuildKeyDecrypter() (keyenc.Decrypter, error) {
253 cipher, err := d.ContentCipher()
254 if err != nil {
255 return nil, errors.Wrap(err, `failed to fetch content crypt cipher`)
256 }
257
258 switch alg := d.keyalg; alg {
259 case jwa.RSA1_5:
260 var privkey rsa.PrivateKey
261 if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil {
262 return nil, errors.Wrapf(err, "*rsa.PrivateKey is required as the key to build %s key decrypter", alg)
263 }
264
265 return keyenc.NewRSAPKCS15Decrypt(alg, &privkey, cipher.KeySize()/2), nil
266 case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
267 var privkey rsa.PrivateKey
268 if err := keyconv.RSAPrivateKey(&privkey, d.privkey); err != nil {
269 return nil, errors.Wrapf(err, "*rsa.PrivateKey is required as the key to build %s key decrypter", alg)
270 }
271
272 return keyenc.NewRSAOAEPDecrypt(alg, &privkey)
273 case jwa.A128KW, jwa.A192KW, jwa.A256KW:
274 sharedkey, ok := d.privkey.([]byte)
275 if !ok {
276 return nil, errors.Errorf("[]byte is required as the key to build %s key decrypter", alg)
277 }
278
279 return keyenc.NewAES(alg, sharedkey)
280 case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
281 switch d.pubkey.(type) {
282 case x25519.PublicKey:
283 return keyenc.NewECDHESDecrypt(alg, d.ctalg, d.pubkey, d.apu, d.apv, d.privkey), nil
284 default:
285 var pubkey ecdsa.PublicKey
286 if err := keyconv.ECDSAPublicKey(&pubkey, d.pubkey); err != nil {
287 return nil, errors.Wrapf(err, "*ecdsa.PublicKey is required as the key to build %s key decrypter", alg)
288 }
289
290 var privkey ecdsa.PrivateKey
291 if err := keyconv.ECDSAPrivateKey(&privkey, d.privkey); err != nil {
292 return nil, errors.Wrapf(err, "*ecdsa.PrivateKey is required as the key to build %s key decrypter", alg)
293 }
294
295 return keyenc.NewECDHESDecrypt(alg, d.ctalg, &pubkey, d.apu, d.apv, &privkey), nil
296 }
297 default:
298 return nil, errors.Errorf(`unsupported algorithm for key decryption (%s)`, alg)
299 }
300 }
301
View as plain text