1
2
3
4
5 package gmtls
6
7 import (
8 "bytes"
9 "crypto"
10 "crypto/ecdsa"
11 "crypto/rsa"
12 "crypto/subtle"
13 "errors"
14 "fmt"
15 "io"
16 "strconv"
17 "sync/atomic"
18
19 "github.com/tjfoc/gmsm/sm2"
20 "github.com/tjfoc/gmsm/x509"
21 )
22
23 type clientHandshakeStateGM struct {
24 c *Conn
25 serverHello *serverHelloMsg
26 hello *clientHelloMsg
27 suite *cipherSuite
28 finishedHash finishedHash
29 masterSecret []byte
30 session *ClientSessionState
31 }
32
33 func makeClientHelloGM(config *Config) (*clientHelloMsg, error) {
34 if len(config.ServerName) == 0 && !config.InsecureSkipVerify {
35 return nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
36 }
37
38 hello := &clientHelloMsg{
39 vers: config.GMSupport.GetVersion(),
40 compressionMethods: []uint8{compressionNone},
41 random: make([]byte, 32),
42 }
43 possibleCipherSuites := getCipherSuites(config)
44 hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites))
45
46 NextCipherSuite:
47 for _, suiteId := range possibleCipherSuites {
48 for _, suite := range config.GMSupport.cipherSuites() {
49 if suite.id != suiteId {
50 continue
51 }
52 hello.cipherSuites = append(hello.cipherSuites, suiteId)
53 continue NextCipherSuite
54 }
55 }
56
57 _, err := io.ReadFull(config.rand(), hello.random)
58 if err != nil {
59 return nil, errors.New("tls: short read from Rand: " + err.Error())
60 }
61
62 return hello, nil
63 }
64
65
66
67 func (hs *clientHandshakeStateGM) handshake() error {
68 c := hs.c
69
70
71 if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
72 return err
73 }
74
75 msg, err := c.readHandshake()
76 if err != nil {
77 return err
78 }
79
80 var ok bool
81 if hs.serverHello, ok = msg.(*serverHelloMsg); !ok {
82 c.sendAlert(alertUnexpectedMessage)
83 return unexpectedMessageError(hs.serverHello, msg)
84 }
85
86 if hs.serverHello.vers != VersionGMSSL {
87 hs.c.sendAlert(alertProtocolVersion)
88 return fmt.Errorf("tls: server selected unsupported protocol version %x, while expecting %x", hs.serverHello.vers, VersionGMSSL)
89 }
90
91 if err = hs.pickCipherSuite(); err != nil {
92 return err
93 }
94
95 isResume, err := hs.processServerHello()
96 if err != nil {
97 return err
98 }
99
100 hs.finishedHash = newFinishedHashGM(hs.suite)
101
102
103
104
105
106 if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) {
107 hs.finishedHash.discardHandshakeBuffer()
108 }
109
110 hs.finishedHash.Write(hs.hello.marshal())
111 hs.finishedHash.Write(hs.serverHello.marshal())
112
113 c.buffering = true
114 if isResume {
115 if err := hs.establishKeys(); err != nil {
116 return err
117 }
118 if err := hs.readSessionTicket(); err != nil {
119 return err
120 }
121 if err := hs.readFinished(c.serverFinished[:]); err != nil {
122 return err
123 }
124 c.clientFinishedIsFirst = false
125 if err := hs.sendFinished(c.clientFinished[:]); err != nil {
126 return err
127 }
128 if _, err := c.flush(); err != nil {
129 return err
130 }
131 } else {
132 if err := hs.doFullHandshake(); err != nil {
133 return err
134 }
135 if err := hs.establishKeys(); err != nil {
136 return err
137 }
138 if err := hs.sendFinished(c.clientFinished[:]); err != nil {
139 return err
140 }
141 if _, err := c.flush(); err != nil {
142 return err
143 }
144 c.clientFinishedIsFirst = true
145 if err := hs.readSessionTicket(); err != nil {
146 return err
147 }
148 if err := hs.readFinished(c.serverFinished[:]); err != nil {
149 return err
150 }
151 }
152
153 c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
154 c.didResume = isResume
155 atomic.StoreUint32(&c.handshakeStatus, 1)
156
157 return nil
158 }
159
160 func (hs *clientHandshakeStateGM) pickCipherSuite() error {
161 if hs.suite = mutualCipherSuiteGM(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil {
162 hs.c.sendAlert(alertHandshakeFailure)
163 return errors.New("tls: server chose an unconfigured cipher suite")
164 }
165
166 hs.c.cipherSuite = hs.suite.id
167 return nil
168 }
169
170 func (hs *clientHandshakeStateGM) doFullHandshake() error {
171 c := hs.c
172
173 msg, err := c.readHandshake()
174 if err != nil {
175 return err
176 }
177 certMsg, ok := msg.(*certificateMsg)
178 if !ok || len(certMsg.certificates) == 0 {
179 c.sendAlert(alertUnexpectedMessage)
180 return unexpectedMessageError(certMsg, msg)
181 }
182
183
184
185 if len(certMsg.certificates) < 2 {
186 c.sendAlert(alertInsufficientSecurity)
187 return fmt.Errorf("tls: length of certificates in GMT0024 must great than 2")
188 }
189
190 hs.finishedHash.Write(certMsg.marshal())
191
192 if c.handshakes == 0 {
193
194
195 certs := make([]*x509.Certificate, len(certMsg.certificates))
196 for i, asn1Data := range certMsg.certificates {
197 cert, err := x509.ParseCertificate(asn1Data)
198 if err != nil {
199 c.sendAlert(alertBadCertificate)
200 return errors.New("tls: failed to parse certificate from server: " + err.Error())
201 }
202
203 pubKey, _ := cert.PublicKey.(*ecdsa.PublicKey)
204 if pubKey.Curve != sm2.P256Sm2() {
205 c.sendAlert(alertUnsupportedCertificate)
206 return fmt.Errorf("tls: pubkey type of cert is error, expect sm2.publicKey")
207 }
208
209
210
211 switch i {
212 case 0:
213 if cert.KeyUsage == 0 || (cert.KeyUsage&(x509.KeyUsageDigitalSignature|cert.KeyUsage&x509.KeyUsageContentCommitment)) == 0 {
214 c.sendAlert(alertInsufficientSecurity)
215 return fmt.Errorf("tls: the keyusage of cert[0] does not exist or is not for KeyUsageDigitalSignature/KeyUsageContentCommitment, value:%d", cert.KeyUsage)
216 }
217 case 1:
218 if cert.KeyUsage == 0 || (cert.KeyUsage&(x509.KeyUsageDataEncipherment|x509.KeyUsageKeyEncipherment|x509.KeyUsageKeyAgreement)) == 0 {
219 c.sendAlert(alertInsufficientSecurity)
220 return fmt.Errorf("tls: the keyusage of cert[1] does not exist or is not for KeyUsageDataEncipherment/KeyUsageKeyEncipherment/KeyUsageKeyAgreement, value:%d", cert.KeyUsage)
221 }
222 }
223
224 certs[i] = cert
225 }
226
227 if !c.config.InsecureSkipVerify {
228 opts := x509.VerifyOptions{
229 Roots: c.config.RootCAs,
230 CurrentTime: c.config.time(),
231 DNSName: c.config.ServerName,
232 Intermediates: x509.NewCertPool(),
233 }
234 if opts.Roots == nil {
235 opts.Roots = x509.NewCertPool()
236 }
237
238 for _, rootca := range getCAs() {
239 opts.Roots.AddCert(rootca)
240 }
241 for i, cert := range certs {
242 c.verifiedChains, err = certs[i].Verify(opts)
243 if err != nil {
244 c.sendAlert(alertBadCertificate)
245 return err
246 }
247 if i == 0 || i == 1 {
248 continue
249 }
250 opts.Intermediates.AddCert(cert)
251 }
252
253 }
254
255 if c.config.VerifyPeerCertificate != nil {
256 if err := c.config.VerifyPeerCertificate(certMsg.certificates, c.verifiedChains); err != nil {
257 c.sendAlert(alertBadCertificate)
258 return err
259 }
260 }
261
262 switch certs[0].PublicKey.(type) {
263 case *sm2.PublicKey, *ecdsa.PublicKey, *rsa.PublicKey:
264 break
265 default:
266 c.sendAlert(alertUnsupportedCertificate)
267 return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey)
268 }
269
270 c.peerCertificates = certs
271 } else {
272
273
274
275
276
277
278 if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
279 c.sendAlert(alertBadCertificate)
280 return errors.New("tls: server's identity changed during renegotiation")
281 }
282 }
283
284 msg, err = c.readHandshake()
285 if err != nil {
286 return err
287 }
288
289 keyAgreement := hs.suite.ka(c.vers)
290 if ka, ok := keyAgreement.(*eccKeyAgreementGM); ok {
291 ka.encipherCert = c.peerCertificates[1]
292 }
293
294 skx, ok := msg.(*serverKeyExchangeMsg)
295 if ok {
296 hs.finishedHash.Write(skx.marshal())
297 err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx)
298 if err != nil {
299 c.sendAlert(alertUnexpectedMessage)
300 return err
301 }
302
303 msg, err = c.readHandshake()
304 if err != nil {
305 return err
306 }
307 }
308
309 var chainToSend *Certificate
310 var certRequested bool
311 certReq, ok := msg.(*certificateRequestMsgGM)
312 if ok {
313 certRequested = true
314 hs.finishedHash.Write(certReq.marshal())
315
316 if chainToSend, err = hs.getCertificate(certReq); err != nil {
317 c.sendAlert(alertInternalError)
318 return err
319 }
320
321 msg, err = c.readHandshake()
322 if err != nil {
323 return err
324 }
325 }
326
327 shd, ok := msg.(*serverHelloDoneMsg)
328 if !ok {
329 c.sendAlert(alertUnexpectedMessage)
330 return unexpectedMessageError(shd, msg)
331 }
332 hs.finishedHash.Write(shd.marshal())
333
334
335
336
337 if certRequested {
338 certMsg = new(certificateMsg)
339 certMsg.certificates = chainToSend.Certificate
340 hs.finishedHash.Write(certMsg.marshal())
341 if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
342 return err
343 }
344 }
345
346 preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[1])
347 if err != nil {
348 c.sendAlert(alertInternalError)
349 return err
350 }
351 if ckx != nil {
352 hs.finishedHash.Write(ckx.marshal())
353 if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil {
354 return err
355 }
356 }
357
358 if chainToSend != nil && len(chainToSend.Certificate) > 0 {
359 certVerify := &certificateVerifyMsg{}
360
361 key, ok := chainToSend.PrivateKey.(crypto.Signer)
362 if !ok {
363 c.sendAlert(alertInternalError)
364 return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey)
365 }
366
367 digest := hs.finishedHash.client.Sum(nil)
368
369 certVerify.signature, err = key.Sign(c.config.rand(), digest, nil)
370 if err != nil {
371 c.sendAlert(alertInternalError)
372 return err
373 }
374
375 hs.finishedHash.Write(certVerify.marshal())
376 if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil {
377 return err
378 }
379 }
380
381 hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random)
382 if err := c.config.writeKeyLog(hs.hello.random, hs.masterSecret); err != nil {
383 c.sendAlert(alertInternalError)
384 return errors.New("tls: failed to write to key log: " + err.Error())
385 }
386
387 hs.finishedHash.discardHandshakeBuffer()
388
389 return nil
390 }
391
392 func (hs *clientHandshakeStateGM) establishKeys() error {
393 c := hs.c
394
395 clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
396 keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
397 var clientCipher, serverCipher interface{}
398 var clientHash, serverHash macFunction
399 if hs.suite.cipher != nil {
400 clientCipher = hs.suite.cipher(clientKey, clientIV, false )
401 clientHash = hs.suite.mac(c.vers, clientMAC)
402 serverCipher = hs.suite.cipher(serverKey, serverIV, true )
403 serverHash = hs.suite.mac(c.vers, serverMAC)
404 } else {
405 clientCipher = hs.suite.aead(clientKey, clientIV)
406 serverCipher = hs.suite.aead(serverKey, serverIV)
407 }
408
409 c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
410 c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
411 return nil
412 }
413
414 func (hs *clientHandshakeStateGM) serverResumedSession() bool {
415
416
417 return hs.session != nil && hs.hello.sessionId != nil &&
418 bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId)
419 }
420
421 func (hs *clientHandshakeStateGM) processServerHello() (bool, error) {
422 c := hs.c
423
424 if hs.serverHello.compressionMethod != compressionNone {
425 c.sendAlert(alertUnexpectedMessage)
426 return false, errors.New("tls: server selected unsupported compression format")
427 }
428
429 if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported {
430 c.secureRenegotiation = true
431 if len(hs.serverHello.secureRenegotiation) != 0 {
432 c.sendAlert(alertHandshakeFailure)
433 return false, errors.New("tls: initial handshake had non-empty renegotiation extension")
434 }
435 }
436
437 if c.handshakes > 0 && c.secureRenegotiation {
438 var expectedSecureRenegotiation [24]byte
439 copy(expectedSecureRenegotiation[:], c.clientFinished[:])
440 copy(expectedSecureRenegotiation[12:], c.serverFinished[:])
441 if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) {
442 c.sendAlert(alertHandshakeFailure)
443 return false, errors.New("tls: incorrect renegotiation extension contents")
444 }
445 }
446
447 clientDidNPN := hs.hello.nextProtoNeg
448 clientDidALPN := len(hs.hello.alpnProtocols) > 0
449 serverHasNPN := hs.serverHello.nextProtoNeg
450 serverHasALPN := len(hs.serverHello.alpnProtocol) > 0
451
452 if !clientDidNPN && serverHasNPN {
453 c.sendAlert(alertHandshakeFailure)
454 return false, errors.New("tls: server advertised unrequested NPN extension")
455 }
456
457 if !clientDidALPN && serverHasALPN {
458 c.sendAlert(alertHandshakeFailure)
459 return false, errors.New("tls: server advertised unrequested ALPN extension")
460 }
461
462 if serverHasNPN && serverHasALPN {
463 c.sendAlert(alertHandshakeFailure)
464 return false, errors.New("tls: server advertised both NPN and ALPN extensions")
465 }
466
467 if serverHasALPN {
468 c.clientProtocol = hs.serverHello.alpnProtocol
469 c.clientProtocolFallback = false
470 }
471 c.scts = hs.serverHello.scts
472
473 if !hs.serverResumedSession() {
474 return false, nil
475 }
476
477 if hs.session.vers != c.vers {
478 c.sendAlert(alertHandshakeFailure)
479 return false, errors.New("tls: server resumed a session with a different version")
480 }
481
482 if hs.session.cipherSuite != hs.suite.id {
483 c.sendAlert(alertHandshakeFailure)
484 return false, errors.New("tls: server resumed a session with a different cipher suite")
485 }
486
487
488 hs.masterSecret = hs.session.masterSecret
489 c.peerCertificates = hs.session.serverCertificates
490 c.verifiedChains = hs.session.verifiedChains
491 return true, nil
492 }
493
494 func (hs *clientHandshakeStateGM) readFinished(out []byte) error {
495 c := hs.c
496
497 c.readRecord(recordTypeChangeCipherSpec)
498 if c.in.err != nil {
499 return c.in.err
500 }
501
502 msg, err := c.readHandshake()
503 if err != nil {
504 return err
505 }
506 serverFinished, ok := msg.(*finishedMsg)
507 if !ok {
508 c.sendAlert(alertUnexpectedMessage)
509 return unexpectedMessageError(serverFinished, msg)
510 }
511
512 verify := hs.finishedHash.serverSum(hs.masterSecret)
513 if len(verify) != len(serverFinished.verifyData) ||
514 subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
515 c.sendAlert(alertHandshakeFailure)
516 return errors.New("tls: server's Finished message was incorrect")
517 }
518 hs.finishedHash.Write(serverFinished.marshal())
519 copy(out, verify)
520 return nil
521 }
522
523 func (hs *clientHandshakeStateGM) readSessionTicket() error {
524 if !hs.serverHello.ticketSupported {
525 return nil
526 }
527
528 c := hs.c
529 msg, err := c.readHandshake()
530 if err != nil {
531 return err
532 }
533 sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
534 if !ok {
535 c.sendAlert(alertUnexpectedMessage)
536 return unexpectedMessageError(sessionTicketMsg, msg)
537 }
538 hs.finishedHash.Write(sessionTicketMsg.marshal())
539
540 hs.session = &ClientSessionState{
541 sessionTicket: sessionTicketMsg.ticket,
542 vers: c.vers,
543 cipherSuite: hs.suite.id,
544 masterSecret: hs.masterSecret,
545 serverCertificates: c.peerCertificates,
546 verifiedChains: c.verifiedChains,
547 }
548
549 return nil
550 }
551
552 func (hs *clientHandshakeStateGM) sendFinished(out []byte) error {
553 c := hs.c
554
555 if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
556 return err
557 }
558 if hs.serverHello.nextProtoNeg {
559 nextProto := new(nextProtoMsg)
560 proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
561 nextProto.proto = proto
562 c.clientProtocol = proto
563 c.clientProtocolFallback = fallback
564
565 hs.finishedHash.Write(nextProto.marshal())
566 if _, err := c.writeRecord(recordTypeHandshake, nextProto.marshal()); err != nil {
567 return err
568 }
569 }
570
571 finished := new(finishedMsg)
572 finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
573 hs.finishedHash.Write(finished.marshal())
574 if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
575 return err
576 }
577 copy(out, finished.verifyData)
578 return nil
579 }
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594 func (hs *clientHandshakeStateGM) getCertificate(certReq *certificateRequestMsgGM) (*Certificate, error) {
595 c := hs.c
596
597 if c.config.GetClientCertificate != nil {
598 var signatureSchemes []SignatureScheme
599
600 return c.config.GetClientCertificate(&CertificateRequestInfo{
601 AcceptableCAs: certReq.certificateAuthorities,
602 SignatureSchemes: signatureSchemes,
603 })
604 }
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620 findCert:
621 for i, chain := range c.config.Certificates {
622
623 for j, cert := range chain.Certificate {
624 x509Cert := chain.Leaf
625
626
627 if j != 0 || x509Cert == nil {
628 var err error
629 if x509Cert, err = x509.ParseCertificate(cert); err != nil {
630 c.sendAlert(alertInternalError)
631 return nil, errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
632 }
633 }
634
635 var isGMCert bool
636
637 if x509Cert.PublicKeyAlgorithm == x509.ECDSA {
638 pubKey, ok := x509Cert.PublicKey.(*ecdsa.PublicKey)
639 if ok && pubKey.Curve == sm2.P256Sm2() {
640 isGMCert = true
641 }
642 }
643
644 if !isGMCert {
645 continue findCert
646 }
647
648 if len(certReq.certificateAuthorities) == 0 {
649
650
651 return &chain, nil
652 }
653
654 for _, ca := range certReq.certificateAuthorities {
655 if bytes.Equal(x509Cert.RawIssuer, ca) {
656 return &chain, nil
657 }
658 }
659 }
660 }
661
662
663 return new(Certificate), nil
664 }
665
View as plain text