1 package keyenc
2
3 import (
4 "crypto"
5 "crypto/aes"
6 "crypto/cipher"
7 "crypto/ecdsa"
8 "crypto/rand"
9 "crypto/rsa"
10 "crypto/sha1"
11 "crypto/sha256"
12 "crypto/sha512"
13 "crypto/subtle"
14 "encoding/binary"
15 "fmt"
16 "hash"
17 "io"
18
19 "golang.org/x/crypto/curve25519"
20 "golang.org/x/crypto/pbkdf2"
21
22 "github.com/lestrrat-go/jwx/internal/ecutil"
23 "github.com/lestrrat-go/jwx/jwa"
24 contentcipher "github.com/lestrrat-go/jwx/jwe/internal/cipher"
25 "github.com/lestrrat-go/jwx/jwe/internal/concatkdf"
26 "github.com/lestrrat-go/jwx/jwe/internal/keygen"
27 "github.com/lestrrat-go/jwx/x25519"
28 "github.com/pkg/errors"
29 )
30
31 func NewNoop(alg jwa.KeyEncryptionAlgorithm, sharedkey []byte) (*Noop, error) {
32 return &Noop{
33 alg: alg,
34 sharedkey: sharedkey,
35 }, nil
36 }
37
38 func (kw *Noop) Algorithm() jwa.KeyEncryptionAlgorithm {
39 return kw.alg
40 }
41
42 func (kw *Noop) SetKeyID(v string) {
43 kw.keyID = v
44 }
45
46 func (kw *Noop) KeyID() string {
47 return kw.keyID
48 }
49
50 func (kw *Noop) Encrypt(cek []byte) (keygen.ByteSource, error) {
51 return keygen.ByteKey(kw.sharedkey), nil
52 }
53
54
55
56 func NewAES(alg jwa.KeyEncryptionAlgorithm, sharedkey []byte) (*AES, error) {
57 return &AES{
58 alg: alg,
59 sharedkey: sharedkey,
60 }, nil
61 }
62
63
64 func (kw *AES) Algorithm() jwa.KeyEncryptionAlgorithm {
65 return kw.alg
66 }
67
68 func (kw *AES) SetKeyID(v string) {
69 kw.keyID = v
70 }
71
72
73 func (kw *AES) KeyID() string {
74 return kw.keyID
75 }
76
77
78 func (kw *AES) Decrypt(enckey []byte) ([]byte, error) {
79 block, err := aes.NewCipher(kw.sharedkey)
80 if err != nil {
81 return nil, errors.Wrap(err, "failed to create cipher from shared key")
82 }
83
84 cek, err := Unwrap(block, enckey)
85 if err != nil {
86 return nil, errors.Wrap(err, "failed to unwrap data")
87 }
88 return cek, nil
89 }
90
91
92 func (kw *AES) Encrypt(cek []byte) (keygen.ByteSource, error) {
93 block, err := aes.NewCipher(kw.sharedkey)
94 if err != nil {
95 return nil, errors.Wrap(err, "failed to create cipher from shared key")
96 }
97 encrypted, err := Wrap(block, cek)
98 if err != nil {
99 return nil, errors.Wrap(err, `keywrap: failed to wrap key`)
100 }
101 return keygen.ByteKey(encrypted), nil
102 }
103
104 func NewAESGCMEncrypt(alg jwa.KeyEncryptionAlgorithm, sharedkey []byte) (*AESGCMEncrypt, error) {
105 return &AESGCMEncrypt{
106 algorithm: alg,
107 sharedkey: sharedkey,
108 }, nil
109 }
110
111 func (kw AESGCMEncrypt) Algorithm() jwa.KeyEncryptionAlgorithm {
112 return kw.algorithm
113 }
114
115 func (kw *AESGCMEncrypt) SetKeyID(v string) {
116 kw.keyID = v
117 }
118
119 func (kw AESGCMEncrypt) KeyID() string {
120 return kw.keyID
121 }
122
123 func (kw AESGCMEncrypt) Encrypt(cek []byte) (keygen.ByteSource, error) {
124 block, err := aes.NewCipher(kw.sharedkey)
125 if err != nil {
126 return nil, errors.Wrap(err, "failed to create cipher from shared key")
127 }
128 aesgcm, err := cipher.NewGCM(block)
129 if err != nil {
130 return nil, errors.Wrap(err, "failed to create gcm from cipher")
131 }
132
133 iv := make([]byte, aesgcm.NonceSize())
134 _, err = io.ReadFull(rand.Reader, iv)
135 if err != nil {
136 return nil, errors.Wrap(err, "failed to get random iv")
137 }
138
139 encrypted := aesgcm.Seal(nil, iv, cek, nil)
140 tag := encrypted[len(encrypted)-aesgcm.Overhead():]
141 ciphertext := encrypted[:len(encrypted)-aesgcm.Overhead()]
142 return keygen.ByteWithIVAndTag{
143 ByteKey: ciphertext,
144 IV: iv,
145 Tag: tag,
146 }, nil
147 }
148
149 func NewPBES2Encrypt(alg jwa.KeyEncryptionAlgorithm, password []byte) (*PBES2Encrypt, error) {
150 var hashFunc func() hash.Hash
151 var keylen int
152 switch alg {
153 case jwa.PBES2_HS256_A128KW:
154 hashFunc = sha256.New
155 keylen = 16
156 case jwa.PBES2_HS384_A192KW:
157 hashFunc = sha512.New384
158 keylen = 24
159 case jwa.PBES2_HS512_A256KW:
160 hashFunc = sha512.New
161 keylen = 32
162 default:
163 return nil, errors.Errorf("unexpected key encryption algorithm %s", alg)
164 }
165 return &PBES2Encrypt{
166 algorithm: alg,
167 password: password,
168 hashFunc: hashFunc,
169 keylen: keylen,
170 }, nil
171 }
172
173 func (kw PBES2Encrypt) Algorithm() jwa.KeyEncryptionAlgorithm {
174 return kw.algorithm
175 }
176
177 func (kw *PBES2Encrypt) SetKeyID(v string) {
178 kw.keyID = v
179 }
180
181 func (kw PBES2Encrypt) KeyID() string {
182 return kw.keyID
183 }
184
185 func (kw PBES2Encrypt) Encrypt(cek []byte) (keygen.ByteSource, error) {
186 count := 10000
187 salt := make([]byte, kw.keylen)
188 _, err := io.ReadFull(rand.Reader, salt)
189 if err != nil {
190 return nil, errors.Wrap(err, "failed to get random salt")
191 }
192
193 fullsalt := []byte(kw.algorithm)
194 fullsalt = append(fullsalt, byte(0))
195 fullsalt = append(fullsalt, salt...)
196 sharedkey := pbkdf2.Key(kw.password, fullsalt, count, kw.keylen, kw.hashFunc)
197
198 block, err := aes.NewCipher(sharedkey)
199 if err != nil {
200 return nil, errors.Wrap(err, "failed to create cipher from shared key")
201 }
202 encrypted, err := Wrap(block, cek)
203 if err != nil {
204 return nil, errors.Wrap(err, `keywrap: failed to wrap key`)
205 }
206 return keygen.ByteWithSaltAndCount{
207 ByteKey: encrypted,
208 Salt: salt,
209 Count: count,
210 }, nil
211 }
212
213
214 func NewECDHESEncrypt(alg jwa.KeyEncryptionAlgorithm, enc jwa.ContentEncryptionAlgorithm, keysize int, keyif interface{}) (*ECDHESEncrypt, error) {
215 var generator keygen.Generator
216 var err error
217 switch key := keyif.(type) {
218 case *ecdsa.PublicKey:
219 generator, err = keygen.NewEcdhes(alg, enc, keysize, key)
220 case x25519.PublicKey:
221 generator, err = keygen.NewX25519(alg, enc, keysize, key)
222 default:
223 return nil, errors.Errorf("unexpected key type %T", keyif)
224 }
225 if err != nil {
226 return nil, errors.Wrap(err, "failed to create key generator")
227 }
228 return &ECDHESEncrypt{
229 algorithm: alg,
230 generator: generator,
231 }, nil
232 }
233
234
235 func (kw ECDHESEncrypt) Algorithm() jwa.KeyEncryptionAlgorithm {
236 return kw.algorithm
237 }
238
239 func (kw *ECDHESEncrypt) SetKeyID(v string) {
240 kw.keyID = v
241 }
242
243
244 func (kw ECDHESEncrypt) KeyID() string {
245 return kw.keyID
246 }
247
248
249 func (kw ECDHESEncrypt) Encrypt(cek []byte) (keygen.ByteSource, error) {
250 kg, err := kw.generator.Generate()
251 if err != nil {
252 return nil, errors.Wrap(err, "failed to create key generator")
253 }
254
255 bwpk, ok := kg.(keygen.ByteWithECPublicKey)
256 if !ok {
257 return nil, errors.New("key generator generated invalid key (expected ByteWithECPrivateKey)")
258 }
259
260 if kw.algorithm == jwa.ECDH_ES {
261 return bwpk, nil
262 }
263
264 block, err := aes.NewCipher(bwpk.Bytes())
265 if err != nil {
266 return nil, errors.Wrap(err, "failed to generate cipher from generated key")
267 }
268
269 jek, err := Wrap(block, cek)
270 if err != nil {
271 return nil, errors.Wrap(err, "failed to wrap data")
272 }
273
274 bwpk.ByteKey = keygen.ByteKey(jek)
275
276 return bwpk, nil
277 }
278
279
280 func NewECDHESDecrypt(keyalg jwa.KeyEncryptionAlgorithm, contentalg jwa.ContentEncryptionAlgorithm, pubkey interface{}, apu, apv []byte, privkey interface{}) *ECDHESDecrypt {
281 return &ECDHESDecrypt{
282 keyalg: keyalg,
283 contentalg: contentalg,
284 apu: apu,
285 apv: apv,
286 privkey: privkey,
287 pubkey: pubkey,
288 }
289 }
290
291
292 func (kw ECDHESDecrypt) Algorithm() jwa.KeyEncryptionAlgorithm {
293 return kw.keyalg
294 }
295
296 func DeriveZ(privkeyif interface{}, pubkeyif interface{}) ([]byte, error) {
297 switch privkeyif.(type) {
298 case x25519.PrivateKey:
299 privkey, ok := privkeyif.(x25519.PrivateKey)
300 if !ok {
301 return nil, errors.Errorf(`private key must be x25519.PrivateKey, was: %T`, privkeyif)
302 }
303 pubkey, ok := pubkeyif.(x25519.PublicKey)
304 if !ok {
305 return nil, errors.Errorf(`public key must be x25519.PublicKey, was: %T`, pubkeyif)
306 }
307 return curve25519.X25519(privkey.Seed(), pubkey)
308 default:
309 privkey, ok := privkeyif.(*ecdsa.PrivateKey)
310 if !ok {
311 return nil, errors.Errorf(`private key must be *ecdsa.PrivateKey, was: %T`, privkeyif)
312 }
313 pubkey, ok := pubkeyif.(*ecdsa.PublicKey)
314 if !ok {
315 return nil, errors.Errorf(`public key must be *ecdsa.PublicKey, was: %T`, pubkeyif)
316 }
317 if !privkey.PublicKey.Curve.IsOnCurve(pubkey.X, pubkey.Y) {
318 return nil, errors.New(`public key must be on the same curve as private key`)
319 }
320
321 z, _ := privkey.PublicKey.Curve.ScalarMult(pubkey.X, pubkey.Y, privkey.D.Bytes())
322 zBytes := ecutil.AllocECPointBuffer(z, privkey.Curve)
323 defer ecutil.ReleaseECPointBuffer(zBytes)
324 zCopy := make([]byte, len(zBytes))
325 copy(zCopy, zBytes)
326 return zCopy, nil
327 }
328 }
329
330 func DeriveECDHES(alg, apu, apv []byte, privkey interface{}, pubkey interface{}, keysize uint32) ([]byte, error) {
331 pubinfo := make([]byte, 4)
332 binary.BigEndian.PutUint32(pubinfo, keysize*8)
333 zBytes, err := DeriveZ(privkey, pubkey)
334 if err != nil {
335 return nil, errors.Wrap(err, "unable to determine Z")
336 }
337 kdf := concatkdf.New(crypto.SHA256, alg, zBytes, apu, apv, pubinfo, []byte{})
338 key := make([]byte, keysize)
339 if _, err := kdf.Read(key); err != nil {
340 return nil, errors.Wrap(err, "failed to read kdf")
341 }
342
343 return key, nil
344 }
345
346
347 func (kw ECDHESDecrypt) Decrypt(enckey []byte) ([]byte, error) {
348 var algBytes []byte
349 var keysize uint32
350
351
352 algBytes = []byte(kw.keyalg.String())
353
354 switch kw.keyalg {
355 case jwa.ECDH_ES:
356
357 c, err := contentcipher.NewAES(kw.contentalg)
358 if err != nil {
359 return nil, errors.Wrapf(err, `failed to create content cipher for %s`, kw.contentalg)
360 }
361 keysize = uint32(c.KeySize())
362 algBytes = []byte(kw.contentalg.String())
363 case jwa.ECDH_ES_A128KW:
364 keysize = 16
365 case jwa.ECDH_ES_A192KW:
366 keysize = 24
367 case jwa.ECDH_ES_A256KW:
368 keysize = 32
369 default:
370 return nil, errors.Errorf("invalid ECDH-ES key wrap algorithm (%s)", kw.keyalg)
371 }
372
373 key, err := DeriveECDHES(algBytes, kw.apu, kw.apv, kw.privkey, kw.pubkey, keysize)
374 if err != nil {
375 return nil, errors.Wrap(err, `failed to derive ECDHES encryption key`)
376 }
377
378
379 if kw.keyalg == jwa.ECDH_ES {
380 return key, nil
381 }
382
383 block, err := aes.NewCipher(key)
384 if err != nil {
385 return nil, errors.Wrap(err, "failed to create cipher for ECDH-ES key wrap")
386 }
387
388 return Unwrap(block, enckey)
389 }
390
391
392 func NewRSAOAEPEncrypt(alg jwa.KeyEncryptionAlgorithm, pubkey *rsa.PublicKey) (*RSAOAEPEncrypt, error) {
393 switch alg {
394 case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
395 default:
396 return nil, errors.Errorf("invalid RSA OAEP encrypt algorithm (%s)", alg)
397 }
398 return &RSAOAEPEncrypt{
399 alg: alg,
400 pubkey: pubkey,
401 }, nil
402 }
403
404
405 func NewRSAPKCSEncrypt(alg jwa.KeyEncryptionAlgorithm, pubkey *rsa.PublicKey) (*RSAPKCSEncrypt, error) {
406 switch alg {
407 case jwa.RSA1_5:
408 default:
409 return nil, errors.Errorf("invalid RSA PKCS encrypt algorithm (%s)", alg)
410 }
411
412 return &RSAPKCSEncrypt{
413 alg: alg,
414 pubkey: pubkey,
415 }, nil
416 }
417
418
419 func (e RSAPKCSEncrypt) Algorithm() jwa.KeyEncryptionAlgorithm {
420 return e.alg
421 }
422
423 func (e *RSAPKCSEncrypt) SetKeyID(v string) {
424 e.keyID = v
425 }
426
427
428 func (e RSAPKCSEncrypt) KeyID() string {
429 return e.keyID
430 }
431
432
433 func (e RSAOAEPEncrypt) Algorithm() jwa.KeyEncryptionAlgorithm {
434 return e.alg
435 }
436
437 func (e *RSAOAEPEncrypt) SetKeyID(v string) {
438 e.keyID = v
439 }
440
441
442 func (e RSAOAEPEncrypt) KeyID() string {
443 return e.keyID
444 }
445
446
447 func (e RSAPKCSEncrypt) Encrypt(cek []byte) (keygen.ByteSource, error) {
448 if e.alg != jwa.RSA1_5 {
449 return nil, errors.Errorf("invalid RSA PKCS encrypt algorithm (%s)", e.alg)
450 }
451 encrypted, err := rsa.EncryptPKCS1v15(rand.Reader, e.pubkey, cek)
452 if err != nil {
453 return nil, errors.Wrap(err, "failed to encrypt using PKCS1v15")
454 }
455 return keygen.ByteKey(encrypted), nil
456 }
457
458
459 func (e RSAOAEPEncrypt) Encrypt(cek []byte) (keygen.ByteSource, error) {
460 var hash hash.Hash
461 switch e.alg {
462 case jwa.RSA_OAEP:
463 hash = sha1.New()
464 case jwa.RSA_OAEP_256:
465 hash = sha256.New()
466 default:
467 return nil, errors.New("failed to generate key encrypter for RSA-OAEP: RSA_OAEP/RSA_OAEP_256 required")
468 }
469 encrypted, err := rsa.EncryptOAEP(hash, rand.Reader, e.pubkey, cek, []byte{})
470 if err != nil {
471 return nil, errors.Wrap(err, `failed to OAEP encrypt`)
472 }
473 return keygen.ByteKey(encrypted), nil
474 }
475
476
477 func NewRSAPKCS15Decrypt(alg jwa.KeyEncryptionAlgorithm, privkey *rsa.PrivateKey, keysize int) *RSAPKCS15Decrypt {
478 generator := keygen.NewRandom(keysize * 2)
479 return &RSAPKCS15Decrypt{
480 alg: alg,
481 privkey: privkey,
482 generator: generator,
483 }
484 }
485
486
487 func (d RSAPKCS15Decrypt) Algorithm() jwa.KeyEncryptionAlgorithm {
488 return d.alg
489 }
490
491
492 func (d RSAPKCS15Decrypt) Decrypt(enckey []byte) ([]byte, error) {
493
494 defer func() {
495
496
497
498
499
500
501 _ = recover()
502 }()
503
504
505 expectedlen := d.privkey.PublicKey.N.BitLen() / 8
506 if expectedlen != len(enckey) {
507
508
509
510 return nil, fmt.Errorf(
511 "input size for key decrypt is incorrect (expected %d, got %d)",
512 expectedlen,
513 len(enckey),
514 )
515 }
516
517 var err error
518
519 bk, err := d.generator.Generate()
520 if err != nil {
521 return nil, errors.New("failed to generate key")
522 }
523 cek := bk.Bytes()
524
525
526
527
528
529 err = rsa.DecryptPKCS1v15SessionKey(rand.Reader, d.privkey, enckey, cek)
530 if err != nil {
531 return nil, errors.Wrap(err, "failed to decrypt via PKCS1v15")
532 }
533
534 return cek, nil
535 }
536
537
538 func NewRSAOAEPDecrypt(alg jwa.KeyEncryptionAlgorithm, privkey *rsa.PrivateKey) (*RSAOAEPDecrypt, error) {
539 switch alg {
540 case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
541 default:
542 return nil, errors.Errorf("invalid RSA OAEP decrypt algorithm (%s)", alg)
543 }
544
545 return &RSAOAEPDecrypt{
546 alg: alg,
547 privkey: privkey,
548 }, nil
549 }
550
551
552 func (d RSAOAEPDecrypt) Algorithm() jwa.KeyEncryptionAlgorithm {
553 return d.alg
554 }
555
556
557 func (d RSAOAEPDecrypt) Decrypt(enckey []byte) ([]byte, error) {
558 var hash hash.Hash
559 switch d.alg {
560 case jwa.RSA_OAEP:
561 hash = sha1.New()
562 case jwa.RSA_OAEP_256:
563 hash = sha256.New()
564 default:
565 return nil, errors.New("failed to generate key encrypter for RSA-OAEP: RSA_OAEP/RSA_OAEP_256 required")
566 }
567 return rsa.DecryptOAEP(hash, rand.Reader, d.privkey, enckey, []byte{})
568 }
569
570
571
572 func (d DirectDecrypt) Decrypt() ([]byte, error) {
573 cek := make([]byte, len(d.Key))
574 copy(cek, d.Key)
575 return cek, nil
576 }
577
578 var keywrapDefaultIV = []byte{0xa6, 0xa6, 0xa6, 0xa6, 0xa6, 0xa6, 0xa6, 0xa6}
579
580 const keywrapChunkLen = 8
581
582 func Wrap(kek cipher.Block, cek []byte) ([]byte, error) {
583 if len(cek)%8 != 0 {
584 return nil, errors.New(`keywrap input must be 8 byte blocks`)
585 }
586
587 n := len(cek) / keywrapChunkLen
588 r := make([][]byte, n)
589
590 for i := 0; i < n; i++ {
591 r[i] = make([]byte, keywrapChunkLen)
592 copy(r[i], cek[i*keywrapChunkLen:])
593 }
594
595 buffer := make([]byte, keywrapChunkLen*2)
596 tBytes := make([]byte, keywrapChunkLen)
597 copy(buffer, keywrapDefaultIV)
598
599 for t := 0; t < 6*n; t++ {
600 copy(buffer[keywrapChunkLen:], r[t%n])
601
602 kek.Encrypt(buffer, buffer)
603
604 binary.BigEndian.PutUint64(tBytes, uint64(t+1))
605
606 for i := 0; i < keywrapChunkLen; i++ {
607 buffer[i] = buffer[i] ^ tBytes[i]
608 }
609 copy(r[t%n], buffer[keywrapChunkLen:])
610 }
611
612 out := make([]byte, (n+1)*keywrapChunkLen)
613 copy(out, buffer[:keywrapChunkLen])
614 for i := range r {
615 copy(out[(i+1)*8:], r[i])
616 }
617
618 return out, nil
619 }
620
621 func Unwrap(block cipher.Block, ciphertxt []byte) ([]byte, error) {
622 if len(ciphertxt)%keywrapChunkLen != 0 {
623 return nil, errors.Errorf(`keyunwrap input must be %d byte blocks`, keywrapChunkLen)
624 }
625
626 n := (len(ciphertxt) / keywrapChunkLen) - 1
627 r := make([][]byte, n)
628
629 for i := range r {
630 r[i] = make([]byte, keywrapChunkLen)
631 copy(r[i], ciphertxt[(i+1)*keywrapChunkLen:])
632 }
633
634 buffer := make([]byte, keywrapChunkLen*2)
635 tBytes := make([]byte, keywrapChunkLen)
636 copy(buffer[:keywrapChunkLen], ciphertxt[:keywrapChunkLen])
637
638 for t := 6*n - 1; t >= 0; t-- {
639 binary.BigEndian.PutUint64(tBytes, uint64(t+1))
640
641 for i := 0; i < keywrapChunkLen; i++ {
642 buffer[i] = buffer[i] ^ tBytes[i]
643 }
644 copy(buffer[keywrapChunkLen:], r[t%n])
645
646 block.Decrypt(buffer, buffer)
647
648 copy(r[t%n], buffer[keywrapChunkLen:])
649 }
650
651 if subtle.ConstantTimeCompare(buffer[:keywrapChunkLen], keywrapDefaultIV) == 0 {
652 return nil, errors.New("key unwrap: failed to unwrap key")
653 }
654
655 out := make([]byte, n*keywrapChunkLen)
656 for i := range r {
657 copy(out[i*keywrapChunkLen:], r[i])
658 }
659
660 return out, nil
661 }
662
View as plain text