1
2
3
4
5
6
7 package gmtls
8
9 import (
10 "crypto"
11 "crypto/ecdsa"
12 "crypto/rsa"
13 "crypto/subtle"
14 "errors"
15 "fmt"
16 "io"
17 "sync/atomic"
18
19 "github.com/tjfoc/gmsm/sm2"
20 "github.com/tjfoc/gmsm/x509"
21 )
22
23
24
25 type serverHandshakeStateGM struct {
26 c *Conn
27 clientHello *clientHelloMsg
28 hello *serverHelloMsg
29 suite *cipherSuite
30 sessionState *sessionState
31 finishedHash finishedHash
32 masterSecret []byte
33 certsFromClient [][]byte
34 cert []Certificate
35 cachedClientHelloInfo *ClientHelloInfo
36 }
37
38
39 func (c *Conn) serverHandshakeGM() error {
40
41
42 c.config.serverInitOnce.Do(func() { c.config.serverInit(nil) })
43
44 hs := serverHandshakeStateGM{
45 c: c,
46 }
47 isResume, err := hs.readClientHello()
48 if err != nil {
49 return err
50 }
51
52
53 c.buffering = true
54 if isResume {
55
56 if err := hs.doResumeHandshake(); err != nil {
57 return err
58 }
59 if err := hs.establishKeys(); err != nil {
60 return err
61 }
62
63
64
65 if hs.hello.ticketSupported {
66 if err := hs.sendSessionTicket(); err != nil {
67 return err
68 }
69 }
70 if err := hs.sendFinished(c.serverFinished[:]); err != nil {
71 return err
72 }
73 if _, err := c.flush(); err != nil {
74 return err
75 }
76 c.clientFinishedIsFirst = false
77 if err := hs.readFinished(nil); err != nil {
78 return err
79 }
80 c.didResume = true
81 } else {
82
83
84 if err := hs.doFullHandshake(); err != nil {
85 return err
86 }
87 if err := hs.establishKeys(); err != nil {
88 return err
89 }
90 if err := hs.readFinished(c.clientFinished[:]); err != nil {
91 return err
92 }
93 c.clientFinishedIsFirst = true
94 c.buffering = true
95 if err := hs.sendSessionTicket(); err != nil {
96 return err
97 }
98 if err := hs.sendFinished(nil); err != nil {
99 return err
100 }
101 if _, err := c.flush(); err != nil {
102 return err
103 }
104 }
105
106 c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
107 atomic.StoreUint32(&c.handshakeStatus, 1)
108
109 return nil
110 }
111
112
113
114 func (hs *serverHandshakeStateGM) readClientHello() (isResume bool, err error) {
115 c := hs.c
116
117 msg, err := c.readHandshake()
118 if err != nil {
119 return false, err
120 }
121 var ok bool
122 hs.clientHello, ok = msg.(*clientHelloMsg)
123 if !ok {
124 c.sendAlert(alertUnexpectedMessage)
125 return false, unexpectedMessageError(hs.clientHello, msg)
126 }
127
128 if c.config.GetConfigForClient != nil {
129 if newConfig, err := c.config.GetConfigForClient(hs.clientHelloInfo()); err != nil {
130 c.sendAlert(alertInternalError)
131 return false, err
132 } else if newConfig != nil {
133 newConfig.serverInitOnce.Do(func() { newConfig.serverInit(c.config) })
134 c.config = newConfig
135 }
136 }
137
138 c.vers, ok = c.config.mutualVersion(hs.clientHello.vers)
139 if !ok {
140 c.sendAlert(alertProtocolVersion)
141 return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
142 }
143 c.haveVers = true
144
145 hs.hello = new(serverHelloMsg)
146
147 foundCompression := false
148
149 for _, compression := range hs.clientHello.compressionMethods {
150 if compression == compressionNone {
151 foundCompression = true
152 break
153 }
154 }
155
156 if !foundCompression {
157 c.sendAlert(alertHandshakeFailure)
158 return false, errors.New("tls: client does not support uncompressed connections")
159 }
160
161 hs.hello.vers = c.vers
162 hs.hello.random = make([]byte, 32)
163 _, err = io.ReadFull(c.config.rand(), hs.hello.random)
164 if err != nil {
165 c.sendAlert(alertInternalError)
166 return false, err
167 }
168
169 if len(hs.clientHello.secureRenegotiation) != 0 {
170 c.sendAlert(alertHandshakeFailure)
171 return false, errors.New("tls: initial handshake had non-empty renegotiation extension")
172 }
173
174 hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
175 hs.hello.compressionMethod = compressionNone
176 if len(hs.clientHello.serverName) > 0 {
177 c.serverName = hs.clientHello.serverName
178 }
179
180 if len(hs.clientHello.alpnProtocols) > 0 {
181 if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback {
182 hs.hello.alpnProtocol = selectedProto
183 c.clientProtocol = selectedProto
184 }
185 } else {
186
187
188
189
190 if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 {
191 hs.hello.nextProtoNeg = true
192 hs.hello.nextProtos = c.config.NextProtos
193 }
194 }
195
196
197 c.config.getCertificate(hs.clientHelloInfo())
198 hs.cert = c.config.Certificates
199
200
201 if len(hs.cert) < 2 {
202 c.sendAlert(alertInternalError)
203 return false, fmt.Errorf("tls: amount of server certificates must be greater than 2, which will sign and encipher respectively")
204 }
205
206 if hs.clientHello.scts {
207 hs.hello.scts = hs.cert[0].SignedCertificateTimestamps
208 }
209
210 if hs.checkForResumption() {
211 return true, nil
212 }
213
214 var preferenceList, supportedList []uint16
215 if c.config.PreferServerCipherSuites {
216 preferenceList = getCipherSuites(c.config)
217 supportedList = hs.clientHello.cipherSuites
218 } else {
219 preferenceList = hs.clientHello.cipherSuites
220 supportedList = getCipherSuites(c.config)
221 }
222
223 for _, id := range preferenceList {
224 if hs.setCipherSuite(id, supportedList, c.vers) {
225 break
226 }
227 }
228
229 if hs.suite == nil {
230 c.sendAlert(alertHandshakeFailure)
231 return false, errors.New("tls: no cipher suite supported by both client and server")
232 }
233
234
235 for _, id := range hs.clientHello.cipherSuites {
236 if id == TLS_FALLBACK_SCSV {
237
238 if hs.clientHello.vers < c.config.maxVersion() {
239 c.sendAlert(alertInappropriateFallback)
240 return false, errors.New("tls: client using inappropriate protocol fallback")
241 }
242 break
243 }
244 }
245
246 return false, nil
247 }
248
249
250 func (hs *serverHandshakeStateGM) checkForResumption() bool {
251 c := hs.c
252
253 if c.config.SessionTicketsDisabled {
254 return false
255 }
256
257 var ok bool
258 var sessionTicket = append([]uint8{}, hs.clientHello.sessionTicket...)
259 if hs.sessionState, ok = c.decryptTicket(sessionTicket); !ok {
260 return false
261 }
262
263
264 if c.vers != hs.sessionState.vers {
265 return false
266 }
267
268 cipherSuiteOk := false
269
270 for _, id := range hs.clientHello.cipherSuites {
271 if id == hs.sessionState.cipherSuite {
272 cipherSuiteOk = true
273 break
274 }
275 }
276 if !cipherSuiteOk {
277 return false
278 }
279
280
281 if !hs.setCipherSuite(hs.sessionState.cipherSuite, c.config.cipherSuites(), hs.sessionState.vers) {
282 return false
283 }
284
285 sessionHasClientCerts := len(hs.sessionState.certificates) != 0
286 needClientCerts := c.config.ClientAuth == RequireAnyClientCert || c.config.ClientAuth == RequireAndVerifyClientCert
287 if needClientCerts && !sessionHasClientCerts {
288 return false
289 }
290 if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
291 return false
292 }
293
294 return true
295 }
296
297 func (hs *serverHandshakeStateGM) doResumeHandshake() error {
298 c := hs.c
299
300 hs.hello.cipherSuite = hs.suite.id
301
302
303 hs.hello.sessionId = hs.clientHello.sessionId
304 hs.hello.ticketSupported = hs.sessionState.usedOldKey
305 hs.finishedHash = newFinishedHash(c.vers, hs.suite)
306 hs.finishedHash.discardHandshakeBuffer()
307 hs.finishedHash.Write(hs.clientHello.marshal())
308 hs.finishedHash.Write(hs.hello.marshal())
309 if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
310 return err
311 }
312
313 if len(hs.sessionState.certificates) > 0 {
314 if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil {
315 return err
316 }
317 }
318
319 hs.masterSecret = hs.sessionState.masterSecret
320
321 return nil
322 }
323
324 func (hs *serverHandshakeStateGM) doFullHandshake() error {
325 c := hs.c
326
327 if hs.clientHello.ocspStapling && len(hs.cert[0].OCSPStaple) > 0 {
328 hs.hello.ocspStapling = true
329 }
330
331 hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
332 hs.hello.cipherSuite = hs.suite.id
333
334 hs.finishedHash = newFinishedHashGM(hs.suite)
335 if c.config.ClientAuth == NoClientCert {
336
337
338 hs.finishedHash.discardHandshakeBuffer()
339 }
340 hs.finishedHash.Write(hs.clientHello.marshal())
341 hs.finishedHash.Write(hs.hello.marshal())
342 if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
343 return err
344 }
345
346 certMsg := new(certificateMsg)
347
348 for i := 0; i < len(hs.cert); i++ {
349 certMsg.certificates = append(certMsg.certificates, hs.cert[i].Certificate...)
350 }
351 hs.finishedHash.Write(certMsg.marshal())
352 if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
353 return err
354 }
355
356 if hs.hello.ocspStapling {
357 certStatus := new(certificateStatusMsg)
358 certStatus.statusType = statusTypeOCSP
359 certStatus.response = hs.cert[0].OCSPStaple
360 hs.finishedHash.Write(certStatus.marshal())
361 if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
362 return err
363 }
364 }
365
366 keyAgreement := hs.suite.ka(c.vers)
367 skx, err := keyAgreement.generateServerKeyExchange(c.config, &hs.cert[0], &hs.cert[1], hs.clientHello, hs.hello)
368 if err != nil {
369 c.sendAlert(alertHandshakeFailure)
370 return err
371 }
372 if skx != nil {
373 hs.finishedHash.Write(skx.marshal())
374 if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil {
375 return err
376 }
377 }
378
379 if c.config.ClientAuth >= RequestClientCert {
380
381 certReq := new(certificateRequestMsgGM)
382 certReq.certificateTypes = []byte{
383 byte(certTypeRSASign),
384 byte(certTypeECDSASign),
385 }
386
387
388
389
390
391
392
393
394
395
396 if c.config.ClientCAs != nil {
397 certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
398 }
399 hs.finishedHash.Write(certReq.marshal())
400 if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
401 return err
402 }
403 }
404
405 helloDone := new(serverHelloDoneMsg)
406 hs.finishedHash.Write(helloDone.marshal())
407 if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil {
408 return err
409 }
410
411 if _, err := c.flush(); err != nil {
412 return err
413 }
414
415 var pub crypto.PublicKey
416
417 msg, err := c.readHandshake()
418 if err != nil {
419 fmt.Println("readHandshake error:", err)
420 return err
421 }
422
423 var ok bool
424
425
426 if c.config.ClientAuth >= RequestClientCert {
427 if certMsg, ok = msg.(*certificateMsg); !ok {
428 c.sendAlert(alertUnexpectedMessage)
429 return unexpectedMessageError(certMsg, msg)
430 }
431 hs.finishedHash.Write(certMsg.marshal())
432
433 if len(certMsg.certificates) == 0 {
434
435 switch c.config.ClientAuth {
436 case RequireAnyClientCert, RequireAndVerifyClientCert:
437 c.sendAlert(alertBadCertificate)
438 return errors.New("tls: client didn't provide a certificate")
439 }
440 }
441
442 pub, err = hs.processCertsFromClient(certMsg.certificates)
443 if err != nil {
444 return err
445 }
446
447 msg, err = c.readHandshake()
448 if err != nil {
449 return err
450 }
451 }
452
453
454 ckx, ok := msg.(*clientKeyExchangeMsg)
455 if !ok {
456 c.sendAlert(alertUnexpectedMessage)
457 return unexpectedMessageError(ckx, msg)
458 }
459 hs.finishedHash.Write(ckx.marshal())
460
461 preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, &hs.cert[1], ckx, c.vers)
462 if err != nil {
463 c.sendAlert(alertHandshakeFailure)
464 return err
465 }
466 hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
467 if err := c.config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
468 c.sendAlert(alertInternalError)
469 return err
470 }
471
472
473
474
475
476
477
478 if len(c.peerCertificates) > 0 {
479 msg, err = c.readHandshake()
480 if err != nil {
481 return err
482 }
483 certVerify, ok := msg.(*certificateVerifyMsg)
484 if !ok {
485 c.sendAlert(alertUnexpectedMessage)
486 return unexpectedMessageError(certVerify, msg)
487 }
488
489
490 _, sigType, hashFunc, err := pickSignatureAlgorithm(pub, []SignatureScheme{certVerify.signatureAlgorithm}, supportedSignatureAlgorithms, c.vers)
491 if err != nil {
492 c.sendAlert(alertIllegalParameter)
493 return err
494 }
495
496 var digest []byte
497 if digest, err = hs.finishedHash.hashForClientCertificate(sigType, hashFunc, hs.masterSecret); err == nil {
498 err = verifyHandshakeSignature(sigType, pub, hashFunc, digest, certVerify.signature)
499 }
500 if err != nil {
501 c.sendAlert(alertBadCertificate)
502 return errors.New("tls: could not validate signature of connection nonces: " + err.Error())
503 }
504
505 hs.finishedHash.Write(certVerify.marshal())
506 }
507
508 hs.finishedHash.discardHandshakeBuffer()
509
510 return nil
511 }
512
513 func (hs *serverHandshakeStateGM) establishKeys() error {
514 c := hs.c
515
516 clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
517 keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
518
519 var clientCipher, serverCipher interface{}
520 var clientHash, serverHash macFunction
521
522 if hs.suite.aead == nil {
523 clientCipher = hs.suite.cipher(clientKey, clientIV, true )
524 clientHash = hs.suite.mac(c.vers, clientMAC)
525 serverCipher = hs.suite.cipher(serverKey, serverIV, false )
526 serverHash = hs.suite.mac(c.vers, serverMAC)
527 } else {
528 clientCipher = hs.suite.aead(clientKey, clientIV)
529 serverCipher = hs.suite.aead(serverKey, serverIV)
530 }
531
532 c.in.prepareCipherSpec(c.vers, clientCipher, clientHash)
533 c.out.prepareCipherSpec(c.vers, serverCipher, serverHash)
534
535 return nil
536 }
537
538 func (hs *serverHandshakeStateGM) readFinished(out []byte) error {
539 c := hs.c
540
541 c.readRecord(recordTypeChangeCipherSpec)
542 if c.in.err != nil {
543 return c.in.err
544 }
545
546 if hs.hello.nextProtoNeg {
547 msg, err := c.readHandshake()
548 if err != nil {
549 return err
550 }
551 nextProto, ok := msg.(*nextProtoMsg)
552 if !ok {
553 c.sendAlert(alertUnexpectedMessage)
554 return unexpectedMessageError(nextProto, msg)
555 }
556 hs.finishedHash.Write(nextProto.marshal())
557 c.clientProtocol = nextProto.proto
558 }
559
560 msg, err := c.readHandshake()
561 if err != nil {
562 return err
563 }
564 clientFinished, ok := msg.(*finishedMsg)
565 if !ok {
566 c.sendAlert(alertUnexpectedMessage)
567 return unexpectedMessageError(clientFinished, msg)
568 }
569
570 verify := hs.finishedHash.clientSum(hs.masterSecret)
571 if len(verify) != len(clientFinished.verifyData) ||
572 subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
573 c.sendAlert(alertHandshakeFailure)
574 return errors.New("tls: client's Finished message is incorrect")
575 }
576
577 hs.finishedHash.Write(clientFinished.marshal())
578 copy(out, verify)
579 return nil
580 }
581
582 func (hs *serverHandshakeStateGM) sendSessionTicket() error {
583 if !hs.hello.ticketSupported {
584 return nil
585 }
586
587 c := hs.c
588 m := new(newSessionTicketMsg)
589
590 var err error
591 state := sessionState{
592 vers: c.vers,
593 cipherSuite: hs.suite.id,
594 masterSecret: hs.masterSecret,
595 certificates: hs.certsFromClient,
596 }
597 m.ticket, err = c.encryptTicket(&state)
598 if err != nil {
599 return err
600 }
601
602 hs.finishedHash.Write(m.marshal())
603 if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
604 return err
605 }
606
607 return nil
608 }
609
610 func (hs *serverHandshakeStateGM) sendFinished(out []byte) error {
611 c := hs.c
612
613 if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
614 return err
615 }
616
617 finished := new(finishedMsg)
618 finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
619 hs.finishedHash.Write(finished.marshal())
620 if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
621 return err
622 }
623
624 c.cipherSuite = hs.suite.id
625 copy(out, finished.verifyData)
626
627 return nil
628 }
629
630
631
632
633 func (hs *serverHandshakeStateGM) processCertsFromClient(certificates [][]byte) (crypto.PublicKey, error) {
634 c := hs.c
635
636 hs.certsFromClient = certificates
637 certs := make([]*x509.Certificate, len(certificates))
638 var err error
639 for i, asn1Data := range certificates {
640 if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
641 c.sendAlert(alertBadCertificate)
642 return nil, errors.New("tls: failed to parse client certificate: " + err.Error())
643 }
644 }
645
646 if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
647 opts := x509.VerifyOptions{
648 Roots: c.config.ClientCAs,
649 CurrentTime: c.config.time(),
650 Intermediates: x509.NewCertPool(),
651 KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
652 }
653
654 for _, cert := range certs[1:] {
655 opts.Intermediates.AddCert(cert)
656 }
657
658 chains, err := certs[0].Verify(opts)
659 if err != nil {
660 c.sendAlert(alertBadCertificate)
661 return nil, errors.New("tls: failed to verify client's certificate: " + err.Error())
662 }
663
664 c.verifiedChains = chains
665 }
666
667 if c.config.VerifyPeerCertificate != nil {
668 if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
669 c.sendAlert(alertBadCertificate)
670 return nil, err
671 }
672 }
673
674 if len(certs) == 0 {
675 return nil, nil
676 }
677
678 var pub crypto.PublicKey
679 switch key := certs[0].PublicKey.(type) {
680 case *ecdsa.PublicKey, *rsa.PublicKey, *sm2.PublicKey:
681 pub = key
682 default:
683 c.sendAlert(alertUnsupportedCertificate)
684 return nil, fmt.Errorf("tls: client's certificate contains an unsupported public key of type %T", certs[0].PublicKey)
685 }
686 c.peerCertificates = certs
687 return pub, nil
688 }
689
690
691
692
693 func (hs *serverHandshakeStateGM) setCipherSuite(id uint16, supportedCipherSuites []uint16, version uint16) bool {
694 for _, supported := range supportedCipherSuites {
695 if id == supported {
696 var candidate *cipherSuite
697
698 for _, s := range gmCipherSuites {
699 if s.id == id {
700 candidate = s
701 break
702 }
703 }
704 if candidate == nil {
705 continue
706 }
707 if version < VersionTLS12 && candidate.flags&suiteTLS12 != 0 {
708 continue
709 }
710 hs.suite = candidate
711 return true
712 }
713 }
714 return false
715 }
716
717 func (hs *serverHandshakeStateGM) clientHelloInfo() *ClientHelloInfo {
718 if hs.cachedClientHelloInfo != nil {
719 return hs.cachedClientHelloInfo
720 }
721
722 supportedVersions := []uint16{VersionGMSSL}
723
724
725
726
727
728
729
730
731
732
733 hs.cachedClientHelloInfo = &ClientHelloInfo{
734 CipherSuites: hs.clientHello.cipherSuites,
735 ServerName: hs.clientHello.serverName,
736 SupportedCurves: hs.clientHello.supportedCurves,
737 SupportedPoints: hs.clientHello.supportedPoints,
738 SignatureSchemes: hs.clientHello.supportedSignatureAlgorithms,
739 SupportedProtos: hs.clientHello.alpnProtocols,
740 SupportedVersions: supportedVersions,
741 Conn: hs.c.conn,
742 }
743
744 return hs.cachedClientHelloInfo
745 }
746
View as plain text