...
1 package expander_test
2
3 import (
4 "bytes"
5 "crypto"
6 _ "crypto/sha256"
7 _ "crypto/sha512"
8 "encoding/hex"
9 "encoding/json"
10 "fmt"
11 "os"
12 "path/filepath"
13 "strconv"
14 "testing"
15
16 "github.com/cloudflare/circl/expander"
17 "github.com/cloudflare/circl/internal/test"
18 "github.com/cloudflare/circl/xof"
19 )
20
21 func TestExpander(t *testing.T) {
22 fileNames, err := filepath.Glob("./testdata/*.json")
23 if err != nil {
24 t.Fatal(err)
25 }
26
27 for _, fileName := range fileNames {
28 f, err := os.Open(fileName)
29 if err != nil {
30 t.Fatal(err)
31 }
32 dec := json.NewDecoder(f)
33 var v vectorExpanderSuite
34 err = dec.Decode(&v)
35 if err != nil {
36 t.Fatal(err)
37 }
38 f.Close()
39
40 t.Run(v.Name+"/"+v.Hash, func(t *testing.T) { testExpander(t, &v) })
41 }
42 }
43
44 func testExpander(t *testing.T, vs *vectorExpanderSuite) {
45 var exp expander.Expander
46 switch vs.Hash {
47 case "SHA256":
48 exp = expander.NewExpanderMD(crypto.SHA256, []byte(vs.DST))
49 case "SHA512":
50 exp = expander.NewExpanderMD(crypto.SHA512, []byte(vs.DST))
51 case "SHAKE128":
52 exp = expander.NewExpanderXOF(xof.SHAKE128, vs.K, []byte(vs.DST))
53 case "SHAKE256":
54 exp = expander.NewExpanderXOF(xof.SHAKE256, vs.K, []byte(vs.DST))
55 default:
56 t.Skip("hash not supported: " + vs.Hash)
57 }
58
59 for i, v := range vs.Tests {
60 lenBytes, err := strconv.ParseUint(v.Len, 0, 64)
61 if err != nil {
62 t.Fatal(err)
63 }
64
65 got := exp.Expand([]byte(v.Msg), uint(lenBytes))
66 want, err := hex.DecodeString(v.UniformBytes)
67 if err != nil {
68 t.Fatal(err)
69 }
70
71 if !bytes.Equal(got, want) {
72 test.ReportError(t, got, want, i)
73 }
74 }
75 }
76
77 type vectorExpanderSuite struct {
78 DST string `json:"DST"`
79 Hash string `json:"hash"`
80 Name string `json:"name"`
81 K uint `json:"k"`
82 Tests []struct {
83 DstPrime string `json:"DST_prime"`
84 Len string `json:"len_in_bytes"`
85 Msg string `json:"msg"`
86 MsgPrime string `json:"msg_prime"`
87 UniformBytes string `json:"uniform_bytes"`
88 } `json:"tests"`
89 }
90
91 func BenchmarkExpander(b *testing.B) {
92 in := []byte("input")
93 dst := []byte("dst")
94
95 for _, v := range []struct {
96 Name string
97 Exp expander.Expander
98 }{
99 {"XMD", expander.NewExpanderMD(crypto.SHA256, dst)},
100 {"XOF", expander.NewExpanderXOF(xof.SHAKE128, 0, dst)},
101 } {
102 exp := v.Exp
103 for l := 8; l <= 10; l++ {
104 max := int64(1) << uint(l)
105
106 b.Run(fmt.Sprintf("%v/%v", v.Name, max), func(b *testing.B) {
107 b.SetBytes(max)
108 b.ResetTimer()
109 for i := 0; i < b.N; i++ {
110 exp.Expand(in, uint(max))
111 }
112 })
113 }
114 }
115 }
116
View as plain text