1 package hpke
2
3 import (
4 "bytes"
5 "crypto/rand"
6 "fmt"
7 "testing"
8
9 "github.com/cloudflare/circl/internal/test"
10 )
11
12 func TestAeadExporter(t *testing.T) {
13 suite := Suite{kdfID: KDF_HKDF_SHA256, aeadID: AEAD_AES128GCM}
14 exporter := &encdecContext{suite: suite}
15 maxLength := uint(255 * suite.kdfID.ExtractSize())
16
17 err := test.CheckPanic(func() {
18 exporter.Export([]byte("exporter"), maxLength+1)
19 })
20 test.CheckNoErr(t, err, "exporter max size")
21 }
22
23 func setupAeadTest() (*sealContext, *openContext, error) {
24 suite := Suite{aeadID: AEAD_AES128GCM}
25 key := make([]byte, suite.aeadID.KeySize())
26 if n, err := rand.Read(key); err != nil {
27 return nil, nil, err
28 } else if n != len(key) {
29 return nil, nil, fmt.Errorf("unexpected key size: got %d; want %d", n, len(key))
30 }
31
32 aead, err := suite.aeadID.New(key)
33 if err != nil {
34 return nil, nil, err
35 }
36
37 Nn := suite.aeadID.NonceSize()
38 baseNonce := make([]byte, Nn)
39 if n, err := rand.Read(baseNonce); err != nil {
40 return nil, nil, err
41 } else if n != len(baseNonce) {
42 return nil, nil, fmt.Errorf("unexpected base nonce size: got %d; want %d", n, len(baseNonce))
43 }
44
45 sealer := &sealContext{
46 &encdecContext{
47 suite, nil, nil, nil, nil, nil, baseNonce, make([]byte, Nn), aead, make([]byte, Nn),
48 },
49 }
50 opener := &openContext{
51 &encdecContext{
52 suite, nil, nil, nil, nil, nil, baseNonce, make([]byte, Nn), aead, make([]byte, Nn),
53 },
54 }
55 return sealer, opener, nil
56 }
57
58 func TestAeadNonceUpdate(t *testing.T) {
59 sealer, opener, err := setupAeadTest()
60 test.CheckNoErr(t, err, "setup failed")
61
62 pt := []byte("plaintext")
63 aad := []byte("aad")
64
65 numAttempts := 2
66 var prevCt []byte
67 for i := 0; i < numAttempts; i++ {
68 ct, err := sealer.Seal(pt, aad)
69 if err != nil {
70 t.Fatalf("encryption failed: %s", err)
71 }
72
73 if prevCt != nil && bytes.Equal(ct, prevCt) {
74 t.Error("ciphertext matches the previous (nonce not updated)")
75 }
76
77 _, err = opener.Open(ct, aad)
78 if err != nil {
79 t.Errorf("decryption failed: %s", err)
80 }
81
82 prevCt = ct
83 }
84 }
85
86 func TestOpenPhaseMismatch(t *testing.T) {
87 sealer, opener, err := setupAeadTest()
88 test.CheckNoErr(t, err, "setup failed")
89
90 pt := []byte("plaintext")
91 aad := []byte("aad")
92
93 ct, err := sealer.Seal(pt, aad)
94 if err != nil {
95 t.Fatalf("encryption failed: %s", err)
96 }
97
98 recovered, err := opener.Open(ct, aad)
99 if err != nil {
100 t.Fatalf("decryption failed: %s", err)
101 }
102
103 if !bytes.Equal(pt, recovered) {
104 t.Fatal("Plaintext mismatch")
105 }
106
107 _, err = opener.Open(ct, aad)
108 if err == nil {
109 t.Fatal("decryption succeeded when it should have failed")
110 }
111 }
112
113 func TestSealPhaseMismatch(t *testing.T) {
114 sealer, opener, err := setupAeadTest()
115 test.CheckNoErr(t, err, "setup failed")
116
117 pt := []byte("plaintext")
118 aad := []byte("aad")
119
120 _, err = sealer.Seal(pt, aad)
121 if err != nil {
122 t.Fatalf("encryption failed: %s", err)
123 }
124
125 ct, err := sealer.Seal(pt, aad)
126 if err != nil {
127 t.Fatalf("encryption failed: %s", err)
128 }
129
130 _, err = opener.Open(ct, aad)
131 if err == nil {
132 t.Fatal("decryption succeeded when it should have failed")
133 }
134 }
135
136 func TestAeadSeqOverflow(t *testing.T) {
137 sealer, opener, err := setupAeadTest()
138 test.CheckNoErr(t, err, "setup failed")
139
140 Nn := len(sealer.baseNonce)
141 pt := []byte("plaintext")
142 aad := []byte("aad")
143
144
145 for i := 0; i < Nn; i++ {
146 sealer.sequenceNumber[i] = 0xFF
147 opener.sequenceNumber[i] = 0xFF
148 }
149 sealer.sequenceNumber[Nn-1] = 0x00
150 opener.sequenceNumber[Nn-1] = 0x00
151
152 numAttempts := 260
153 wantCorrect := 2 * 255
154 wantIncorrect := 2*numAttempts - wantCorrect
155 gotCorrect := 0
156 gotIncorrect := 0
157
158 for i := 0; i < numAttempts; i++ {
159 ct, err := sealer.Seal(pt, aad)
160 switch {
161 case ct != nil && err == nil:
162 gotCorrect++
163 case ct == nil && err != nil:
164 gotIncorrect++
165 default:
166 t.FailNow()
167 }
168
169 pt2, err := opener.Open(ct, aad)
170 switch {
171 case pt2 != nil && err == nil:
172 gotCorrect++
173 case pt2 == nil && err != nil:
174 gotIncorrect++
175 default:
176 t.FailNow()
177 }
178 }
179
180 if gotCorrect != wantCorrect {
181 test.ReportError(t, gotCorrect, wantCorrect)
182 }
183 if gotIncorrect != wantIncorrect {
184 test.ReportError(t, gotIncorrect, wantIncorrect)
185 }
186 }
187
View as plain text