1
15
16 package gmtls
17
18 import (
19 "crypto"
20 "crypto/ecdsa"
21 "crypto/rsa"
22 "encoding/asn1"
23 "errors"
24 "fmt"
25
26 "github.com/tjfoc/gmsm/sm2"
27 )
28
29
30
31
32
33
34
35
36 func pickSignatureAlgorithm(pubkey crypto.PublicKey, peerSigAlgs, ourSigAlgs []SignatureScheme, tlsVersion uint16) (sigAlg SignatureScheme, sigType uint8, hashFunc crypto.Hash, err error) {
37 if tlsVersion < VersionTLS12 || len(peerSigAlgs) == 0 {
38
39
40
41
42
43 switch pubkey.(type) {
44 case *rsa.PublicKey:
45 if tlsVersion < VersionTLS12 {
46 return 0, signaturePKCS1v15, crypto.MD5SHA1, nil
47 } else {
48 return PKCS1WithSHA1, signaturePKCS1v15, crypto.SHA1, nil
49 }
50 case *ecdsa.PublicKey:
51 return ECDSAWithSHA1, signatureECDSA, crypto.SHA1, nil
52 case *sm2.PublicKey:
53 return SM2WITHSM3, signatureSM2, crypto.SHA1, nil
54 default:
55 return 0, 0, 0, fmt.Errorf("tls: unsupported public key: %T", pubkey)
56 }
57 }
58 for _, sigAlg := range peerSigAlgs {
59 if !isSupportedSignatureAlgorithm(sigAlg, ourSigAlgs) {
60 continue
61 }
62 hashAlg, err := lookupTLSHash(sigAlg)
63 if err != nil {
64 panic("tls: supported signature algorithm has an unknown hash function")
65 }
66 sigType := signatureFromSignatureScheme(sigAlg)
67 switch pubkey.(type) {
68 case *rsa.PublicKey:
69 if sigType == signaturePKCS1v15 || sigType == signatureRSAPSS {
70 return sigAlg, sigType, hashAlg, nil
71 }
72 case *ecdsa.PublicKey:
73 if sigType == signatureECDSA {
74 return sigAlg, sigType, hashAlg, nil
75 }
76 case *sm2.PublicKey:
77 if sigType == signatureECDSA {
78 return sigAlg, sigType, hashAlg, nil
79 }
80 default:
81 return 0, 0, 0, fmt.Errorf("tls: unsupported public key: %T", pubkey)
82 }
83 }
84 return 0, 0, 0, errors.New("tls: peer doesn't support any common signature algorithms")
85 }
86
87
88
89 func verifyHandshakeSignature(sigType uint8, pubkey crypto.PublicKey, hashFunc crypto.Hash, digest, sig []byte) error {
90 switch sigType {
91 case signatureECDSA:
92 pubKey, ok := pubkey.(*ecdsa.PublicKey)
93 if !ok {
94 return errors.New("tls: ECDSA signing requires a ECDSA public key")
95 }
96 ecdsaSig := new(ecdsaSignature)
97 if _, err := asn1.Unmarshal(sig, ecdsaSig); err != nil {
98 return err
99 }
100 if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 {
101 return errors.New("tls: ECDSA signature contained zero or negative values")
102 }
103 if pubKey.Curve == sm2.P256Sm2() {
104 sm2Public := sm2.PublicKey{
105 Curve: pubKey.Curve,
106 X: pubKey.X,
107 Y: pubKey.Y,
108 }
109 if !sm2Public.Verify(digest, sig) {
110 return errors.New("tls: SM2 verification failure")
111 }
112 } else if !ecdsa.Verify(pubKey, digest, ecdsaSig.R, ecdsaSig.S) {
113 return errors.New("tls: ECDSA verification failure")
114 }
115 case signaturePKCS1v15:
116 pubKey, ok := pubkey.(*rsa.PublicKey)
117 if !ok {
118 return errors.New("tls: RSA signing requires a RSA public key")
119 }
120 if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, digest, sig); err != nil {
121 return err
122 }
123 case signatureRSAPSS:
124 pubKey, ok := pubkey.(*rsa.PublicKey)
125 if !ok {
126 return errors.New("tls: RSA signing requires a RSA public key")
127 }
128 signOpts := &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash}
129 if err := rsa.VerifyPSS(pubKey, hashFunc, digest, sig, signOpts); err != nil {
130 return err
131 }
132 case signatureSM2:
133 pubKey, ok := pubkey.(*sm2.PublicKey)
134 if !ok {
135 return errors.New("tls: SM2 signing requires a SM2 public key")
136 }
137 if ok := pubKey.Verify(digest, sig); !ok {
138 return errors.New("verify sm2 signature error")
139 }
140 default:
141 return errors.New("tls: unknown signature algorithm")
142 }
143 return nil
144 }
145
View as plain text