1 package mlsbset_test
2
3 import (
4 "crypto/rand"
5 "errors"
6 "math/big"
7 "testing"
8
9 "github.com/cloudflare/circl/internal/conv"
10 "github.com/cloudflare/circl/internal/test"
11 "github.com/cloudflare/circl/math/mlsbset"
12 )
13
14 func TestExp(t *testing.T) {
15 T := uint(126)
16 for v := uint(1); v <= 5; v++ {
17 for w := uint(2); w <= 5; w++ {
18 m, err := mlsbset.New(T, v, w)
19 if err != nil {
20 test.ReportError(t, err, nil)
21 }
22 testExp(t, m)
23 }
24 }
25 }
26
27 func testExp(t *testing.T, m mlsbset.Encoder) {
28 const testTimes = 1 << 8
29 params := m.GetParams()
30 TBytes := (params.T + 7) / 8
31 topBits := (byte(1) << (params.T % 8)) - 1
32 k := make([]byte, TBytes)
33 for i := 0; i < testTimes; i++ {
34 _, _ = rand.Read(k)
35 k[0] |= 1
36 k[TBytes-1] &= topBits
37
38 c, err := m.Encode(k)
39 if err != nil {
40 test.ReportError(t, err, nil)
41 }
42
43 g := zzAdd{m.GetParams()}
44 a := c.Exp(g)
45
46 got := a.(*big.Int)
47 want := conv.BytesLe2BigInt(k)
48 if got.Cmp(want) != 0 {
49 test.ReportError(t, got, want, m)
50 }
51 }
52 }
53
54 type zzAdd struct{ set mlsbset.Params }
55
56 func (zzAdd) Identity() mlsbset.EltG { return big.NewInt(0) }
57 func (zzAdd) NewEltP() mlsbset.EltP { return new(big.Int) }
58 func (zzAdd) Sqr(x mlsbset.EltG) {
59 a := x.(*big.Int)
60 a.Add(a, a)
61 }
62
63 func (zzAdd) Mul(x mlsbset.EltG, y mlsbset.EltP) {
64 a := x.(*big.Int)
65 b := y.(*big.Int)
66 a.Add(a, b)
67 }
68
69 func (z zzAdd) ExtendedEltP() mlsbset.EltP {
70 a := big.NewInt(1)
71 a.Lsh(a, z.set.W*z.set.D)
72 return a
73 }
74
75 func (z zzAdd) Lookup(x mlsbset.EltP, idTable uint, sgnElt int32, idElt int32) {
76 a := x.(*big.Int)
77 a.SetInt64(1)
78 a.Lsh(a, z.set.E*idTable)
79 sum := big.NewInt(0)
80 for i := int(z.set.W - 2); i >= 0; i-- {
81 ui := big.NewInt(int64((idElt >> uint(i)) & 0x1))
82 sum.Add(sum, ui)
83 sum.Lsh(sum, z.set.D)
84 }
85 sum.Add(sum, big.NewInt(1))
86 a.Mul(a, sum)
87 if sgnElt == -1 {
88 a.Neg(a)
89 }
90 }
91
92 func TestEncodeErr(t *testing.T) {
93 t.Run("mArgs", func(t *testing.T) {
94 _, got := mlsbset.New(0, 0, 0)
95 want := errors.New("t>1, v>=1, w>=2")
96 if got.Error() != want.Error() {
97 test.ReportError(t, got, want)
98 }
99 })
100 t.Run("kOdd", func(t *testing.T) {
101 m, _ := mlsbset.New(16, 2, 2)
102 k := make([]byte, 2)
103 _, got := m.Encode(k)
104 want := errors.New("k must be odd")
105 if got.Error() != want.Error() {
106 test.ReportError(t, got, want)
107 }
108 })
109 t.Run("kBig", func(t *testing.T) {
110 m, _ := mlsbset.New(16, 2, 2)
111 k := make([]byte, 4)
112 _, got := m.Encode(k)
113 want := errors.New("k too big")
114 if got.Error() != want.Error() {
115 test.ReportError(t, got, want)
116 }
117 })
118 t.Run("kEmpty", func(t *testing.T) {
119 m, _ := mlsbset.New(16, 2, 2)
120 k := []byte{}
121 _, got := m.Encode(k)
122 want := errors.New("empty slice")
123 if got.Error() != want.Error() {
124 test.ReportError(t, got, want)
125 }
126 })
127 }
128
129 func BenchmarkEncode(b *testing.B) {
130 t, v, w := uint(256), uint(2), uint(3)
131 m, _ := mlsbset.New(t, v, w)
132 params := m.GetParams()
133 TBytes := (params.T + 7) / 8
134 topBits := (byte(1) << (params.T % 8)) - 1
135
136 k := make([]byte, TBytes)
137 _, _ = rand.Read(k)
138 k[0] |= 1
139 k[TBytes-1] &= topBits
140
141 c, _ := m.Encode(k)
142 g := zzAdd{params}
143
144 b.Run("Encode", func(b *testing.B) {
145 for i := 0; i < b.N; i++ {
146 _, _ = m.Encode(k)
147 }
148 })
149 b.Run("Exp", func(b *testing.B) {
150 for i := 0; i < b.N; i++ {
151 c.Exp(g)
152 }
153 })
154 }
155
View as plain text