1 package cipher
2
3 import (
4 "crypto/aes"
5 "crypto/cipher"
6 "fmt"
7
8 "github.com/lestrrat-go/jwx/jwa"
9 "github.com/lestrrat-go/jwx/jwe/internal/aescbc"
10 "github.com/lestrrat-go/jwx/jwe/internal/keygen"
11 "github.com/pkg/errors"
12 )
13
14 var gcm = &gcmFetcher{}
15 var cbc = &cbcFetcher{}
16
17 func (f gcmFetcher) Fetch(key []byte) (cipher.AEAD, error) {
18 aescipher, err := aes.NewCipher(key)
19 if err != nil {
20 return nil, errors.Wrap(err, "cipher: failed to create AES cipher for GCM")
21 }
22
23 aead, err := cipher.NewGCM(aescipher)
24 if err != nil {
25 return nil, errors.Wrap(err, `failed to create GCM for cipher`)
26 }
27 return aead, nil
28 }
29
30 func (f cbcFetcher) Fetch(key []byte) (cipher.AEAD, error) {
31 aead, err := aescbc.New(key, aes.NewCipher)
32 if err != nil {
33 return nil, errors.Wrap(err, "cipher: failed to create AES cipher for CBC")
34 }
35 return aead, nil
36 }
37
38 func (c AesContentCipher) KeySize() int {
39 return c.keysize
40 }
41
42 func (c AesContentCipher) TagSize() int {
43 return c.tagsize
44 }
45
46 func NewAES(alg jwa.ContentEncryptionAlgorithm) (*AesContentCipher, error) {
47 var keysize int
48 var tagsize int
49 var fetcher Fetcher
50 switch alg {
51 case jwa.A128GCM:
52 keysize = 16
53 tagsize = 16
54 fetcher = gcm
55 case jwa.A192GCM:
56 keysize = 24
57 tagsize = 16
58 fetcher = gcm
59 case jwa.A256GCM:
60 keysize = 32
61 tagsize = 16
62 fetcher = gcm
63 case jwa.A128CBC_HS256:
64 tagsize = 16
65 keysize = tagsize * 2
66 fetcher = cbc
67 case jwa.A192CBC_HS384:
68 tagsize = 24
69 keysize = tagsize * 2
70 fetcher = cbc
71 case jwa.A256CBC_HS512:
72 tagsize = 32
73 keysize = tagsize * 2
74 fetcher = cbc
75 default:
76 return nil, errors.Errorf("failed to create AES content cipher: invalid algorithm (%s)", alg)
77 }
78
79 return &AesContentCipher{
80 keysize: keysize,
81 tagsize: tagsize,
82 fetch: fetcher,
83 }, nil
84 }
85
86 func (c AesContentCipher) Encrypt(cek, plaintext, aad []byte) (iv, ciphertext, tag []byte, err error) {
87 var aead cipher.AEAD
88 aead, err = c.fetch.Fetch(cek)
89 if err != nil {
90 return nil, nil, nil, errors.Wrap(err, "failed to fetch AEAD")
91 }
92
93
94 defer func() {
95 if e := recover(); e != nil {
96 switch e := e.(type) {
97 case error:
98 err = e
99 default:
100 err = errors.Errorf("%s", e)
101 }
102 err = errors.Wrap(err, "failed to encrypt")
103 }
104 }()
105
106 var bs keygen.ByteSource
107 if c.NonceGenerator == nil {
108 bs, err = keygen.NewRandom(aead.NonceSize()).Generate()
109 } else {
110 bs, err = c.NonceGenerator.Generate()
111 }
112 if err != nil {
113 return nil, nil, nil, errors.Wrap(err, "failed to generate nonce")
114 }
115 iv = bs.Bytes()
116
117 combined := aead.Seal(nil, iv, plaintext, aad)
118 tagoffset := len(combined) - c.TagSize()
119
120 if tagoffset < 0 {
121 panic(fmt.Sprintf("tag offset is less than 0 (combined len = %d, tagsize = %d)", len(combined), c.TagSize()))
122 }
123
124 tag = combined[tagoffset:]
125 ciphertext = make([]byte, tagoffset)
126 copy(ciphertext, combined[:tagoffset])
127
128 return
129 }
130
131 func (c AesContentCipher) Decrypt(cek, iv, ciphertxt, tag, aad []byte) (plaintext []byte, err error) {
132 aead, err := c.fetch.Fetch(cek)
133 if err != nil {
134 return nil, errors.Wrap(err, "failed to fetch AEAD data")
135 }
136
137
138 defer func() {
139 if e := recover(); e != nil {
140 switch e := e.(type) {
141 case error:
142 err = e
143 default:
144 err = errors.Errorf("%s", e)
145 }
146 err = errors.Wrap(err, "failed to decrypt")
147 return
148 }
149 }()
150
151 combined := make([]byte, len(ciphertxt)+len(tag))
152 copy(combined, ciphertxt)
153 copy(combined[len(ciphertxt):], tag)
154
155 buf, aeaderr := aead.Open(nil, iv, combined, aad)
156 if aeaderr != nil {
157 err = errors.Wrap(aeaderr, `aead.Open failed`)
158 return
159 }
160 plaintext = buf
161 return
162 }
163
View as plain text