1
2
3
4
5
6
7 package gmtls
8
9 import (
10 "bytes"
11 "crypto"
12 "crypto/ecdsa"
13 "crypto/rsa"
14 "crypto/subtle"
15 "errors"
16 "fmt"
17 "io"
18 "strconv"
19 "sync/atomic"
20
21 "github.com/tjfoc/gmsm/sm2"
22 "github.com/tjfoc/gmsm/x509"
23 )
24
25 type clientHandshakeStateGM struct {
26 c *Conn
27 serverHello *serverHelloMsg
28 hello *clientHelloMsg
29 suite *cipherSuite
30 finishedHash finishedHash
31 masterSecret []byte
32 session *ClientSessionState
33 }
34
35 func makeClientHelloGM(config *Config) (*clientHelloMsg, error) {
36 if len(config.ServerName) == 0 && !config.InsecureSkipVerify {
37 return nil, errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config")
38 }
39
40 hello := &clientHelloMsg{
41 vers: config.GMSupport.GetVersion(),
42 compressionMethods: []uint8{compressionNone},
43 random: make([]byte, 32),
44 }
45 possibleCipherSuites := getCipherSuites(config)
46 hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites))
47
48 NextCipherSuite:
49 for _, suiteId := range possibleCipherSuites {
50 for _, suite := range config.GMSupport.cipherSuites() {
51 if suite.id != suiteId {
52 continue
53 }
54 hello.cipherSuites = append(hello.cipherSuites, suiteId)
55 continue NextCipherSuite
56 }
57 }
58
59 _, err := io.ReadFull(config.rand(), hello.random)
60 if err != nil {
61 return nil, errors.New("tls: short read from Rand: " + err.Error())
62 }
63
64 return hello, nil
65 }
66
67
68
69 func (hs *clientHandshakeStateGM) handshake() error {
70 c := hs.c
71
72
73 if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
74 return err
75 }
76
77 msg, err := c.readHandshake()
78 if err != nil {
79 return err
80 }
81
82 var ok bool
83 if hs.serverHello, ok = msg.(*serverHelloMsg); !ok {
84 c.sendAlert(alertUnexpectedMessage)
85 return unexpectedMessageError(hs.serverHello, msg)
86 }
87
88 if hs.serverHello.vers != VersionGMSSL {
89 hs.c.sendAlert(alertProtocolVersion)
90 return fmt.Errorf("tls: server selected unsupported protocol version %x, while expecting %x", hs.serverHello.vers, VersionGMSSL)
91 }
92
93 if err = hs.pickCipherSuite(); err != nil {
94 return err
95 }
96
97 isResume, err := hs.processServerHello()
98 if err != nil {
99 return err
100 }
101
102 hs.finishedHash = newFinishedHashGM(hs.suite)
103
104
105
106
107
108 if isResume || (len(c.config.Certificates) == 0 && c.config.GetClientCertificate == nil) {
109 hs.finishedHash.discardHandshakeBuffer()
110 }
111
112 hs.finishedHash.Write(hs.hello.marshal())
113 hs.finishedHash.Write(hs.serverHello.marshal())
114
115 c.buffering = true
116 if isResume {
117 if err := hs.establishKeys(); err != nil {
118 return err
119 }
120 if err := hs.readSessionTicket(); err != nil {
121 return err
122 }
123 if err := hs.readFinished(c.serverFinished[:]); err != nil {
124 return err
125 }
126 c.clientFinishedIsFirst = false
127 if err := hs.sendFinished(c.clientFinished[:]); err != nil {
128 return err
129 }
130 if _, err := c.flush(); err != nil {
131 return err
132 }
133 } else {
134 if err := hs.doFullHandshake(); err != nil {
135 return err
136 }
137 if err := hs.establishKeys(); err != nil {
138 return err
139 }
140 if err := hs.sendFinished(c.clientFinished[:]); err != nil {
141 return err
142 }
143 if _, err := c.flush(); err != nil {
144 return err
145 }
146 c.clientFinishedIsFirst = true
147 if err := hs.readSessionTicket(); err != nil {
148 return err
149 }
150 if err := hs.readFinished(c.serverFinished[:]); err != nil {
151 return err
152 }
153 }
154
155 c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random)
156 c.didResume = isResume
157 atomic.StoreUint32(&c.handshakeStatus, 1)
158
159 return nil
160 }
161
162 func (hs *clientHandshakeStateGM) pickCipherSuite() error {
163 if hs.suite = mutualCipherSuiteGM(hs.hello.cipherSuites, hs.serverHello.cipherSuite); hs.suite == nil {
164 hs.c.sendAlert(alertHandshakeFailure)
165 return errors.New("tls: server chose an unconfigured cipher suite")
166 }
167
168 hs.c.cipherSuite = hs.suite.id
169 return nil
170 }
171
172 func (hs *clientHandshakeStateGM) doFullHandshake() error {
173 c := hs.c
174
175 msg, err := c.readHandshake()
176 if err != nil {
177 return err
178 }
179 certMsg, ok := msg.(*certificateMsg)
180 if !ok || len(certMsg.certificates) == 0 {
181 c.sendAlert(alertUnexpectedMessage)
182 return unexpectedMessageError(certMsg, msg)
183 }
184
185
186
187
188
189
190
191
192 hs.finishedHash.Write(certMsg.marshal())
193
194 if c.handshakes == 0 {
195
196
197 certs := make([]*x509.Certificate, len(certMsg.certificates))
198 for i, asn1Data := range certMsg.certificates {
199 cert, err := x509.ParseCertificate(asn1Data)
200 if err != nil {
201 c.sendAlert(alertBadCertificate)
202 return errors.New("tls: failed to parse certificate from server: " + err.Error())
203 }
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229 certs[i] = cert
230 }
231
232 if !c.config.InsecureSkipVerify {
233 opts := x509.VerifyOptions{
234 Roots: c.config.RootCAs,
235 CurrentTime: c.config.time(),
236 DNSName: c.config.ServerName,
237 Intermediates: x509.NewCertPool(),
238 }
239 if opts.Roots == nil {
240 opts.Roots = x509.NewCertPool()
241 }
242
243 for _, rootca := range getCAs() {
244 opts.Roots.AddCert(rootca)
245 }
246 for i, cert := range certs {
247 if i == 0 {
248 continue
249 }
250 opts.Intermediates.AddCert(cert)
251 }
252
253 c.verifiedChains, err = certs[0].Verify(opts)
254 if err != nil {
255 c.sendAlert(alertBadCertificate)
256 return err
257 }
258 }
259
260 if c.config.VerifyPeerCertificate != nil {
261 if err := c.config.VerifyPeerCertificate(certMsg.certificates, c.verifiedChains); err != nil {
262 c.sendAlert(alertBadCertificate)
263 return err
264 }
265 }
266
267 switch certs[0].PublicKey.(type) {
268 case *sm2.PublicKey, *ecdsa.PublicKey, *rsa.PublicKey:
269 break
270 default:
271 c.sendAlert(alertUnsupportedCertificate)
272 return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", certs[0].PublicKey)
273 }
274
275 c.peerCertificates = certs
276 } else {
277
278
279
280
281
282
283 if !bytes.Equal(c.peerCertificates[0].Raw, certMsg.certificates[0]) {
284 c.sendAlert(alertBadCertificate)
285 return errors.New("tls: server's identity changed during renegotiation")
286 }
287 }
288
289 msg, err = c.readHandshake()
290 if err != nil {
291 return err
292 }
293
294 keyAgreement := hs.suite.ka(c.vers)
295 if ka, ok := keyAgreement.(*eccKeyAgreementGM); ok {
296
297
298 ka.encipherCert = c.peerCertificates[0]
299 }
300
301 skx, ok := msg.(*serverKeyExchangeMsg)
302 if ok {
303 hs.finishedHash.Write(skx.marshal())
304 err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx)
305 if err != nil {
306 c.sendAlert(alertUnexpectedMessage)
307 return err
308 }
309
310 msg, err = c.readHandshake()
311 if err != nil {
312 return err
313 }
314 }
315
316 var chainToSend *Certificate
317 var certRequested bool
318 certReq, ok := msg.(*certificateRequestMsgGM)
319 if ok {
320 certRequested = true
321 hs.finishedHash.Write(certReq.marshal())
322
323 if chainToSend, err = hs.getCertificate(certReq); err != nil || chainToSend.Certificate == nil {
324 c.sendAlert(alertInternalError)
325 return err
326 }
327
328 msg, err = c.readHandshake()
329 if err != nil {
330 return err
331 }
332 }
333
334 shd, ok := msg.(*serverHelloDoneMsg)
335 if !ok {
336 c.sendAlert(alertUnexpectedMessage)
337 return unexpectedMessageError(shd, msg)
338 }
339 hs.finishedHash.Write(shd.marshal())
340
341
342
343
344 if certRequested {
345 certMsg = new(certificateMsg)
346 certMsg.certificates = chainToSend.Certificate
347 hs.finishedHash.Write(certMsg.marshal())
348 if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
349 return err
350 }
351 }
352
353
354
355 preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, c.peerCertificates[0])
356 if err != nil {
357 c.sendAlert(alertInternalError)
358 return err
359 }
360 if ckx != nil {
361 hs.finishedHash.Write(ckx.marshal())
362 if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil {
363 return err
364 }
365 }
366
367 if chainToSend != nil && len(chainToSend.Certificate) > 0 {
368 certVerify := &certificateVerifyMsg{}
369
370 key, ok := chainToSend.PrivateKey.(crypto.Signer)
371 if !ok {
372 c.sendAlert(alertInternalError)
373 return fmt.Errorf("tls: client certificate private key of type %T does not implement crypto.Signer", chainToSend.PrivateKey)
374 }
375
376 digest := hs.finishedHash.client.Sum(nil)
377
378 certVerify.signature, err = key.Sign(c.config.rand(), digest, nil)
379 if err != nil {
380 c.sendAlert(alertInternalError)
381 return err
382 }
383
384 hs.finishedHash.Write(certVerify.marshal())
385 if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil {
386 return err
387 }
388 }
389
390 hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random)
391 if err := c.config.writeKeyLog(hs.hello.random, hs.masterSecret); err != nil {
392 c.sendAlert(alertInternalError)
393 return errors.New("tls: failed to write to key log: " + err.Error())
394 }
395
396 hs.finishedHash.discardHandshakeBuffer()
397
398 return nil
399 }
400
401 func (hs *clientHandshakeStateGM) establishKeys() error {
402 c := hs.c
403
404 clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
405 keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
406 var clientCipher, serverCipher interface{}
407 var clientHash, serverHash macFunction
408 if hs.suite.cipher != nil {
409 clientCipher = hs.suite.cipher(clientKey, clientIV, false )
410 clientHash = hs.suite.mac(c.vers, clientMAC)
411 serverCipher = hs.suite.cipher(serverKey, serverIV, true )
412 serverHash = hs.suite.mac(c.vers, serverMAC)
413 } else {
414 clientCipher = hs.suite.aead(clientKey, clientIV)
415 serverCipher = hs.suite.aead(serverKey, serverIV)
416 }
417
418 c.in.prepareCipherSpec(c.vers, serverCipher, serverHash)
419 c.out.prepareCipherSpec(c.vers, clientCipher, clientHash)
420 return nil
421 }
422
423 func (hs *clientHandshakeStateGM) serverResumedSession() bool {
424
425
426 return hs.session != nil && hs.hello.sessionId != nil &&
427 bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId)
428 }
429
430 func (hs *clientHandshakeStateGM) processServerHello() (bool, error) {
431 c := hs.c
432
433 if hs.serverHello.compressionMethod != compressionNone {
434 c.sendAlert(alertUnexpectedMessage)
435 return false, errors.New("tls: server selected unsupported compression format")
436 }
437
438 if c.handshakes == 0 && hs.serverHello.secureRenegotiationSupported {
439 c.secureRenegotiation = true
440 if len(hs.serverHello.secureRenegotiation) != 0 {
441 c.sendAlert(alertHandshakeFailure)
442 return false, errors.New("tls: initial handshake had non-empty renegotiation extension")
443 }
444 }
445
446 if c.handshakes > 0 && c.secureRenegotiation {
447 var expectedSecureRenegotiation [24]byte
448 copy(expectedSecureRenegotiation[:], c.clientFinished[:])
449 copy(expectedSecureRenegotiation[12:], c.serverFinished[:])
450 if !bytes.Equal(hs.serverHello.secureRenegotiation, expectedSecureRenegotiation[:]) {
451 c.sendAlert(alertHandshakeFailure)
452 return false, errors.New("tls: incorrect renegotiation extension contents")
453 }
454 }
455
456 clientDidNPN := hs.hello.nextProtoNeg
457 clientDidALPN := len(hs.hello.alpnProtocols) > 0
458 serverHasNPN := hs.serverHello.nextProtoNeg
459 serverHasALPN := len(hs.serverHello.alpnProtocol) > 0
460
461 if !clientDidNPN && serverHasNPN {
462 c.sendAlert(alertHandshakeFailure)
463 return false, errors.New("tls: server advertised unrequested NPN extension")
464 }
465
466 if !clientDidALPN && serverHasALPN {
467 c.sendAlert(alertHandshakeFailure)
468 return false, errors.New("tls: server advertised unrequested ALPN extension")
469 }
470
471 if serverHasNPN && serverHasALPN {
472 c.sendAlert(alertHandshakeFailure)
473 return false, errors.New("tls: server advertised both NPN and ALPN extensions")
474 }
475
476 if serverHasALPN {
477 c.clientProtocol = hs.serverHello.alpnProtocol
478 c.clientProtocolFallback = false
479 }
480 c.scts = hs.serverHello.scts
481
482 if !hs.serverResumedSession() {
483 return false, nil
484 }
485
486 if hs.session.vers != c.vers {
487 c.sendAlert(alertHandshakeFailure)
488 return false, errors.New("tls: server resumed a session with a different version")
489 }
490
491 if hs.session.cipherSuite != hs.suite.id {
492 c.sendAlert(alertHandshakeFailure)
493 return false, errors.New("tls: server resumed a session with a different cipher suite")
494 }
495
496
497 hs.masterSecret = hs.session.masterSecret
498 c.peerCertificates = hs.session.serverCertificates
499 c.verifiedChains = hs.session.verifiedChains
500 return true, nil
501 }
502
503 func (hs *clientHandshakeStateGM) readFinished(out []byte) error {
504 c := hs.c
505
506 c.readRecord(recordTypeChangeCipherSpec)
507 if c.in.err != nil {
508 return c.in.err
509 }
510
511 msg, err := c.readHandshake()
512 if err != nil {
513 return err
514 }
515 serverFinished, ok := msg.(*finishedMsg)
516 if !ok {
517 c.sendAlert(alertUnexpectedMessage)
518 return unexpectedMessageError(serverFinished, msg)
519 }
520
521 verify := hs.finishedHash.serverSum(hs.masterSecret)
522 if len(verify) != len(serverFinished.verifyData) ||
523 subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 {
524 c.sendAlert(alertHandshakeFailure)
525 return errors.New("tls: server's Finished message was incorrect")
526 }
527 hs.finishedHash.Write(serverFinished.marshal())
528 copy(out, verify)
529 return nil
530 }
531
532 func (hs *clientHandshakeStateGM) readSessionTicket() error {
533 if !hs.serverHello.ticketSupported {
534 return nil
535 }
536
537 c := hs.c
538 msg, err := c.readHandshake()
539 if err != nil {
540 return err
541 }
542 sessionTicketMsg, ok := msg.(*newSessionTicketMsg)
543 if !ok {
544 c.sendAlert(alertUnexpectedMessage)
545 return unexpectedMessageError(sessionTicketMsg, msg)
546 }
547 hs.finishedHash.Write(sessionTicketMsg.marshal())
548
549 hs.session = &ClientSessionState{
550 sessionTicket: sessionTicketMsg.ticket,
551 vers: c.vers,
552 cipherSuite: hs.suite.id,
553 masterSecret: hs.masterSecret,
554 serverCertificates: c.peerCertificates,
555 verifiedChains: c.verifiedChains,
556 }
557
558 return nil
559 }
560
561 func (hs *clientHandshakeStateGM) sendFinished(out []byte) error {
562 c := hs.c
563
564 if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
565 return err
566 }
567 if hs.serverHello.nextProtoNeg {
568 nextProto := new(nextProtoMsg)
569 proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos)
570 nextProto.proto = proto
571 c.clientProtocol = proto
572 c.clientProtocolFallback = fallback
573
574 hs.finishedHash.Write(nextProto.marshal())
575 if _, err := c.writeRecord(recordTypeHandshake, nextProto.marshal()); err != nil {
576 return err
577 }
578 }
579
580 finished := new(finishedMsg)
581 finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
582 hs.finishedHash.Write(finished.marshal())
583 if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
584 return err
585 }
586 copy(out, finished.verifyData)
587 return nil
588 }
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603 func (hs *clientHandshakeStateGM) getCertificate(certReq *certificateRequestMsgGM) (*Certificate, error) {
604 c := hs.c
605
606 if c.config.GetClientCertificate != nil {
607 var signatureSchemes []SignatureScheme
608
609 return c.config.GetClientCertificate(&CertificateRequestInfo{
610 AcceptableCAs: certReq.certificateAuthorities,
611 SignatureSchemes: signatureSchemes,
612 })
613 }
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629 findCert:
630 for i, chain := range c.config.Certificates {
631
632 for j, cert := range chain.Certificate {
633 x509Cert := chain.Leaf
634
635
636 if j != 0 || x509Cert == nil {
637 var err error
638 if x509Cert, err = x509.ParseCertificate(cert); err != nil {
639 c.sendAlert(alertInternalError)
640 return nil, errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error())
641 }
642 }
643
644 switch {
645 case x509Cert.PublicKeyAlgorithm == x509.SM2:
646 default:
647 continue findCert
648 }
649
650 if len(certReq.certificateAuthorities) == 0 {
651
652
653 return &chain, nil
654 }
655
656 for _, ca := range certReq.certificateAuthorities {
657 if bytes.Equal(x509Cert.RawIssuer, ca) {
658 return &chain, nil
659 }
660 }
661 }
662 }
663
664
665 return new(Certificate), nil
666 }
667
View as plain text