1
2
3
4
5
6
7 package gmtls
8
9 import (
10 "crypto"
11 "crypto/ecdsa"
12 "crypto/rsa"
13 "crypto/subtle"
14 "errors"
15 "fmt"
16 "io"
17 "sync/atomic"
18
19 "github.com/tjfoc/gmsm/sm2"
20 "github.com/tjfoc/gmsm/x509"
21 )
22
23
24
25 type serverHandshakeStateGM struct {
26 c *Conn
27 clientHello *clientHelloMsg
28 hello *serverHelloMsg
29 suite *cipherSuite
30 sessionState *sessionState
31 finishedHash finishedHash
32 masterSecret []byte
33 certsFromClient [][]byte
34 cert *Certificate
35 cachedClientHelloInfo *ClientHelloInfo
36 }
37
38
39 func (c *Conn) serverHandshakeGM() error {
40
41
42 c.config.serverInitOnce.Do(func() { c.config.serverInit(nil) })
43
44 hs := serverHandshakeStateGM{
45 c: c,
46 }
47 isResume, err := hs.readClientHello()
48 if err != nil {
49 return err
50 }
51
52
53 c.buffering = true
54 if isResume {
55
56 if err := hs.doResumeHandshake(); err != nil {
57 return err
58 }
59 if err := hs.establishKeys(); err != nil {
60 return err
61 }
62
63
64
65 if hs.hello.ticketSupported {
66 if err := hs.sendSessionTicket(); err != nil {
67 return err
68 }
69 }
70 if err := hs.sendFinished(c.serverFinished[:]); err != nil {
71 return err
72 }
73 if _, err := c.flush(); err != nil {
74 return err
75 }
76 c.clientFinishedIsFirst = false
77 if err := hs.readFinished(nil); err != nil {
78 return err
79 }
80 c.didResume = true
81 } else {
82
83
84 if err := hs.doFullHandshake(); err != nil {
85 return err
86 }
87 if err := hs.establishKeys(); err != nil {
88 return err
89 }
90 if err := hs.readFinished(c.clientFinished[:]); err != nil {
91 return err
92 }
93 c.clientFinishedIsFirst = true
94 c.buffering = true
95 if err := hs.sendSessionTicket(); err != nil {
96 return err
97 }
98 if err := hs.sendFinished(nil); err != nil {
99 return err
100 }
101 if _, err := c.flush(); err != nil {
102 return err
103 }
104 }
105
106 c.ekm = ekmFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random)
107 atomic.StoreUint32(&c.handshakeStatus, 1)
108
109 return nil
110 }
111
112
113
114 func (hs *serverHandshakeStateGM) readClientHello() (isResume bool, err error) {
115 c := hs.c
116
117 msg, err := c.readHandshake()
118 if err != nil {
119 return false, err
120 }
121 var ok bool
122 hs.clientHello, ok = msg.(*clientHelloMsg)
123 if !ok {
124 c.sendAlert(alertUnexpectedMessage)
125 return false, unexpectedMessageError(hs.clientHello, msg)
126 }
127
128 if c.config.GetConfigForClient != nil {
129 if newConfig, err := c.config.GetConfigForClient(hs.clientHelloInfo()); err != nil {
130 c.sendAlert(alertInternalError)
131 return false, err
132 } else if newConfig != nil {
133 newConfig.serverInitOnce.Do(func() { newConfig.serverInit(c.config) })
134 c.config = newConfig
135 }
136 }
137
138 c.vers, ok = c.config.mutualVersion(hs.clientHello.vers)
139 if !ok {
140 c.sendAlert(alertProtocolVersion)
141 return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers)
142 }
143 c.haveVers = true
144
145 hs.hello = new(serverHelloMsg)
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168 foundCompression := false
169
170 for _, compression := range hs.clientHello.compressionMethods {
171 if compression == compressionNone {
172 foundCompression = true
173 break
174 }
175 }
176
177 if !foundCompression {
178 c.sendAlert(alertHandshakeFailure)
179 return false, errors.New("tls: client does not support uncompressed connections")
180 }
181
182 hs.hello.vers = c.vers
183 hs.hello.random = make([]byte, 32)
184 _, err = io.ReadFull(c.config.rand(), hs.hello.random)
185 if err != nil {
186 c.sendAlert(alertInternalError)
187 return false, err
188 }
189
190 if len(hs.clientHello.secureRenegotiation) != 0 {
191 c.sendAlert(alertHandshakeFailure)
192 return false, errors.New("tls: initial handshake had non-empty renegotiation extension")
193 }
194
195 hs.hello.secureRenegotiationSupported = hs.clientHello.secureRenegotiationSupported
196 hs.hello.compressionMethod = compressionNone
197 if len(hs.clientHello.serverName) > 0 {
198 c.serverName = hs.clientHello.serverName
199 }
200
201 if len(hs.clientHello.alpnProtocols) > 0 {
202 if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback {
203 hs.hello.alpnProtocol = selectedProto
204 c.clientProtocol = selectedProto
205 }
206 } else {
207
208
209
210
211 if hs.clientHello.nextProtoNeg && len(c.config.NextProtos) > 0 {
212 hs.hello.nextProtoNeg = true
213 hs.hello.nextProtos = c.config.NextProtos
214 }
215 }
216
217 hs.cert, err = c.config.getCertificate(hs.clientHelloInfo())
218 if err != nil {
219 c.sendAlert(alertInternalError)
220 return false, err
221 }
222
223
224
225
226
227
228 if len(hs.cert.Certificate) < 1 {
229 c.sendAlert(alertInternalError)
230 return false, fmt.Errorf("tls: amount of server certificates must be greater than 0")
231 }
232 if hs.clientHello.scts {
233 hs.hello.scts = hs.cert.SignedCertificateTimestamps
234 }
235
236 if hs.checkForResumption() {
237 return true, nil
238 }
239
240 var preferenceList, supportedList []uint16
241 if c.config.PreferServerCipherSuites {
242 preferenceList = getCipherSuites(c.config)
243 supportedList = hs.clientHello.cipherSuites
244 } else {
245 preferenceList = hs.clientHello.cipherSuites
246 supportedList = getCipherSuites(c.config)
247 }
248
249 for _, id := range preferenceList {
250 if hs.setCipherSuite(id, supportedList, c.vers) {
251 break
252 }
253 }
254
255 if hs.suite == nil {
256 c.sendAlert(alertHandshakeFailure)
257 return false, errors.New("tls: no cipher suite supported by both client and server")
258 }
259
260
261 for _, id := range hs.clientHello.cipherSuites {
262 if id == TLS_FALLBACK_SCSV {
263
264 if hs.clientHello.vers < c.config.maxVersion() {
265 c.sendAlert(alertInappropriateFallback)
266 return false, errors.New("tls: client using inappropriate protocol fallback")
267 }
268 break
269 }
270 }
271
272 return false, nil
273 }
274
275
276 func (hs *serverHandshakeStateGM) checkForResumption() bool {
277 c := hs.c
278
279 if c.config.SessionTicketsDisabled {
280 return false
281 }
282
283 var ok bool
284 var sessionTicket = append([]uint8{}, hs.clientHello.sessionTicket...)
285 if hs.sessionState, ok = c.decryptTicket(sessionTicket); !ok {
286 return false
287 }
288
289
290 if c.vers != hs.sessionState.vers {
291 return false
292 }
293
294 cipherSuiteOk := false
295
296 for _, id := range hs.clientHello.cipherSuites {
297 if id == hs.sessionState.cipherSuite {
298 cipherSuiteOk = true
299 break
300 }
301 }
302 if !cipherSuiteOk {
303 return false
304 }
305
306
307 if !hs.setCipherSuite(hs.sessionState.cipherSuite, c.config.cipherSuites(), hs.sessionState.vers) {
308 return false
309 }
310
311 sessionHasClientCerts := len(hs.sessionState.certificates) != 0
312 needClientCerts := c.config.ClientAuth == RequireAnyClientCert || c.config.ClientAuth == RequireAndVerifyClientCert
313 if needClientCerts && !sessionHasClientCerts {
314 return false
315 }
316 if sessionHasClientCerts && c.config.ClientAuth == NoClientCert {
317 return false
318 }
319
320 return true
321 }
322
323 func (hs *serverHandshakeStateGM) doResumeHandshake() error {
324 c := hs.c
325
326 hs.hello.cipherSuite = hs.suite.id
327
328
329 hs.hello.sessionId = hs.clientHello.sessionId
330 hs.hello.ticketSupported = hs.sessionState.usedOldKey
331 hs.finishedHash = newFinishedHash(c.vers, hs.suite)
332 hs.finishedHash.discardHandshakeBuffer()
333 hs.finishedHash.Write(hs.clientHello.marshal())
334 hs.finishedHash.Write(hs.hello.marshal())
335 if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
336 return err
337 }
338
339 if len(hs.sessionState.certificates) > 0 {
340 if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil {
341 return err
342 }
343 }
344
345 hs.masterSecret = hs.sessionState.masterSecret
346
347 return nil
348 }
349
350 func (hs *serverHandshakeStateGM) doFullHandshake() error {
351 c := hs.c
352
353 if hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 {
354 hs.hello.ocspStapling = true
355 }
356
357 hs.hello.ticketSupported = hs.clientHello.ticketSupported && !c.config.SessionTicketsDisabled
358 hs.hello.cipherSuite = hs.suite.id
359
360 hs.finishedHash = newFinishedHashGM(hs.suite)
361 if c.config.ClientAuth == NoClientCert {
362
363
364 hs.finishedHash.discardHandshakeBuffer()
365 }
366 hs.finishedHash.Write(hs.clientHello.marshal())
367 hs.finishedHash.Write(hs.hello.marshal())
368 if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
369 return err
370 }
371
372 certMsg := new(certificateMsg)
373 certMsg.certificates = hs.cert.Certificate
374 hs.finishedHash.Write(certMsg.marshal())
375 if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
376 return err
377 }
378
379 if hs.hello.ocspStapling {
380 certStatus := new(certificateStatusMsg)
381 certStatus.statusType = statusTypeOCSP
382 certStatus.response = hs.cert.OCSPStaple
383 hs.finishedHash.Write(certStatus.marshal())
384 if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
385 return err
386 }
387 }
388
389 keyAgreement := hs.suite.ka(c.vers)
390 skx, err := keyAgreement.generateServerKeyExchange(c.config, hs.cert, hs.cert, hs.clientHello, hs.hello)
391 if err != nil {
392 c.sendAlert(alertHandshakeFailure)
393 return err
394 }
395 if skx != nil {
396 hs.finishedHash.Write(skx.marshal())
397 if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil {
398 return err
399 }
400 }
401
402 if c.config.ClientAuth >= RequestClientCert {
403
404 certReq := new(certificateRequestMsgGM)
405 certReq.certificateTypes = []byte{
406 byte(certTypeRSASign),
407 byte(certTypeECDSASign),
408 }
409
410
411
412
413
414
415
416
417
418
419 if c.config.ClientCAs != nil {
420 certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
421 }
422 hs.finishedHash.Write(certReq.marshal())
423 if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
424 return err
425 }
426 }
427
428 helloDone := new(serverHelloDoneMsg)
429 hs.finishedHash.Write(helloDone.marshal())
430 if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil {
431 return err
432 }
433
434 if _, err := c.flush(); err != nil {
435 return err
436 }
437
438 var pub crypto.PublicKey
439
440 msg, err := c.readHandshake()
441 if err != nil {
442 fmt.Println("readHandshake error:", err)
443 return err
444 }
445
446 var ok bool
447
448
449 if c.config.ClientAuth >= RequestClientCert {
450 if certMsg, ok = msg.(*certificateMsg); !ok {
451 c.sendAlert(alertUnexpectedMessage)
452 return unexpectedMessageError(certMsg, msg)
453 }
454 hs.finishedHash.Write(certMsg.marshal())
455
456 if len(certMsg.certificates) == 0 {
457
458 switch c.config.ClientAuth {
459 case RequireAnyClientCert, RequireAndVerifyClientCert:
460 c.sendAlert(alertBadCertificate)
461 return errors.New("tls: client didn't provide a certificate")
462 }
463 }
464
465 pub, err = hs.processCertsFromClient(certMsg.certificates)
466 if err != nil {
467 return err
468 }
469
470 msg, err = c.readHandshake()
471 if err != nil {
472 return err
473 }
474 }
475
476
477 ckx, ok := msg.(*clientKeyExchangeMsg)
478 if !ok {
479 c.sendAlert(alertUnexpectedMessage)
480 return unexpectedMessageError(ckx, msg)
481 }
482 hs.finishedHash.Write(ckx.marshal())
483
484 preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
485 if err != nil {
486 c.sendAlert(alertHandshakeFailure)
487 return err
488 }
489 hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random)
490 if err := c.config.writeKeyLog(hs.clientHello.random, hs.masterSecret); err != nil {
491 c.sendAlert(alertInternalError)
492 return err
493 }
494
495
496
497
498
499
500
501 if len(c.peerCertificates) > 0 {
502 msg, err = c.readHandshake()
503 if err != nil {
504 return err
505 }
506 certVerify, ok := msg.(*certificateVerifyMsg)
507 if !ok {
508 c.sendAlert(alertUnexpectedMessage)
509 return unexpectedMessageError(certVerify, msg)
510 }
511
512
513 _, sigType, hashFunc, err := pickSignatureAlgorithm(pub, []SignatureScheme{certVerify.signatureAlgorithm}, supportedSignatureAlgorithms, c.vers)
514 if err != nil {
515 c.sendAlert(alertIllegalParameter)
516 return err
517 }
518
519 var digest []byte
520 if digest, err = hs.finishedHash.hashForClientCertificate(sigType, hashFunc, hs.masterSecret); err == nil {
521 err = verifyHandshakeSignature(sigType, pub, hashFunc, digest, certVerify.signature)
522 }
523 if err != nil {
524 c.sendAlert(alertBadCertificate)
525 return errors.New("tls: could not validate signature of connection nonces: " + err.Error())
526 }
527
528 hs.finishedHash.Write(certVerify.marshal())
529 }
530
531 hs.finishedHash.discardHandshakeBuffer()
532
533 return nil
534 }
535
536 func (hs *serverHandshakeStateGM) establishKeys() error {
537 c := hs.c
538
539 clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV :=
540 keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen)
541
542 var clientCipher, serverCipher interface{}
543 var clientHash, serverHash macFunction
544
545 if hs.suite.aead == nil {
546 clientCipher = hs.suite.cipher(clientKey, clientIV, true )
547 clientHash = hs.suite.mac(c.vers, clientMAC)
548 serverCipher = hs.suite.cipher(serverKey, serverIV, false )
549 serverHash = hs.suite.mac(c.vers, serverMAC)
550 } else {
551 clientCipher = hs.suite.aead(clientKey, clientIV)
552 serverCipher = hs.suite.aead(serverKey, serverIV)
553 }
554
555 c.in.prepareCipherSpec(c.vers, clientCipher, clientHash)
556 c.out.prepareCipherSpec(c.vers, serverCipher, serverHash)
557
558 return nil
559 }
560
561 func (hs *serverHandshakeStateGM) readFinished(out []byte) error {
562 c := hs.c
563
564 c.readRecord(recordTypeChangeCipherSpec)
565 if c.in.err != nil {
566 return c.in.err
567 }
568
569 if hs.hello.nextProtoNeg {
570 msg, err := c.readHandshake()
571 if err != nil {
572 return err
573 }
574 nextProto, ok := msg.(*nextProtoMsg)
575 if !ok {
576 c.sendAlert(alertUnexpectedMessage)
577 return unexpectedMessageError(nextProto, msg)
578 }
579 hs.finishedHash.Write(nextProto.marshal())
580 c.clientProtocol = nextProto.proto
581 }
582
583 msg, err := c.readHandshake()
584 if err != nil {
585 return err
586 }
587 clientFinished, ok := msg.(*finishedMsg)
588 if !ok {
589 c.sendAlert(alertUnexpectedMessage)
590 return unexpectedMessageError(clientFinished, msg)
591 }
592
593 verify := hs.finishedHash.clientSum(hs.masterSecret)
594 if len(verify) != len(clientFinished.verifyData) ||
595 subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 {
596 c.sendAlert(alertHandshakeFailure)
597 return errors.New("tls: client's Finished message is incorrect")
598 }
599
600 hs.finishedHash.Write(clientFinished.marshal())
601 copy(out, verify)
602 return nil
603 }
604
605 func (hs *serverHandshakeStateGM) sendSessionTicket() error {
606 if !hs.hello.ticketSupported {
607 return nil
608 }
609
610 c := hs.c
611 m := new(newSessionTicketMsg)
612
613 var err error
614 state := sessionState{
615 vers: c.vers,
616 cipherSuite: hs.suite.id,
617 masterSecret: hs.masterSecret,
618 certificates: hs.certsFromClient,
619 }
620 m.ticket, err = c.encryptTicket(&state)
621 if err != nil {
622 return err
623 }
624
625 hs.finishedHash.Write(m.marshal())
626 if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
627 return err
628 }
629
630 return nil
631 }
632
633 func (hs *serverHandshakeStateGM) sendFinished(out []byte) error {
634 c := hs.c
635
636 if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
637 return err
638 }
639
640 finished := new(finishedMsg)
641 finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
642 hs.finishedHash.Write(finished.marshal())
643 if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
644 return err
645 }
646
647 c.cipherSuite = hs.suite.id
648 copy(out, finished.verifyData)
649
650 return nil
651 }
652
653
654
655
656 func (hs *serverHandshakeStateGM) processCertsFromClient(certificates [][]byte) (crypto.PublicKey, error) {
657 c := hs.c
658
659 hs.certsFromClient = certificates
660 certs := make([]*x509.Certificate, len(certificates))
661 var err error
662 for i, asn1Data := range certificates {
663 if certs[i], err = x509.ParseCertificate(asn1Data); err != nil {
664 c.sendAlert(alertBadCertificate)
665 return nil, errors.New("tls: failed to parse client certificate: " + err.Error())
666 }
667 }
668
669 if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 {
670 opts := x509.VerifyOptions{
671 Roots: c.config.ClientCAs,
672 CurrentTime: c.config.time(),
673 Intermediates: x509.NewCertPool(),
674 KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth},
675 }
676
677 for _, cert := range certs[1:] {
678 opts.Intermediates.AddCert(cert)
679 }
680
681 chains, err := certs[0].Verify(opts)
682 if err != nil {
683 c.sendAlert(alertBadCertificate)
684 return nil, errors.New("tls: failed to verify client's certificate: " + err.Error())
685 }
686
687 c.verifiedChains = chains
688 }
689
690 if c.config.VerifyPeerCertificate != nil {
691 if err := c.config.VerifyPeerCertificate(certificates, c.verifiedChains); err != nil {
692 c.sendAlert(alertBadCertificate)
693 return nil, err
694 }
695 }
696
697 if len(certs) == 0 {
698 return nil, nil
699 }
700
701 var pub crypto.PublicKey
702 switch key := certs[0].PublicKey.(type) {
703 case *ecdsa.PublicKey, *rsa.PublicKey, *sm2.PublicKey:
704 pub = key
705 default:
706 c.sendAlert(alertUnsupportedCertificate)
707 return nil, fmt.Errorf("tls: client's certificate contains an unsupported public key of type %T", certs[0].PublicKey)
708 }
709 c.peerCertificates = certs
710 return pub, nil
711 }
712
713
714
715
716 func (hs *serverHandshakeStateGM) setCipherSuite(id uint16, supportedCipherSuites []uint16, version uint16) bool {
717 for _, supported := range supportedCipherSuites {
718 if id == supported {
719 var candidate *cipherSuite
720
721 for _, s := range gmCipherSuites {
722 if s.id == id {
723 candidate = s
724 break
725 }
726 }
727 if candidate == nil {
728 continue
729 }
730 if version < VersionTLS12 && candidate.flags&suiteTLS12 != 0 {
731 continue
732 }
733 hs.suite = candidate
734 return true
735 }
736 }
737 return false
738 }
739
740 func (hs *serverHandshakeStateGM) clientHelloInfo() *ClientHelloInfo {
741 if hs.cachedClientHelloInfo != nil {
742 return hs.cachedClientHelloInfo
743 }
744
745 var supportedVersions []uint16
746 if hs.clientHello.vers > VersionTLS12 {
747 supportedVersions = suppVersArray[:]
748 } else if hs.clientHello.vers >= VersionSSL30 {
749 supportedVersions = suppVersArray[VersionTLS12-hs.clientHello.vers:]
750 }
751
752 hs.cachedClientHelloInfo = &ClientHelloInfo{
753 CipherSuites: hs.clientHello.cipherSuites,
754 ServerName: hs.clientHello.serverName,
755 SupportedCurves: hs.clientHello.supportedCurves,
756 SupportedPoints: hs.clientHello.supportedPoints,
757 SignatureSchemes: hs.clientHello.supportedSignatureAlgorithms,
758 SupportedProtos: hs.clientHello.alpnProtocols,
759 SupportedVersions: supportedVersions,
760 Conn: hs.c.conn,
761 }
762
763 return hs.cachedClientHelloInfo
764 }
765
View as plain text