1
15
16 package sm2
17
18
19 import (
20 "bytes"
21 "crypto"
22 "crypto/elliptic"
23 "crypto/rand"
24 "encoding/asn1"
25 "encoding/binary"
26 "errors"
27 "io"
28 "math/big"
29
30 "github.com/tjfoc/gmsm/sm3"
31 )
32
33 var (
34 default_uid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
35 C1C3C2=0
36 C1C2C3=1
37 )
38
39 type PublicKey struct {
40 elliptic.Curve
41 X, Y *big.Int
42 }
43
44 type PrivateKey struct {
45 PublicKey
46 D *big.Int
47 }
48
49 type sm2Signature struct {
50 R, S *big.Int
51 }
52 type sm2Cipher struct {
53 XCoordinate *big.Int
54 YCoordinate *big.Int
55 HASH []byte
56 CipherText []byte
57 }
58
59
60 func (priv *PrivateKey) Public() crypto.PublicKey {
61 return &priv.PublicKey
62 }
63
64 var errZeroParam = errors.New("zero parameter")
65 var one = new(big.Int).SetInt64(1)
66 var two = new(big.Int).SetInt64(2)
67
68
69 func (priv *PrivateKey) Sign(random io.Reader, msg []byte, signer crypto.SignerOpts) ([]byte, error) {
70 r, s, err := Sm2Sign(priv, msg, nil, random)
71 if err != nil {
72 return nil, err
73 }
74 return asn1.Marshal(sm2Signature{r, s})
75 }
76
77 func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
78 var sm2Sign sm2Signature
79 _, err := asn1.Unmarshal(sign, &sm2Sign)
80 if err != nil {
81 return false
82 }
83 return Sm2Verify(pub, msg, default_uid, sm2Sign.R, sm2Sign.S)
84 }
85
86 func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) {
87 if len(uid) == 0 {
88 uid = default_uid
89 }
90
91 za, err := ZA(pub, uid)
92 if err != nil {
93 return nil, err
94 }
95
96 e, err := msgHash(za, msg)
97 if err != nil {
98 return nil, err
99 }
100
101 return e.Bytes(), nil
102 }
103
104
105 func (pub *PublicKey) EncryptAsn1(data []byte, random io.Reader) ([]byte, error) {
106 return EncryptAsn1(pub, data, random)
107 }
108
109 func (priv *PrivateKey) DecryptAsn1(data []byte) ([]byte, error) {
110 return DecryptAsn1(priv, data)
111 }
112
113
114
115 func KeyExchangeB(klen int, ida, idb []byte, priB *PrivateKey, pubA *PublicKey, rpri *PrivateKey, rpubA *PublicKey) (k, s1, s2 []byte, err error) {
116 return keyExchange(klen, ida, idb, priB, pubA, rpri, rpubA, false)
117 }
118
119
120 func KeyExchangeA(klen int, ida, idb []byte, priA *PrivateKey, pubB *PublicKey, rpri *PrivateKey, rpubB *PublicKey) (k, s1, s2 []byte, err error) {
121 return keyExchange(klen, ida, idb, priA, pubB, rpri, rpubB, true)
122 }
123
124
125
126 func Sm2Sign(priv *PrivateKey, msg, uid []byte, random io.Reader) (r, s *big.Int, err error) {
127 digest, err := priv.PublicKey.Sm3Digest(msg, uid)
128 if err != nil {
129 return nil, nil, err
130 }
131 e := new(big.Int).SetBytes(digest)
132 c := priv.PublicKey.Curve
133 N := c.Params().N
134 if N.Sign() == 0 {
135 return nil, nil, errZeroParam
136 }
137 var k *big.Int
138 for {
139 for {
140 k, err = randFieldElement(c, random)
141 if err != nil {
142 r = nil
143 return
144 }
145 r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
146 r.Add(r, e)
147 r.Mod(r, N)
148 if r.Sign() != 0 {
149 if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
150 break
151 }
152 }
153
154 }
155 rD := new(big.Int).Mul(priv.D, r)
156 s = new(big.Int).Sub(k, rD)
157 d1 := new(big.Int).Add(priv.D, one)
158 d1Inv := new(big.Int).ModInverse(d1, N)
159 s.Mul(s, d1Inv)
160 s.Mod(s, N)
161 if s.Sign() != 0 {
162 break
163 }
164 }
165 return
166 }
167 func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
168 c := pub.Curve
169 N := c.Params().N
170 one := new(big.Int).SetInt64(1)
171 if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
172 return false
173 }
174 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
175 return false
176 }
177 if len(uid) == 0 {
178 uid = default_uid
179 }
180 za, err := ZA(pub, uid)
181 if err != nil {
182 return false
183 }
184 e, err := msgHash(za, msg)
185 if err != nil {
186 return false
187 }
188 t := new(big.Int).Add(r, s)
189 t.Mod(t, N)
190 if t.Sign() == 0 {
191 return false
192 }
193 var x *big.Int
194 x1, y1 := c.ScalarBaseMult(s.Bytes())
195 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
196 x, _ = c.Add(x1, y1, x2, y2)
197
198 x.Add(x, e)
199 x.Mod(x, N)
200 return x.Cmp(r) == 0
201 }
202
203
211 func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
212 c := pub.Curve
213 N := c.Params().N
214
215 if r.Sign() <= 0 || s.Sign() <= 0 {
216 return false
217 }
218 if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
219 return false
220 }
221
222
223 t := new(big.Int).Add(r, s)
224 t.Mod(t, N)
225 if t.Sign() == 0 {
226 return false
227 }
228
229 var x *big.Int
230 x1, y1 := c.ScalarBaseMult(s.Bytes())
231 x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
232 x, _ = c.Add(x1, y1, x2, y2)
233
234 e := new(big.Int).SetBytes(hash)
235 x.Add(x, e)
236 x.Mod(x, N)
237 return x.Cmp(r) == 0
238 }
239
240
247 func Encrypt(pub *PublicKey, data []byte, random io.Reader,mode int) ([]byte, error) {
248 length := len(data)
249 for {
250 c := []byte{}
251 curve := pub.Curve
252 k, err := randFieldElement(curve, random)
253 if err != nil {
254 return nil, err
255 }
256 x1, y1 := curve.ScalarBaseMult(k.Bytes())
257 x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
258 x1Buf := x1.Bytes()
259 y1Buf := y1.Bytes()
260 x2Buf := x2.Bytes()
261 y2Buf := y2.Bytes()
262 if n := len(x1Buf); n < 32 {
263 x1Buf = append(zeroByteSlice()[:32-n], x1Buf...)
264 }
265 if n := len(y1Buf); n < 32 {
266 y1Buf = append(zeroByteSlice()[:32-n], y1Buf...)
267 }
268 if n := len(x2Buf); n < 32 {
269 x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
270 }
271 if n := len(y2Buf); n < 32 {
272 y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
273 }
274 c = append(c, x1Buf...)
275 c = append(c, y1Buf...)
276 tm := []byte{}
277 tm = append(tm, x2Buf...)
278 tm = append(tm, data...)
279 tm = append(tm, y2Buf...)
280 h := sm3.Sm3Sum(tm)
281 c = append(c, h...)
282 ct, ok := kdf(length, x2Buf, y2Buf)
283 if !ok {
284 continue
285 }
286 c = append(c, ct...)
287 for i := 0; i < length; i++ {
288 c[96+i] ^= data[i]
289 }
290 switch mode{
291
292 case C1C3C2:
293 return append([]byte{0x04}, c...), nil
294 case C1C2C3:
295 c1 := make([]byte, 64)
296 c2 := make([]byte, len(c) - 96)
297 c3 := make([]byte, 32)
298 copy(c1, c[:64])
299 copy(c3, c[64:96])
300 copy(c2, c[96:])
301 ciphertext := []byte{}
302 ciphertext = append(ciphertext, c1...)
303 ciphertext = append(ciphertext, c2...)
304 ciphertext = append(ciphertext, c3...)
305 return append([]byte{0x04}, ciphertext...), nil
306 default:
307 return append([]byte{0x04}, c...), nil
308 }
309 }
310 }
311
312
313
314 func Decrypt(priv *PrivateKey, data []byte,mode int) ([]byte, error) {
315 switch mode {
316 case C1C3C2:
317 data = data[1:]
318 case C1C2C3:
319 data = data[1:]
320 c1 := make([]byte, 64)
321 c2 := make([]byte, len(data) - 96)
322 c3 := make([]byte, 32)
323 copy(c1, data[:64])
324 copy(c2, data[64:len(data) - 32])
325 copy(c3, data[len(data) - 32:])
326 c := []byte{}
327 c = append(c, c1...)
328 c = append(c, c3...)
329 c = append(c, c2...)
330 data = c
331 default:
332 data = data[1:]
333 }
334 length := len(data) - 96
335 curve := priv.Curve
336 x := new(big.Int).SetBytes(data[:32])
337 y := new(big.Int).SetBytes(data[32:64])
338 x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
339 x2Buf := x2.Bytes()
340 y2Buf := y2.Bytes()
341 if n := len(x2Buf); n < 32 {
342 x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
343 }
344 if n := len(y2Buf); n < 32 {
345 y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
346 }
347 c, ok := kdf(length, x2Buf, y2Buf)
348 if !ok {
349 return nil, errors.New("Decrypt: failed to decrypt")
350 }
351 for i := 0; i < length; i++ {
352 c[i] ^= data[i+96]
353 }
354 tm := []byte{}
355 tm = append(tm, x2Buf...)
356 tm = append(tm, c...)
357 tm = append(tm, y2Buf...)
358 h := sm3.Sm3Sum(tm)
359 if bytes.Compare(h, data[64:96]) != 0 {
360 return c, errors.New("Decrypt: failed to decrypt")
361 }
362 return c, nil
363 }
364
365
366
367
368
369
370
371
372
373
374
375
376 func keyExchange(klen int, ida, idb []byte, pri *PrivateKey, pub *PublicKey, rpri *PrivateKey, rpub *PublicKey, thisISA bool) (k, s1, s2 []byte, err error) {
377 curve := P256Sm2()
378 N := curve.Params().N
379 x2hat := keXHat(rpri.PublicKey.X)
380 x2rb := new(big.Int).Mul(x2hat, rpri.D)
381 tbt := new(big.Int).Add(pri.D, x2rb)
382 tb := new(big.Int).Mod(tbt, N)
383 if !curve.IsOnCurve(rpub.X, rpub.Y) {
384 err = errors.New("Ra not on curve")
385 return
386 }
387 x1hat := keXHat(rpub.X)
388 ramx1, ramy1 := curve.ScalarMult(rpub.X, rpub.Y, x1hat.Bytes())
389 vxt, vyt := curve.Add(pub.X, pub.Y, ramx1, ramy1)
390
391 vx, vy := curve.ScalarMult(vxt, vyt, tb.Bytes())
392 pza := pub
393 if thisISA {
394 pza = &pri.PublicKey
395 }
396 za, err := ZA(pza, ida)
397 if err != nil {
398 return
399 }
400 zero := new(big.Int)
401 if vx.Cmp(zero) == 0 || vy.Cmp(zero) == 0 {
402 err = errors.New("V is infinite")
403 }
404 pzb := pub
405 if !thisISA {
406 pzb = &pri.PublicKey
407 }
408 zb, err := ZA(pzb, idb)
409 k, ok := kdf(klen, vx.Bytes(), vy.Bytes(), za, zb)
410 if !ok {
411 err = errors.New("kdf: zero key")
412 return
413 }
414 h1 := BytesCombine(vx.Bytes(), za, zb, rpub.X.Bytes(), rpub.Y.Bytes(), rpri.X.Bytes(), rpri.Y.Bytes())
415 if !thisISA {
416 h1 = BytesCombine(vx.Bytes(), za, zb, rpri.X.Bytes(), rpri.Y.Bytes(), rpub.X.Bytes(), rpub.Y.Bytes())
417 }
418 hash := sm3.Sm3Sum(h1)
419 h2 := BytesCombine([]byte{0x02}, vy.Bytes(), hash)
420 S1 := sm3.Sm3Sum(h2)
421 h3 := BytesCombine([]byte{0x03}, vy.Bytes(), hash)
422 S2 := sm3.Sm3Sum(h3)
423 return k, S1, S2, nil
424 }
425
426 func msgHash(za, msg []byte) (*big.Int, error) {
427 e := sm3.New()
428 e.Write(za)
429 e.Write(msg)
430 return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
431 }
432
433
434 func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
435 za := sm3.New()
436 uidLen := len(uid)
437 if uidLen >= 8192 {
438 return []byte{}, errors.New("SM2: uid too large")
439 }
440 Entla := uint16(8 * uidLen)
441 za.Write([]byte{byte((Entla >> 8) & 0xFF)})
442 za.Write([]byte{byte(Entla & 0xFF)})
443 if uidLen > 0 {
444 za.Write(uid)
445 }
446 za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
447 za.Write(sm2P256.B.Bytes())
448 za.Write(sm2P256.Gx.Bytes())
449 za.Write(sm2P256.Gy.Bytes())
450
451 xBuf := pub.X.Bytes()
452 yBuf := pub.Y.Bytes()
453 if n := len(xBuf); n < 32 {
454 xBuf = append(zeroByteSlice()[:32-n], xBuf...)
455 }
456 if n := len(yBuf); n < 32 {
457 yBuf = append(zeroByteSlice()[:32-n], yBuf...)
458 }
459 za.Write(xBuf)
460 za.Write(yBuf)
461 return za.Sum(nil)[:32], nil
462 }
463
464
465 func zeroByteSlice() []byte {
466 return []byte{
467 0, 0, 0, 0,
468 0, 0, 0, 0,
469 0, 0, 0, 0,
470 0, 0, 0, 0,
471 0, 0, 0, 0,
472 0, 0, 0, 0,
473 0, 0, 0, 0,
474 0, 0, 0, 0,
475 }
476 }
477
478
481 func EncryptAsn1(pub *PublicKey, data []byte, rand io.Reader) ([]byte, error) {
482 cipher, err := Encrypt(pub, data, rand,C1C3C2)
483 if err != nil {
484 return nil, err
485 }
486 return CipherMarshal(cipher)
487 }
488
489
492 func DecryptAsn1(pub *PrivateKey, data []byte) ([]byte, error) {
493 cipher, err := CipherUnmarshal(data)
494 if err != nil {
495 return nil, err
496 }
497 return Decrypt(pub, cipher,C1C3C2)
498 }
499
500
508 func CipherMarshal(data []byte) ([]byte, error) {
509 data = data[1:]
510 x := new(big.Int).SetBytes(data[:32])
511 y := new(big.Int).SetBytes(data[32:64])
512 hash := data[64:96]
513 cipherText := data[96:]
514 return asn1.Marshal(sm2Cipher{x, y, hash, cipherText})
515 }
516
517
520 func CipherUnmarshal(data []byte) ([]byte, error) {
521 var cipher sm2Cipher
522 _, err := asn1.Unmarshal(data, &cipher)
523 if err != nil {
524 return nil, err
525 }
526 x := cipher.XCoordinate.Bytes()
527 y := cipher.YCoordinate.Bytes()
528 hash := cipher.HASH
529 if err != nil {
530 return nil, err
531 }
532 cipherText := cipher.CipherText
533 if err != nil {
534 return nil, err
535 }
536 if n := len(x); n < 32 {
537 x = append(zeroByteSlice()[:32-n], x...)
538 }
539 if n := len(y); n < 32 {
540 y = append(zeroByteSlice()[:32-n], y...)
541 }
542 c := []byte{}
543 c = append(c, x...)
544 c = append(c, y...)
545 c = append(c, hash...)
546 c = append(c, cipherText...)
547 return append([]byte{0x04}, c...), nil
548 }
549
550
551
552 func keXHat(x *big.Int) (xul *big.Int) {
553 buf := x.Bytes()
554 for i := 0; i < len(buf)-16; i++ {
555 buf[i] = 0
556 }
557 if len(buf) >= 16 {
558 c := buf[len(buf)-16]
559 buf[len(buf)-16] = c & 0x7f
560 }
561
562 r := new(big.Int).SetBytes(buf)
563 _2w := new(big.Int).SetBytes([]byte{
564 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
565 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
566 return r.Add(r, _2w)
567 }
568
569 func BytesCombine(pBytes ...[]byte) []byte {
570 len := len(pBytes)
571 s := make([][]byte, len)
572 for index := 0; index < len; index++ {
573 s[index] = pBytes[index]
574 }
575 sep := []byte("")
576 return bytes.Join(s, sep)
577 }
578
579 func intToBytes(x int) []byte {
580 var buf = make([]byte, 4)
581
582 binary.BigEndian.PutUint32(buf, uint32(x))
583 return buf
584 }
585
586 func kdf(length int, x ...[]byte) ([]byte, bool) {
587 var c []byte
588
589 ct := 1
590 h := sm3.New()
591 for i, j := 0, (length+31)/32; i < j; i++ {
592 h.Reset()
593 for _, xx := range x {
594 h.Write(xx)
595 }
596 h.Write(intToBytes(ct))
597 hash := h.Sum(nil)
598 if i+1 == j && length%32 != 0 {
599 c = append(c, hash[:length%32]...)
600 } else {
601 c = append(c, hash...)
602 }
603 ct++
604 }
605 for i := 0; i < length; i++ {
606 if c[i] != 0 {
607 return c, true
608 }
609 }
610 return c, false
611 }
612
613 func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
614 if random == nil {
615 random = rand.Reader
616 }
617 params := c.Params()
618 b := make([]byte, params.BitSize/8+8)
619 _, err = io.ReadFull(random, b)
620 if err != nil {
621 return
622 }
623 k = new(big.Int).SetBytes(b)
624 n := new(big.Int).Sub(params.N, one)
625 k.Mod(k, n)
626 k.Add(k, one)
627 return
628 }
629
630 func GenerateKey(random io.Reader) (*PrivateKey, error) {
631 c := P256Sm2()
632 if random == nil {
633 random = rand.Reader
634 }
635 params := c.Params()
636 b := make([]byte, params.BitSize/8+8)
637 _, err := io.ReadFull(random, b)
638 if err != nil {
639 return nil, err
640 }
641
642 k := new(big.Int).SetBytes(b)
643 n := new(big.Int).Sub(params.N, two)
644 k.Mod(k, n)
645 k.Add(k, one)
646 priv := new(PrivateKey)
647 priv.PublicKey.Curve = c
648 priv.D = k
649 priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
650
651 return priv, nil
652 }
653
654 type zr struct {
655 io.Reader
656 }
657
658 func (z *zr) Read(dst []byte) (n int, err error) {
659 for i := range dst {
660 dst[i] = 0
661 }
662 return len(dst), nil
663 }
664
665 var zeroReader = &zr{}
666
667 func getLastBit(a *big.Int) uint {
668 return a.Bit(0)
669 }
670
671
672 func (priv *PrivateKey) Decrypt(_ io.Reader, msg []byte, _ crypto.DecrypterOpts) (plaintext []byte, err error) {
673 return Decrypt(priv, msg,C1C3C2)
674 }
675
View as plain text