1
2
3
4
5 package packet
6
7 import (
8 "bytes"
9 "crypto/rand"
10 "crypto/sha1"
11 "encoding/hex"
12 goerrors "errors"
13 "io"
14 "io/ioutil"
15 "testing"
16
17 "github.com/ProtonMail/go-crypto/openpgp/errors"
18 )
19
20
21 type testReader struct {
22 data []byte
23 stride int
24 }
25
26 func (t *testReader) Read(buf []byte) (n int, err error) {
27 n = t.stride
28 if n > len(t.data) {
29 n = len(t.data)
30 }
31 if n > len(buf) {
32 n = len(buf)
33 }
34
35 copy(buf[:n], t.data)
36 t.data = t.data[n:]
37
38 if len(t.data) == 0 {
39 err = io.EOF
40 }
41
42 return
43 }
44
45 const mdcPlaintextHex = "cb1362000000000048656c6c6f2c20776f726c6421d314c23d643f478a9a2098811fcb191e7b24b80966a1"
46
47 func TestMDCReader(t *testing.T) {
48 mdcPlaintext, _ := hex.DecodeString(mdcPlaintextHex)
49 for stride := 1; stride < len(mdcPlaintext)/2; stride++ {
50 r := &testReader{data: mdcPlaintext, stride: stride}
51 mdcReader := &seMDCReader{in: r, h: sha1.New()}
52 body, err := ioutil.ReadAll(mdcReader)
53 if err != nil {
54 t.Errorf("stride: %d, error: %s", stride, err)
55 continue
56 }
57 if !bytes.Equal(body, mdcPlaintext[:len(mdcPlaintext)-22]) {
58 t.Errorf("stride: %d: bad contents %x", stride, body)
59 continue
60 }
61
62 err = mdcReader.Close()
63 if err != nil {
64 t.Errorf("stride: %d, error on Close: %s", stride, err)
65 }
66 }
67
68 mdcPlaintext[15] ^= 80
69
70 r := &testReader{data: mdcPlaintext, stride: 2}
71 mdcReader := &seMDCReader{in: r, h: sha1.New()}
72 _, err := ioutil.ReadAll(mdcReader)
73 if err != nil {
74 t.Errorf("corruption test, error: %s", err)
75 return
76 }
77 err = mdcReader.Close()
78 if err == nil {
79 t.Error("corruption: no error")
80 } else if !goerrors.Is(err, errors.ErrMDCHashMismatch) {
81 t.Errorf("corruption: expected SignatureError, got: %s", err)
82 }
83 }
84
85 func TestSerializeMdc(t *testing.T) {
86 buf := bytes.NewBuffer(nil)
87 c := CipherAES128
88 key := make([]byte, c.KeySize())
89
90 cipherSuite := CipherSuite{
91 Cipher: c,
92 Mode: AEADModeOCB,
93 }
94
95 w, err := SerializeSymmetricallyEncrypted(buf, c, false, cipherSuite, key, nil)
96 if err != nil {
97 t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
98 return
99 }
100
101 contents := []byte("hello world\n")
102
103 w.Write(contents)
104 w.Close()
105
106 p, err := Read(buf)
107 if err != nil {
108 t.Errorf("error from Read: %s", err)
109 return
110 }
111
112 se, ok := p.(*SymmetricallyEncrypted)
113 if !ok {
114 t.Errorf("didn't read a *SymmetricallyEncrypted")
115 return
116 }
117
118 r, err := se.Decrypt(c, key)
119 if err != nil {
120 t.Errorf("error from Decrypt: %s", err)
121 return
122 }
123
124 contentsCopy := bytes.NewBuffer(nil)
125 _, err = io.Copy(contentsCopy, r)
126 if err != nil {
127 t.Errorf("error from io.Copy: %s", err)
128 return
129 }
130 if !bytes.Equal(contentsCopy.Bytes(), contents) {
131 t.Errorf("contents not equal got: %x want: %x", contentsCopy.Bytes(), contents)
132 }
133 }
134
135 const aeadHexKey = "1936fc8568980274bb900d8319360c77"
136 const aeadHexSeipd = "d26902070306fcb94490bcb98bbdc9d106c6090266940f72e89edc21b5596b1576b101ed0f9ffc6fc6d65bbfd24dcd0790966e6d1e85a30053784cb1d8b6a0699ef12155a7b2ad6258531b57651fd7777912fa95e35d9b40216f69a4c248db28ff4331f1632907399e6ff9"
137 const aeadHexPlainText = "cb1362000000000048656c6c6f2c20776f726c6421d50e1ce2269a9eddef81032172b7ed7c"
138 const aeadExpectedSalt = "fcb94490bcb98bbdc9d106c6090266940f72e89edc21b5596b1576b101ed0f9f"
139
140 func TestAeadRfcVector(t *testing.T) {
141 key, err := hex.DecodeString(aeadHexKey)
142 if err != nil {
143 t.Errorf("error in decoding key: %s", err)
144 }
145
146 packet, err := hex.DecodeString(aeadHexSeipd)
147 if err != nil {
148 t.Errorf("error in decoding packet: %s", err)
149 }
150
151 plainText, err := hex.DecodeString(aeadHexPlainText)
152 if err != nil {
153 t.Errorf("error in decoding plaintext: %s", err)
154 }
155
156 expectedSalt, err := hex.DecodeString(aeadExpectedSalt)
157 if err != nil {
158 t.Errorf("error in decoding salt: %s", err)
159 }
160
161 buf := bytes.NewBuffer(packet)
162 p, err := Read(buf)
163 if err != nil {
164 t.Errorf("error from Read: %s", err)
165 return
166 }
167
168 se, ok := p.(*SymmetricallyEncrypted)
169 if !ok {
170 t.Errorf("didn't read a *SymmetricallyEncrypted")
171 return
172 }
173
174 if se.Version != symmetricallyEncryptedVersionAead {
175 t.Errorf("found wrong version, want: %d, got: %d", symmetricallyEncryptedVersionAead, se.Version)
176 }
177
178 if se.Cipher != CipherAES128 {
179 t.Errorf("found wrong cipher, want: %d, got: %d", CipherAES128, se.Cipher)
180 }
181
182 if se.Mode != AEADModeGCM {
183 t.Errorf("found wrong mode, want: %d, got: %d", AEADModeGCM, se.Mode)
184 }
185
186 if !bytes.Equal(se.Salt[:], expectedSalt) {
187 t.Errorf("found wrong salt, want: %x, got: %x", expectedSalt, se.Salt)
188 }
189
190 if se.ChunkSizeByte != 0x06 {
191 t.Errorf("found wrong chunk size byte, want: %d, got: %d", 0x06, se.ChunkSizeByte)
192 }
193
194 aeadReader, err := se.Decrypt(CipherFunction(0), key)
195 if err != nil {
196 t.Errorf("error from Decrypt: %s", err)
197 return
198 }
199
200 decrypted, err := ioutil.ReadAll(aeadReader)
201 if err != nil {
202 t.Errorf("error when reading: %s", err)
203 return
204 }
205
206 err = aeadReader.Close()
207 if err != nil {
208 t.Errorf("error when closing reader: %s", err)
209 return
210 }
211
212 if !bytes.Equal(decrypted, plainText) {
213 t.Errorf("contents not equal got: %x want: %x", decrypted, plainText)
214 }
215 }
216
217 func TestAeadEncryptDecrypt(t *testing.T) {
218 ciphers := map[string]CipherFunction{
219 "AES128": CipherAES128,
220 "AES192": CipherAES192,
221 "AES256": CipherAES256,
222 }
223
224 modes := map[string]AEADMode{
225 "EAX": AEADModeEAX,
226 "OCB": AEADModeOCB,
227 "GCM": AEADModeGCM,
228 }
229
230 for cipherName, cipher := range ciphers {
231 t.Run(cipherName, func(t *testing.T) {
232 for modeName, mode := range modes {
233 t.Run(modeName, func(t *testing.T) {
234 testSerializeAead(t, CipherSuite{Cipher: cipher, Mode: mode})
235 })
236 }
237 })
238 }
239 }
240
241 func testSerializeAead(t *testing.T, cipherSuite CipherSuite) {
242 buf := bytes.NewBuffer(nil)
243 key := make([]byte, cipherSuite.Cipher.KeySize())
244 _, _ = rand.Read(key)
245
246 w, err := SerializeSymmetricallyEncrypted(buf, CipherFunction(0), true, cipherSuite, key, &Config{AEADConfig: &AEADConfig{}})
247 if err != nil {
248 t.Errorf("error from SerializeSymmetricallyEncrypted: %s", err)
249 return
250 }
251
252 contents := []byte("hello world\n")
253
254 w.Write(contents)
255 w.Close()
256
257 p, err := Read(buf)
258 if err != nil {
259 t.Errorf("error from Read: %s", err)
260 return
261 }
262
263 se, ok := p.(*SymmetricallyEncrypted)
264 if !ok {
265 t.Errorf("didn't read a *SymmetricallyEncrypted")
266 return
267 }
268
269 if se.Version != symmetricallyEncryptedVersionAead {
270 t.Errorf("found wrong version, want: %d, got: %d", symmetricallyEncryptedVersionAead, se.Version)
271 }
272
273 if se.Cipher != cipherSuite.Cipher {
274 t.Errorf("found wrong cipher, want: %d, got: %d", cipherSuite.Cipher, se.Cipher)
275 }
276
277 if se.Mode != cipherSuite.Mode {
278 t.Errorf("found wrong mode, want: %d, got: %d", cipherSuite.Mode, se.Mode)
279 }
280
281 r, err := se.Decrypt(CipherFunction(0), key)
282 if err != nil {
283 t.Errorf("error from Decrypt: %s", err)
284 return
285 }
286
287 contentsCopy := bytes.NewBuffer(nil)
288 _, err = io.Copy(contentsCopy, r)
289 if err != nil {
290 t.Errorf("error from io.Copy: %s", err)
291 return
292 }
293 if !bytes.Equal(contentsCopy.Bytes(), contents) {
294 t.Errorf("contents not equal got: %x want: %x", contentsCopy.Bytes(), contents)
295 }
296 }
297
View as plain text