1
15
16 package gmtls
17
18 import (
19 "crypto"
20 "crypto/elliptic"
21 "crypto/md5"
22 "crypto/rsa"
23 "crypto/sha1"
24 "errors"
25 "io"
26 "math/big"
27
28 "github.com/tjfoc/gmsm/x509"
29
30 "golang.org/x/crypto/curve25519"
31 )
32
33 var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message")
34 var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message")
35
36
37
38 type rsaKeyAgreement struct{}
39
40 func (ka rsaKeyAgreement) generateServerKeyExchange(config *Config, signCert, cipherCert *Certificate,
41 clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
42 return nil, nil
43 }
44
45 func (ka rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
46 if len(ckx.ciphertext) < 2 {
47 return nil, errClientKeyExchange
48 }
49
50 ciphertext := ckx.ciphertext
51 if version != VersionSSL30 {
52 ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1])
53 if ciphertextLen != len(ckx.ciphertext)-2 {
54 return nil, errClientKeyExchange
55 }
56 ciphertext = ckx.ciphertext[2:]
57 }
58 priv, ok := cert.PrivateKey.(crypto.Decrypter)
59 if !ok {
60 return nil, errors.New("tls: certificate private key does not implement crypto.Decrypter")
61 }
62
63 preMasterSecret, err := priv.Decrypt(config.rand(), ciphertext, &rsa.PKCS1v15DecryptOptions{SessionKeyLen: 48})
64 if err != nil {
65 return nil, err
66 }
67
68
69
70
71
72
73 return preMasterSecret, nil
74 }
75
76 func (ka rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
77 return errors.New("tls: unexpected ServerKeyExchange")
78 }
79
80 func (ka rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
81 preMasterSecret := make([]byte, 48)
82 preMasterSecret[0] = byte(clientHello.vers >> 8)
83 preMasterSecret[1] = byte(clientHello.vers)
84 _, err := io.ReadFull(config.rand(), preMasterSecret[2:])
85 if err != nil {
86 return nil, nil, err
87 }
88
89 encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret)
90 if err != nil {
91 return nil, nil, err
92 }
93 ckx := new(clientKeyExchangeMsg)
94 ckx.ciphertext = make([]byte, len(encrypted)+2)
95 ckx.ciphertext[0] = byte(len(encrypted) >> 8)
96 ckx.ciphertext[1] = byte(len(encrypted))
97 copy(ckx.ciphertext[2:], encrypted)
98 return preMasterSecret, ckx, nil
99 }
100
101
102 func sha1Hash(slices [][]byte) []byte {
103 hsha1 := sha1.New()
104 for _, slice := range slices {
105 hsha1.Write(slice)
106 }
107 return hsha1.Sum(nil)
108 }
109
110
111
112 func md5SHA1Hash(slices [][]byte) []byte {
113 md5sha1 := make([]byte, md5.Size+sha1.Size)
114 hmd5 := md5.New()
115 for _, slice := range slices {
116 hmd5.Write(slice)
117 }
118 copy(md5sha1, hmd5.Sum(nil))
119 copy(md5sha1[md5.Size:], sha1Hash(slices))
120 return md5sha1
121 }
122
123
124
125
126 func hashForServerKeyExchange(sigType uint8, hashFunc crypto.Hash, version uint16, slices ...[]byte) ([]byte, error) {
127 if version >= VersionTLS12 {
128 h := hashFunc.New()
129 for _, slice := range slices {
130 h.Write(slice)
131 }
132 digest := h.Sum(nil)
133 return digest, nil
134 }
135 if sigType == signatureECDSA {
136 return sha1Hash(slices), nil
137 }
138 return md5SHA1Hash(slices), nil
139 }
140
141 func curveForCurveID(id CurveID) (elliptic.Curve, bool) {
142 switch id {
143 case CurveP256:
144 return elliptic.P256(), true
145 case CurveP384:
146 return elliptic.P384(), true
147 case CurveP521:
148 return elliptic.P521(), true
149 default:
150 return nil, false
151 }
152
153 }
154
155
156
157
158
159 type ecdheKeyAgreement struct {
160 version uint16
161 isRSA bool
162 privateKey []byte
163 curveid CurveID
164
165
166
167 publicKey []byte
168
169
170 x, y *big.Int
171 }
172
173 func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, signCert, cipherCert *Certificate,
174 clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) {
175 preferredCurves := config.curvePreferences()
176
177 NextCandidate:
178 for _, candidate := range preferredCurves {
179 for _, c := range clientHello.supportedCurves {
180 if candidate == c {
181 ka.curveid = c
182 break NextCandidate
183 }
184 }
185 }
186
187 if ka.curveid == 0 {
188 return nil, errors.New("tls: no supported elliptic curves offered")
189 }
190
191 var ecdhePublic []byte
192
193 if ka.curveid == X25519 {
194 var scalar, public [32]byte
195 if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
196 return nil, err
197 }
198
199 curve25519.ScalarBaseMult(&public, &scalar)
200 ka.privateKey = scalar[:]
201 ecdhePublic = public[:]
202 } else {
203 curve, ok := curveForCurveID(ka.curveid)
204 if !ok {
205 return nil, errors.New("tls: preferredCurves includes unsupported curve")
206 }
207
208 var x, y *big.Int
209 var err error
210 ka.privateKey, x, y, err = elliptic.GenerateKey(curve, config.rand())
211 if err != nil {
212 return nil, err
213 }
214 ecdhePublic = elliptic.Marshal(curve, x, y)
215 }
216
217
218 serverECDHParams := make([]byte, 1+2+1+len(ecdhePublic))
219 serverECDHParams[0] = 3
220 serverECDHParams[1] = byte(ka.curveid >> 8)
221 serverECDHParams[2] = byte(ka.curveid)
222 serverECDHParams[3] = byte(len(ecdhePublic))
223 copy(serverECDHParams[4:], ecdhePublic)
224
225 priv, ok := signCert.PrivateKey.(crypto.Signer)
226 if !ok {
227 return nil, errors.New("tls: certificate private key does not implement crypto.Signer")
228 }
229
230 signatureAlgorithm, sigType, hashFunc, err := pickSignatureAlgorithm(priv.Public(), clientHello.supportedSignatureAlgorithms, supportedSignatureAlgorithms, ka.version)
231 if err != nil {
232 return nil, err
233 }
234 if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
235 return nil, errors.New("tls: certificate cannot be used with the selected cipher suite")
236 }
237
238 digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, hello.random, serverECDHParams)
239 if err != nil {
240 return nil, err
241 }
242
243 signOpts := crypto.SignerOpts(hashFunc)
244 if sigType == signatureRSAPSS {
245 signOpts = &rsa.PSSOptions{SaltLength: rsa.PSSSaltLengthEqualsHash, Hash: hashFunc}
246 }
247 sig, err := priv.Sign(config.rand(), digest, signOpts)
248 if err != nil {
249 return nil, errors.New("tls: failed to sign ECDHE parameters: " + err.Error())
250 }
251
252 skx := new(serverKeyExchangeMsg)
253 sigAndHashLen := 0
254 if ka.version >= VersionTLS12 {
255 sigAndHashLen = 2
256 }
257 skx.key = make([]byte, len(serverECDHParams)+sigAndHashLen+2+len(sig))
258 copy(skx.key, serverECDHParams)
259 k := skx.key[len(serverECDHParams):]
260 if ka.version >= VersionTLS12 {
261 k[0] = byte(signatureAlgorithm >> 8)
262 k[1] = byte(signatureAlgorithm)
263 k = k[2:]
264 }
265 k[0] = byte(len(sig) >> 8)
266 k[1] = byte(len(sig))
267 copy(k[2:], sig)
268
269 return skx, nil
270 }
271
272 func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) {
273 if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 {
274 return nil, errClientKeyExchange
275 }
276
277 if ka.curveid == X25519 {
278 if len(ckx.ciphertext) != 1+32 {
279 return nil, errClientKeyExchange
280 }
281
282 var theirPublic, sharedKey, scalar [32]byte
283 copy(theirPublic[:], ckx.ciphertext[1:])
284 copy(scalar[:], ka.privateKey)
285 curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
286 return sharedKey[:], nil
287 }
288
289 curve, ok := curveForCurveID(ka.curveid)
290 if !ok {
291 panic("internal error")
292 }
293 x, y := elliptic.Unmarshal(curve, ckx.ciphertext[1:])
294 if x == nil {
295 return nil, errClientKeyExchange
296 }
297 x, _ = curve.ScalarMult(x, y, ka.privateKey)
298 preMasterSecret := make([]byte, (curve.Params().BitSize+7)>>3)
299 xBytes := x.Bytes()
300 copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
301
302 return preMasterSecret, nil
303 }
304
305 func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error {
306 if len(skx.key) < 4 {
307 return errServerKeyExchange
308 }
309 if skx.key[0] != 3 {
310 return errors.New("tls: server selected unsupported curve")
311 }
312 ka.curveid = CurveID(skx.key[1])<<8 | CurveID(skx.key[2])
313
314 publicLen := int(skx.key[3])
315 if publicLen+4 > len(skx.key) {
316 return errServerKeyExchange
317 }
318 serverECDHParams := skx.key[:4+publicLen]
319 publicKey := serverECDHParams[4:]
320
321 sig := skx.key[4+publicLen:]
322 if len(sig) < 2 {
323 return errServerKeyExchange
324 }
325
326 if ka.curveid == X25519 {
327 if len(publicKey) != 32 {
328 return errors.New("tls: bad X25519 public value")
329 }
330 ka.publicKey = publicKey
331 } else {
332 curve, ok := curveForCurveID(ka.curveid)
333 if !ok {
334 return errors.New("tls: server selected unsupported curve")
335 }
336 ka.x, ka.y = elliptic.Unmarshal(curve, publicKey)
337 if ka.x == nil {
338 return errServerKeyExchange
339 }
340 }
341
342 var signatureAlgorithm SignatureScheme
343 if ka.version >= VersionTLS12 {
344
345 signatureAlgorithm = SignatureScheme(sig[0])<<8 | SignatureScheme(sig[1])
346 sig = sig[2:]
347 if len(sig) < 2 {
348 return errServerKeyExchange
349 }
350 }
351 _, sigType, hashFunc, err := pickSignatureAlgorithm(cert.PublicKey, []SignatureScheme{signatureAlgorithm}, clientHello.supportedSignatureAlgorithms, ka.version)
352 if err != nil {
353 return err
354 }
355 if (sigType == signaturePKCS1v15 || sigType == signatureRSAPSS) != ka.isRSA {
356 return errServerKeyExchange
357 }
358
359 sigLen := int(sig[0])<<8 | int(sig[1])
360 if sigLen+2 != len(sig) {
361 return errServerKeyExchange
362 }
363 sig = sig[2:]
364
365 digest, err := hashForServerKeyExchange(sigType, hashFunc, ka.version, clientHello.random, serverHello.random, serverECDHParams)
366 if err != nil {
367 return err
368 }
369 return verifyHandshakeSignature(sigType, cert.PublicKey, hashFunc, digest, sig)
370 }
371
372 func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) {
373 if ka.curveid == 0 {
374 return nil, nil, errors.New("tls: missing ServerKeyExchange message")
375 }
376
377 var serialized, preMasterSecret []byte
378
379 if ka.curveid == X25519 {
380 var ourPublic, theirPublic, sharedKey, scalar [32]byte
381
382 if _, err := io.ReadFull(config.rand(), scalar[:]); err != nil {
383 return nil, nil, err
384 }
385
386 copy(theirPublic[:], ka.publicKey)
387 curve25519.ScalarBaseMult(&ourPublic, &scalar)
388 curve25519.ScalarMult(&sharedKey, &scalar, &theirPublic)
389 serialized = ourPublic[:]
390 preMasterSecret = sharedKey[:]
391 } else {
392 curve, ok := curveForCurveID(ka.curveid)
393 if !ok {
394 panic("internal error")
395 }
396 priv, mx, my, err := elliptic.GenerateKey(curve, config.rand())
397 if err != nil {
398 return nil, nil, err
399 }
400 x, _ := curve.ScalarMult(ka.x, ka.y, priv)
401 preMasterSecret = make([]byte, (curve.Params().BitSize+7)>>3)
402 xBytes := x.Bytes()
403 copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes)
404
405 serialized = elliptic.Marshal(curve, mx, my)
406 }
407
408 ckx := new(clientKeyExchangeMsg)
409 ckx.ciphertext = make([]byte, 1+len(serialized))
410 ckx.ciphertext[0] = byte(len(serialized))
411 copy(ckx.ciphertext[1:], serialized)
412
413 return preMasterSecret, ckx, nil
414 }
415
View as plain text