1
16
17 package jose
18
19 import (
20 "crypto/ecdsa"
21 "crypto/rsa"
22 "errors"
23 "fmt"
24 "reflect"
25
26 "gopkg.in/square/go-jose.v2/json"
27 )
28
29
30 type Encrypter interface {
31 Encrypt(plaintext []byte) (*JSONWebEncryption, error)
32 EncryptWithAuthData(plaintext []byte, aad []byte) (*JSONWebEncryption, error)
33 Options() EncrypterOptions
34 }
35
36
37 type contentCipher interface {
38 keySize() int
39 encrypt(cek []byte, aad, plaintext []byte) (*aeadParts, error)
40 decrypt(cek []byte, aad []byte, parts *aeadParts) ([]byte, error)
41 }
42
43
44 type keyGenerator interface {
45 keySize() int
46 genKey() ([]byte, rawHeader, error)
47 }
48
49
50 type keyEncrypter interface {
51 encryptKey(cek []byte, alg KeyAlgorithm) (recipientInfo, error)
52 }
53
54
55 type keyDecrypter interface {
56 decryptKey(headers rawHeader, recipient *recipientInfo, generator keyGenerator) ([]byte, error)
57 }
58
59
60 type genericEncrypter struct {
61 contentAlg ContentEncryption
62 compressionAlg CompressionAlgorithm
63 cipher contentCipher
64 recipients []recipientKeyInfo
65 keyGenerator keyGenerator
66 extraHeaders map[HeaderKey]interface{}
67 }
68
69 type recipientKeyInfo struct {
70 keyID string
71 keyAlg KeyAlgorithm
72 keyEncrypter keyEncrypter
73 }
74
75
76 type EncrypterOptions struct {
77 Compression CompressionAlgorithm
78
79
80
81
82 ExtraHeaders map[HeaderKey]interface{}
83 }
84
85
86
87 func (eo *EncrypterOptions) WithHeader(k HeaderKey, v interface{}) *EncrypterOptions {
88 if eo.ExtraHeaders == nil {
89 eo.ExtraHeaders = map[HeaderKey]interface{}{}
90 }
91 eo.ExtraHeaders[k] = v
92 return eo
93 }
94
95
96
97 func (eo *EncrypterOptions) WithContentType(contentType ContentType) *EncrypterOptions {
98 return eo.WithHeader(HeaderContentType, contentType)
99 }
100
101
102 func (eo *EncrypterOptions) WithType(typ ContentType) *EncrypterOptions {
103 return eo.WithHeader(HeaderType, typ)
104 }
105
106
107
108
109
110
111
112
113 type Recipient struct {
114 Algorithm KeyAlgorithm
115 Key interface{}
116 KeyID string
117 PBES2Count int
118 PBES2Salt []byte
119 }
120
121
122 func NewEncrypter(enc ContentEncryption, rcpt Recipient, opts *EncrypterOptions) (Encrypter, error) {
123 encrypter := &genericEncrypter{
124 contentAlg: enc,
125 recipients: []recipientKeyInfo{},
126 cipher: getContentCipher(enc),
127 }
128 if opts != nil {
129 encrypter.compressionAlg = opts.Compression
130 encrypter.extraHeaders = opts.ExtraHeaders
131 }
132
133 if encrypter.cipher == nil {
134 return nil, ErrUnsupportedAlgorithm
135 }
136
137 var keyID string
138 var rawKey interface{}
139 switch encryptionKey := rcpt.Key.(type) {
140 case JSONWebKey:
141 keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
142 case *JSONWebKey:
143 keyID, rawKey = encryptionKey.KeyID, encryptionKey.Key
144 case OpaqueKeyEncrypter:
145 keyID, rawKey = encryptionKey.KeyID(), encryptionKey
146 default:
147 rawKey = encryptionKey
148 }
149
150 switch rcpt.Algorithm {
151 case DIRECT:
152
153 if reflect.TypeOf(rawKey) != reflect.TypeOf([]byte{}) {
154 return nil, ErrUnsupportedKeyType
155 }
156 if encrypter.cipher.keySize() != len(rawKey.([]byte)) {
157 return nil, ErrInvalidKeySize
158 }
159 encrypter.keyGenerator = staticKeyGenerator{
160 key: rawKey.([]byte),
161 }
162 recipientInfo, _ := newSymmetricRecipient(rcpt.Algorithm, rawKey.([]byte))
163 recipientInfo.keyID = keyID
164 if rcpt.KeyID != "" {
165 recipientInfo.keyID = rcpt.KeyID
166 }
167 encrypter.recipients = []recipientKeyInfo{recipientInfo}
168 return encrypter, nil
169 case ECDH_ES:
170
171 typeOf := reflect.TypeOf(rawKey)
172 if typeOf != reflect.TypeOf(&ecdsa.PublicKey{}) {
173 return nil, ErrUnsupportedKeyType
174 }
175 encrypter.keyGenerator = ecKeyGenerator{
176 size: encrypter.cipher.keySize(),
177 algID: string(enc),
178 publicKey: rawKey.(*ecdsa.PublicKey),
179 }
180 recipientInfo, _ := newECDHRecipient(rcpt.Algorithm, rawKey.(*ecdsa.PublicKey))
181 recipientInfo.keyID = keyID
182 if rcpt.KeyID != "" {
183 recipientInfo.keyID = rcpt.KeyID
184 }
185 encrypter.recipients = []recipientKeyInfo{recipientInfo}
186 return encrypter, nil
187 default:
188
189 encrypter.keyGenerator = randomKeyGenerator{
190 size: encrypter.cipher.keySize(),
191 }
192 err := encrypter.addRecipient(rcpt)
193 return encrypter, err
194 }
195 }
196
197
198 func NewMultiEncrypter(enc ContentEncryption, rcpts []Recipient, opts *EncrypterOptions) (Encrypter, error) {
199 cipher := getContentCipher(enc)
200
201 if cipher == nil {
202 return nil, ErrUnsupportedAlgorithm
203 }
204 if rcpts == nil || len(rcpts) == 0 {
205 return nil, fmt.Errorf("square/go-jose: recipients is nil or empty")
206 }
207
208 encrypter := &genericEncrypter{
209 contentAlg: enc,
210 recipients: []recipientKeyInfo{},
211 cipher: cipher,
212 keyGenerator: randomKeyGenerator{
213 size: cipher.keySize(),
214 },
215 }
216
217 if opts != nil {
218 encrypter.compressionAlg = opts.Compression
219 encrypter.extraHeaders = opts.ExtraHeaders
220 }
221
222 for _, recipient := range rcpts {
223 err := encrypter.addRecipient(recipient)
224 if err != nil {
225 return nil, err
226 }
227 }
228
229 return encrypter, nil
230 }
231
232 func (ctx *genericEncrypter) addRecipient(recipient Recipient) (err error) {
233 var recipientInfo recipientKeyInfo
234
235 switch recipient.Algorithm {
236 case DIRECT, ECDH_ES:
237 return fmt.Errorf("square/go-jose: key algorithm '%s' not supported in multi-recipient mode", recipient.Algorithm)
238 }
239
240 recipientInfo, err = makeJWERecipient(recipient.Algorithm, recipient.Key)
241 if recipient.KeyID != "" {
242 recipientInfo.keyID = recipient.KeyID
243 }
244
245 switch recipient.Algorithm {
246 case PBES2_HS256_A128KW, PBES2_HS384_A192KW, PBES2_HS512_A256KW:
247 if sr, ok := recipientInfo.keyEncrypter.(*symmetricKeyCipher); ok {
248 sr.p2c = recipient.PBES2Count
249 sr.p2s = recipient.PBES2Salt
250 }
251 }
252
253 if err == nil {
254 ctx.recipients = append(ctx.recipients, recipientInfo)
255 }
256 return err
257 }
258
259 func makeJWERecipient(alg KeyAlgorithm, encryptionKey interface{}) (recipientKeyInfo, error) {
260 switch encryptionKey := encryptionKey.(type) {
261 case *rsa.PublicKey:
262 return newRSARecipient(alg, encryptionKey)
263 case *ecdsa.PublicKey:
264 return newECDHRecipient(alg, encryptionKey)
265 case []byte:
266 return newSymmetricRecipient(alg, encryptionKey)
267 case string:
268 return newSymmetricRecipient(alg, []byte(encryptionKey))
269 case *JSONWebKey:
270 recipient, err := makeJWERecipient(alg, encryptionKey.Key)
271 recipient.keyID = encryptionKey.KeyID
272 return recipient, err
273 }
274 if encrypter, ok := encryptionKey.(OpaqueKeyEncrypter); ok {
275 return newOpaqueKeyEncrypter(alg, encrypter)
276 }
277 return recipientKeyInfo{}, ErrUnsupportedKeyType
278 }
279
280
281 func newDecrypter(decryptionKey interface{}) (keyDecrypter, error) {
282 switch decryptionKey := decryptionKey.(type) {
283 case *rsa.PrivateKey:
284 return &rsaDecrypterSigner{
285 privateKey: decryptionKey,
286 }, nil
287 case *ecdsa.PrivateKey:
288 return &ecDecrypterSigner{
289 privateKey: decryptionKey,
290 }, nil
291 case []byte:
292 return &symmetricKeyCipher{
293 key: decryptionKey,
294 }, nil
295 case string:
296 return &symmetricKeyCipher{
297 key: []byte(decryptionKey),
298 }, nil
299 case JSONWebKey:
300 return newDecrypter(decryptionKey.Key)
301 case *JSONWebKey:
302 return newDecrypter(decryptionKey.Key)
303 }
304 if okd, ok := decryptionKey.(OpaqueKeyDecrypter); ok {
305 return &opaqueKeyDecrypter{decrypter: okd}, nil
306 }
307 return nil, ErrUnsupportedKeyType
308 }
309
310
311 func (ctx *genericEncrypter) Encrypt(plaintext []byte) (*JSONWebEncryption, error) {
312 return ctx.EncryptWithAuthData(plaintext, nil)
313 }
314
315
316 func (ctx *genericEncrypter) EncryptWithAuthData(plaintext, aad []byte) (*JSONWebEncryption, error) {
317 obj := &JSONWebEncryption{}
318 obj.aad = aad
319
320 obj.protected = &rawHeader{}
321 err := obj.protected.set(headerEncryption, ctx.contentAlg)
322 if err != nil {
323 return nil, err
324 }
325
326 obj.recipients = make([]recipientInfo, len(ctx.recipients))
327
328 if len(ctx.recipients) == 0 {
329 return nil, fmt.Errorf("square/go-jose: no recipients to encrypt to")
330 }
331
332 cek, headers, err := ctx.keyGenerator.genKey()
333 if err != nil {
334 return nil, err
335 }
336
337 obj.protected.merge(&headers)
338
339 for i, info := range ctx.recipients {
340 recipient, err := info.keyEncrypter.encryptKey(cek, info.keyAlg)
341 if err != nil {
342 return nil, err
343 }
344
345 err = recipient.header.set(headerAlgorithm, info.keyAlg)
346 if err != nil {
347 return nil, err
348 }
349
350 if info.keyID != "" {
351 err = recipient.header.set(headerKeyID, info.keyID)
352 if err != nil {
353 return nil, err
354 }
355 }
356 obj.recipients[i] = recipient
357 }
358
359 if len(ctx.recipients) == 1 {
360
361
362 obj.protected.merge(obj.recipients[0].header)
363 obj.recipients[0].header = nil
364 }
365
366 if ctx.compressionAlg != NONE {
367 plaintext, err = compress(ctx.compressionAlg, plaintext)
368 if err != nil {
369 return nil, err
370 }
371
372 err = obj.protected.set(headerCompression, ctx.compressionAlg)
373 if err != nil {
374 return nil, err
375 }
376 }
377
378 for k, v := range ctx.extraHeaders {
379 b, err := json.Marshal(v)
380 if err != nil {
381 return nil, err
382 }
383 (*obj.protected)[k] = makeRawMessage(b)
384 }
385
386 authData := obj.computeAuthData()
387 parts, err := ctx.cipher.encrypt(cek, authData, plaintext)
388 if err != nil {
389 return nil, err
390 }
391
392 obj.iv = parts.iv
393 obj.ciphertext = parts.ciphertext
394 obj.tag = parts.tag
395
396 return obj, nil
397 }
398
399 func (ctx *genericEncrypter) Options() EncrypterOptions {
400 return EncrypterOptions{
401 Compression: ctx.compressionAlg,
402 ExtraHeaders: ctx.extraHeaders,
403 }
404 }
405
406
407
408
409 func (obj JSONWebEncryption) Decrypt(decryptionKey interface{}) ([]byte, error) {
410 headers := obj.mergedHeaders(nil)
411
412 if len(obj.recipients) > 1 {
413 return nil, errors.New("square/go-jose: too many recipients in payload; expecting only one")
414 }
415
416 critical, err := headers.getCritical()
417 if err != nil {
418 return nil, fmt.Errorf("square/go-jose: invalid crit header")
419 }
420
421 if len(critical) > 0 {
422 return nil, fmt.Errorf("square/go-jose: unsupported crit header")
423 }
424
425 decrypter, err := newDecrypter(decryptionKey)
426 if err != nil {
427 return nil, err
428 }
429
430 cipher := getContentCipher(headers.getEncryption())
431 if cipher == nil {
432 return nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(headers.getEncryption()))
433 }
434
435 generator := randomKeyGenerator{
436 size: cipher.keySize(),
437 }
438
439 parts := &aeadParts{
440 iv: obj.iv,
441 ciphertext: obj.ciphertext,
442 tag: obj.tag,
443 }
444
445 authData := obj.computeAuthData()
446
447 var plaintext []byte
448 recipient := obj.recipients[0]
449 recipientHeaders := obj.mergedHeaders(&recipient)
450
451 cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
452 if err == nil {
453
454 plaintext, err = cipher.decrypt(cek, authData, parts)
455 }
456
457 if plaintext == nil {
458 return nil, ErrCryptoFailure
459 }
460
461
462 if comp := obj.protected.getCompression(); comp != "" {
463 plaintext, err = decompress(comp, plaintext)
464 }
465
466 return plaintext, err
467 }
468
469
470
471
472
473 func (obj JSONWebEncryption) DecryptMulti(decryptionKey interface{}) (int, Header, []byte, error) {
474 globalHeaders := obj.mergedHeaders(nil)
475
476 critical, err := globalHeaders.getCritical()
477 if err != nil {
478 return -1, Header{}, nil, fmt.Errorf("square/go-jose: invalid crit header")
479 }
480
481 if len(critical) > 0 {
482 return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported crit header")
483 }
484
485 decrypter, err := newDecrypter(decryptionKey)
486 if err != nil {
487 return -1, Header{}, nil, err
488 }
489
490 encryption := globalHeaders.getEncryption()
491 cipher := getContentCipher(encryption)
492 if cipher == nil {
493 return -1, Header{}, nil, fmt.Errorf("square/go-jose: unsupported enc value '%s'", string(encryption))
494 }
495
496 generator := randomKeyGenerator{
497 size: cipher.keySize(),
498 }
499
500 parts := &aeadParts{
501 iv: obj.iv,
502 ciphertext: obj.ciphertext,
503 tag: obj.tag,
504 }
505
506 authData := obj.computeAuthData()
507
508 index := -1
509 var plaintext []byte
510 var headers rawHeader
511
512 for i, recipient := range obj.recipients {
513 recipientHeaders := obj.mergedHeaders(&recipient)
514
515 cek, err := decrypter.decryptKey(recipientHeaders, &recipient, generator)
516 if err == nil {
517
518 plaintext, err = cipher.decrypt(cek, authData, parts)
519 if err == nil {
520 index = i
521 headers = recipientHeaders
522 break
523 }
524 }
525 }
526
527 if plaintext == nil || err != nil {
528 return -1, Header{}, nil, ErrCryptoFailure
529 }
530
531
532 if comp := obj.protected.getCompression(); comp != "" {
533 plaintext, err = decompress(comp, plaintext)
534 }
535
536 sanitized, err := headers.sanitized()
537 if err != nil {
538 return -1, Header{}, nil, fmt.Errorf("square/go-jose: failed to sanitize header: %v", err)
539 }
540
541 return index, sanitized, plaintext, err
542 }
543
View as plain text