1
2
3 package packet
4
5 import (
6 "bytes"
7 "crypto/cipher"
8 "encoding/binary"
9 "io"
10
11 "github.com/ProtonMail/go-crypto/openpgp/errors"
12 )
13
14
15 type aeadCrypter struct {
16 aead cipher.AEAD
17 chunkSize int
18 initialNonce []byte
19 associatedData []byte
20 chunkIndex []byte
21 packetTag packetType
22 bytesProcessed int
23 buffer bytes.Buffer
24 }
25
26
27
28
29 func (wo *aeadCrypter) computeNextNonce() (nonce []byte) {
30 if wo.packetTag == packetTypeSymmetricallyEncryptedIntegrityProtected {
31 return append(wo.initialNonce, wo.chunkIndex...)
32 }
33
34 nonce = make([]byte, len(wo.initialNonce))
35 copy(nonce, wo.initialNonce)
36 offset := len(wo.initialNonce) - 8
37 for i := 0; i < 8; i++ {
38 nonce[i+offset] ^= wo.chunkIndex[i]
39 }
40 return
41 }
42
43
44
45 func (wo *aeadCrypter) incrementIndex() error {
46 index := wo.chunkIndex
47 if len(index) == 0 {
48 return errors.AEADError("Index has length 0")
49 }
50 for i := len(index) - 1; i >= 0; i-- {
51 if index[i] < 255 {
52 index[i]++
53 return nil
54 }
55 index[i] = 0
56 }
57 return errors.AEADError("cannot further increment index")
58 }
59
60
61
62 type aeadDecrypter struct {
63 aeadCrypter
64 reader io.Reader
65 peekedBytes []byte
66 eof bool
67 }
68
69
70
71
72 func (ar *aeadDecrypter) Read(dst []byte) (n int, err error) {
73
74 if ar.buffer.Len() > 0 {
75 return ar.buffer.Read(dst)
76 }
77
78
79 if ar.eof {
80 return 0, io.EOF
81 }
82
83
84 tagLen := ar.aead.Overhead()
85 cipherChunkBuf := new(bytes.Buffer)
86 _, errRead := io.CopyN(cipherChunkBuf, ar.reader, int64(ar.chunkSize+tagLen))
87 cipherChunk := cipherChunkBuf.Bytes()
88 if errRead != nil && errRead != io.EOF {
89 return 0, errRead
90 }
91 decrypted, errChunk := ar.openChunk(cipherChunk)
92 if errChunk != nil {
93 return 0, errChunk
94 }
95
96
97 if len(dst) < len(decrypted) {
98 n = copy(dst, decrypted[:len(dst)])
99 ar.buffer.Write(decrypted[len(dst):])
100 } else {
101 n = copy(dst, decrypted)
102 }
103
104
105 if errRead == io.EOF {
106 errChunk := ar.validateFinalTag(ar.peekedBytes)
107 if errChunk != nil {
108 return n, errChunk
109 }
110 ar.eof = true
111 }
112 return
113 }
114
115
116
117
118 func (ar *aeadDecrypter) Close() (err error) {
119 return nil
120 }
121
122
123
124
125 func (ar *aeadDecrypter) openChunk(data []byte) ([]byte, error) {
126 tagLen := ar.aead.Overhead()
127
128 chunkExtra := append(ar.peekedBytes, data...)
129
130 chunk := chunkExtra[:len(chunkExtra)-tagLen]
131 ar.peekedBytes = chunkExtra[len(chunkExtra)-tagLen:]
132
133 adata := ar.associatedData
134 if ar.aeadCrypter.packetTag == packetTypeAEADEncrypted {
135 adata = append(ar.associatedData, ar.chunkIndex...)
136 }
137
138 nonce := ar.computeNextNonce()
139 plainChunk, err := ar.aead.Open(nil, nonce, chunk, adata)
140 if err != nil {
141 return nil, err
142 }
143 ar.bytesProcessed += len(plainChunk)
144 if err = ar.aeadCrypter.incrementIndex(); err != nil {
145 return nil, err
146 }
147 return plainChunk, nil
148 }
149
150
151
152 func (ar *aeadDecrypter) validateFinalTag(tag []byte) error {
153
154 amountBytes := make([]byte, 8)
155 binary.BigEndian.PutUint64(amountBytes, uint64(ar.bytesProcessed))
156
157 adata := ar.associatedData
158 if ar.aeadCrypter.packetTag == packetTypeAEADEncrypted {
159
160 adata = append(ar.associatedData, ar.chunkIndex...)
161 }
162
163
164 adata = append(adata, amountBytes...)
165 nonce := ar.computeNextNonce()
166 _, err := ar.aead.Open(nil, nonce, tag, adata)
167 if err != nil {
168 return err
169 }
170 return nil
171 }
172
173
174
175 type aeadEncrypter struct {
176 aeadCrypter
177 writer io.WriteCloser
178 }
179
180
181
182
183 func (aw *aeadEncrypter) Write(plaintextBytes []byte) (n int, err error) {
184
185 n, err = aw.buffer.Write(plaintextBytes)
186 if err != nil {
187 return n, err
188 }
189
190 for aw.buffer.Len() >= aw.chunkSize {
191 plainChunk := aw.buffer.Next(aw.chunkSize)
192 encryptedChunk, err := aw.sealChunk(plainChunk)
193 if err != nil {
194 return n, err
195 }
196 _, err = aw.writer.Write(encryptedChunk)
197 if err != nil {
198 return n, err
199 }
200 }
201 return
202 }
203
204
205
206
207 func (aw *aeadEncrypter) Close() (err error) {
208
209
210 if aw.buffer.Len() > 0 || aw.bytesProcessed == 0 {
211 plainChunk := aw.buffer.Bytes()
212 lastEncryptedChunk, err := aw.sealChunk(plainChunk)
213 if err != nil {
214 return err
215 }
216 _, err = aw.writer.Write(lastEncryptedChunk)
217 if err != nil {
218 return err
219 }
220 }
221
222
223 adata := aw.associatedData
224
225 if aw.aeadCrypter.packetTag == packetTypeAEADEncrypted {
226
227 adata = append(aw.associatedData, aw.chunkIndex...)
228 }
229
230
231 amountBytes := make([]byte, 8)
232 binary.BigEndian.PutUint64(amountBytes, uint64(aw.bytesProcessed))
233 adata = append(adata, amountBytes...)
234
235 nonce := aw.computeNextNonce()
236 finalTag := aw.aead.Seal(nil, nonce, nil, adata)
237 _, err = aw.writer.Write(finalTag)
238 if err != nil {
239 return err
240 }
241 return aw.writer.Close()
242 }
243
244
245 func (aw *aeadEncrypter) sealChunk(data []byte) ([]byte, error) {
246 if len(data) > aw.chunkSize {
247 return nil, errors.AEADError("chunk exceeds maximum length")
248 }
249 if aw.associatedData == nil {
250 return nil, errors.AEADError("can't seal without headers")
251 }
252 adata := aw.associatedData
253 if aw.aeadCrypter.packetTag == packetTypeAEADEncrypted {
254 adata = append(aw.associatedData, aw.chunkIndex...)
255 }
256
257 nonce := aw.computeNextNonce()
258 encrypted := aw.aead.Seal(nil, nonce, data, adata)
259 aw.bytesProcessed += len(data)
260 if err := aw.aeadCrypter.incrementIndex(); err != nil {
261 return nil, err
262 }
263 return encrypted, nil
264 }
265
View as plain text