1
2
3
4
5
6
7
8
9 package rsa
10
11 import (
12 "crypto"
13 "crypto/rand"
14 "crypto/rsa"
15 "errors"
16 "fmt"
17 "io"
18 "math"
19 "math/big"
20
21 cmath "github.com/cloudflare/circl/math"
22 )
23
24
25
26
27
28 func GenerateKey(random io.Reader, bits int) (*rsa.PrivateKey, error) {
29 p, err := cmath.SafePrime(random, bits/2)
30 if err != nil {
31 return nil, err
32 }
33
34 var q *big.Int
35 n := new(big.Int)
36 found := false
37 for !found {
38 q, err = cmath.SafePrime(random, bits-p.BitLen())
39 if err != nil {
40 return nil, err
41 }
42
43
44 if p.Cmp(q) != 0 {
45 n.Mul(p, q)
46
47 if n.BitLen() == bits {
48 found = true
49 }
50 }
51 }
52
53 one := big.NewInt(1)
54 pminus1 := new(big.Int).Sub(p, one)
55 qminus1 := new(big.Int).Sub(q, one)
56 totient := new(big.Int).Mul(pminus1, qminus1)
57
58 priv := new(rsa.PrivateKey)
59 priv.Primes = []*big.Int{p, q}
60 priv.N = n
61 priv.E = 65537
62 priv.D = new(big.Int)
63 e := big.NewInt(int64(priv.E))
64 ok := priv.D.ModInverse(e, totient)
65 if ok == nil {
66 return nil, errors.New("public key is not coprime to phi(n)")
67 }
68
69 priv.Precompute()
70
71 return priv, nil
72 }
73
74
75
76
77
78 func validateParams(players, threshold uint) error {
79 if players <= 1 {
80 return errors.New("rsa_threshold: Players (l) invalid: should be > 1")
81 }
82 if threshold < 1 || threshold > players {
83 return fmt.Errorf("rsa_threshold: Threshold (k) invalid: %d < 1 || %d > %d", threshold, threshold, players)
84 }
85 return nil
86 }
87
88
89
90 func Deal(randSource io.Reader, players, threshold uint, key *rsa.PrivateKey, cache bool) ([]KeyShare, error) {
91 err := validateParams(players, threshold)
92
93 ONE := big.NewInt(1)
94
95 if err != nil {
96 return nil, err
97 }
98
99 if len(key.Primes) != 2 {
100 return nil, errors.New("multiprime rsa keys are unsupported")
101 }
102
103 p := key.Primes[0]
104 q := key.Primes[1]
105 e := int64(key.E)
106
107
108
109
110
111
112
113 var pprime big.Int
114
115 pprime.Sub(p, ONE)
116
117
118 var m big.Int
119 m.Sub(q, ONE)
120
121 m.Mul(&m, &pprime)
122
123 m.Rsh(&m, 2)
124
125
126 var d big.Int
127 _d := d.ModInverse(big.NewInt(e), &m)
128
129 if _d == nil {
130 return nil, errors.New("rsa_threshold: no ModInverse for e in Z/Zm")
131 }
132
133
134 a := make([]*big.Int, threshold)
135
136 a[0] = &d
137
138
139 for i := uint(1); i <= threshold-1; i++ {
140 a[i], err = rand.Int(randSource, &m)
141 if err != nil {
142 return nil, errors.New("rsa_threshold: unable to generate an int within [0, m)")
143 }
144 }
145
146 shares := make([]KeyShare, players)
147
148
149 for i := uint(1); i <= players; i++ {
150 shares[i-1].Players = players
151 shares[i-1].Threshold = threshold
152
153 poly := computePolynomial(threshold, a, i, &m)
154 shares[i-1].si = poly
155 shares[i-1].Index = i
156 if cache {
157 shares[i-1].get2DeltaSi(int64(players))
158 }
159 }
160
161 return shares, nil
162 }
163
164 func calcN(p, q *big.Int) big.Int {
165
166 var n big.Int
167 n.Mul(p, q)
168 return n
169 }
170
171
172 func computePolynomial(k uint, a []*big.Int, x uint, m *big.Int) *big.Int {
173
174 sum := big.NewInt(0)
175
176 for i := uint(0); i <= k-1; i++ {
177
178
179 xi := int64(math.Pow(float64(x), float64(i)))
180
181 prod := big.Int{}
182 prod.Mul(a[i], big.NewInt(xi))
183
184 prod.Mod(&prod, m)
185
186 sum.Add(sum, &prod)
187 }
188
189 sum.Mod(sum, m)
190
191 return sum
192 }
193
194
195 func PadHash(padder Padder, hash crypto.Hash, pub *rsa.PublicKey, msg []byte) ([]byte, error) {
196
197
198 hasher := hash.New()
199 hasher.Write(msg)
200 digest := hasher.Sum(nil)
201
202 return padder.Pad(pub, hash, digest)
203 }
204
205 type Signature = []byte
206
207
208 func CombineSignShares(pub *rsa.PublicKey, shares []SignShare, msg []byte) (Signature, error) {
209 players := shares[0].Players
210 threshold := shares[0].Threshold
211
212 for i := range shares {
213 if shares[i].Players != players {
214 return nil, errors.New("rsa_threshold: shares didn't have consistent players")
215 }
216 if shares[i].Threshold != threshold {
217 return nil, errors.New("rsa_threshold: shares didn't have consistent threshold")
218 }
219 }
220
221 if uint(len(shares)) < threshold {
222 return nil, errors.New("rsa_threshold: insufficient shares for the threshold")
223 }
224
225 w := big.NewInt(1)
226 delta := calculateDelta(int64(players))
227
228 for _, share := range shares {
229
230 lambda, err := computeLambda(delta, shares, 0, int64(share.Index))
231 if err != nil {
232 return nil, err
233 }
234
235 var exp big.Int
236 exp.Add(lambda, lambda)
237
238
239 abslam := big.Int{}
240 abslam.Abs(&exp)
241 var tmp big.Int
242
243 tmp.Exp(share.xi, &abslam, pub.N)
244 if abslam.Cmp(&exp) == 1 {
245 tmp.ModInverse(&tmp, pub.N)
246 }
247
248
249 w.Mul(w, &tmp).Mod(w, pub.N)
250 }
251 w.Mod(w, pub.N)
252
253
254 eprime := big.Int{}
255 eprime.Mul(delta, delta)
256 eprime.Add(&eprime, &eprime)
257 eprime.Add(&eprime, &eprime)
258
259
260 a := big.Int{}
261 b := big.Int{}
262 e := big.NewInt(int64(pub.E))
263 tmp := big.Int{}
264 tmp.GCD(&a, &b, &eprime, e)
265
266
267
268 wa := big.Int{}
269 wa.Exp(w, &a, pub.N)
270
271 x := big.Int{}
272 x.SetBytes(msg)
273 xb := big.Int{}
274 xb.Exp(&x, &b, pub.N)
275
276 y := big.Int{}
277 y.Mul(&wa, &xb).Mod(&y, pub.N)
278
279
280 ye := big.Int{}
281 ye.Exp(&y, e, pub.N)
282 if ye.Cmp(&x) != 0 {
283 return nil, errors.New("rsa: internal error")
284 }
285
286
287 sig := y.FillBytes(make([]byte, pub.Size()))
288
289 return sig, nil
290 }
291
292
293
294
295 func computeLambda(delta *big.Int, S []SignShare, i, j int64) (*big.Int, error) {
296 if i == j {
297 return nil, errors.New("rsa_threshold: i and j can't be equal by precondition")
298 }
299
300 foundi := false
301 foundj := false
302
303
304
305 num := int64(1)
306 den := int64(1)
307
308
309 for _, s := range S {
310
311 jprime := int64(s.Index)
312
313 if jprime == j {
314 foundj = true
315 continue
316 }
317 if jprime == i {
318 foundi = false
319 break
320 }
321
322 num *= i - jprime
323
324 den *= j - jprime
325 }
326
327
328 var lambda big.Int
329
330 lambda.Div(big.NewInt(num), big.NewInt(den))
331
332 lambda.Mul(delta, &lambda)
333
334 if foundi {
335 return nil, fmt.Errorf("rsa_threshold: i: %d should not be in S", i)
336 }
337
338 if !foundj {
339 return nil, fmt.Errorf("rsa_threshold: j: %d should be in S", j)
340 }
341
342 return &lambda, nil
343 }
344
View as plain text