1 package common
2
3 import (
4 "bytes"
5 "crypto/rand"
6 "fmt"
7 "testing"
8 )
9
10 func (p *Poly) RandAbsLe9Q() {
11 max := 9 * uint32(Q)
12 r := randSliceUint32WithMax(uint(N), max)
13 for i := 0; i < N; i++ {
14 p[i] = int16(int32(r[i]))
15 }
16 }
17
18
19 func sModQ(x int16) int16 {
20 x = x % Q
21 if x >= (Q-1)/2 {
22 x = x - Q
23 }
24 return x
25 }
26
27 func TestDecompressMessage(t *testing.T) {
28 var m, m2 [PlaintextSize]byte
29 var p Poly
30 for i := 0; i < 1000; i++ {
31 if n, err := rand.Read(m[:]); err != nil {
32 t.Error(err)
33 } else if n != len(m) {
34 t.Fatal("short read from RNG")
35 }
36
37 p.DecompressMessage(m[:])
38 p.CompressMessageTo(m2[:])
39 if m != m2 {
40 t.Fatal()
41 }
42 }
43 }
44
45 func TestCompress(t *testing.T) {
46 for _, d := range []int{4, 5, 10, 11} {
47 d := d
48 t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) {
49 var p, q Poly
50 bound := (Q + (1 << uint(d))) >> uint(d+1)
51 buf := make([]byte, (N*d-1)/8+1)
52 for i := 0; i < 1000; i++ {
53 p.Rand()
54 p.CompressTo(buf, d)
55 q.Decompress(buf, d)
56 for j := 0; j < N; j++ {
57 diff := sModQ(p[j] - q[j])
58 if diff < 0 {
59 diff = -diff
60 }
61 if diff > bound {
62 t.Logf("%v\n", buf)
63 t.Fatalf("|%d - %d mod^± q| = %d > %d, j=%d",
64 p[i], q[j], diff, bound, j)
65 }
66 }
67 }
68 })
69 }
70 }
71
72 func TestCompressMessage(t *testing.T) {
73 var p Poly
74 var m [32]byte
75 ok := true
76 for i := 0; i < int(Q); i++ {
77 p[0] = int16(i)
78 p.CompressMessageTo(m[:])
79 want := byte(0)
80 if i >= 833 && i < 2497 {
81 want = 1
82 }
83 if m[0] != want {
84 ok = false
85 t.Logf("%d %d %d", i, want, m[0])
86 }
87 }
88 if !ok {
89 t.Fatal()
90 }
91 }
92
93 func TestMulHat(t *testing.T) {
94 for k := 0; k < 1000; k++ {
95 var a, b, p, ah, bh, ph Poly
96 a.RandAbsLeQ()
97 b.RandAbsLeQ()
98 b[0] = 1
99
100 ah = a
101 bh = b
102 ah.NTT()
103 bh.NTT()
104 ph.MulHat(&ah, &bh)
105 ph.BarrettReduce()
106 ph.InvNTT()
107
108 for i := 0; i < N; i++ {
109 for j := 0; j < N; j++ {
110 v := montReduce(int32(a[i]) * int32(b[j]))
111 k := i + j
112 if k >= N {
113
114 k -= N
115 v = -v
116 }
117 p[k] = barrettReduce(v + p[k])
118 }
119 }
120
121 for i := 0; i < N; i++ {
122 p[i] = int16((int32(p[i]) * ((1 << 16) % int32(Q))) % int32(Q))
123 }
124
125 p.Normalize()
126 ph.Normalize()
127 a.Normalize()
128 b.Normalize()
129
130 if p != ph {
131 t.Fatalf("%v\n%v\n%v\n%v", a, b, p, ph)
132 }
133 }
134 }
135
136 func TestAddAgainstGeneric(t *testing.T) {
137 for k := 0; k < 1000; k++ {
138 var p1, p2, a, b Poly
139 a.RandAbsLeQ()
140 b.RandAbsLeQ()
141 p1.Add(&a, &b)
142 p2.addGeneric(&a, &b)
143 if p1 != p2 {
144 t.Fatalf("Add(%v, %v) = \n%v \n!= %v", a, b, p1, p2)
145 }
146 }
147 }
148
149 func BenchmarkAdd(b *testing.B) {
150 var p Poly
151 for i := 0; i < b.N; i++ {
152 p.Add(&p, &p)
153 }
154 }
155
156 func BenchmarkAddGeneric(b *testing.B) {
157 var p Poly
158 for i := 0; i < b.N; i++ {
159 p.addGeneric(&p, &p)
160 }
161 }
162
163 func TestSubAgainstGeneric(t *testing.T) {
164 for k := 0; k < 1000; k++ {
165 var p1, p2, a, b Poly
166 a.RandAbsLeQ()
167 b.RandAbsLeQ()
168 p1.Sub(&a, &b)
169 p2.subGeneric(&a, &b)
170 if p1 != p2 {
171 t.Fatalf("Sub(%v, %v) = \n%v \n!= %v", a, b, p1, p2)
172 }
173 }
174 }
175
176 func BenchmarkSub(b *testing.B) {
177 var p Poly
178 for i := 0; i < b.N; i++ {
179 p.Sub(&p, &p)
180 }
181 }
182
183 func BenchmarkSubGeneric(b *testing.B) {
184 var p Poly
185 for i := 0; i < b.N; i++ {
186 p.subGeneric(&p, &p)
187 }
188 }
189
190 func TestMulHatAgainstGeneric(t *testing.T) {
191 for k := 0; k < 1000; k++ {
192 var p1, p2, a, b Poly
193 a.RandAbsLeQ()
194 b.RandAbsLeQ()
195 a2 := a
196 b2 := b
197 a2.Tangle()
198 b2.Tangle()
199 p1.MulHat(&a2, &b2)
200 p1.Detangle()
201 p2.mulHatGeneric(&a, &b)
202 if p1 != p2 {
203 t.Fatalf("MulHat(%v, %v) = \n%v \n!= %v", a, b, p1, p2)
204 }
205 }
206 }
207
208 func BenchmarkMulHat(b *testing.B) {
209 var p Poly
210 for i := 0; i < b.N; i++ {
211 p.MulHat(&p, &p)
212 }
213 }
214
215 func BenchmarkMulHatGeneric(b *testing.B) {
216 var p Poly
217 for i := 0; i < b.N; i++ {
218 p.mulHatGeneric(&p, &p)
219 }
220 }
221
222 func BenchmarkBarrettReduce(b *testing.B) {
223 var p Poly
224 for i := 0; i < b.N; i++ {
225 p.BarrettReduce()
226 }
227 }
228
229 func BenchmarkBarrettReduceGeneric(b *testing.B) {
230 var p Poly
231 for i := 0; i < b.N; i++ {
232 p.barrettReduceGeneric()
233 }
234 }
235
236 func TestBarrettReduceAgainstGeneric(t *testing.T) {
237 for k := 0; k < 1000; k++ {
238 var p1, p2, a Poly
239 a.RandAbsLe9Q()
240 p1 = a
241 p2 = a
242 p1.BarrettReduce()
243 p2.barrettReduceGeneric()
244 if p1 != p2 {
245 t.Fatalf("BarrettReduce(%v) = \n%v \n!= %v", a, p1, p2)
246 }
247 }
248 }
249
250 func BenchmarkNormalize(b *testing.B) {
251 var p Poly
252 for i := 0; i < b.N; i++ {
253 p.Normalize()
254 }
255 }
256
257 func BenchmarkNormalizeGeneric(b *testing.B) {
258 var p Poly
259 for i := 0; i < b.N; i++ {
260 p.barrettReduceGeneric()
261 }
262 }
263
264 func TestNormalizeAgainstGeneric(t *testing.T) {
265 for k := 0; k < 1000; k++ {
266 var p1, p2, a Poly
267 a.RandAbsLe9Q()
268 p1 = a
269 p2 = a
270 p1.Normalize()
271 p2.normalizeGeneric()
272 if p1 != p2 {
273 t.Fatalf("Normalize(%v) = \n%v \n!= %v", a, p1, p2)
274 }
275 }
276 }
277
278 func (p *Poly) OldCompressTo(m []byte, d int) {
279 switch d {
280 case 4:
281 var t [8]uint16
282 idx := 0
283 for i := 0; i < N/8; i++ {
284 for j := 0; j < 8; j++ {
285 t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/
286 uint32(Q)) & ((1 << 4) - 1)
287 }
288 m[idx] = byte(t[0]) | byte(t[1]<<4)
289 m[idx+1] = byte(t[2]) | byte(t[3]<<4)
290 m[idx+2] = byte(t[4]) | byte(t[5]<<4)
291 m[idx+3] = byte(t[6]) | byte(t[7]<<4)
292 idx += 4
293 }
294
295 case 5:
296 var t [8]uint16
297 idx := 0
298 for i := 0; i < N/8; i++ {
299 for j := 0; j < 8; j++ {
300 t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/
301 uint32(Q)) & ((1 << 5) - 1)
302 }
303 m[idx] = byte(t[0]) | byte(t[1]<<5)
304 m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
305 m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4)
306 m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6)
307 m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3)
308 idx += 5
309 }
310
311 case 10:
312 var t [4]uint16
313 idx := 0
314 for i := 0; i < N/4; i++ {
315 for j := 0; j < 4; j++ {
316 t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/
317 uint32(Q)) & ((1 << 10) - 1)
318 }
319 m[idx] = byte(t[0])
320 m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
321 m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4)
322 m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6)
323 m[idx+4] = byte(t[3] >> 2)
324 idx += 5
325 }
326 case 11:
327 var t [8]uint16
328 idx := 0
329 for i := 0; i < N/8; i++ {
330 for j := 0; j < 8; j++ {
331 t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/
332 uint32(Q)) & ((1 << 11) - 1)
333 }
334 m[idx] = byte(t[0])
335 m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
336 m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6)
337 m[idx+3] = byte(t[2] >> 2)
338 m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1)
339 m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4)
340 m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7)
341 m[idx+7] = byte(t[5] >> 1)
342 m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2)
343 m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5)
344 m[idx+10] = byte(t[7] >> 3)
345 idx += 11
346 }
347 default:
348 panic("unsupported d")
349 }
350 }
351
352 func TestCompressFullInputFirstCoeff(t *testing.T) {
353 for _, d := range []int{4, 5, 10, 11} {
354 d := d
355 t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) {
356 var p, q Poly
357 bound := (Q + (1 << uint(d))) >> uint(d+1)
358 buf := make([]byte, (N*d-1)/8+1)
359 buf2 := make([]byte, len(buf))
360 for i := int16(0); i < Q; i++ {
361 p[0] = i
362 p.CompressTo(buf, d)
363 p.OldCompressTo(buf2, d)
364 if !bytes.Equal(buf, buf2) {
365 t.Fatalf("%d", i)
366 }
367 q.Decompress(buf, d)
368 diff := sModQ(p[0] - q[0])
369 if diff < 0 {
370 diff = -diff
371 }
372 if diff > bound {
373 t.Logf("%v\n", buf)
374 t.Fatalf("|%d - %d mod^± q| = %d > %d",
375 p[0], q[0], diff, bound)
376 }
377 }
378 })
379 }
380 }
381
View as plain text