1 package hpke
2
3 import (
4 "bytes"
5 "encoding/hex"
6 "encoding/json"
7 "fmt"
8 "io"
9 "os"
10 "testing"
11
12 "github.com/cloudflare/circl/internal/test"
13 "github.com/cloudflare/circl/kem"
14 "golang.org/x/crypto/sha3"
15 )
16
17 var (
18 outputTestVectorEnvironmentKey = "HPKE_TEST_VECTORS_OUT"
19 testVectorEncryptionCount = 257
20 testVectorExportLength = 32
21 )
22
23 func TestVectors(t *testing.T) {
24
25
26 vectors := readFile(t, "testdata/vectors_rfc9180_5f503c5.json")
27 for i, v := range vectors {
28 t.Run(fmt.Sprintf("v%v", i), v.verify)
29 }
30 }
31
32 func (v *vector) verify(t *testing.T) {
33 m := v.ModeID
34 kem, kdf, aead := KEM(v.KemID), KDF(v.KdfID), AEAD(v.AeadID)
35 if !kem.IsValid() {
36 t.Skipf("Skipping test with unknown KEM: %x", kem)
37 }
38 if !kdf.IsValid() {
39 t.Skipf("Skipping test with unknown KDF: %x", kdf)
40 }
41 if !aead.IsValid() {
42 t.Skipf("Skipping test with unknown AEAD: %x", aead)
43 }
44 s := NewSuite(kem, kdf, aead)
45
46 sender, recv := v.getActors(t, kem.Scheme(), s)
47 sealer, opener := v.setup(t, kem.Scheme(), sender, recv, m, s)
48
49 v.checkAead(t, (sealer.(*sealContext)).encdecContext, m)
50 v.checkAead(t, (opener.(*openContext)).encdecContext, m)
51 v.checkEncryptions(t, sealer, opener, m)
52 v.checkExports(t, sealer, m)
53 v.checkExports(t, opener, m)
54 }
55
56 func (v *vector) getActors(
57 t *testing.T, dhkem kem.Scheme, s Suite,
58 ) (*Sender, *Receiver) {
59 h := s.String() + "\n"
60
61 pkR, err := dhkem.UnmarshalBinaryPublicKey(hexB(t, v.PkRm))
62 test.CheckNoErr(t, err, h+"bad public key")
63
64 skR, err := dhkem.UnmarshalBinaryPrivateKey(hexB(t, v.SkRm))
65 test.CheckNoErr(t, err, h+"bad private key")
66
67 info := hexB(t, v.Info)
68 sender, err := s.NewSender(pkR, info)
69 test.CheckNoErr(t, err, h+"err sender")
70
71 recv, err := s.NewReceiver(skR, info)
72 test.CheckNoErr(t, err, h+"err receiver")
73
74 return sender, recv
75 }
76
77 func (v *vector) setup(t *testing.T, k kem.Scheme,
78 se *Sender, re *Receiver,
79 m modeID, s Suite,
80 ) (sealer Sealer, opener Opener) {
81 seed := hexB(t, v.IkmE)
82 rd := bytes.NewReader(seed)
83
84 var enc []byte
85 var skS kem.PrivateKey
86 var pkS kem.PublicKey
87 var errS, errR, errPK, errSK error
88
89 switch v.ModeID {
90 case modeBase:
91 enc, sealer, errS = se.Setup(rd)
92 if errS == nil {
93 opener, errR = re.Setup(enc)
94 }
95
96 case modePSK:
97 psk, pskid := hexB(t, v.Psk), hexB(t, v.PskID)
98 enc, sealer, errS = se.SetupPSK(rd, psk, pskid)
99 if errS == nil {
100 opener, errR = re.SetupPSK(enc, psk, pskid)
101 }
102
103 case modeAuth:
104 skS, errSK = k.UnmarshalBinaryPrivateKey(hexB(t, v.SkSm))
105 if errSK == nil {
106 pkS, errPK = k.UnmarshalBinaryPublicKey(hexB(t, v.PkSm))
107 if errPK == nil {
108 enc, sealer, errS = se.SetupAuth(rd, skS)
109 if errS == nil {
110 opener, errR = re.SetupAuth(enc, pkS)
111 }
112 }
113 }
114
115 case modeAuthPSK:
116 psk, pskid := hexB(t, v.Psk), hexB(t, v.PskID)
117 skS, errSK = k.UnmarshalBinaryPrivateKey(hexB(t, v.SkSm))
118 if errSK == nil {
119 pkS, errPK = k.UnmarshalBinaryPublicKey(hexB(t, v.PkSm))
120 if errPK == nil {
121 enc, sealer, errS = se.SetupAuthPSK(rd, skS, psk, pskid)
122 if errS == nil {
123 opener, errR = re.SetupAuthPSK(enc, psk, pskid, pkS)
124 }
125 }
126 }
127 }
128
129 h := fmt.Sprintf("mode: %v %v\n", m, s)
130 test.CheckNoErr(t, errS, h+"error on sender setup")
131 test.CheckNoErr(t, errR, h+"error on receiver setup")
132 test.CheckNoErr(t, errSK, h+"bad private key")
133 test.CheckNoErr(t, errPK, h+"bad public key")
134
135 return sealer, opener
136 }
137
138 func (v *vector) checkAead(t *testing.T, e *encdecContext, m modeID) {
139 got := e.baseNonce
140 want := hexB(t, v.BaseNonce)
141 if !bytes.Equal(got, want) {
142 test.ReportError(t, got, want, m, e.Suite())
143 }
144
145 got = e.exporterSecret
146 want = hexB(t, v.ExporterSecret)
147 if !bytes.Equal(got, want) {
148 test.ReportError(t, got, want, m, e.Suite())
149 }
150 }
151
152 func (v *vector) checkEncryptions(
153 t *testing.T,
154 se Sealer,
155 op Opener,
156 m modeID,
157 ) {
158 for j, encv := range v.Encryptions {
159 pt := hexB(t, encv.Plaintext)
160 aad := hexB(t, encv.Aad)
161
162 ct, err := se.Seal(pt, aad)
163 test.CheckNoErr(t, err, "error on sealing")
164
165 got, err := op.Open(ct, aad)
166 test.CheckNoErr(t, err, "error on opening")
167
168 want := pt
169 if !bytes.Equal(got, want) {
170 test.ReportError(t, got, want, m, se.Suite(), j)
171 }
172 }
173 }
174
175 func (v *vector) checkExports(t *testing.T, context Context, m modeID) {
176 for j, expv := range v.Exports {
177 ctx := hexB(t, expv.ExportContext)
178 want := hexB(t, expv.ExportValue)
179
180 got := context.Export(ctx, uint(expv.ExportLength))
181 if !bytes.Equal(got, want) {
182 test.ReportError(t, got, want, m, context.Suite(), j)
183 }
184 }
185 }
186
187 func hexB(t *testing.T, x string) []byte {
188 t.Helper()
189 z, err := hex.DecodeString(x)
190 test.CheckNoErr(t, err, "")
191 return z
192 }
193
194 func readFile(t *testing.T, fileName string) []vector {
195 jsonFile, err := os.Open(fileName)
196 if err != nil {
197 t.Fatalf("File %v can not be opened. Error: %v", fileName, err)
198 }
199 defer jsonFile.Close()
200 input, err := io.ReadAll(jsonFile)
201 if err != nil {
202 t.Fatalf("File %v can not be read. Error: %v", fileName, err)
203 }
204 var vectors []vector
205 err = json.Unmarshal(input, &vectors)
206 if err != nil {
207 t.Fatalf("File %v can not be loaded. Error: %v", fileName, err)
208 }
209 return vectors
210 }
211
212 type encryptionVector struct {
213 Aad string `json:"aad"`
214 Ciphertext string `json:"ct"`
215 Nonce string `json:"nonce"`
216 Plaintext string `json:"pt"`
217 }
218
219 type exportVector struct {
220 ExportContext string `json:"exporter_context"`
221 ExportLength int `json:"L"`
222 ExportValue string `json:"exported_value"`
223 }
224
225 type vector struct {
226 ModeID uint8 `json:"mode"`
227 KemID uint16 `json:"kem_id"`
228 KdfID uint16 `json:"kdf_id"`
229 AeadID uint16 `json:"aead_id"`
230 Info string `json:"info"`
231 Ier string `json:"ier,omitempty"`
232 IkmR string `json:"ikmR"`
233 IkmE string `json:"ikmE,omitempty"`
234 SkRm string `json:"skRm"`
235 SkEm string `json:"skEm,omitempty"`
236 SkSm string `json:"skSm,omitempty"`
237 Psk string `json:"psk,omitempty"`
238 PskID string `json:"psk_id,omitempty"`
239 PkSm string `json:"pkSm,omitempty"`
240 PkRm string `json:"pkRm"`
241 PkEm string `json:"pkEm,omitempty"`
242 Enc string `json:"enc"`
243 SharedSecret string `json:"shared_secret"`
244 KeyScheduleContext string `json:"key_schedule_context"`
245 Secret string `json:"secret"`
246 Key string `json:"key"`
247 BaseNonce string `json:"base_nonce"`
248 ExporterSecret string `json:"exporter_secret"`
249 Encryptions []encryptionVector `json:"encryptions"`
250 Exports []exportVector `json:"exports"`
251 }
252
253 func generateHybridKeyPair(rnd io.Reader, h kem.Scheme) ([]byte, kem.PublicKey, kem.PrivateKey, error) {
254 seed := make([]byte, h.SeedSize())
255 _, err := rnd.Read(seed)
256 if err != nil {
257 return nil, nil, nil, err
258 }
259
260 pk, sk := h.DeriveKeyPair(seed)
261 return seed, pk, sk, nil
262 }
263
264 func mustEncodePublicKey(pk kem.PublicKey) []byte {
265 enc, err := pk.MarshalBinary()
266 if err != nil {
267 panic(err)
268 }
269 return enc
270 }
271
272 func mustEncodePrivateKey(sk kem.PrivateKey) []byte {
273 enc, err := sk.MarshalBinary()
274 if err != nil {
275 panic(err)
276 }
277 return enc
278 }
279
280 func generateEncryptions(sealer Sealer, opener Opener, msg []byte) ([]encryptionVector, error) {
281 vectors := make([]encryptionVector, testVectorEncryptionCount)
282 for i := 0; i < len(vectors); i++ {
283 aad := []byte(fmt.Sprintf("Count-%d", i))
284 innerSealer := sealer.(*sealContext)
285 nonce := innerSealer.calcNonce()
286 encrypted, err := sealer.Seal(msg, aad)
287 if err != nil {
288 return nil, err
289 }
290 decrypted, err := opener.Open(encrypted, aad)
291 if err != nil {
292 return nil, err
293 }
294 if !bytes.Equal(decrypted, msg) {
295 return nil, fmt.Errorf("Mismatch messages %d", i)
296 }
297 vectors[i] = encryptionVector{
298 Plaintext: hex.EncodeToString(msg),
299 Aad: hex.EncodeToString(aad),
300 Nonce: hex.EncodeToString(nonce),
301 Ciphertext: hex.EncodeToString(encrypted),
302 }
303 }
304
305 return vectors, nil
306 }
307
308 func generateExports(sealer Sealer, opener Opener) ([]exportVector, error) {
309 exportContexts := [][]byte{
310 []byte(""),
311 {0x00},
312 []byte("TestContext"),
313 }
314 vectors := make([]exportVector, len(exportContexts))
315 for i := 0; i < len(vectors); i++ {
316 senderValue := sealer.Export(exportContexts[i], uint(testVectorExportLength))
317 receiverValue := opener.Export(exportContexts[i], uint(testVectorExportLength))
318 if !bytes.Equal(senderValue, receiverValue) {
319 return nil, fmt.Errorf("Mismatch export values")
320 }
321 vectors[i] = exportVector{
322 ExportContext: hex.EncodeToString(exportContexts[i]),
323 ExportLength: testVectorExportLength,
324 ExportValue: hex.EncodeToString(senderValue),
325 }
326 }
327
328 return vectors, nil
329 }
330
331 func TestHybridKemRoundTrip(t *testing.T) {
332 kemID := KEM_X25519_KYBER768_DRAFT00
333 kdfID := KDF_HKDF_SHA256
334 aeadID := AEAD_AES128GCM
335 rnd := sha3.NewShake128()
336 suite := NewSuite(kemID, kdfID, aeadID)
337 msg := []byte("To the universal deployment of PQC")
338 info := []byte("Hear hear")
339 pskid := []byte("before everybody for everybody for everything")
340 psk := make([]byte, 32)
341 _, _ = rnd.Read(psk)
342
343 ikmR, pkR, skR, err := generateHybridKeyPair(rnd, kemID.Scheme())
344 if err != nil {
345 t.Error(err)
346 }
347
348 ier := make([]byte, 64)
349 _, _ = rnd.Read(ier)
350
351 receiver, err := suite.NewReceiver(skR, info)
352 if err != nil {
353 t.Error(err)
354 }
355
356 sender, err := suite.NewSender(pkR, info)
357 if err != nil {
358 t.Error(err)
359 }
360
361 generateVector := func(mode uint8) vector {
362 var (
363 err2 error
364 sealer Sealer
365 opener Opener
366 enc []byte
367 )
368 rnd2 := bytes.NewBuffer(ier)
369 switch mode {
370 case modeBase:
371 enc, sealer, err2 = sender.Setup(rnd2)
372 if err2 != nil {
373 t.Error(err2)
374 }
375 opener, err2 = receiver.Setup(enc)
376 if err2 != nil {
377 t.Error(err2)
378 }
379 case modePSK:
380 enc, sealer, err2 = sender.SetupPSK(rnd2, psk, pskid)
381 if err2 != nil {
382 t.Error(err2)
383 }
384 opener, err2 = receiver.SetupPSK(enc, psk, pskid)
385 if err2 != nil {
386 t.Error(err2)
387 }
388 default:
389 panic("unsupported mode")
390 }
391
392 if rnd2.Len() != 0 {
393 t.Fatal()
394 }
395
396 innerSealer := sealer.(*sealContext)
397
398 encryptions, err2 := generateEncryptions(sealer, opener, msg)
399 if err2 != nil {
400 t.Error(err2)
401 }
402 exports, err2 := generateExports(sealer, opener)
403 if err2 != nil {
404 t.Error(err2)
405 }
406
407 ret := vector{
408 ModeID: mode,
409 KemID: uint16(kemID),
410 KdfID: uint16(kdfID),
411 AeadID: uint16(aeadID),
412 Ier: hex.EncodeToString(ier),
413 Info: hex.EncodeToString(info),
414 IkmR: hex.EncodeToString(ikmR),
415 SkRm: hex.EncodeToString(mustEncodePrivateKey(skR)),
416 PkRm: hex.EncodeToString(mustEncodePublicKey(pkR)),
417 Enc: hex.EncodeToString(enc),
418 SharedSecret: hex.EncodeToString(innerSealer.sharedSecret),
419 KeyScheduleContext: hex.EncodeToString(innerSealer.keyScheduleContext),
420 Secret: hex.EncodeToString(innerSealer.secret),
421 Key: hex.EncodeToString(innerSealer.key),
422 BaseNonce: hex.EncodeToString(innerSealer.baseNonce),
423 ExporterSecret: hex.EncodeToString(innerSealer.exporterSecret),
424 Encryptions: encryptions,
425 Exports: exports,
426 }
427
428 if mode == modePSK {
429 ret.Psk = hex.EncodeToString(psk)
430 ret.PskID = hex.EncodeToString(pskid)
431 }
432
433 return ret
434 }
435
436 encodedVector, err := json.Marshal([]vector{
437 generateVector(modeBase),
438 generateVector(modePSK),
439 })
440 if err != nil {
441 t.Error(err)
442 }
443
444 var outputFile string
445 if outputFile = os.Getenv(outputTestVectorEnvironmentKey); len(outputFile) > 0 {
446
447 err = os.WriteFile(outputFile, encodedVector, 0o644)
448 if err != nil {
449 t.Error(err)
450 }
451 }
452 }
453
View as plain text