1
2
3 package internal
4
5 import (
6 cryptoRand "crypto/rand"
7 "crypto/subtle"
8 "io"
9
10 "github.com/cloudflare/circl/internal/sha3"
11 "github.com/cloudflare/circl/sign/dilithium/internal/common"
12 )
13
14 const (
15
16
17 PolyLeqEtaSize = (common.N * DoubleEtaBits) / 8
18
19
20 Beta = Tau * Eta
21
22
23 Gamma1 = 1 << Gamma1Bits
24
25
26 PolyLeGamma1Size = (Gamma1Bits + 1) * common.N / 8
27
28
29 Alpha = 2 * Gamma2
30
31
32 PrivateKeySize = 32 + 32 + 32 + PolyLeqEtaSize*(L+K) + common.PolyT0Size*K
33
34
35 PublicKeySize = 32 + common.PolyT1Size*K
36
37
38 SignatureSize = L*PolyLeGamma1Size + Omega + K + 32
39
40
41 PolyW1Size = (common.N * (common.QBits - Gamma1Bits)) / 8
42 )
43
44
45 type PublicKey struct {
46 rho [32]byte
47 t1 VecK
48
49
50 t1p [common.PolyT1Size * K]byte
51 A *Mat
52 tr *[32]byte
53 }
54
55
56 type PrivateKey struct {
57 rho [32]byte
58 key [32]byte
59 s1 VecL
60 s2 VecK
61 t0 VecK
62 tr [32]byte
63
64
65 A Mat
66 s1h VecL
67 s2h VecK
68 t0h VecK
69 }
70
71 type unpackedSignature struct {
72 z VecL
73 hint VecK
74 c [32]byte
75 }
76
77
78 func (sig *unpackedSignature) Pack(buf []byte) {
79 copy(buf[:], sig.c[:])
80 sig.z.PackLeGamma1(buf[32:])
81 sig.hint.PackHint(buf[32+L*PolyLeGamma1Size:])
82 }
83
84
85
86
87 func (sig *unpackedSignature) Unpack(buf []byte) bool {
88 if len(buf) < SignatureSize {
89 return false
90 }
91 copy(sig.c[:], buf[:])
92 sig.z.UnpackLeGamma1(buf[32:])
93 if sig.z.Exceeds(Gamma1 - Beta) {
94 return false
95 }
96 if !sig.hint.UnpackHint(buf[32+L*PolyLeGamma1Size:]) {
97 return false
98 }
99 return true
100 }
101
102
103 func (pk *PublicKey) Pack(buf *[PublicKeySize]byte) {
104 copy(buf[:32], pk.rho[:])
105 copy(buf[32:], pk.t1p[:])
106 }
107
108
109 func (pk *PublicKey) Unpack(buf *[PublicKeySize]byte) {
110 copy(pk.rho[:], buf[:32])
111 copy(pk.t1p[:], buf[32:])
112
113 pk.t1.UnpackT1(pk.t1p[:])
114 pk.A = new(Mat)
115 pk.A.Derive(&pk.rho)
116
117
118 pk.tr = new([32]byte)
119 h := sha3.NewShake256()
120 _, _ = h.Write(buf[:])
121 _, _ = h.Read(pk.tr[:])
122 }
123
124
125 func (sk *PrivateKey) Pack(buf *[PrivateKeySize]byte) {
126 copy(buf[:32], sk.rho[:])
127 copy(buf[32:64], sk.key[:])
128 copy(buf[64:96], sk.tr[:])
129 offset := 96
130 sk.s1.PackLeqEta(buf[offset:])
131 offset += PolyLeqEtaSize * L
132 sk.s2.PackLeqEta(buf[offset:])
133 offset += PolyLeqEtaSize * K
134 sk.t0.PackT0(buf[offset:])
135 }
136
137
138 func (sk *PrivateKey) Unpack(buf *[PrivateKeySize]byte) {
139 copy(sk.rho[:], buf[:32])
140 copy(sk.key[:], buf[32:64])
141 copy(sk.tr[:], buf[64:96])
142 offset := 96
143 sk.s1.UnpackLeqEta(buf[offset:])
144 offset += PolyLeqEtaSize * L
145 sk.s2.UnpackLeqEta(buf[offset:])
146 offset += PolyLeqEtaSize * K
147 sk.t0.UnpackT0(buf[offset:])
148
149
150 sk.A.Derive(&sk.rho)
151 sk.t0h = sk.t0
152 sk.t0h.NTT()
153 sk.s1h = sk.s1
154 sk.s1h.NTT()
155 sk.s2h = sk.s2
156 sk.s2h.NTT()
157 }
158
159
160
161 func GenerateKey(rand io.Reader) (*PublicKey, *PrivateKey, error) {
162 var seed [32]byte
163 if rand == nil {
164 rand = cryptoRand.Reader
165 }
166 _, err := io.ReadFull(rand, seed[:])
167 if err != nil {
168 return nil, nil, err
169 }
170 pk, sk := NewKeyFromSeed(&seed)
171 return pk, sk, nil
172 }
173
174
175 func NewKeyFromSeed(seed *[common.SeedSize]byte) (*PublicKey, *PrivateKey) {
176 var eSeed [128]byte
177 var pk PublicKey
178 var sk PrivateKey
179 var sSeed [64]byte
180
181 h := sha3.NewShake256()
182 _, _ = h.Write(seed[:])
183 _, _ = h.Read(eSeed[:])
184
185 copy(pk.rho[:], eSeed[:32])
186 copy(sSeed[:], eSeed[32:96])
187 copy(sk.key[:], eSeed[96:])
188 copy(sk.rho[:], pk.rho[:])
189
190 sk.A.Derive(&pk.rho)
191
192 for i := uint16(0); i < L; i++ {
193 PolyDeriveUniformLeqEta(&sk.s1[i], &sSeed, i)
194 }
195
196 for i := uint16(0); i < K; i++ {
197 PolyDeriveUniformLeqEta(&sk.s2[i], &sSeed, i+L)
198 }
199
200 sk.s1h = sk.s1
201 sk.s1h.NTT()
202 sk.s2h = sk.s2
203 sk.s2h.NTT()
204
205 sk.computeT0andT1(&sk.t0, &pk.t1)
206
207 sk.t0h = sk.t0
208 sk.t0h.NTT()
209
210
211 pk.t1.PackT1(pk.t1p[:])
212 pk.A = &sk.A
213
214
215 var packedPk [PublicKeySize]byte
216 pk.Pack(&packedPk)
217
218
219 h.Reset()
220 _, _ = h.Write(packedPk[:])
221 _, _ = h.Read(sk.tr[:])
222
223
224 pk.tr = &sk.tr
225
226 return &pk, &sk
227 }
228
229
230 func (sk *PrivateKey) computeT0andT1(t0, t1 *VecK) {
231 var t VecK
232
233
234 for i := 0; i < K; i++ {
235 PolyDotHat(&t[i], &sk.A[i], &sk.s1h)
236 t[i].ReduceLe2Q()
237 t[i].InvNTT()
238 }
239 t.Add(&t, &sk.s2)
240 t.Normalize()
241
242
243 t.Power2Round(t0, t1)
244 }
245
246
247 func Verify(pk *PublicKey, msg []byte, signature []byte) bool {
248 var sig unpackedSignature
249 var mu [64]byte
250 var zh VecL
251 var Az, Az2dct1, w1 VecK
252 var ch common.Poly
253 var cp [32]byte
254 var w1Packed [PolyW1Size * K]byte
255
256
257
258 if !sig.Unpack(signature) {
259 return false
260 }
261
262
263 h := sha3.NewShake256()
264 _, _ = h.Write(pk.tr[:])
265 _, _ = h.Write(msg)
266 _, _ = h.Read(mu[:])
267
268
269 zh = sig.z
270 zh.NTT()
271
272 for i := 0; i < K; i++ {
273 PolyDotHat(&Az[i], &pk.A[i], &zh)
274 }
275
276
277
278
279
280 Az2dct1.MulBy2toD(&pk.t1)
281 Az2dct1.NTT()
282 PolyDeriveUniformBall(&ch, &sig.c)
283 ch.NTT()
284 for i := 0; i < K; i++ {
285 Az2dct1[i].MulHat(&Az2dct1[i], &ch)
286 }
287 Az2dct1.Sub(&Az, &Az2dct1)
288 Az2dct1.ReduceLe2Q()
289 Az2dct1.InvNTT()
290 Az2dct1.NormalizeAssumingLe2Q()
291
292
293
294
295
296 w1.UseHint(&Az2dct1, &sig.hint)
297 w1.PackW1(w1Packed[:])
298
299
300 h.Reset()
301 _, _ = h.Write(mu[:])
302 _, _ = h.Write(w1Packed[:])
303 _, _ = h.Read(cp[:])
304
305 return sig.c == cp
306 }
307
308
309
310
311 func SignTo(sk *PrivateKey, msg []byte, signature []byte) {
312 var mu, rhop [64]byte
313 var w1Packed [PolyW1Size * K]byte
314 var y, yh VecL
315 var w, w0, w1, w0mcs2, ct0, w0mcs2pct0 VecK
316 var ch common.Poly
317 var yNonce uint16
318 var sig unpackedSignature
319
320 if len(signature) < SignatureSize {
321 panic("Signature does not fit in that byteslice")
322 }
323
324
325 h := sha3.NewShake256()
326 _, _ = h.Write(sk.tr[:])
327 _, _ = h.Write(msg)
328 _, _ = h.Read(mu[:])
329
330
331 h.Reset()
332 _, _ = h.Write(sk.key[:])
333 _, _ = h.Write(mu[:])
334 _, _ = h.Read(rhop[:])
335
336
337 attempt := 0
338 for {
339 attempt++
340 if attempt >= 576 {
341
342
343
344 panic("This should only happen 1 in 2^{128}: something is wrong.")
345 }
346
347
348 VecLDeriveUniformLeGamma1(&y, &rhop, yNonce)
349 yNonce += uint16(L)
350
351
352 yh = y
353 yh.NTT()
354 for i := 0; i < K; i++ {
355 PolyDotHat(&w[i], &sk.A[i], &yh)
356 w[i].ReduceLe2Q()
357 w[i].InvNTT()
358 }
359
360
361 w.NormalizeAssumingLe2Q()
362 w.Decompose(&w0, &w1)
363
364
365 w1.PackW1(w1Packed[:])
366 h.Reset()
367 _, _ = h.Write(mu[:])
368 _, _ = h.Write(w1Packed[:])
369 _, _ = h.Read(sig.c[:])
370
371 PolyDeriveUniformBall(&ch, &sig.c)
372 ch.NTT()
373
374
375
376
377
378
379
380 for i := 0; i < K; i++ {
381 w0mcs2[i].MulHat(&ch, &sk.s2h[i])
382 w0mcs2[i].InvNTT()
383 }
384 w0mcs2.Sub(&w0, &w0mcs2)
385 w0mcs2.Normalize()
386
387 if w0mcs2.Exceeds(Gamma2 - Beta) {
388 continue
389 }
390
391
392 for i := 0; i < L; i++ {
393 sig.z[i].MulHat(&ch, &sk.s1h[i])
394 sig.z[i].InvNTT()
395 }
396 sig.z.Add(&sig.z, &y)
397 sig.z.Normalize()
398
399
400 if sig.z.Exceeds(Gamma1 - Beta) {
401 continue
402 }
403
404
405 for i := 0; i < K; i++ {
406 ct0[i].MulHat(&ch, &sk.t0h[i])
407 ct0[i].InvNTT()
408 }
409 ct0.NormalizeAssumingLe2Q()
410
411
412 if ct0.Exceeds(Gamma2) {
413 continue
414 }
415
416
417
418
419
420
421
422
423
424
425
426
427 w0mcs2pct0.Add(&w0mcs2, &ct0)
428 w0mcs2pct0.NormalizeAssumingLe2Q()
429 hintPop := sig.hint.MakeHint(&w0mcs2pct0, &w1)
430 if hintPop > Omega {
431 continue
432 }
433
434 break
435 }
436
437 sig.Pack(signature[:])
438 }
439
440
441 func (sk *PrivateKey) Public() *PublicKey {
442 var t0 VecK
443 pk := &PublicKey{
444 rho: sk.rho,
445 A: &sk.A,
446 tr: &sk.tr,
447 }
448 sk.computeT0andT1(&t0, &pk.t1)
449 pk.t1.PackT1(pk.t1p[:])
450 return pk
451 }
452
453
454 func (pk *PublicKey) Equal(other *PublicKey) bool {
455 return pk.rho == other.rho && pk.t1 == other.t1
456 }
457
458
459 func (sk *PrivateKey) Equal(other *PrivateKey) bool {
460 ret := (subtle.ConstantTimeCompare(sk.rho[:], other.rho[:]) &
461 subtle.ConstantTimeCompare(sk.key[:], other.key[:]) &
462 subtle.ConstantTimeCompare(sk.tr[:], other.tr[:]))
463
464 acc := uint32(0)
465 for i := 0; i < L; i++ {
466 for j := 0; j < common.N; j++ {
467 acc |= sk.s1[i][j] ^ other.s1[i][j]
468 }
469 }
470 for i := 0; i < K; i++ {
471 for j := 0; j < common.N; j++ {
472 acc |= sk.s2[i][j] ^ other.s2[i][j]
473 acc |= sk.t0[i][j] ^ other.t0[i][j]
474 }
475 }
476 return (ret & subtle.ConstantTimeEq(int32(acc), 0)) == 1
477 }
478
View as plain text