...
1 package common
2
3 import "testing"
4
5 func BenchmarkNTT(b *testing.B) {
6 var a Poly
7 for i := 0; i < b.N; i++ {
8 a.NTT()
9 }
10 }
11
12 func BenchmarkNTTGeneric(b *testing.B) {
13 var a Poly
14 for i := 0; i < b.N; i++ {
15 a.nttGeneric()
16 }
17 }
18
19 func BenchmarkInvNTT(b *testing.B) {
20 var a Poly
21 for i := 0; i < b.N; i++ {
22 a.InvNTT()
23 }
24 }
25
26 func BenchmarkInvNTTGeneric(b *testing.B) {
27 var a Poly
28 for i := 0; i < b.N; i++ {
29 a.invNTTGeneric()
30 }
31 }
32
33 func (p *Poly) Rand() {
34 max := uint32(Q)
35 r := randSliceUint32WithMax(uint(N), max)
36 for i := 0; i < N; i++ {
37 p[i] = int16(r[i])
38 }
39 }
40
41 func (p *Poly) RandAbsLeQ() {
42 max := 2 * uint32(Q)
43 r := randSliceUint32WithMax(uint(N), max)
44 for i := 0; i < N; i++ {
45 p[i] = int16(int32(r[i]) - int32(Q))
46 }
47 }
48
49 func TestNTTAgainstGeneric(t *testing.T) {
50 for k := 0; k < 1000; k++ {
51 var p, q1, q2 Poly
52 p.RandAbsLeQ()
53 q1 = p
54 q2 = p
55 q1.NTT()
56 q1.Detangle()
57 q2.nttGeneric()
58 if q1 != q2 {
59 t.Fatalf("NTT(%v) = \n%v \n!= %v", p, q2, q1)
60 }
61 }
62 }
63
64 func TestInvNTTAgainstGeneric(t *testing.T) {
65 for k := 0; k < 1000; k++ {
66 var p, q1, q2 Poly
67 p.RandAbsLeQ()
68 q1 = p
69 q2 = p
70 q1.Tangle()
71 q1.InvNTT()
72 q2.invNTTGeneric()
73
74 q1.Normalize()
75 q2.Normalize()
76
77 if q1 != q2 {
78 t.Fatalf("InvNTT(%v) = \n%v \n!= %v", p, q2, q1)
79 }
80 }
81 }
82
83 func TestNTT(t *testing.T) {
84 for k := 0; k < 1000; k++ {
85 var p, q Poly
86 p.RandAbsLeQ()
87 q = p
88 q.Normalize()
89 p.NTT()
90 for i := 0; i < N; i++ {
91 if p[i] > 7*Q || 7*Q < p[i] {
92 t.Fatal()
93 }
94 }
95 p.Normalize()
96 p.InvNTT()
97 for i := 0; i < N; i++ {
98 if p[i] > Q || p[i] < -Q {
99 t.Fatal()
100 }
101 }
102 p.Normalize()
103 for i := 0; i < N; i++ {
104 if int32(p[i]) != (int32(q[i])*(1<<16))%int32(Q) {
105 t.Fatal()
106 }
107 }
108 }
109 }
110
111 func TestInvNTTReductions(t *testing.T) {
112
113
114 xs := [256]int{}
115 for i := 0; i < 256; i++ {
116 xs[i] = 1
117 }
118
119 r := -1
120 for layer := 1; layer < 8; layer++ {
121 w := 1 << uint(layer)
122 i := 0
123 for i+w < 256 {
124 xs[i] = xs[i] + xs[i+w]
125 if xs[i] > 9 {
126 t.Fatal()
127 }
128 xs[i+w] = 1
129 i++
130 if i%w == 0 {
131 i += w
132 }
133 }
134 for {
135 r++
136 i := InvNTTReductions[r]
137 if i < 0 {
138 break
139 }
140 xs[i] = 1
141 }
142 }
143 }
144
View as plain text