...

Source file src/github.com/tjfoc/gmsm/sm2/sm2.go

Documentation: github.com/tjfoc/gmsm/sm2

     1  /*
     2  Copyright Suzhou Tongji Fintech Research Institute 2017 All Rights Reserved.
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7  	http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  package sm2
    17  
    18  // reference to ecdsa
    19  import (
    20  	"bytes"
    21  	"crypto"
    22  	"crypto/elliptic"
    23  	"crypto/rand"
    24  	"encoding/asn1"
    25  	"encoding/binary"
    26  	"errors"
    27  	"io"
    28  	"math/big"
    29  
    30  	"github.com/tjfoc/gmsm/sm3"
    31  )
    32  
    33  var (
    34  	default_uid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38}
    35  	C1C3C2=0
    36  	C1C2C3=1
    37  )
    38  
    39  type PublicKey struct {
    40  	elliptic.Curve
    41  	X, Y *big.Int
    42  }
    43  
    44  type PrivateKey struct {
    45  	PublicKey
    46  	D *big.Int
    47  }
    48  
    49  type sm2Signature struct {
    50  	R, S *big.Int
    51  }
    52  type sm2Cipher struct {
    53  	XCoordinate *big.Int
    54  	YCoordinate *big.Int
    55  	HASH        []byte
    56  	CipherText  []byte
    57  }
    58  
    59  // The SM2's private key contains the public key
    60  func (priv *PrivateKey) Public() crypto.PublicKey {
    61  	return &priv.PublicKey
    62  }
    63  
    64  var errZeroParam = errors.New("zero parameter")
    65  var one = new(big.Int).SetInt64(1)
    66  var two = new(big.Int).SetInt64(2)
    67  
    68  // sign format = 30 + len(z) + 02 + len(r) + r + 02 + len(s) + s, z being what follows its size, ie 02+len(r)+r+02+len(s)+s
    69  func (priv *PrivateKey) Sign(random io.Reader, msg []byte, signer crypto.SignerOpts) ([]byte, error) {
    70  	r, s, err := Sm2Sign(priv, msg, nil, random)
    71  	if err != nil {
    72  		return nil, err
    73  	}
    74  	return asn1.Marshal(sm2Signature{r, s})
    75  }
    76  
    77  func (pub *PublicKey) Verify(msg []byte, sign []byte) bool {
    78  	var sm2Sign sm2Signature
    79  	_, err := asn1.Unmarshal(sign, &sm2Sign)
    80  	if err != nil {
    81  		return false
    82  	}
    83  	return Sm2Verify(pub, msg, default_uid, sm2Sign.R, sm2Sign.S)
    84  }
    85  
    86  func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) {
    87  	if len(uid) == 0 {
    88  		uid = default_uid
    89  	}
    90  
    91  	za, err := ZA(pub, uid)
    92  	if err != nil {
    93  		return nil, err
    94  	}
    95  
    96  	e, err := msgHash(za, msg)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	return e.Bytes(), nil
   102  }
   103  
   104  //****************************Encryption algorithm****************************//
   105  func (pub *PublicKey) EncryptAsn1(data []byte, random io.Reader) ([]byte, error) {
   106  	return EncryptAsn1(pub, data, random)
   107  }
   108  
   109  func (priv *PrivateKey) DecryptAsn1(data []byte) ([]byte, error) {
   110  	return DecryptAsn1(priv, data)
   111  }
   112  
   113  //**************************Key agreement algorithm**************************//
   114  // KeyExchangeB 协商第二部,用户B调用, 返回共享密钥k
   115  func KeyExchangeB(klen int, ida, idb []byte, priB *PrivateKey, pubA *PublicKey, rpri *PrivateKey, rpubA *PublicKey) (k, s1, s2 []byte, err error) {
   116  	return keyExchange(klen, ida, idb, priB, pubA, rpri, rpubA, false)
   117  }
   118  
   119  // KeyExchangeA 协商第二部,用户A调用,返回共享密钥k
   120  func KeyExchangeA(klen int, ida, idb []byte, priA *PrivateKey, pubB *PublicKey, rpri *PrivateKey, rpubB *PublicKey) (k, s1, s2 []byte, err error) {
   121  	return keyExchange(klen, ida, idb, priA, pubB, rpri, rpubB, true)
   122  }
   123  
   124  //****************************************************************************//
   125  
   126  func Sm2Sign(priv *PrivateKey, msg, uid []byte, random io.Reader) (r, s *big.Int, err error) {
   127  	digest, err := priv.PublicKey.Sm3Digest(msg, uid)
   128  	if err != nil {
   129  		return nil, nil, err
   130  	}
   131  	e := new(big.Int).SetBytes(digest)
   132  	c := priv.PublicKey.Curve
   133  	N := c.Params().N
   134  	if N.Sign() == 0 {
   135  		return nil, nil, errZeroParam
   136  	}
   137  	var k *big.Int
   138  	for { // 调整算法细节以实现SM2
   139  		for {
   140  			k, err = randFieldElement(c, random)
   141  			if err != nil {
   142  				r = nil
   143  				return
   144  			}
   145  			r, _ = priv.Curve.ScalarBaseMult(k.Bytes())
   146  			r.Add(r, e)
   147  			r.Mod(r, N)
   148  			if r.Sign() != 0 {
   149  				if t := new(big.Int).Add(r, k); t.Cmp(N) != 0 {
   150  					break
   151  				}
   152  			}
   153  
   154  		}
   155  		rD := new(big.Int).Mul(priv.D, r)
   156  		s = new(big.Int).Sub(k, rD)
   157  		d1 := new(big.Int).Add(priv.D, one)
   158  		d1Inv := new(big.Int).ModInverse(d1, N)
   159  		s.Mul(s, d1Inv)
   160  		s.Mod(s, N)
   161  		if s.Sign() != 0 {
   162  			break
   163  		}
   164  	}
   165  	return
   166  }
   167  func Sm2Verify(pub *PublicKey, msg, uid []byte, r, s *big.Int) bool {
   168  	c := pub.Curve
   169  	N := c.Params().N
   170  	one := new(big.Int).SetInt64(1)
   171  	if r.Cmp(one) < 0 || s.Cmp(one) < 0 {
   172  		return false
   173  	}
   174  	if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
   175  		return false
   176  	}
   177  	if len(uid) == 0 {
   178  		uid = default_uid
   179  	}
   180  	za, err := ZA(pub, uid)
   181  	if err != nil {
   182  		return false
   183  	}
   184  	e, err := msgHash(za, msg)
   185  	if err != nil {
   186  		return false
   187  	}
   188  	t := new(big.Int).Add(r, s)
   189  	t.Mod(t, N)
   190  	if t.Sign() == 0 {
   191  		return false
   192  	}
   193  	var x *big.Int
   194  	x1, y1 := c.ScalarBaseMult(s.Bytes())
   195  	x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
   196  	x, _ = c.Add(x1, y1, x2, y2)
   197  
   198  	x.Add(x, e)
   199  	x.Mod(x, N)
   200  	return x.Cmp(r) == 0
   201  }
   202  
   203  /*
   204      za, err := ZA(pub, uid)
   205  	if err != nil {
   206  		return
   207  	}
   208  	e, err := msgHash(za, msg)
   209  	hash=e.getBytes()
   210  */
   211  func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool {
   212  	c := pub.Curve
   213  	N := c.Params().N
   214  
   215  	if r.Sign() <= 0 || s.Sign() <= 0 {
   216  		return false
   217  	}
   218  	if r.Cmp(N) >= 0 || s.Cmp(N) >= 0 {
   219  		return false
   220  	}
   221  
   222  	// 调整算法细节以实现SM2
   223  	t := new(big.Int).Add(r, s)
   224  	t.Mod(t, N)
   225  	if t.Sign() == 0 {
   226  		return false
   227  	}
   228  
   229  	var x *big.Int
   230  	x1, y1 := c.ScalarBaseMult(s.Bytes())
   231  	x2, y2 := c.ScalarMult(pub.X, pub.Y, t.Bytes())
   232  	x, _ = c.Add(x1, y1, x2, y2)
   233  
   234  	e := new(big.Int).SetBytes(hash)
   235  	x.Add(x, e)
   236  	x.Mod(x, N)
   237  	return x.Cmp(r) == 0
   238  }
   239  
   240  /*
   241   * sm2密文结构如下:
   242   *  x
   243   *  y
   244   *  hash
   245   *  CipherText
   246   */
   247  func Encrypt(pub *PublicKey, data []byte, random io.Reader,mode int) ([]byte, error) {
   248  	length := len(data)
   249  	for {
   250  		c := []byte{}
   251  		curve := pub.Curve
   252  		k, err := randFieldElement(curve, random)
   253  		if err != nil {
   254  			return nil, err
   255  		}
   256  		x1, y1 := curve.ScalarBaseMult(k.Bytes())
   257  		x2, y2 := curve.ScalarMult(pub.X, pub.Y, k.Bytes())
   258  		x1Buf := x1.Bytes()
   259  		y1Buf := y1.Bytes()
   260  		x2Buf := x2.Bytes()
   261  		y2Buf := y2.Bytes()
   262  		if n := len(x1Buf); n < 32 {
   263  			x1Buf = append(zeroByteSlice()[:32-n], x1Buf...)
   264  		}
   265  		if n := len(y1Buf); n < 32 {
   266  			y1Buf = append(zeroByteSlice()[:32-n], y1Buf...)
   267  		}
   268  		if n := len(x2Buf); n < 32 {
   269  			x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
   270  		}
   271  		if n := len(y2Buf); n < 32 {
   272  			y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
   273  		}
   274  		c = append(c, x1Buf...) // x分量
   275  		c = append(c, y1Buf...) // y分量
   276  		tm := []byte{}
   277  		tm = append(tm, x2Buf...)
   278  		tm = append(tm, data...)
   279  		tm = append(tm, y2Buf...)
   280  		h := sm3.Sm3Sum(tm)
   281  		c = append(c, h...)
   282  		ct, ok := kdf(length, x2Buf, y2Buf) // 密文
   283  		if !ok {
   284  			continue
   285  		}
   286  		c = append(c, ct...)
   287  		for i := 0; i < length; i++ {
   288  			c[96+i] ^= data[i]
   289  		}
   290  		switch mode{
   291  	
   292  		case C1C3C2:
   293  			return append([]byte{0x04}, c...), nil
   294  		case C1C2C3:
   295  			c1 := make([]byte, 64)
   296  			c2 := make([]byte, len(c) - 96)
   297  			c3 := make([]byte, 32)
   298  			copy(c1, c[:64])//x1,y1
   299  			copy(c3, c[64:96])//hash
   300  			copy(c2, c[96:])//密文
   301  			ciphertext := []byte{}
   302  			ciphertext = append(ciphertext, c1...)
   303  			ciphertext = append(ciphertext, c2...)
   304  			ciphertext = append(ciphertext, c3...)
   305  			return append([]byte{0x04}, ciphertext...), nil
   306      	default:
   307  			return append([]byte{0x04}, c...), nil
   308  	}
   309  }
   310  }
   311  
   312  
   313  
   314  func Decrypt(priv *PrivateKey, data []byte,mode int) ([]byte, error) {
   315  	switch mode {
   316  	case C1C3C2:
   317  		data = data[1:]
   318  	case  C1C2C3:
   319  		data = data[1:]
   320  		c1 := make([]byte, 64)
   321  		c2 := make([]byte, len(data) - 96)
   322  		c3 := make([]byte, 32)
   323  		copy(c1, data[:64])//x1,y1
   324  		copy(c2, data[64:len(data) - 32])//密文
   325  		copy(c3, data[len(data) - 32:])//hash
   326  		c := []byte{}
   327  		c = append(c, c1...)
   328  		c = append(c, c3...)
   329  		c = append(c, c2...)
   330  		data = c
   331  	default:
   332  		data = data[1:]
   333  	}
   334  	length := len(data) - 96
   335  	curve := priv.Curve
   336  	x := new(big.Int).SetBytes(data[:32])
   337  	y := new(big.Int).SetBytes(data[32:64])
   338  	x2, y2 := curve.ScalarMult(x, y, priv.D.Bytes())
   339  	x2Buf := x2.Bytes()
   340  	y2Buf := y2.Bytes()
   341  	if n := len(x2Buf); n < 32 {
   342  		x2Buf = append(zeroByteSlice()[:32-n], x2Buf...)
   343  	}
   344  	if n := len(y2Buf); n < 32 {
   345  		y2Buf = append(zeroByteSlice()[:32-n], y2Buf...)
   346  	}
   347  	c, ok := kdf(length, x2Buf, y2Buf)
   348  	if !ok {
   349  		return nil, errors.New("Decrypt: failed to decrypt")
   350  	}
   351  	for i := 0; i < length; i++ {
   352  		c[i] ^= data[i+96]
   353  	}
   354  	tm := []byte{}
   355  	tm = append(tm, x2Buf...)
   356  	tm = append(tm, c...)
   357  	tm = append(tm, y2Buf...)
   358  	h := sm3.Sm3Sum(tm)
   359  	if bytes.Compare(h, data[64:96]) != 0 {
   360  		return c, errors.New("Decrypt: failed to decrypt")
   361  	}
   362  	return c, nil
   363  }
   364  
   365  
   366  
   367  // keyExchange 为SM2密钥交换算法的第二部和第三步复用部分,协商的双方均调用此函数计算共同的字节串
   368  // klen: 密钥长度
   369  // ida, idb: 协商双方的标识,ida为密钥协商算法发起方标识,idb为响应方标识
   370  // pri: 函数调用者的密钥
   371  // pub: 对方的公钥
   372  // rpri: 函数调用者生成的临时SM2密钥
   373  // rpub: 对方发来的临时SM2公钥
   374  // thisIsA: 如果是A调用,文档中的协商第三步,设置为true,否则设置为false
   375  // 返回 k 为klen长度的字节串
   376  func keyExchange(klen int, ida, idb []byte, pri *PrivateKey, pub *PublicKey, rpri *PrivateKey, rpub *PublicKey, thisISA bool) (k, s1, s2 []byte, err error) {
   377  	curve := P256Sm2()
   378  	N := curve.Params().N
   379  	x2hat := keXHat(rpri.PublicKey.X)
   380  	x2rb := new(big.Int).Mul(x2hat, rpri.D)
   381  	tbt := new(big.Int).Add(pri.D, x2rb)
   382  	tb := new(big.Int).Mod(tbt, N)
   383  	if !curve.IsOnCurve(rpub.X, rpub.Y) {
   384  		err = errors.New("Ra not on curve")
   385  		return
   386  	}
   387  	x1hat := keXHat(rpub.X)
   388  	ramx1, ramy1 := curve.ScalarMult(rpub.X, rpub.Y, x1hat.Bytes())
   389  	vxt, vyt := curve.Add(pub.X, pub.Y, ramx1, ramy1)
   390  
   391  	vx, vy := curve.ScalarMult(vxt, vyt, tb.Bytes())
   392  	pza := pub
   393  	if thisISA {
   394  		pza = &pri.PublicKey
   395  	}
   396  	za, err := ZA(pza, ida)
   397  	if err != nil {
   398  		return
   399  	}
   400  	zero := new(big.Int)
   401  	if vx.Cmp(zero) == 0 || vy.Cmp(zero) == 0 {
   402  		err = errors.New("V is infinite")
   403  	}
   404  	pzb := pub
   405  	if !thisISA {
   406  		pzb = &pri.PublicKey
   407  	}
   408  	zb, err := ZA(pzb, idb)
   409  	k, ok := kdf(klen, vx.Bytes(), vy.Bytes(), za, zb)
   410  	if !ok {
   411  		err = errors.New("kdf: zero key")
   412  		return
   413  	}
   414  	h1 := BytesCombine(vx.Bytes(), za, zb, rpub.X.Bytes(), rpub.Y.Bytes(), rpri.X.Bytes(), rpri.Y.Bytes())
   415  	if !thisISA {
   416  		h1 = BytesCombine(vx.Bytes(), za, zb, rpri.X.Bytes(), rpri.Y.Bytes(), rpub.X.Bytes(), rpub.Y.Bytes())
   417  	}
   418  	hash := sm3.Sm3Sum(h1)
   419  	h2 := BytesCombine([]byte{0x02}, vy.Bytes(), hash)
   420  	S1 := sm3.Sm3Sum(h2)
   421  	h3 := BytesCombine([]byte{0x03}, vy.Bytes(), hash)
   422  	S2 := sm3.Sm3Sum(h3)
   423  	return k, S1, S2, nil
   424  }
   425  
   426  func msgHash(za, msg []byte) (*big.Int, error) {
   427  	e := sm3.New()
   428  	e.Write(za)
   429  	e.Write(msg)
   430  	return new(big.Int).SetBytes(e.Sum(nil)[:32]), nil
   431  }
   432  
   433  // ZA = H256(ENTLA || IDA || a || b || xG || yG || xA || yA)
   434  func ZA(pub *PublicKey, uid []byte) ([]byte, error) {
   435  	za := sm3.New()
   436  	uidLen := len(uid)
   437  	if uidLen >= 8192 {
   438  		return []byte{}, errors.New("SM2: uid too large")
   439  	}
   440  	Entla := uint16(8 * uidLen)
   441  	za.Write([]byte{byte((Entla >> 8) & 0xFF)})
   442  	za.Write([]byte{byte(Entla & 0xFF)})
   443  	if uidLen > 0 {
   444  		za.Write(uid)
   445  	}
   446  	za.Write(sm2P256ToBig(&sm2P256.a).Bytes())
   447  	za.Write(sm2P256.B.Bytes())
   448  	za.Write(sm2P256.Gx.Bytes())
   449  	za.Write(sm2P256.Gy.Bytes())
   450  
   451  	xBuf := pub.X.Bytes()
   452  	yBuf := pub.Y.Bytes()
   453  	if n := len(xBuf); n < 32 {
   454  		xBuf = append(zeroByteSlice()[:32-n], xBuf...)
   455  	}
   456  	if n := len(yBuf); n < 32 {
   457  		yBuf = append(zeroByteSlice()[:32-n], yBuf...)
   458  	}
   459  	za.Write(xBuf)
   460  	za.Write(yBuf)
   461  	return za.Sum(nil)[:32], nil
   462  }
   463  
   464  // 32byte
   465  func zeroByteSlice() []byte {
   466  	return []byte{
   467  		0, 0, 0, 0,
   468  		0, 0, 0, 0,
   469  		0, 0, 0, 0,
   470  		0, 0, 0, 0,
   471  		0, 0, 0, 0,
   472  		0, 0, 0, 0,
   473  		0, 0, 0, 0,
   474  		0, 0, 0, 0,
   475  	}
   476  }
   477  
   478  /*
   479  sm2加密,返回asn.1编码格式的密文内容
   480  */
   481  func EncryptAsn1(pub *PublicKey, data []byte, rand io.Reader) ([]byte, error) {
   482  	cipher, err := Encrypt(pub, data, rand,C1C3C2)
   483  	if err != nil {
   484  		return nil, err
   485  	}
   486  	return CipherMarshal(cipher)
   487  }
   488  
   489  /*
   490  sm2解密,解析asn.1编码格式的密文内容
   491  */
   492  func DecryptAsn1(pub *PrivateKey, data []byte) ([]byte, error) {
   493  	cipher, err := CipherUnmarshal(data)
   494  	if err != nil {
   495  		return nil, err
   496  	}
   497  	return Decrypt(pub, cipher,C1C3C2)
   498  }
   499  
   500  /*
   501  *sm2密文转asn.1编码格式
   502  *sm2密文结构如下:
   503  *  x
   504  *  y
   505  *  hash
   506  *  CipherText
   507   */
   508  func CipherMarshal(data []byte) ([]byte, error) {
   509  	data = data[1:]
   510  	x := new(big.Int).SetBytes(data[:32])
   511  	y := new(big.Int).SetBytes(data[32:64])
   512  	hash := data[64:96]
   513  	cipherText := data[96:]
   514  	return asn1.Marshal(sm2Cipher{x, y, hash, cipherText})
   515  }
   516  
   517  /*
   518  sm2密文asn.1编码格式转C1|C3|C2拼接格式
   519  */
   520  func CipherUnmarshal(data []byte) ([]byte, error) {
   521  	var cipher sm2Cipher
   522  	_, err := asn1.Unmarshal(data, &cipher)
   523  	if err != nil {
   524  		return nil, err
   525  	}
   526  	x := cipher.XCoordinate.Bytes()
   527  	y := cipher.YCoordinate.Bytes()
   528  	hash := cipher.HASH
   529  	if err != nil {
   530  		return nil, err
   531  	}
   532  	cipherText := cipher.CipherText
   533  	if err != nil {
   534  		return nil, err
   535  	}
   536  	if n := len(x); n < 32 {
   537  		x = append(zeroByteSlice()[:32-n], x...)
   538  	}
   539  	if n := len(y); n < 32 {
   540  		y = append(zeroByteSlice()[:32-n], y...)
   541  	}
   542  	c := []byte{}
   543  	c = append(c, x...)          // x分量
   544  	c = append(c, y...)          // y分
   545  	c = append(c, hash...)       // x分量
   546  	c = append(c, cipherText...) // y分
   547  	return append([]byte{0x04}, c...), nil
   548  }
   549  
   550  // keXHat 计算 x = 2^w + (x & (2^w-1))
   551  // 密钥协商算法辅助函数
   552  func keXHat(x *big.Int) (xul *big.Int) {
   553  	buf := x.Bytes()
   554  	for i := 0; i < len(buf)-16; i++ {
   555  		buf[i] = 0
   556  	}
   557  	if len(buf) >= 16 {
   558  		c := buf[len(buf)-16]
   559  		buf[len(buf)-16] = c & 0x7f
   560  	}
   561  
   562  	r := new(big.Int).SetBytes(buf)
   563  	_2w := new(big.Int).SetBytes([]byte{
   564  		0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
   565  		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00})
   566  	return r.Add(r, _2w)
   567  }
   568  
   569  func BytesCombine(pBytes ...[]byte) []byte {
   570  	len := len(pBytes)
   571  	s := make([][]byte, len)
   572  	for index := 0; index < len; index++ {
   573  		s[index] = pBytes[index]
   574  	}
   575  	sep := []byte("")
   576  	return bytes.Join(s, sep)
   577  }
   578  
   579  func intToBytes(x int) []byte {
   580  	var buf = make([]byte, 4)
   581  
   582  	binary.BigEndian.PutUint32(buf, uint32(x))
   583  	return buf
   584  }
   585  
   586  func kdf(length int, x ...[]byte) ([]byte, bool) {
   587  	var c []byte
   588  
   589  	ct := 1
   590  	h := sm3.New()
   591  	for i, j := 0, (length+31)/32; i < j; i++ {
   592  		h.Reset()
   593  		for _, xx := range x {
   594  			h.Write(xx)
   595  		}
   596  		h.Write(intToBytes(ct))
   597  		hash := h.Sum(nil)
   598  		if i+1 == j && length%32 != 0 {
   599  			c = append(c, hash[:length%32]...)
   600  		} else {
   601  			c = append(c, hash...)
   602  		}
   603  		ct++
   604  	}
   605  	for i := 0; i < length; i++ {
   606  		if c[i] != 0 {
   607  			return c, true
   608  		}
   609  	}
   610  	return c, false
   611  }
   612  
   613  func randFieldElement(c elliptic.Curve, random io.Reader) (k *big.Int, err error) {
   614  	if random == nil {
   615  		random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
   616  	}
   617  	params := c.Params()
   618  	b := make([]byte, params.BitSize/8+8)
   619  	_, err = io.ReadFull(random, b)
   620  	if err != nil {
   621  		return
   622  	}
   623  	k = new(big.Int).SetBytes(b)
   624  	n := new(big.Int).Sub(params.N, one)
   625  	k.Mod(k, n)
   626  	k.Add(k, one)
   627  	return
   628  }
   629  
   630  func GenerateKey(random io.Reader) (*PrivateKey, error) {
   631  	c := P256Sm2()
   632  	if random == nil {
   633  		random = rand.Reader //If there is no external trusted random source,please use rand.Reader to instead of it.
   634  	}
   635  	params := c.Params()
   636  	b := make([]byte, params.BitSize/8+8)
   637  	_, err := io.ReadFull(random, b)
   638  	if err != nil {
   639  		return nil, err
   640  	}
   641  
   642  	k := new(big.Int).SetBytes(b)
   643  	n := new(big.Int).Sub(params.N, two)
   644  	k.Mod(k, n)
   645  	k.Add(k, one)
   646  	priv := new(PrivateKey)
   647  	priv.PublicKey.Curve = c
   648  	priv.D = k
   649  	priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
   650  
   651  	return priv, nil
   652  }
   653  
   654  type zr struct {
   655  	io.Reader
   656  }
   657  
   658  func (z *zr) Read(dst []byte) (n int, err error) {
   659  	for i := range dst {
   660  		dst[i] = 0
   661  	}
   662  	return len(dst), nil
   663  }
   664  
   665  var zeroReader = &zr{}
   666  
   667  func getLastBit(a *big.Int) uint {
   668  	return a.Bit(0)
   669  }
   670  
   671  // crypto.Decrypter
   672  func (priv *PrivateKey) Decrypt(_ io.Reader, msg []byte, _ crypto.DecrypterOpts) (plaintext []byte, err error) {
   673  	return Decrypt(priv, msg,C1C3C2)
   674  }
   675  

View as plain text