1 package internal
2
3 import (
4 cryptoRand "crypto/rand"
5 "crypto/subtle"
6 "io"
7
8 "github.com/cloudflare/circl/internal/sha3"
9 "github.com/cloudflare/circl/sign/dilithium/internal/common"
10 )
11
12 const (
13
14
15 PolyLeqEtaSize = (common.N * DoubleEtaBits) / 8
16
17
18 Beta = Tau * Eta
19
20
21 Gamma1 = 1 << Gamma1Bits
22
23
24 PolyLeGamma1Size = (Gamma1Bits + 1) * common.N / 8
25
26
27 Alpha = 2 * Gamma2
28
29
30 PrivateKeySize = 32 + 32 + 32 + PolyLeqEtaSize*(L+K) + common.PolyT0Size*K
31
32
33 PublicKeySize = 32 + common.PolyT1Size*K
34
35
36 SignatureSize = L*PolyLeGamma1Size + Omega + K + 32
37
38
39 PolyW1Size = (common.N * (common.QBits - Gamma1Bits)) / 8
40 )
41
42
43 type PublicKey struct {
44 rho [32]byte
45 t1 VecK
46
47
48 t1p [common.PolyT1Size * K]byte
49 A *Mat
50 tr *[32]byte
51 }
52
53
54 type PrivateKey struct {
55 rho [32]byte
56 key [32]byte
57 s1 VecL
58 s2 VecK
59 t0 VecK
60 tr [32]byte
61
62
63 A Mat
64 s1h VecL
65 s2h VecK
66 t0h VecK
67 }
68
69 type unpackedSignature struct {
70 z VecL
71 hint VecK
72 c [32]byte
73 }
74
75
76 func (sig *unpackedSignature) Pack(buf []byte) {
77 copy(buf[:], sig.c[:])
78 sig.z.PackLeGamma1(buf[32:])
79 sig.hint.PackHint(buf[32+L*PolyLeGamma1Size:])
80 }
81
82
83
84
85 func (sig *unpackedSignature) Unpack(buf []byte) bool {
86 if len(buf) < SignatureSize {
87 return false
88 }
89 copy(sig.c[:], buf[:])
90 sig.z.UnpackLeGamma1(buf[32:])
91 if sig.z.Exceeds(Gamma1 - Beta) {
92 return false
93 }
94 if !sig.hint.UnpackHint(buf[32+L*PolyLeGamma1Size:]) {
95 return false
96 }
97 return true
98 }
99
100
101 func (pk *PublicKey) Pack(buf *[PublicKeySize]byte) {
102 copy(buf[:32], pk.rho[:])
103 copy(buf[32:], pk.t1p[:])
104 }
105
106
107 func (pk *PublicKey) Unpack(buf *[PublicKeySize]byte) {
108 copy(pk.rho[:], buf[:32])
109 copy(pk.t1p[:], buf[32:])
110
111 pk.t1.UnpackT1(pk.t1p[:])
112 pk.A = new(Mat)
113 pk.A.Derive(&pk.rho)
114
115
116 pk.tr = new([32]byte)
117 h := sha3.NewShake256()
118 _, _ = h.Write(buf[:])
119 _, _ = h.Read(pk.tr[:])
120 }
121
122
123 func (sk *PrivateKey) Pack(buf *[PrivateKeySize]byte) {
124 copy(buf[:32], sk.rho[:])
125 copy(buf[32:64], sk.key[:])
126 copy(buf[64:96], sk.tr[:])
127 offset := 96
128 sk.s1.PackLeqEta(buf[offset:])
129 offset += PolyLeqEtaSize * L
130 sk.s2.PackLeqEta(buf[offset:])
131 offset += PolyLeqEtaSize * K
132 sk.t0.PackT0(buf[offset:])
133 }
134
135
136 func (sk *PrivateKey) Unpack(buf *[PrivateKeySize]byte) {
137 copy(sk.rho[:], buf[:32])
138 copy(sk.key[:], buf[32:64])
139 copy(sk.tr[:], buf[64:96])
140 offset := 96
141 sk.s1.UnpackLeqEta(buf[offset:])
142 offset += PolyLeqEtaSize * L
143 sk.s2.UnpackLeqEta(buf[offset:])
144 offset += PolyLeqEtaSize * K
145 sk.t0.UnpackT0(buf[offset:])
146
147
148 sk.A.Derive(&sk.rho)
149 sk.t0h = sk.t0
150 sk.t0h.NTT()
151 sk.s1h = sk.s1
152 sk.s1h.NTT()
153 sk.s2h = sk.s2
154 sk.s2h.NTT()
155 }
156
157
158
159 func GenerateKey(rand io.Reader) (*PublicKey, *PrivateKey, error) {
160 var seed [32]byte
161 if rand == nil {
162 rand = cryptoRand.Reader
163 }
164 _, err := io.ReadFull(rand, seed[:])
165 if err != nil {
166 return nil, nil, err
167 }
168 pk, sk := NewKeyFromSeed(&seed)
169 return pk, sk, nil
170 }
171
172
173 func NewKeyFromSeed(seed *[common.SeedSize]byte) (*PublicKey, *PrivateKey) {
174 var eSeed [128]byte
175 var pk PublicKey
176 var sk PrivateKey
177 var sSeed [64]byte
178
179 h := sha3.NewShake256()
180 _, _ = h.Write(seed[:])
181 _, _ = h.Read(eSeed[:])
182
183 copy(pk.rho[:], eSeed[:32])
184 copy(sSeed[:], eSeed[32:96])
185 copy(sk.key[:], eSeed[96:])
186 copy(sk.rho[:], pk.rho[:])
187
188 sk.A.Derive(&pk.rho)
189
190 for i := uint16(0); i < L; i++ {
191 PolyDeriveUniformLeqEta(&sk.s1[i], &sSeed, i)
192 }
193
194 for i := uint16(0); i < K; i++ {
195 PolyDeriveUniformLeqEta(&sk.s2[i], &sSeed, i+L)
196 }
197
198 sk.s1h = sk.s1
199 sk.s1h.NTT()
200 sk.s2h = sk.s2
201 sk.s2h.NTT()
202
203 sk.computeT0andT1(&sk.t0, &pk.t1)
204
205 sk.t0h = sk.t0
206 sk.t0h.NTT()
207
208
209 pk.t1.PackT1(pk.t1p[:])
210 pk.A = &sk.A
211
212
213 var packedPk [PublicKeySize]byte
214 pk.Pack(&packedPk)
215
216
217 h.Reset()
218 _, _ = h.Write(packedPk[:])
219 _, _ = h.Read(sk.tr[:])
220
221
222 pk.tr = &sk.tr
223
224 return &pk, &sk
225 }
226
227
228 func (sk *PrivateKey) computeT0andT1(t0, t1 *VecK) {
229 var t VecK
230
231
232 for i := 0; i < K; i++ {
233 PolyDotHat(&t[i], &sk.A[i], &sk.s1h)
234 t[i].ReduceLe2Q()
235 t[i].InvNTT()
236 }
237 t.Add(&t, &sk.s2)
238 t.Normalize()
239
240
241 t.Power2Round(t0, t1)
242 }
243
244
245 func Verify(pk *PublicKey, msg []byte, signature []byte) bool {
246 var sig unpackedSignature
247 var mu [64]byte
248 var zh VecL
249 var Az, Az2dct1, w1 VecK
250 var ch common.Poly
251 var cp [32]byte
252 var w1Packed [PolyW1Size * K]byte
253
254
255
256 if !sig.Unpack(signature) {
257 return false
258 }
259
260
261 h := sha3.NewShake256()
262 _, _ = h.Write(pk.tr[:])
263 _, _ = h.Write(msg)
264 _, _ = h.Read(mu[:])
265
266
267 zh = sig.z
268 zh.NTT()
269
270 for i := 0; i < K; i++ {
271 PolyDotHat(&Az[i], &pk.A[i], &zh)
272 }
273
274
275
276
277
278 Az2dct1.MulBy2toD(&pk.t1)
279 Az2dct1.NTT()
280 PolyDeriveUniformBall(&ch, &sig.c)
281 ch.NTT()
282 for i := 0; i < K; i++ {
283 Az2dct1[i].MulHat(&Az2dct1[i], &ch)
284 }
285 Az2dct1.Sub(&Az, &Az2dct1)
286 Az2dct1.ReduceLe2Q()
287 Az2dct1.InvNTT()
288 Az2dct1.NormalizeAssumingLe2Q()
289
290
291
292
293
294 w1.UseHint(&Az2dct1, &sig.hint)
295 w1.PackW1(w1Packed[:])
296
297
298 h.Reset()
299 _, _ = h.Write(mu[:])
300 _, _ = h.Write(w1Packed[:])
301 _, _ = h.Read(cp[:])
302
303 return sig.c == cp
304 }
305
306
307
308
309 func SignTo(sk *PrivateKey, msg []byte, signature []byte) {
310 var mu, rhop [64]byte
311 var w1Packed [PolyW1Size * K]byte
312 var y, yh VecL
313 var w, w0, w1, w0mcs2, ct0, w0mcs2pct0 VecK
314 var ch common.Poly
315 var yNonce uint16
316 var sig unpackedSignature
317
318 if len(signature) < SignatureSize {
319 panic("Signature does not fit in that byteslice")
320 }
321
322
323 h := sha3.NewShake256()
324 _, _ = h.Write(sk.tr[:])
325 _, _ = h.Write(msg)
326 _, _ = h.Read(mu[:])
327
328
329 h.Reset()
330 _, _ = h.Write(sk.key[:])
331 _, _ = h.Write(mu[:])
332 _, _ = h.Read(rhop[:])
333
334
335 attempt := 0
336 for {
337 attempt++
338 if attempt >= 576 {
339
340
341
342 panic("This should only happen 1 in 2^{128}: something is wrong.")
343 }
344
345
346 VecLDeriveUniformLeGamma1(&y, &rhop, yNonce)
347 yNonce += uint16(L)
348
349
350 yh = y
351 yh.NTT()
352 for i := 0; i < K; i++ {
353 PolyDotHat(&w[i], &sk.A[i], &yh)
354 w[i].ReduceLe2Q()
355 w[i].InvNTT()
356 }
357
358
359 w.NormalizeAssumingLe2Q()
360 w.Decompose(&w0, &w1)
361
362
363 w1.PackW1(w1Packed[:])
364 h.Reset()
365 _, _ = h.Write(mu[:])
366 _, _ = h.Write(w1Packed[:])
367 _, _ = h.Read(sig.c[:])
368
369 PolyDeriveUniformBall(&ch, &sig.c)
370 ch.NTT()
371
372
373
374
375
376
377
378 for i := 0; i < K; i++ {
379 w0mcs2[i].MulHat(&ch, &sk.s2h[i])
380 w0mcs2[i].InvNTT()
381 }
382 w0mcs2.Sub(&w0, &w0mcs2)
383 w0mcs2.Normalize()
384
385 if w0mcs2.Exceeds(Gamma2 - Beta) {
386 continue
387 }
388
389
390 for i := 0; i < L; i++ {
391 sig.z[i].MulHat(&ch, &sk.s1h[i])
392 sig.z[i].InvNTT()
393 }
394 sig.z.Add(&sig.z, &y)
395 sig.z.Normalize()
396
397
398 if sig.z.Exceeds(Gamma1 - Beta) {
399 continue
400 }
401
402
403 for i := 0; i < K; i++ {
404 ct0[i].MulHat(&ch, &sk.t0h[i])
405 ct0[i].InvNTT()
406 }
407 ct0.NormalizeAssumingLe2Q()
408
409
410 if ct0.Exceeds(Gamma2) {
411 continue
412 }
413
414
415
416
417
418
419
420
421
422
423
424
425 w0mcs2pct0.Add(&w0mcs2, &ct0)
426 w0mcs2pct0.NormalizeAssumingLe2Q()
427 hintPop := sig.hint.MakeHint(&w0mcs2pct0, &w1)
428 if hintPop > Omega {
429 continue
430 }
431
432 break
433 }
434
435 sig.Pack(signature[:])
436 }
437
438
439 func (sk *PrivateKey) Public() *PublicKey {
440 var t0 VecK
441 pk := &PublicKey{
442 rho: sk.rho,
443 A: &sk.A,
444 tr: &sk.tr,
445 }
446 sk.computeT0andT1(&t0, &pk.t1)
447 pk.t1.PackT1(pk.t1p[:])
448 return pk
449 }
450
451
452 func (pk *PublicKey) Equal(other *PublicKey) bool {
453 return pk.rho == other.rho && pk.t1 == other.t1
454 }
455
456
457 func (sk *PrivateKey) Equal(other *PrivateKey) bool {
458 ret := (subtle.ConstantTimeCompare(sk.rho[:], other.rho[:]) &
459 subtle.ConstantTimeCompare(sk.key[:], other.key[:]) &
460 subtle.ConstantTimeCompare(sk.tr[:], other.tr[:]))
461
462 acc := uint32(0)
463 for i := 0; i < L; i++ {
464 for j := 0; j < common.N; j++ {
465 acc |= sk.s1[i][j] ^ other.s1[i][j]
466 }
467 }
468 for i := 0; i < K; i++ {
469 for j := 0; j < common.N; j++ {
470 acc |= sk.s2[i][j] ^ other.s2[i][j]
471 acc |= sk.t0[i][j] ^ other.t0[i][j]
472 }
473 }
474 return (ret & subtle.ConstantTimeEq(int32(acc), 0)) == 1
475 }
476
View as plain text