1 package jwe
2
3 import (
4 "context"
5 "sync"
6
7 "github.com/lestrrat-go/jwx/internal/base64"
8 "github.com/lestrrat-go/jwx/jwa"
9 "github.com/pkg/errors"
10 )
11
12 var encryptCtxPool = sync.Pool{
13 New: func() interface{} {
14 return &encryptCtx{}
15 },
16 }
17
18 func getEncryptCtx() *encryptCtx {
19
20 return encryptCtxPool.Get().(*encryptCtx)
21 }
22
23 func releaseEncryptCtx(ctx *encryptCtx) {
24 ctx.protected = nil
25 ctx.contentEncrypter = nil
26 ctx.generator = nil
27 ctx.keyEncrypters = nil
28 ctx.compress = jwa.NoCompress
29 encryptCtxPool.Put(ctx)
30 }
31
32
33 func (e encryptCtx) Encrypt(plaintext []byte) (*Message, error) {
34 bk, err := e.generator.Generate()
35 if err != nil {
36 return nil, errors.Wrap(err, "failed to generate key")
37 }
38 cek := bk.Bytes()
39
40 if e.protected == nil {
41
42 e.protected = NewHeaders()
43 }
44
45 if err := e.protected.Set(ContentEncryptionKey, e.contentEncrypter.Algorithm()); err != nil {
46 return nil, errors.Wrap(err, `failed to set "enc" in protected header`)
47 }
48
49 compression := e.compress
50 if compression != jwa.NoCompress {
51 if err := e.protected.Set(CompressionKey, compression); err != nil {
52 return nil, errors.Wrap(err, `failed to set "zip" in protected header`)
53 }
54 }
55
56
57
58
59 recipients := make([]Recipient, len(e.keyEncrypters))
60 for i, enc := range e.keyEncrypters {
61 r := NewRecipient()
62 if err := r.Headers().Set(AlgorithmKey, enc.Algorithm()); err != nil {
63 return nil, errors.Wrap(err, "failed to set header")
64 }
65 if v := enc.KeyID(); v != "" {
66 if err := r.Headers().Set(KeyIDKey, v); err != nil {
67 return nil, errors.Wrap(err, "failed to set header")
68 }
69 }
70
71 enckey, err := enc.Encrypt(cek)
72 if err != nil {
73 return nil, errors.Wrap(err, `failed to encrypt key`)
74 }
75 if enc.Algorithm() == jwa.ECDH_ES || enc.Algorithm() == jwa.DIRECT {
76 if len(e.keyEncrypters) > 1 {
77 return nil, errors.Errorf("unable to support multiple recipients for ECDH-ES")
78 }
79 cek = enckey.Bytes()
80 } else {
81 if err := r.SetEncryptedKey(enckey.Bytes()); err != nil {
82 return nil, errors.Wrap(err, "failed to set encrypted key")
83 }
84 }
85 if hp, ok := enckey.(populater); ok {
86 if err := hp.Populate(r.Headers()); err != nil {
87 return nil, errors.Wrap(err, "failed to populate")
88 }
89 }
90 recipients[i] = r
91 }
92
93
94
95 if len(recipients) == 1 {
96 h, err := e.protected.Merge(context.TODO(), recipients[0].Headers())
97 if err != nil {
98 return nil, errors.Wrap(err, "failed to merge protected headers")
99 }
100 e.protected = h
101 }
102
103 aad, err := e.protected.Encode()
104 if err != nil {
105 return nil, errors.Wrap(err, "failed to base64 encode protected headers")
106 }
107
108 plaintext, err = compress(plaintext, compression)
109 if err != nil {
110 return nil, errors.Wrap(err, `failed to compress payload before encryption`)
111 }
112
113
114 iv, ciphertext, tag, err := e.contentEncrypter.Encrypt(cek, plaintext, aad)
115 if err != nil {
116 return nil, errors.Wrap(err, "failed to encrypt payload")
117 }
118
119 msg := NewMessage()
120
121 decodedAad, err := base64.Decode(aad)
122 if err != nil {
123 return nil, errors.Wrap(err, "failed to decode base64")
124 }
125 if err := msg.Set(AuthenticatedDataKey, decodedAad); err != nil {
126 return nil, errors.Wrapf(err, `failed to set %s`, AuthenticatedDataKey)
127 }
128 if err := msg.Set(CipherTextKey, ciphertext); err != nil {
129 return nil, errors.Wrapf(err, `failed to set %s`, CipherTextKey)
130 }
131 if err := msg.Set(InitializationVectorKey, iv); err != nil {
132 return nil, errors.Wrapf(err, `failed to set %s`, InitializationVectorKey)
133 }
134 if err := msg.Set(ProtectedHeadersKey, e.protected); err != nil {
135 return nil, errors.Wrapf(err, `failed to set %s`, ProtectedHeadersKey)
136 }
137 if err := msg.Set(RecipientsKey, recipients); err != nil {
138 return nil, errors.Wrapf(err, `failed to set %s`, RecipientsKey)
139 }
140 if err := msg.Set(TagKey, tag); err != nil {
141 return nil, errors.Wrapf(err, `failed to set %s`, TagKey)
142 }
143
144 return msg, nil
145 }
146
View as plain text