1
14
15 package gmtls
16
17 import (
18 "bytes"
19 "crypto/cipher"
20 "crypto/subtle"
21 "errors"
22 "fmt"
23 "io"
24 "net"
25 "sync"
26 "sync/atomic"
27 "time"
28
29 "github.com/tjfoc/gmsm/x509"
30 )
31
32
33
34 type Conn struct {
35
36 conn net.Conn
37 isClient bool
38
39
40
41
42 handshakeStatus uint32
43
44 handshakeMutex sync.Mutex
45 handshakeErr error
46 vers uint16
47 haveVers bool
48 config *Config
49
50
51
52 handshakes int
53 didResume bool
54 cipherSuite uint16
55 ocspResponse []byte
56 scts [][]byte
57 peerCertificates []*x509.Certificate
58
59
60 verifiedChains [][]*x509.Certificate
61
62 serverName string
63
64
65
66 secureRenegotiation bool
67
68 ekm func(label string, context []byte, length int) ([]byte, error)
69
70
71
72
73
74 clientFinishedIsFirst bool
75
76
77 closeNotifyErr error
78
79
80 closeNotifySent bool
81
82
83
84
85
86 clientFinished [12]byte
87 serverFinished [12]byte
88
89 clientProtocol string
90 clientProtocolFallback bool
91
92
93 in, out halfConn
94 rawInput *block
95 input *block
96 hand bytes.Buffer
97 buffering bool
98 sendBuf []byte
99
100
101
102 bytesSent int64
103 packetsSent int64
104
105
106
107 warnCount int
108
109
110
111
112 activeCall int32
113
114 tmp [16]byte
115 }
116
117
118
119
120
121
122 func (c *Conn) LocalAddr() net.Addr {
123 return c.conn.LocalAddr()
124 }
125
126
127 func (c *Conn) RemoteAddr() net.Addr {
128 return c.conn.RemoteAddr()
129 }
130
131
132
133
134 func (c *Conn) SetDeadline(t time.Time) error {
135 return c.conn.SetDeadline(t)
136 }
137
138
139
140 func (c *Conn) SetReadDeadline(t time.Time) error {
141 return c.conn.SetReadDeadline(t)
142 }
143
144
145
146
147 func (c *Conn) SetWriteDeadline(t time.Time) error {
148 return c.conn.SetWriteDeadline(t)
149 }
150
151
152
153 type halfConn struct {
154 sync.Mutex
155
156 err error
157 version uint16
158 cipher interface{}
159 mac macFunction
160 seq [8]byte
161 bfree *block
162 additionalData [13]byte
163
164 nextCipher interface{}
165 nextMac macFunction
166
167
168 inDigestBuf, outDigestBuf []byte
169 }
170
171 func (hc *halfConn) setErrorLocked(err error) error {
172 hc.err = err
173 return err
174 }
175
176
177
178 func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) {
179 hc.version = version
180 hc.nextCipher = cipher
181 hc.nextMac = mac
182 }
183
184
185
186 func (hc *halfConn) changeCipherSpec() error {
187 if hc.nextCipher == nil {
188 return alertInternalError
189 }
190 hc.cipher = hc.nextCipher
191 hc.mac = hc.nextMac
192 hc.nextCipher = nil
193 hc.nextMac = nil
194 for i := range hc.seq {
195 hc.seq[i] = 0
196 }
197 return nil
198 }
199
200
201 func (hc *halfConn) incSeq() {
202 for i := 7; i >= 0; i-- {
203 hc.seq[i]++
204 if hc.seq[i] != 0 {
205 return
206 }
207 }
208
209
210
211
212 panic("TLS: sequence number wraparound")
213 }
214
215
216
217
218 func extractPadding(payload []byte) (toRemove int, good byte) {
219 if len(payload) < 1 {
220 return 0, 0
221 }
222
223 paddingLen := payload[len(payload)-1]
224 t := uint(len(payload)-1) - uint(paddingLen)
225
226 good = byte(int32(^t) >> 31)
227
228
229 toCheck := 256
230
231 if toCheck > len(payload) {
232 toCheck = len(payload)
233 }
234
235 for i := 0; i < toCheck; i++ {
236 t := uint(paddingLen) - uint(i)
237
238 mask := byte(int32(^t) >> 31)
239 b := payload[len(payload)-1-i]
240 good &^= mask&paddingLen ^ mask&b
241 }
242
243
244
245 good &= good << 4
246 good &= good << 2
247 good &= good << 1
248 good = uint8(int8(good) >> 7)
249
250 toRemove = int(paddingLen) + 1
251 return
252 }
253
254
255
256
257 func extractPaddingSSL30(payload []byte) (toRemove int, good byte) {
258 if len(payload) < 1 {
259 return 0, 0
260 }
261
262 paddingLen := int(payload[len(payload)-1]) + 1
263 if paddingLen > len(payload) {
264 return 0, 0
265 }
266
267 return paddingLen, 255
268 }
269
270 func roundUp(a, b int) int {
271 return a + (b-a%b)%b
272 }
273
274
275 type cbcMode interface {
276 cipher.BlockMode
277 SetIV([]byte)
278 }
279
280
281
282
283 func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) {
284
285 payload := b.data[recordHeaderLen:]
286
287 macSize := 0
288 if hc.mac != nil {
289 macSize = hc.mac.Size()
290 }
291
292 paddingGood := byte(255)
293 paddingLen := 0
294 explicitIVLen := 0
295
296
297 if hc.cipher != nil {
298 switch c := hc.cipher.(type) {
299 case cipher.Stream:
300 c.XORKeyStream(payload, payload)
301 case aead:
302 explicitIVLen = c.explicitNonceLen()
303 if len(payload) < explicitIVLen {
304 return false, 0, alertBadRecordMAC
305 }
306 nonce := payload[:explicitIVLen]
307 payload = payload[explicitIVLen:]
308
309 if len(nonce) == 0 {
310 nonce = hc.seq[:]
311 }
312
313 copy(hc.additionalData[:], hc.seq[:])
314 copy(hc.additionalData[8:], b.data[:3])
315 n := len(payload) - c.Overhead()
316 hc.additionalData[11] = byte(n >> 8)
317 hc.additionalData[12] = byte(n)
318 var err error
319 payload, err = c.Open(payload[:0], nonce, payload, hc.additionalData[:])
320 if err != nil {
321 return false, 0, alertBadRecordMAC
322 }
323 b.resize(recordHeaderLen + explicitIVLen + len(payload))
324 case cbcMode:
325 blockSize := c.BlockSize()
326 if hc.version >= VersionTLS11 || hc.version == VersionGMSSL {
327 explicitIVLen = blockSize
328 }
329
330 if len(payload)%blockSize != 0 || len(payload) < roundUp(explicitIVLen+macSize+1, blockSize) {
331 return false, 0, alertBadRecordMAC
332 }
333
334 if explicitIVLen > 0 {
335 c.SetIV(payload[:explicitIVLen])
336 payload = payload[explicitIVLen:]
337 }
338 c.CryptBlocks(payload, payload)
339 if hc.version == VersionSSL30 {
340 paddingLen, paddingGood = extractPaddingSSL30(payload)
341 } else {
342 paddingLen, paddingGood = extractPadding(payload)
343
344
345
346
347
348
349
350 }
351 default:
352 panic("unknown cipher type")
353 }
354 }
355
356
357 if hc.mac != nil {
358 if len(payload) < macSize {
359 return false, 0, alertBadRecordMAC
360 }
361
362
363 n := len(payload) - macSize - paddingLen
364 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
365 b.data[3] = byte(n >> 8)
366 b.data[4] = byte(n)
367 remoteMAC := payload[n : n+macSize]
368 localMAC := hc.mac.MAC(hc.inDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], payload[:n], payload[n+macSize:])
369
370 if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 {
371 return false, 0, alertBadRecordMAC
372 }
373 hc.inDigestBuf = localMAC
374
375 b.resize(recordHeaderLen + explicitIVLen + n)
376 }
377 hc.incSeq()
378
379 return true, recordHeaderLen + explicitIVLen, 0
380 }
381
382
383
384
385
386
387 func padToBlockSize(payload []byte, blockSize int) (prefix, finalBlock []byte) {
388 overrun := len(payload) % blockSize
389 paddingLen := blockSize - overrun
390 prefix = payload[:len(payload)-overrun]
391 finalBlock = make([]byte, blockSize)
392 copy(finalBlock, payload[len(payload)-overrun:])
393 for i := overrun; i < blockSize; i++ {
394 finalBlock[i] = byte(paddingLen - 1)
395 }
396 return
397 }
398
399
400 func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) {
401
402 if hc.mac != nil {
403 mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:], nil)
404
405 n := len(b.data)
406 b.resize(n + len(mac))
407 copy(b.data[n:], mac)
408 hc.outDigestBuf = mac
409 }
410
411 payload := b.data[recordHeaderLen:]
412
413
414 if hc.cipher != nil {
415 switch c := hc.cipher.(type) {
416 case cipher.Stream:
417 c.XORKeyStream(payload, payload)
418 case aead:
419 payloadLen := len(b.data) - recordHeaderLen - explicitIVLen
420 b.resize(len(b.data) + c.Overhead())
421 nonce := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
422 if len(nonce) == 0 {
423 nonce = hc.seq[:]
424 }
425 payload := b.data[recordHeaderLen+explicitIVLen:]
426 payload = payload[:payloadLen]
427
428 copy(hc.additionalData[:], hc.seq[:])
429 copy(hc.additionalData[8:], b.data[:3])
430 hc.additionalData[11] = byte(payloadLen >> 8)
431 hc.additionalData[12] = byte(payloadLen)
432
433 c.Seal(payload[:0], nonce, payload, hc.additionalData[:])
434 case cbcMode:
435 blockSize := c.BlockSize()
436 if explicitIVLen > 0 {
437 c.SetIV(payload[:explicitIVLen])
438 payload = payload[explicitIVLen:]
439 }
440 prefix, finalBlock := padToBlockSize(payload, blockSize)
441 b.resize(recordHeaderLen + explicitIVLen + len(prefix) + len(finalBlock))
442 c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen:], prefix)
443 c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen+len(prefix):], finalBlock)
444 default:
445 panic("unknown cipher type")
446 }
447 }
448
449
450 n := len(b.data) - recordHeaderLen
451 b.data[3] = byte(n >> 8)
452 b.data[4] = byte(n)
453 hc.incSeq()
454
455 return true, 0
456 }
457
458
459 type block struct {
460 data []byte
461 off int
462 link *block
463 }
464
465
466 func (b *block) resize(n int) {
467 if n > cap(b.data) {
468 b.reserve(n)
469 }
470 b.data = b.data[0:n]
471 }
472
473
474 func (b *block) reserve(n int) {
475 if cap(b.data) >= n {
476 return
477 }
478 m := cap(b.data)
479 if m == 0 {
480 m = 1024
481 }
482 for m < n {
483 m *= 2
484 }
485 data := make([]byte, len(b.data), m)
486 copy(data, b.data)
487 b.data = data
488 }
489
490
491
492 func (b *block) readFromUntil(r io.Reader, n int) error {
493
494 if len(b.data) >= n {
495 return nil
496 }
497
498
499 b.reserve(n)
500 for {
501 m, err := r.Read(b.data[len(b.data):cap(b.data)])
502 b.data = b.data[0 : len(b.data)+m]
503 if len(b.data) >= n {
504
505
506 break
507 }
508 if err != nil {
509 return err
510 }
511 }
512 return nil
513 }
514
515 func (b *block) Read(p []byte) (n int, err error) {
516 n = copy(p, b.data[b.off:])
517 b.off += n
518 return
519 }
520
521
522 func (hc *halfConn) newBlock() *block {
523 b := hc.bfree
524 if b == nil {
525 return new(block)
526 }
527 hc.bfree = b.link
528 b.link = nil
529 b.resize(0)
530 return b
531 }
532
533
534
535
536
537 func (hc *halfConn) freeBlock(b *block) {
538 b.link = hc.bfree
539 hc.bfree = b
540 }
541
542
543
544
545 func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) {
546 if len(b.data) <= n {
547 return b, nil
548 }
549 bb := hc.newBlock()
550 bb.resize(len(b.data) - n)
551 copy(bb.data, b.data[n:])
552 b.data = b.data[0:n]
553 return b, bb
554 }
555
556
557 type RecordHeaderError struct {
558
559 Msg string
560
561
562 RecordHeader [5]byte
563 }
564
565 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
566
567 func (c *Conn) newRecordHeaderError(msg string) (err RecordHeaderError) {
568 err.Msg = msg
569 copy(err.RecordHeader[:], c.rawInput.data)
570 return err
571 }
572
573
574
575 func (c *Conn) readRecord(want recordType) error {
576
577
578
579 switch want {
580 case recordTypeHandshake, recordTypeChangeCipherSpec:
581 if c.handshakeComplete() {
582 c.sendAlert(alertInternalError)
583 return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested while not in handshake"))
584 }
585 case recordTypeApplicationData:
586 if !c.handshakeComplete() {
587 c.sendAlert(alertInternalError)
588 return c.in.setErrorLocked(errors.New("tls: application data record requested while in handshake"))
589 }
590 default:
591 c.sendAlert(alertInternalError)
592 return c.in.setErrorLocked(errors.New("tls: unknown record type requested"))
593 }
594
595 Again:
596 if c.rawInput == nil {
597 c.rawInput = c.in.newBlock()
598 }
599 b := c.rawInput
600
601
602 if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil {
603
604
605
606
607
608
609 if e, ok := err.(net.Error); !ok || !e.Temporary() {
610 c.in.setErrorLocked(err)
611 }
612 return err
613 }
614 typ := recordType(b.data[0])
615
616
617
618
619
620 if want == recordTypeHandshake && typ == 0x80 {
621 c.sendAlert(alertProtocolVersion)
622 return c.in.setErrorLocked(c.newRecordHeaderError("unsupported SSLv2 handshake received"))
623 }
624
625 vers := uint16(b.data[1])<<8 | uint16(b.data[2])
626 n := int(b.data[3])<<8 | int(b.data[4])
627 if c.haveVers && vers != c.vers {
628 c.sendAlert(alertProtocolVersion)
629 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, c.vers)
630 return c.in.setErrorLocked(c.newRecordHeaderError(msg))
631 }
632 if n > maxCiphertext {
633 c.sendAlert(alertRecordOverflow)
634 msg := fmt.Sprintf("oversized record received with length %d", n)
635 return c.in.setErrorLocked(c.newRecordHeaderError(msg))
636 }
637 if !c.haveVers {
638
639
640
641
642 if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 {
643 c.sendAlert(alertUnexpectedMessage)
644 return c.in.setErrorLocked(c.newRecordHeaderError("first record does not look like a TLS handshake"))
645 }
646 }
647 if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
648 if err == io.EOF {
649 err = io.ErrUnexpectedEOF
650 }
651 if e, ok := err.(net.Error); !ok || !e.Temporary() {
652 c.in.setErrorLocked(err)
653 }
654 return err
655 }
656
657
658 b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n)
659 ok, off, alertValue := c.in.decrypt(b)
660 if !ok {
661 c.in.freeBlock(b)
662 return c.in.setErrorLocked(c.sendAlert(alertValue))
663 }
664 b.off = off
665 data := b.data[b.off:]
666 if len(data) > maxPlaintext {
667 err := c.sendAlert(alertRecordOverflow)
668 c.in.freeBlock(b)
669 return c.in.setErrorLocked(err)
670 }
671
672 if typ != recordTypeAlert && len(data) > 0 {
673
674 c.warnCount = 0
675 }
676
677 switch typ {
678 default:
679 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
680
681 case recordTypeAlert:
682 if len(data) != 2 {
683 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
684 break
685 }
686 if alert(data[1]) == alertCloseNotify {
687 c.in.setErrorLocked(io.EOF)
688 break
689 }
690 switch data[0] {
691 case alertLevelWarning:
692
693 c.in.freeBlock(b)
694
695 c.warnCount++
696 if c.warnCount > maxWarnAlertCount {
697 c.sendAlert(alertUnexpectedMessage)
698 return c.in.setErrorLocked(errors.New("tls: too many warn alerts"))
699 }
700
701 goto Again
702 case alertLevelError:
703 c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
704 default:
705 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
706 }
707
708 case recordTypeChangeCipherSpec:
709 if typ != want || len(data) != 1 || data[0] != 1 {
710 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
711 break
712 }
713
714 if c.hand.Len() > 0 {
715 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
716 break
717 }
718 err := c.in.changeCipherSpec()
719 if err != nil {
720 c.in.setErrorLocked(c.sendAlert(err.(alert)))
721 }
722
723 case recordTypeApplicationData:
724 if typ != want {
725 c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
726 break
727 }
728 c.input = b
729 b = nil
730
731 case recordTypeHandshake:
732
733 if typ != want && !(c.isClient && c.config.Renegotiation != RenegotiateNever) {
734 return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation))
735 }
736 c.hand.Write(data)
737 }
738
739 if b != nil {
740 c.in.freeBlock(b)
741 }
742 return c.in.err
743 }
744
745
746 func (c *Conn) sendAlertLocked(err alert) error {
747 switch err {
748 case alertNoRenegotiation, alertCloseNotify:
749 c.tmp[0] = alertLevelWarning
750 default:
751 c.tmp[0] = alertLevelError
752 }
753 c.tmp[1] = byte(err)
754
755 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
756 if err == alertCloseNotify {
757
758 return writeErr
759 }
760
761 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
762 }
763
764
765 func (c *Conn) sendAlert(err alert) error {
766 c.out.Lock()
767 defer c.out.Unlock()
768 return c.sendAlertLocked(err)
769 }
770
771 const (
772
773
774
775
776
777 tcpMSSEstimate = 1208
778
779
780
781
782 recordSizeBoostThreshold = 128 * 1024
783 )
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801 func (c *Conn) maxPayloadSizeForWrite(typ recordType, explicitIVLen int) int {
802 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
803 return maxPlaintext
804 }
805
806 if c.bytesSent >= recordSizeBoostThreshold {
807 return maxPlaintext
808 }
809
810
811 macSize := 0
812 if c.out.mac != nil {
813 macSize = c.out.mac.Size()
814 }
815
816 payloadBytes := tcpMSSEstimate - recordHeaderLen - explicitIVLen
817 if c.out.cipher != nil {
818 switch ciph := c.out.cipher.(type) {
819 case cipher.Stream:
820 payloadBytes -= macSize
821 case cipher.AEAD:
822 payloadBytes -= ciph.Overhead()
823 case cbcMode:
824 blockSize := ciph.BlockSize()
825
826
827 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
828
829
830 payloadBytes -= macSize
831 default:
832 panic("unknown cipher type")
833 }
834 }
835
836
837 pkt := c.packetsSent
838 c.packetsSent++
839 if pkt > 1000 {
840 return maxPlaintext
841 }
842
843 n := payloadBytes * int(pkt+1)
844 if n > maxPlaintext {
845 n = maxPlaintext
846 }
847 return n
848 }
849
850 func (c *Conn) write(data []byte) (int, error) {
851 if c.buffering {
852 c.sendBuf = append(c.sendBuf, data...)
853 return len(data), nil
854 }
855
856 n, err := c.conn.Write(data)
857 c.bytesSent += int64(n)
858 return n, err
859 }
860
861 func (c *Conn) flush() (int, error) {
862 if len(c.sendBuf) == 0 {
863 return 0, nil
864 }
865
866 n, err := c.conn.Write(c.sendBuf)
867 c.bytesSent += int64(n)
868 c.sendBuf = nil
869 c.buffering = false
870 return n, err
871 }
872
873
874
875 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
876 b := c.out.newBlock()
877 defer c.out.freeBlock(b)
878
879 var n int
880 for len(data) > 0 {
881 explicitIVLen := 0
882 explicitIVIsSeq := false
883
884 var cbc cbcMode
885 if c.out.version >= VersionTLS11 || c.out.version == VersionGMSSL {
886 var ok bool
887 if cbc, ok = c.out.cipher.(cbcMode); ok {
888 explicitIVLen = cbc.BlockSize()
889 }
890 }
891 if explicitIVLen == 0 {
892 if c, ok := c.out.cipher.(aead); ok {
893 explicitIVLen = c.explicitNonceLen()
894
895
896
897
898
899
900
901 explicitIVIsSeq = explicitIVLen > 0
902 }
903 }
904 m := len(data)
905 if maxPayload := c.maxPayloadSizeForWrite(typ, explicitIVLen); m > maxPayload {
906 m = maxPayload
907 }
908 b.resize(recordHeaderLen + explicitIVLen + m)
909 b.data[0] = byte(typ)
910 vers := c.vers
911 if vers == 0 {
912
913
914 vers = VersionTLS10
915 }
916 b.data[1] = byte(vers >> 8)
917 b.data[2] = byte(vers)
918 b.data[3] = byte(m >> 8)
919 b.data[4] = byte(m)
920 if explicitIVLen > 0 {
921 explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen]
922 if explicitIVIsSeq {
923 copy(explicitIV, c.out.seq[:])
924 } else {
925 if _, err := io.ReadFull(c.config.rand(), explicitIV); err != nil {
926 return n, err
927 }
928 }
929 }
930 copy(b.data[recordHeaderLen+explicitIVLen:], data)
931 c.out.encrypt(b, explicitIVLen)
932 if _, err := c.write(b.data); err != nil {
933 return n, err
934 }
935 n += m
936 data = data[m:]
937 }
938
939 if typ == recordTypeChangeCipherSpec {
940 if err := c.out.changeCipherSpec(); err != nil {
941 return n, c.sendAlertLocked(err.(alert))
942 }
943 }
944
945 return n, nil
946 }
947
948
949
950 func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
951 c.out.Lock()
952 defer c.out.Unlock()
953
954 return c.writeRecordLocked(typ, data)
955 }
956
957
958
959 func (c *Conn) readHandshake() (interface{}, error) {
960 for c.hand.Len() < 4 {
961 if err := c.in.err; err != nil {
962 return nil, err
963 }
964 if err := c.readRecord(recordTypeHandshake); err != nil {
965 return nil, err
966 }
967 }
968
969 data := c.hand.Bytes()
970 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
971 if n > maxHandshake {
972 c.sendAlertLocked(alertInternalError)
973 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshake))
974 }
975 for c.hand.Len() < 4+n {
976 if err := c.in.err; err != nil {
977 return nil, err
978 }
979 if err := c.readRecord(recordTypeHandshake); err != nil {
980 return nil, err
981 }
982 }
983 data = c.hand.Next(4 + n)
984 var m handshakeMessage
985 switch data[0] {
986 case typeHelloRequest:
987 m = new(helloRequestMsg)
988 case typeClientHello:
989 m = new(clientHelloMsg)
990 case typeServerHello:
991 m = new(serverHelloMsg)
992 case typeNewSessionTicket:
993 m = new(newSessionTicketMsg)
994 case typeCertificate:
995 m = new(certificateMsg)
996 case typeCertificateRequest:
997 if c.config.GMSupport != nil {
998 m = &certificateRequestMsgGM{}
999 } else {
1000 m = &certificateRequestMsg{
1001 hasSignatureAndHash: c.vers >= VersionTLS12,
1002 }
1003 }
1004 case typeCertificateStatus:
1005 m = new(certificateStatusMsg)
1006 case typeServerKeyExchange:
1007 m = new(serverKeyExchangeMsg)
1008 case typeServerHelloDone:
1009 m = new(serverHelloDoneMsg)
1010 case typeClientKeyExchange:
1011 m = new(clientKeyExchangeMsg)
1012 case typeCertificateVerify:
1013 m = &certificateVerifyMsg{
1014 hasSignatureAndHash: c.vers >= VersionTLS12,
1015 }
1016 case typeNextProtocol:
1017 m = new(nextProtoMsg)
1018 case typeFinished:
1019 m = new(finishedMsg)
1020 default:
1021 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1022 }
1023
1024
1025
1026
1027 data = append([]byte(nil), data...)
1028
1029 if !m.unmarshal(data) {
1030 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1031 }
1032 return m, nil
1033 }
1034
1035 var (
1036 errClosed = errors.New("tls: use of closed connection")
1037 errShutdown = errors.New("tls: protocol is shutdown")
1038 )
1039
1040
1041 func (c *Conn) Write(b []byte) (int, error) {
1042
1043 for {
1044 x := atomic.LoadInt32(&c.activeCall)
1045 if x&1 != 0 {
1046 return 0, errClosed
1047 }
1048 if atomic.CompareAndSwapInt32(&c.activeCall, x, x+2) {
1049 defer atomic.AddInt32(&c.activeCall, -2)
1050 break
1051 }
1052 }
1053
1054 if err := c.Handshake(); err != nil {
1055 return 0, err
1056 }
1057
1058 c.out.Lock()
1059 defer c.out.Unlock()
1060
1061 if err := c.out.err; err != nil {
1062 return 0, err
1063 }
1064
1065 if !c.handshakeComplete() {
1066 return 0, alertInternalError
1067 }
1068
1069 if c.closeNotifySent {
1070 return 0, errShutdown
1071 }
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082 var m int
1083 if len(b) > 1 && c.vers <= VersionTLS10 {
1084 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1085 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1086 if err != nil {
1087 return n, c.out.setErrorLocked(err)
1088 }
1089 m, b = 1, b[1:]
1090 }
1091 }
1092
1093 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1094 return n + m, c.out.setErrorLocked(err)
1095 }
1096
1097
1098 func (c *Conn) handleRenegotiation() error {
1099 msg, err := c.readHandshake()
1100 if err != nil {
1101 return err
1102 }
1103
1104 _, ok := msg.(*helloRequestMsg)
1105 if !ok {
1106 c.sendAlert(alertUnexpectedMessage)
1107 return alertUnexpectedMessage
1108 }
1109
1110 if !c.isClient {
1111 return c.sendAlert(alertNoRenegotiation)
1112 }
1113
1114 switch c.config.Renegotiation {
1115 case RenegotiateNever:
1116 return c.sendAlert(alertNoRenegotiation)
1117 case RenegotiateOnceAsClient:
1118 if c.handshakes > 1 {
1119 return c.sendAlert(alertNoRenegotiation)
1120 }
1121 case RenegotiateFreelyAsClient:
1122
1123 default:
1124 c.sendAlert(alertInternalError)
1125 return errors.New("tls: unknown Renegotiation value")
1126 }
1127
1128 c.handshakeMutex.Lock()
1129 defer c.handshakeMutex.Unlock()
1130
1131 atomic.StoreUint32(&c.handshakeStatus, 0)
1132 if c.handshakeErr = c.clientHandshake(); c.handshakeErr == nil {
1133 c.handshakes++
1134 }
1135 return c.handshakeErr
1136 }
1137
1138
1139
1140 func (c *Conn) Read(b []byte) (n int, err error) {
1141 if err = c.Handshake(); err != nil {
1142 return
1143 }
1144 if len(b) == 0 {
1145
1146
1147 return
1148 }
1149
1150 c.in.Lock()
1151 defer c.in.Unlock()
1152
1153
1154
1155 const maxConsecutiveEmptyRecords = 100
1156 for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ {
1157 for c.input == nil && c.in.err == nil {
1158 if err := c.readRecord(recordTypeApplicationData); err != nil {
1159
1160 return 0, err
1161 }
1162 if c.hand.Len() > 0 {
1163
1164
1165 if err := c.handleRenegotiation(); err != nil {
1166 return 0, err
1167 }
1168 }
1169 }
1170 if err := c.in.err; err != nil {
1171 return 0, err
1172 }
1173
1174 n, err = c.input.Read(b)
1175 if c.input.off >= len(c.input.data) {
1176 c.in.freeBlock(c.input)
1177 c.input = nil
1178 }
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191 if ri := c.rawInput; ri != nil &&
1192 n != 0 && err == nil &&
1193 c.input == nil && len(ri.data) > 0 && recordType(ri.data[0]) == recordTypeAlert {
1194 if recErr := c.readRecord(recordTypeApplicationData); recErr != nil {
1195 err = recErr
1196 }
1197 }
1198
1199 if n != 0 || err != nil {
1200 return n, err
1201 }
1202 }
1203
1204 return 0, io.ErrNoProgress
1205 }
1206
1207
1208 func (c *Conn) Close() error {
1209
1210 var x int32
1211 for {
1212 x = atomic.LoadInt32(&c.activeCall)
1213 if x&1 != 0 {
1214 return errClosed
1215 }
1216 if atomic.CompareAndSwapInt32(&c.activeCall, x, x|1) {
1217 break
1218 }
1219 }
1220 if x != 0 {
1221
1222
1223
1224
1225
1226
1227 return c.conn.Close()
1228 }
1229
1230 var alertErr error
1231
1232 if c.handshakeComplete() {
1233 alertErr = c.closeNotify()
1234 }
1235
1236 if err := c.conn.Close(); err != nil {
1237 return err
1238 }
1239 return alertErr
1240 }
1241
1242 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1243
1244
1245
1246
1247 func (c *Conn) CloseWrite() error {
1248 if !c.handshakeComplete() {
1249 return errEarlyCloseWrite
1250 }
1251
1252 return c.closeNotify()
1253 }
1254
1255 func (c *Conn) closeNotify() error {
1256 c.out.Lock()
1257 defer c.out.Unlock()
1258
1259 if !c.closeNotifySent {
1260 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1261 c.closeNotifySent = true
1262 }
1263 return c.closeNotifyErr
1264 }
1265
1266
1267
1268
1269
1270 func (c *Conn) Handshake() error {
1271 c.handshakeMutex.Lock()
1272 defer c.handshakeMutex.Unlock()
1273
1274 if err := c.handshakeErr; err != nil {
1275 return err
1276 }
1277 if c.handshakeComplete() {
1278 return nil
1279 }
1280
1281 c.in.Lock()
1282 defer c.in.Unlock()
1283
1284 if c.isClient {
1285 c.handshakeErr = c.clientHandshake()
1286 } else {
1287 if c.config.GMSupport == nil {
1288
1289 c.handshakeErr = c.serverHandshake()
1290 } else if c.config.GMSupport.IsAutoSwitchMode() {
1291
1292 c.handshakeErr = c.serverHandshakeAutoSwitch()
1293 } else {
1294
1295 c.handshakeErr = c.serverHandshakeGM()
1296 }
1297 }
1298 if c.handshakeErr == nil {
1299 c.handshakes++
1300 } else {
1301
1302
1303 c.flush()
1304 fmt.Println("handshake error :", c.handshakeErr)
1305 }
1306
1307 if c.handshakeErr == nil && !c.handshakeComplete() {
1308 panic("handshake should have had a result.")
1309 }
1310
1311 return c.handshakeErr
1312 }
1313
1314
1315 func (c *Conn) ConnectionState() ConnectionState {
1316 c.handshakeMutex.Lock()
1317 defer c.handshakeMutex.Unlock()
1318
1319 var state ConnectionState
1320 state.HandshakeComplete = c.handshakeComplete()
1321 state.ServerName = c.serverName
1322
1323 if state.HandshakeComplete {
1324 state.Version = c.vers
1325 state.NegotiatedProtocol = c.clientProtocol
1326 state.DidResume = c.didResume
1327 state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback
1328 state.CipherSuite = c.cipherSuite
1329 state.PeerCertificates = c.peerCertificates
1330 state.VerifiedChains = c.verifiedChains
1331 state.SignedCertificateTimestamps = c.scts
1332 state.OCSPResponse = c.ocspResponse
1333 if !c.didResume {
1334 if c.clientFinishedIsFirst {
1335 state.TLSUnique = c.clientFinished[:]
1336 } else {
1337 state.TLSUnique = c.serverFinished[:]
1338 }
1339 }
1340 if c.config.Renegotiation != RenegotiateNever {
1341 state.ekm = noExportedKeyingMaterial
1342 } else {
1343 state.ekm = c.ekm
1344 }
1345 }
1346
1347 return state
1348 }
1349
1350
1351
1352 func (c *Conn) OCSPResponse() []byte {
1353 c.handshakeMutex.Lock()
1354 defer c.handshakeMutex.Unlock()
1355
1356 return c.ocspResponse
1357 }
1358
1359
1360
1361
1362 func (c *Conn) VerifyHostname(host string) error {
1363 c.handshakeMutex.Lock()
1364 defer c.handshakeMutex.Unlock()
1365 if !c.isClient {
1366 return errors.New("tls: VerifyHostname called on TLS server connection")
1367 }
1368 if !c.handshakeComplete() {
1369 return errors.New("tls: handshake has not yet been performed")
1370 }
1371 if len(c.verifiedChains) == 0 {
1372 return errors.New("tls: handshake did not verify certificate chain")
1373 }
1374 return c.peerCertificates[0].VerifyHostname(host)
1375 }
1376
1377 func (c *Conn) handshakeComplete() bool {
1378 return atomic.LoadUint32(&c.handshakeStatus) == 1
1379 }
1380
View as plain text