1
2
3
4 package jwe
5
6 import (
7 "bytes"
8 "crypto/ecdsa"
9 "crypto/rsa"
10 "io"
11 "io/ioutil"
12
13 "github.com/lestrrat-go/jwx/internal/base64"
14 "github.com/lestrrat-go/jwx/internal/json"
15 "github.com/lestrrat-go/jwx/internal/keyconv"
16 "github.com/lestrrat-go/jwx/jwk"
17
18 "github.com/lestrrat-go/jwx/jwa"
19 "github.com/lestrrat-go/jwx/jwe/internal/content_crypt"
20 "github.com/lestrrat-go/jwx/jwe/internal/keyenc"
21 "github.com/lestrrat-go/jwx/jwe/internal/keygen"
22 "github.com/lestrrat-go/jwx/x25519"
23 "github.com/pkg/errors"
24 )
25
26 var registry = json.NewRegistry()
27
28
29
30
31
32 func Encrypt(payload []byte, keyalg jwa.KeyEncryptionAlgorithm, key interface{}, contentalg jwa.ContentEncryptionAlgorithm, compressalg jwa.CompressionAlgorithm, options ...EncryptOption) ([]byte, error) {
33 var protected Headers
34 for _, option := range options {
35
36 switch option.Ident() {
37 case identProtectedHeader{}:
38 protected = option.Value().(Headers)
39 }
40 }
41 if protected == nil {
42 protected = NewHeaders()
43 }
44
45 contentcrypt, err := content_crypt.NewGeneric(contentalg)
46 if err != nil {
47 return nil, errors.Wrap(err, `failed to create AES encrypter`)
48 }
49
50 var keyID string
51 if jwkKey, ok := key.(jwk.Key); ok {
52 keyID = jwkKey.KeyID()
53
54 var raw interface{}
55 if err := jwkKey.Raw(&raw); err != nil {
56 return nil, errors.Wrapf(err, `failed to retrieve raw key out of %T`, key)
57 }
58
59 key = raw
60 }
61
62 var enc keyenc.Encrypter
63 switch keyalg {
64 case jwa.RSA1_5:
65 var pubkey rsa.PublicKey
66 if err := keyconv.RSAPublicKey(&pubkey, key); err != nil {
67 return nil, errors.Wrapf(err, "failed to generate public key from key (%T)", key)
68 }
69
70 enc, err = keyenc.NewRSAPKCSEncrypt(keyalg, &pubkey)
71 if err != nil {
72 return nil, errors.Wrap(err, "failed to create RSA PKCS encrypter")
73 }
74 case jwa.RSA_OAEP, jwa.RSA_OAEP_256:
75 var pubkey rsa.PublicKey
76 if err := keyconv.RSAPublicKey(&pubkey, key); err != nil {
77 return nil, errors.Wrapf(err, "failed to generate public key from key (%T)", key)
78 }
79
80 enc, err = keyenc.NewRSAOAEPEncrypt(keyalg, &pubkey)
81 if err != nil {
82 return nil, errors.Wrap(err, "failed to create RSA OAEP encrypter")
83 }
84 case jwa.A128KW, jwa.A192KW, jwa.A256KW,
85 jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW,
86 jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
87 sharedkey, ok := key.([]byte)
88 if !ok {
89 return nil, errors.New("invalid key: []byte required")
90 }
91 switch keyalg {
92 case jwa.A128KW, jwa.A192KW, jwa.A256KW:
93 enc, err = keyenc.NewAES(keyalg, sharedkey)
94 case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
95 enc, err = keyenc.NewPBES2Encrypt(keyalg, sharedkey)
96 default:
97 enc, err = keyenc.NewAESGCMEncrypt(keyalg, sharedkey)
98 }
99 if err != nil {
100 return nil, errors.Wrap(err, "failed to create key wrap encrypter")
101 }
102
103
104
105
106 case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
107 var keysize int
108 switch keyalg {
109 case jwa.ECDH_ES:
110
111
112
113 keysize = contentcrypt.KeySize()
114 case jwa.ECDH_ES_A128KW:
115 keysize = 16
116 case jwa.ECDH_ES_A192KW:
117 keysize = 24
118 case jwa.ECDH_ES_A256KW:
119 keysize = 32
120 }
121
122 switch key := key.(type) {
123 case x25519.PublicKey:
124 enc, err = keyenc.NewECDHESEncrypt(keyalg, contentalg, keysize, key)
125 default:
126 var pubkey ecdsa.PublicKey
127 if err := keyconv.ECDSAPublicKey(&pubkey, key); err != nil {
128 return nil, errors.Wrapf(err, "failed to generate public key from key (%T)", key)
129 }
130 enc, err = keyenc.NewECDHESEncrypt(keyalg, contentalg, keysize, &pubkey)
131 }
132 if err != nil {
133 return nil, errors.Wrap(err, "failed to create ECDHS key wrap encrypter")
134 }
135 case jwa.DIRECT:
136 sharedkey, ok := key.([]byte)
137 if !ok {
138 return nil, errors.New("invalid key: []byte required")
139 }
140 enc, _ = keyenc.NewNoop(keyalg, sharedkey)
141 default:
142 return nil, errors.Errorf(`invalid key encryption algorithm (%s)`, keyalg)
143 }
144
145 if keyID != "" {
146 enc.SetKeyID(keyID)
147 }
148
149 keysize := contentcrypt.KeySize()
150 encctx := getEncryptCtx()
151 defer releaseEncryptCtx(encctx)
152
153 encctx.protected = protected
154 encctx.contentEncrypter = contentcrypt
155 encctx.generator = keygen.NewRandom(keysize)
156 encctx.keyEncrypters = []keyenc.Encrypter{enc}
157 encctx.compress = compressalg
158 msg, err := encctx.Encrypt(payload)
159 if err != nil {
160 return nil, errors.Wrap(err, "failed to encrypt payload")
161 }
162
163 return Compact(msg)
164 }
165
166
167
168
169
170
171
172 type DecryptCtx interface {
173 Algorithm() jwa.KeyEncryptionAlgorithm
174 SetAlgorithm(jwa.KeyEncryptionAlgorithm)
175 Key() interface{}
176 SetKey(interface{})
177 Message() *Message
178 SetMessage(*Message)
179 }
180
181 type decryptCtx struct {
182 alg jwa.KeyEncryptionAlgorithm
183 key interface{}
184 msg *Message
185 }
186
187 func (ctx *decryptCtx) Algorithm() jwa.KeyEncryptionAlgorithm {
188 return ctx.alg
189 }
190
191 func (ctx *decryptCtx) SetAlgorithm(v jwa.KeyEncryptionAlgorithm) {
192 ctx.alg = v
193 }
194
195 func (ctx *decryptCtx) Key() interface{} {
196 return ctx.key
197 }
198
199 func (ctx *decryptCtx) SetKey(v interface{}) {
200 ctx.key = v
201 }
202
203 func (ctx *decryptCtx) Message() *Message {
204 return ctx.msg
205 }
206
207 func (ctx *decryptCtx) SetMessage(m *Message) {
208 ctx.msg = m
209 }
210
211
212
213
214
215
216 func Decrypt(buf []byte, alg jwa.KeyEncryptionAlgorithm, key interface{}, options ...DecryptOption) ([]byte, error) {
217 var ctx decryptCtx
218 ctx.key = key
219 ctx.alg = alg
220
221 var dst *Message
222 var postParse PostParser
223
224 for _, option := range options {
225 switch option.Ident() {
226 case identMessage{}:
227 dst = option.Value().(*Message)
228 case identPostParser{}:
229 postParse = option.Value().(PostParser)
230 }
231 }
232
233 msg, err := parseJSONOrCompact(buf, true)
234 if err != nil {
235 return nil, errors.Wrap(err, "failed to parse buffer for Decrypt")
236 }
237
238 ctx.msg = msg
239 if postParse != nil {
240 if err := postParse.PostParse(&ctx); err != nil {
241 return nil, errors.Wrap(err, `failed to execute PostParser hook`)
242 }
243 }
244
245 payload, err := doDecryptCtx(&ctx)
246 if err != nil {
247 return nil, errors.Wrap(err, `failed to decrypt message`)
248 }
249
250 if dst != nil {
251 *dst = *msg
252 dst.rawProtectedHeaders = nil
253 dst.storeProtectedHeaders = false
254 }
255
256 return payload, nil
257 }
258
259
260
261 func Parse(buf []byte) (*Message, error) {
262 return parseJSONOrCompact(buf, false)
263 }
264
265 func parseJSONOrCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) {
266 buf = bytes.TrimSpace(buf)
267 if len(buf) == 0 {
268 return nil, errors.New("empty buffer")
269 }
270
271 if buf[0] == '{' {
272 return parseJSON(buf, storeProtectedHeaders)
273 }
274 return parseCompact(buf, storeProtectedHeaders)
275 }
276
277
278 func ParseString(s string) (*Message, error) {
279 return Parse([]byte(s))
280 }
281
282
283 func ParseReader(src io.Reader) (*Message, error) {
284 buf, err := ioutil.ReadAll(src)
285 if err != nil {
286 return nil, errors.Wrap(err, `failed to read from io.Reader`)
287 }
288 return Parse(buf)
289 }
290
291 func parseJSON(buf []byte, storeProtectedHeaders bool) (*Message, error) {
292 m := NewMessage()
293 m.storeProtectedHeaders = storeProtectedHeaders
294 if err := json.Unmarshal(buf, &m); err != nil {
295 return nil, errors.Wrap(err, "failed to parse JSON")
296 }
297 return m, nil
298 }
299
300 func parseCompact(buf []byte, storeProtectedHeaders bool) (*Message, error) {
301 parts := bytes.Split(buf, []byte{'.'})
302 if len(parts) != 5 {
303 return nil, errors.Errorf(`compact JWE format must have five parts (%d)`, len(parts))
304 }
305
306 hdrbuf, err := base64.Decode(parts[0])
307 if err != nil {
308 return nil, errors.Wrap(err, `failed to parse first part of compact form`)
309 }
310
311 protected := NewHeaders()
312 if err := json.Unmarshal(hdrbuf, protected); err != nil {
313 return nil, errors.Wrap(err, "failed to parse header JSON")
314 }
315
316 ivbuf, err := base64.Decode(parts[2])
317 if err != nil {
318 return nil, errors.Wrap(err, "failed to base64 decode iv")
319 }
320
321 ctbuf, err := base64.Decode(parts[3])
322 if err != nil {
323 return nil, errors.Wrap(err, "failed to base64 decode content")
324 }
325
326 tagbuf, err := base64.Decode(parts[4])
327 if err != nil {
328 return nil, errors.Wrap(err, "failed to base64 decode tag")
329 }
330
331 m := NewMessage()
332 if err := m.Set(CipherTextKey, ctbuf); err != nil {
333 return nil, errors.Wrapf(err, `failed to set %s`, CipherTextKey)
334 }
335 if err := m.Set(InitializationVectorKey, ivbuf); err != nil {
336 return nil, errors.Wrapf(err, `failed to set %s`, InitializationVectorKey)
337 }
338 if err := m.Set(ProtectedHeadersKey, protected); err != nil {
339 return nil, errors.Wrapf(err, `failed to set %s`, ProtectedHeadersKey)
340 }
341
342 if err := m.makeDummyRecipient(string(parts[1]), protected); err != nil {
343 return nil, errors.Wrap(err, `failed to setup recipient`)
344 }
345
346 if err := m.Set(TagKey, tagbuf); err != nil {
347 return nil, errors.Wrapf(err, `failed to set %s`, TagKey)
348 }
349
350 if storeProtectedHeaders {
351
352 m.rawProtectedHeaders = parts[0]
353 }
354
355 return m, nil
356 }
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375 func RegisterCustomField(name string, object interface{}) {
376 registry.Register(name, object)
377 }
378
View as plain text