1 package jwe
2
3 import (
4 "context"
5 "crypto/ecdsa"
6 "fmt"
7
8 "github.com/lestrrat-go/jwx/internal/json"
9 "github.com/lestrrat-go/jwx/internal/pool"
10 "github.com/lestrrat-go/jwx/jwk"
11
12 "github.com/lestrrat-go/jwx/internal/base64"
13 "github.com/lestrrat-go/jwx/jwa"
14 "github.com/pkg/errors"
15 )
16
17
18 func NewRecipient() Recipient {
19 return &stdRecipient{
20 headers: NewHeaders(),
21 }
22 }
23
24 func (r *stdRecipient) SetHeaders(h Headers) error {
25 r.headers = h
26 return nil
27 }
28
29 func (r *stdRecipient) SetEncryptedKey(v []byte) error {
30 r.encryptedKey = v
31 return nil
32 }
33
34 func (r *stdRecipient) Headers() Headers {
35 return r.headers
36 }
37
38 func (r *stdRecipient) EncryptedKey() []byte {
39 return r.encryptedKey
40 }
41
42 type recipientMarshalProxy struct {
43 Headers Headers `json:"header"`
44 EncryptedKey string `json:"encrypted_key"`
45 }
46
47 func (r *stdRecipient) UnmarshalJSON(buf []byte) error {
48 var proxy recipientMarshalProxy
49 proxy.Headers = NewHeaders()
50 if err := json.Unmarshal(buf, &proxy); err != nil {
51 return errors.Wrap(err, `failed to unmarshal json into recipient`)
52 }
53
54 r.headers = proxy.Headers
55 decoded, err := base64.DecodeString(proxy.EncryptedKey)
56 if err != nil {
57 return errors.Wrap(err, `failed to decode "encrypted_key"`)
58 }
59 r.encryptedKey = decoded
60 return nil
61 }
62
63 func (r *stdRecipient) MarshalJSON() ([]byte, error) {
64 buf := pool.GetBytesBuffer()
65 defer pool.ReleaseBytesBuffer(buf)
66
67 buf.WriteString(`{"header":`)
68 hdrbuf, err := r.headers.MarshalJSON()
69 if err != nil {
70 return nil, errors.Wrap(err, `failed to marshal recipient header`)
71 }
72 buf.Write(hdrbuf)
73 buf.WriteString(`,"encrypted_key":"`)
74 buf.WriteString(base64.EncodeToString(r.encryptedKey))
75 buf.WriteString(`"}`)
76
77 ret := make([]byte, buf.Len())
78 copy(ret, buf.Bytes())
79 return ret, nil
80 }
81
82
83 func NewMessage() *Message {
84 return &Message{}
85 }
86
87 func (m *Message) AuthenticatedData() []byte {
88 return m.authenticatedData
89 }
90
91 func (m *Message) CipherText() []byte {
92 return m.cipherText
93 }
94
95 func (m *Message) InitializationVector() []byte {
96 return m.initializationVector
97 }
98
99 func (m *Message) Tag() []byte {
100 return m.tag
101 }
102
103 func (m *Message) ProtectedHeaders() Headers {
104 return m.protectedHeaders
105 }
106
107 func (m *Message) Recipients() []Recipient {
108 return m.recipients
109 }
110
111 func (m *Message) UnprotectedHeaders() Headers {
112 return m.unprotectedHeaders
113 }
114
115 const (
116 AuthenticatedDataKey = "aad"
117 CipherTextKey = "ciphertext"
118 CountKey = "p2c"
119 InitializationVectorKey = "iv"
120 ProtectedHeadersKey = "protected"
121 RecipientsKey = "recipients"
122 SaltKey = "p2s"
123 TagKey = "tag"
124 UnprotectedHeadersKey = "unprotected"
125 HeadersKey = "header"
126 EncryptedKeyKey = "encrypted_key"
127 )
128
129 func (m *Message) Set(k string, v interface{}) error {
130 switch k {
131 case AuthenticatedDataKey:
132 buf, ok := v.([]byte)
133 if !ok {
134 return errors.Errorf(`invalid value %T for %s key`, v, AuthenticatedDataKey)
135 }
136 m.authenticatedData = buf
137 case CipherTextKey:
138 buf, ok := v.([]byte)
139 if !ok {
140 return errors.Errorf(`invalid value %T for %s key`, v, CipherTextKey)
141 }
142 m.cipherText = buf
143 case InitializationVectorKey:
144 buf, ok := v.([]byte)
145 if !ok {
146 return errors.Errorf(`invalid value %T for %s key`, v, InitializationVectorKey)
147 }
148 m.initializationVector = buf
149 case ProtectedHeadersKey:
150 cv, ok := v.(Headers)
151 if !ok {
152 return errors.Errorf(`invalid value %T for %s key`, v, ProtectedHeadersKey)
153 }
154 m.protectedHeaders = cv
155 case RecipientsKey:
156 cv, ok := v.([]Recipient)
157 if !ok {
158 return errors.Errorf(`invalid value %T for %s key`, v, RecipientsKey)
159 }
160 m.recipients = cv
161 case TagKey:
162 buf, ok := v.([]byte)
163 if !ok {
164 return errors.Errorf(`invalid value %T for %s key`, v, TagKey)
165 }
166 m.tag = buf
167 case UnprotectedHeadersKey:
168 cv, ok := v.(Headers)
169 if !ok {
170 return errors.Errorf(`invalid value %T for %s key`, v, UnprotectedHeadersKey)
171 }
172 m.unprotectedHeaders = cv
173 default:
174 if m.unprotectedHeaders == nil {
175 m.unprotectedHeaders = NewHeaders()
176 }
177 return m.unprotectedHeaders.Set(k, v)
178 }
179 return nil
180 }
181
182 type messageMarshalProxy struct {
183 AuthenticatedData string `json:"aad,omitempty"`
184 CipherText string `json:"ciphertext"`
185 InitializationVector string `json:"iv,omitempty"`
186 ProtectedHeaders json.RawMessage `json:"protected"`
187 Recipients []json.RawMessage `json:"recipients,omitempty"`
188 Tag string `json:"tag,omitempty"`
189 UnprotectedHeaders Headers `json:"unprotected,omitempty"`
190
191
192
193 Headers json.RawMessage `json:"header,omitempty"`
194 EncryptedKey string `json:"encrypted_key,omitempty"`
195 }
196
197 func (m *Message) MarshalJSON() ([]byte, error) {
198
199
200 buf := pool.GetBytesBuffer()
201 defer pool.ReleaseBytesBuffer(buf)
202 enc := json.NewEncoder(buf)
203 fmt.Fprintf(buf, `{`)
204
205 var wrote bool
206 if aad := m.AuthenticatedData(); len(aad) > 0 {
207 wrote = true
208 fmt.Fprintf(buf, `%#v:`, AuthenticatedDataKey)
209 if err := enc.Encode(base64.EncodeToString(aad)); err != nil {
210 return nil, errors.Wrapf(err, `failed to encode %s field`, AuthenticatedDataKey)
211 }
212 }
213 if cipherText := m.CipherText(); len(cipherText) > 0 {
214 if wrote {
215 fmt.Fprintf(buf, `,`)
216 }
217 wrote = true
218 fmt.Fprintf(buf, `%#v:`, CipherTextKey)
219 if err := enc.Encode(base64.EncodeToString(cipherText)); err != nil {
220 return nil, errors.Wrapf(err, `failed to encode %s field`, CipherTextKey)
221 }
222 }
223
224 if iv := m.InitializationVector(); len(iv) > 0 {
225 if wrote {
226 fmt.Fprintf(buf, `,`)
227 }
228 wrote = true
229 fmt.Fprintf(buf, `%#v:`, InitializationVectorKey)
230 if err := enc.Encode(base64.EncodeToString(iv)); err != nil {
231 return nil, errors.Wrapf(err, `failed to encode %s field`, InitializationVectorKey)
232 }
233 }
234
235 if h := m.ProtectedHeaders(); h != nil {
236 encodedHeaders, err := h.Encode()
237 if err != nil {
238 return nil, errors.Wrap(err, `failed to encode protected headers`)
239 }
240
241 if len(encodedHeaders) > 2 {
242 if wrote {
243 fmt.Fprintf(buf, `,`)
244 }
245 wrote = true
246 fmt.Fprintf(buf, `%#v:%#v`, ProtectedHeadersKey, string(encodedHeaders))
247 }
248 }
249
250 if recipients := m.Recipients(); len(recipients) > 0 {
251 if wrote {
252 fmt.Fprintf(buf, `,`)
253 }
254 if len(recipients) == 1 {
255 fmt.Fprintf(buf, `%#v:`, HeadersKey)
256 if err := enc.Encode(recipients[0].Headers()); err != nil {
257 return nil, errors.Wrapf(err, `failed to encode %s field`, HeadersKey)
258 }
259 if ek := recipients[0].EncryptedKey(); len(ek) > 0 {
260 fmt.Fprintf(buf, `,%#v:`, EncryptedKeyKey)
261 if err := enc.Encode(base64.EncodeToString(ek)); err != nil {
262 return nil, errors.Wrapf(err, `failed to encode %s field`, EncryptedKeyKey)
263 }
264 }
265 } else {
266 fmt.Fprintf(buf, `%#v:`, RecipientsKey)
267 if err := enc.Encode(recipients); err != nil {
268 return nil, errors.Wrapf(err, `failed to encode %s field`, RecipientsKey)
269 }
270 }
271 }
272
273 if tag := m.Tag(); len(tag) > 0 {
274 if wrote {
275 fmt.Fprintf(buf, `,`)
276 }
277 fmt.Fprintf(buf, `%#v:`, TagKey)
278 if err := enc.Encode(base64.EncodeToString(tag)); err != nil {
279 return nil, errors.Wrapf(err, `failed to encode %s field`, TagKey)
280 }
281 }
282
283 if h := m.UnprotectedHeaders(); h != nil {
284 unprotected, err := json.Marshal(h)
285 if err != nil {
286 return nil, errors.Wrap(err, `failed to encode unprotected headers`)
287 }
288
289 if len(unprotected) > 2 {
290 fmt.Fprintf(buf, `,%#v:%#v`, UnprotectedHeadersKey, string(unprotected))
291 }
292 }
293 fmt.Fprintf(buf, `}`)
294
295 ret := make([]byte, buf.Len())
296 copy(ret, buf.Bytes())
297 return ret, nil
298 }
299
300 func (m *Message) UnmarshalJSON(buf []byte) error {
301 var proxy messageMarshalProxy
302 proxy.UnprotectedHeaders = NewHeaders()
303
304 if err := json.Unmarshal(buf, &proxy); err != nil {
305 return errors.Wrap(err, `failed to unmashal JSON into message`)
306 }
307
308
309 var protectedHeadersStr string
310 if err := json.Unmarshal(proxy.ProtectedHeaders, &protectedHeadersStr); err != nil {
311 return errors.Wrap(err, `failed to decode protected headers (1)`)
312 }
313
314
315 protectedHeadersRaw, err := base64.DecodeString(protectedHeadersStr)
316 if err != nil {
317 return errors.Wrap(err, "failed to base64 decoded protected headers buffer")
318 }
319
320 h := NewHeaders()
321 if err := json.Unmarshal(protectedHeadersRaw, h); err != nil {
322 return errors.Wrap(err, `failed to decode protected headers (2)`)
323 }
324
325
326
327 if proxy.Headers != nil || len(proxy.EncryptedKey) > 0 {
328 recipient := NewRecipient()
329 hdrs := NewHeaders()
330 if err := json.Unmarshal(proxy.Headers, hdrs); err != nil {
331 return errors.Wrap(err, `failed to decode headers field`)
332 }
333
334 if err := recipient.SetHeaders(hdrs); err != nil {
335 return errors.Wrap(err, `failed to set new headers`)
336 }
337
338 if v := proxy.EncryptedKey; len(v) > 0 {
339 buf, err := base64.DecodeString(v)
340 if err != nil {
341 return errors.Wrap(err, `failed to decode encrypted key`)
342 }
343 if err := recipient.SetEncryptedKey(buf); err != nil {
344 return errors.Wrap(err, `failed to set encrypted key`)
345 }
346 }
347
348 m.recipients = append(m.recipients, recipient)
349 } else {
350 for i, recipientbuf := range proxy.Recipients {
351 recipient := NewRecipient()
352 if err := json.Unmarshal(recipientbuf, recipient); err != nil {
353 return errors.Wrapf(err, `failed to decode recipient at index %d`, i)
354 }
355
356 m.recipients = append(m.recipients, recipient)
357 }
358 }
359
360 if src := proxy.AuthenticatedData; len(src) > 0 {
361 v, err := base64.DecodeString(src)
362 if err != nil {
363 return errors.Wrap(err, `failed to decode "aad"`)
364 }
365 m.authenticatedData = v
366 }
367
368 if src := proxy.CipherText; len(src) > 0 {
369 v, err := base64.DecodeString(src)
370 if err != nil {
371 return errors.Wrap(err, `failed to decode "ciphertext"`)
372 }
373 m.cipherText = v
374 }
375
376 if src := proxy.InitializationVector; len(src) > 0 {
377 v, err := base64.DecodeString(src)
378 if err != nil {
379 return errors.Wrap(err, `failed to decode "iv"`)
380 }
381 m.initializationVector = v
382 }
383
384 if src := proxy.Tag; len(src) > 0 {
385 v, err := base64.DecodeString(src)
386 if err != nil {
387 return errors.Wrap(err, `failed to decode "tag"`)
388 }
389 m.tag = v
390 }
391
392 m.protectedHeaders = h
393 if m.storeProtectedHeaders {
394
395 m.rawProtectedHeaders = base64.Encode(protectedHeadersRaw)
396 }
397
398 if iz, ok := proxy.UnprotectedHeaders.(isZeroer); ok {
399 if !iz.isZero() {
400 m.unprotectedHeaders = proxy.UnprotectedHeaders
401 }
402 }
403
404 if len(m.recipients) == 0 {
405 if err := m.makeDummyRecipient(proxy.EncryptedKey, m.protectedHeaders); err != nil {
406 return errors.Wrap(err, `failed to setup recipient`)
407 }
408 }
409
410 return nil
411 }
412
413 func (m *Message) makeDummyRecipient(enckeybuf string, protected Headers) error {
414
415
416 hdrs, err := protected.Clone(context.TODO())
417 if err != nil {
418 return errors.Wrap(err, `failed to clone headers`)
419 }
420
421 if err := hdrs.Remove(ContentEncryptionKey); err != nil {
422 return errors.Wrapf(err, "failed to remove %#v from public header", ContentEncryptionKey)
423 }
424
425 enckey, err := base64.DecodeString(enckeybuf)
426 if err != nil {
427 return errors.Wrap(err, `failed to decode encrypted key`)
428 }
429
430 if err := m.Set(RecipientsKey, []Recipient{
431 &stdRecipient{
432 headers: hdrs,
433 encryptedKey: enckey,
434 },
435 }); err != nil {
436 return errors.Wrapf(err, `failed to set %s`, RecipientsKey)
437 }
438 return nil
439 }
440
441
442
443
444
445
446
447
448
449
450 func (m *Message) Decrypt(alg jwa.KeyEncryptionAlgorithm, key interface{}) ([]byte, error) {
451 var ctx decryptCtx
452 ctx.alg = alg
453 ctx.key = key
454 ctx.msg = m
455
456 return doDecryptCtx(&ctx)
457 }
458
459 func doDecryptCtx(dctx *decryptCtx) ([]byte, error) {
460 m := dctx.msg
461 alg := dctx.alg
462 key := dctx.key
463
464 if jwkKey, ok := key.(jwk.Key); ok {
465 var raw interface{}
466 if err := jwkKey.Raw(&raw); err != nil {
467 return nil, errors.Wrapf(err, `failed to retrieve raw key from %T`, key)
468 }
469 key = raw
470 }
471
472 var err error
473 ctx := context.TODO()
474 h, err := m.protectedHeaders.Clone(ctx)
475 if err != nil {
476 return nil, errors.Wrap(err, `failed to copy protected headers`)
477 }
478 h, err = h.Merge(ctx, m.unprotectedHeaders)
479 if err != nil {
480 return nil, errors.Wrap(err, "failed to merge headers for message decryption")
481 }
482
483 enc := m.protectedHeaders.ContentEncryption()
484 var aad []byte
485 if aadContainer := m.authenticatedData; aadContainer != nil {
486 aad = base64.Encode(aadContainer)
487 }
488
489 var computedAad []byte
490 if len(m.rawProtectedHeaders) > 0 {
491 computedAad = m.rawProtectedHeaders
492 } else {
493
494 var err error
495 computedAad, err = m.protectedHeaders.Encode()
496 if err != nil {
497 return nil, errors.Wrap(err, "failed to encode protected headers")
498 }
499 }
500
501 dec := NewDecrypter(alg, enc, key).
502 AuthenticatedData(aad).
503 ComputedAuthenticatedData(computedAad).
504 InitializationVector(m.initializationVector).
505 Tag(m.tag)
506
507 var plaintext []byte
508 var lastError error
509
510
511 recipients := m.recipients
512 if len(recipients) == 0 {
513 r := NewRecipient()
514 if err := r.SetHeaders(m.protectedHeaders); err != nil {
515 return nil, errors.Wrap(err, `failed to set headers to recipient`)
516 }
517 recipients = append(recipients, r)
518 }
519
520 for _, recipient := range recipients {
521
522
523 if recipient.Headers().Algorithm() != alg {
524
525 continue
526 }
527
528 h2, err := h.Clone(ctx)
529 if err != nil {
530 lastError = errors.Wrap(err, `failed to copy headers (1)`)
531 continue
532 }
533
534 h2, err = h2.Merge(ctx, recipient.Headers())
535 if err != nil {
536 lastError = errors.Wrap(err, `failed to copy headers (2)`)
537 continue
538 }
539
540 switch alg {
541 case jwa.ECDH_ES, jwa.ECDH_ES_A128KW, jwa.ECDH_ES_A192KW, jwa.ECDH_ES_A256KW:
542 epkif, ok := h2.Get(EphemeralPublicKeyKey)
543 if !ok {
544 return nil, errors.New("failed to get 'epk' field")
545 }
546 switch epk := epkif.(type) {
547 case jwk.ECDSAPublicKey:
548 var pubkey ecdsa.PublicKey
549 if err := epk.Raw(&pubkey); err != nil {
550 return nil, errors.Wrap(err, "failed to get public key")
551 }
552 dec.PublicKey(&pubkey)
553 case jwk.OKPPublicKey:
554 var pubkey interface{}
555 if err := epk.Raw(&pubkey); err != nil {
556 return nil, errors.Wrap(err, "failed to get public key")
557 }
558 dec.PublicKey(pubkey)
559 default:
560 return nil, errors.Errorf("unexpected 'epk' type %T for alg %s", epkif, alg)
561 }
562
563 if apu := h2.AgreementPartyUInfo(); len(apu) > 0 {
564 dec.AgreementPartyUInfo(apu)
565 }
566
567 if apv := h2.AgreementPartyVInfo(); len(apv) > 0 {
568 dec.AgreementPartyVInfo(apv)
569 }
570 case jwa.A128GCMKW, jwa.A192GCMKW, jwa.A256GCMKW:
571 ivB64, ok := h2.Get(InitializationVectorKey)
572 if !ok {
573 return nil, errors.New("failed to get 'iv' field")
574 }
575 ivB64Str, ok := ivB64.(string)
576 if !ok {
577 return nil, errors.Errorf("unexpected type for 'iv': %T", ivB64)
578 }
579 tagB64, ok := h2.Get(TagKey)
580 if !ok {
581 return nil, errors.New("failed to get 'tag' field")
582 }
583 tagB64Str, ok := tagB64.(string)
584 if !ok {
585 return nil, errors.Errorf("unexpected type for 'tag': %T", tagB64)
586 }
587 iv, err := base64.DecodeString(ivB64Str)
588 if err != nil {
589 return nil, errors.Wrap(err, "failed to b64-decode 'iv'")
590 }
591 tag, err := base64.DecodeString(tagB64Str)
592 if err != nil {
593 return nil, errors.Wrap(err, "failed to b64-decode 'tag'")
594 }
595 dec.KeyInitializationVector(iv)
596 dec.KeyTag(tag)
597 case jwa.PBES2_HS256_A128KW, jwa.PBES2_HS384_A192KW, jwa.PBES2_HS512_A256KW:
598 saltB64, ok := h2.Get(SaltKey)
599 if !ok {
600 return nil, errors.New("failed to get 'p2s' field")
601 }
602 saltB64Str, ok := saltB64.(string)
603 if !ok {
604 return nil, errors.Errorf("unexpected type for 'p2s': %T", saltB64)
605 }
606
607 count, ok := h2.Get(CountKey)
608 if !ok {
609 return nil, errors.New("failed to get 'p2c' field")
610 }
611 countFlt, ok := count.(float64)
612 if !ok {
613 return nil, errors.Errorf("unexpected type for 'p2c': %T", count)
614 }
615 salt, err := base64.DecodeString(saltB64Str)
616 if err != nil {
617 return nil, errors.Wrap(err, "failed to b64-decode 'salt'")
618 }
619 dec.KeySalt(salt)
620 dec.KeyCount(int(countFlt))
621 }
622
623 plaintext, err = dec.Decrypt(recipient.EncryptedKey(), m.cipherText)
624 if err != nil {
625 lastError = errors.Wrap(err, `failed to decrypt`)
626 continue
627 }
628
629 if h2.Compression() == jwa.Deflate {
630 buf, err := uncompress(plaintext)
631 if err != nil {
632 lastError = errors.Wrap(err, `failed to uncompress payload`)
633 continue
634 }
635 plaintext = buf
636 }
637 break
638 }
639
640 if plaintext == nil {
641 if lastError != nil {
642 return nil, errors.Errorf(`failed to find matching recipient to decrypt key (last error = %s)`, lastError)
643 }
644 return nil, errors.New("failed to find matching recipient")
645 }
646
647 return plaintext, nil
648 }
649
View as plain text