1
18
19
20
21 package record
22
23 import (
24 "encoding/binary"
25 "errors"
26 "fmt"
27 "math"
28 "net"
29 "sync"
30
31 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
32 "github.com/google/s2a-go/internal/record/internal/halfconn"
33 "github.com/google/s2a-go/internal/tokenmanager"
34 "google.golang.org/grpc/grpclog"
35 )
36
37
38
39 type recordType byte
40
41 const (
42 alert recordType = 21
43 handshake recordType = 22
44 applicationData recordType = 23
45 )
46
47
48
49 type keyUpdateRequest byte
50
51 const (
52 updateNotRequested keyUpdateRequest = 0
53 updateRequested keyUpdateRequest = 1
54 )
55
56
57
58 type alertDescription byte
59
60 const (
61 closeNotify alertDescription = 0
62 )
63
64
65
66
67 type sessionTicketState byte
68
69 const (
70 ticketsNotYetReceived sessionTicketState = 0
71 receivingTickets sessionTicketState = 1
72 notReceivingTickets sessionTicketState = 2
73 )
74
75 const (
76
77
78
79
80
81
82 tlsRecordMaxPlaintextSize = 16384
83
84 tlsRecordTypeSize = 1
85
86
87
88 tlsTagSize = 16
89
90
91
92 tlsRecordMaxPayloadSize = tlsRecordMaxPlaintextSize + tlsRecordTypeSize + tlsTagSize
93
94
95 tlsRecordHeaderTypeSize = 1
96
97
98 tlsRecordHeaderLegacyRecordVersionSize = 2
99
100
101 tlsRecordHeaderPayloadLengthSize = 2
102
103 tlsRecordHeaderSize = tlsRecordHeaderTypeSize + tlsRecordHeaderLegacyRecordVersionSize + tlsRecordHeaderPayloadLengthSize
104
105 tlsRecordMaxSize = tlsRecordMaxPayloadSize + tlsRecordHeaderSize
106
107
108 tlsApplicationData = 23
109
110 tlsLegacyRecordVersion = 3
111
112 tlsAlertSize = 2
113 )
114
115 const (
116
117
118
119
120 tlsHandshakeNewSessionTicketType = 4
121
122
123 tlsHandshakeKeyUpdateType = 24
124
125
126 tlsHandshakeMsgTypeSize = 1
127
128
129 tlsHandshakeLengthSize = 3
130
131
132 tlsHandshakeKeyUpdateMsgSize = 1
133
134
135 tlsHandshakePrefixSize = 4
136
137
138
139
140 tlsMaxSessionTicketSize = 131338
141 )
142
143 const (
144
145
146 outBufMaxRecords = 16
147
148 outBufMaxSize = outBufMaxRecords * tlsRecordMaxSize
149
150
151
152
153 maxAllowedTickets = 5
154 )
155
156
157
158
159 var preConstructedKeyUpdateMsg = buildKeyUpdateRequest()
160
161
162
163 type conn struct {
164 net.Conn
165
166 inConn *halfconn.S2AHalfConnection
167
168 outConn *halfconn.S2AHalfConnection
169
170
171 pendingApplicationData []byte
172
173
174
175 unusedBuf []byte
176
177
178 outRecordsBuf []byte
179
180 nextRecord []byte
181
182
183
184 overheadSize int
185
186
187 readMutex sync.Mutex
188
189
190
191 writeMutex sync.Mutex
192
193 handshakeBuf []byte
194
195 ticketState sessionTicketState
196
197
198 sessionTickets [][]byte
199
200 ticketSender s2aTicketSender
201
202
203 callComplete chan bool
204 }
205
206
207 type ConnParameters struct {
208
209 NetConn net.Conn
210
211
212 Ciphersuite commonpb.Ciphersuite
213
214
215 TLSVersion commonpb.TLSVersion
216
217
218 InTrafficSecret []byte
219
220
221 OutTrafficSecret []byte
222
223
224
225
226 UnusedBuf []byte
227
228
229 InSequence uint64
230
231
232 OutSequence uint64
233
234
235 HSAddr string
236
237
238 ConnectionID uint64
239
240
241 LocalIdentity *commonpb.Identity
242
243
244 EnsureProcessSessionTickets *sync.WaitGroup
245 }
246
247
248 func NewConn(o *ConnParameters) (net.Conn, error) {
249 if o == nil {
250 return nil, errors.New("conn options must not be nil")
251 }
252 if o.TLSVersion != commonpb.TLSVersion_TLS1_3 {
253 return nil, errors.New("TLS version must be TLS 1.3")
254 }
255
256 inConn, err := halfconn.New(o.Ciphersuite, o.InTrafficSecret, o.InSequence)
257 if err != nil {
258 return nil, fmt.Errorf("failed to create inbound half connection: %v", err)
259 }
260 outConn, err := halfconn.New(o.Ciphersuite, o.OutTrafficSecret, o.OutSequence)
261 if err != nil {
262 return nil, fmt.Errorf("failed to create outbound half connection: %v", err)
263 }
264
265
266 overheadSize := tlsRecordHeaderSize + tlsRecordTypeSize + inConn.TagSize()
267 var unusedBuf []byte
268 if o.UnusedBuf == nil {
269
270
271
272
273
274
275
276 unusedBuf = make([]byte, 0, 2*tlsRecordMaxSize-1)
277 } else {
278 unusedBuf = make([]byte, len(o.UnusedBuf))
279 copy(unusedBuf, o.UnusedBuf)
280 }
281
282 tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
283 if err != nil {
284 grpclog.Infof("failed to create single token access token manager: %v", err)
285 }
286
287 s2aConn := &conn{
288 Conn: o.NetConn,
289 inConn: inConn,
290 outConn: outConn,
291 unusedBuf: unusedBuf,
292 outRecordsBuf: make([]byte, tlsRecordMaxSize),
293 nextRecord: unusedBuf,
294 overheadSize: overheadSize,
295 ticketState: ticketsNotYetReceived,
296
297
298
299
300
301
302
303
304 handshakeBuf: make([]byte, 0, tlsHandshakePrefixSize+tlsMaxSessionTicketSize+tlsRecordMaxPlaintextSize-1),
305 ticketSender: &ticketSender{
306 hsAddr: o.HSAddr,
307 connectionID: o.ConnectionID,
308 localIdentity: o.LocalIdentity,
309 tokenManager: tokenManager,
310 ensureProcessSessionTickets: o.EnsureProcessSessionTickets,
311 },
312 callComplete: make(chan bool),
313 }
314 return s2aConn, nil
315 }
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330 func (p *conn) Read(b []byte) (n int, err error) {
331 p.readMutex.Lock()
332 defer p.readMutex.Unlock()
333
334
335 if len(p.pendingApplicationData) == 0 {
336
337 record, err := p.readFullRecord()
338 if err != nil {
339 return 0, err
340 }
341
342
343
344
345
346
347
348 header, payload, err := splitAndValidateHeader(record)
349 if err != nil {
350 return 0, err
351 }
352
353 p.pendingApplicationData, err = p.inConn.Decrypt(payload[:0], payload, header)
354 if err != nil {
355 return 0, err
356 }
357
358
359 msgType, err := p.stripPaddingAndType()
360 if err != nil {
361 return 0, err
362 }
363
364
365 if len(p.pendingApplicationData) > tlsRecordMaxPlaintextSize {
366 return 0, errors.New("plaintext size larger than maximum")
367 }
368
369
370
371
372
373
374
375
376 switch msgType {
377 case applicationData:
378 if len(p.handshakeBuf) > 0 {
379 return 0, errors.New("application data received while processing fragmented handshake messages")
380 }
381 if p.ticketState == receivingTickets {
382 p.ticketState = notReceivingTickets
383 grpclog.Infof("Sending session tickets to S2A.")
384 p.ticketSender.sendTicketsToS2A(p.sessionTickets, p.callComplete)
385 }
386 case alert:
387 return 0, p.handleAlertMessage()
388 case handshake:
389 if err = p.handleHandshakeMessage(); err != nil {
390 return 0, err
391 }
392 return 0, nil
393 default:
394 return 0, errors.New("unknown record type")
395 }
396 }
397
398 n = copy(b, p.pendingApplicationData)
399 p.pendingApplicationData = p.pendingApplicationData[n:]
400 return n, nil
401 }
402
403
404
405
406
407 func (p *conn) Write(b []byte) (n int, err error) {
408 p.writeMutex.Lock()
409 defer p.writeMutex.Unlock()
410 return p.writeTLSRecord(b, tlsApplicationData)
411 }
412
413
414
415
416
417 func (p *conn) writeTLSRecord(b []byte, recordType byte) (n int, err error) {
418
419
420 if len(b) == 0 {
421 recordEndIndex, _, err := p.buildRecord(b, recordType, 0)
422 if err != nil {
423 return 0, err
424 }
425
426
427
428
429 _, err = p.Conn.Write(p.outRecordsBuf[:recordEndIndex])
430 return 0, err
431 }
432
433 numRecords := int(math.Ceil(float64(len(b)) / float64(tlsRecordMaxPlaintextSize)))
434 totalRecordsSize := len(b) + numRecords*p.overheadSize
435 partialBSize := len(b)
436 if totalRecordsSize > outBufMaxSize {
437 totalRecordsSize = outBufMaxSize
438 partialBSize = outBufMaxRecords * tlsRecordMaxPlaintextSize
439 }
440 if len(p.outRecordsBuf) < totalRecordsSize {
441 p.outRecordsBuf = make([]byte, totalRecordsSize)
442 }
443 for bStart := 0; bStart < len(b); bStart += partialBSize {
444 bEnd := bStart + partialBSize
445 if bEnd > len(b) {
446 bEnd = len(b)
447 }
448 partialB := b[bStart:bEnd]
449 recordEndIndex := 0
450 for len(partialB) > 0 {
451 recordEndIndex, partialB, err = p.buildRecord(partialB, recordType, recordEndIndex)
452 if err != nil {
453
454 return bStart, err
455 }
456 }
457
458
459
460 nn, err := p.Conn.Write(p.outRecordsBuf[:recordEndIndex])
461 if err != nil {
462 numberOfCompletedRecords := int(math.Floor(float64(nn) / float64(tlsRecordMaxSize)))
463 return bStart + numberOfCompletedRecords*tlsRecordMaxPlaintextSize, err
464 }
465 }
466 return len(b), nil
467 }
468
469
470
471
472
473
474 func (p *conn) buildRecord(plaintext []byte, recordType byte, recordStartIndex int) (n int, remainingPlaintext []byte, err error) {
475
476 dataLen := len(plaintext)
477 if dataLen > tlsRecordMaxPlaintextSize {
478 dataLen = tlsRecordMaxPlaintextSize
479 }
480 remainingPlaintext = plaintext[dataLen:]
481 newRecordBuf := p.outRecordsBuf[recordStartIndex:]
482
483 copy(newRecordBuf[tlsRecordHeaderSize:], plaintext[:dataLen])
484 newRecordBuf[tlsRecordHeaderSize+dataLen] = recordType
485 payload := newRecordBuf[tlsRecordHeaderSize : tlsRecordHeaderSize+dataLen+1]
486
487 newRecordBuf[0] = tlsApplicationData
488 newRecordBuf[1] = tlsLegacyRecordVersion
489 newRecordBuf[2] = tlsLegacyRecordVersion
490 binary.BigEndian.PutUint16(newRecordBuf[3:], uint16(len(payload)+tlsTagSize))
491 header := newRecordBuf[:tlsRecordHeaderSize]
492
493
494 encryptedPayload, err := p.outConn.Encrypt(newRecordBuf[tlsRecordHeaderSize:][:0], payload, header)
495 if err != nil {
496 return 0, plaintext, err
497 }
498 recordStartIndex += len(header) + len(encryptedPayload)
499 return recordStartIndex, remainingPlaintext, nil
500 }
501
502 func (p *conn) Close() error {
503 p.readMutex.Lock()
504 defer p.readMutex.Unlock()
505 p.writeMutex.Lock()
506 defer p.writeMutex.Unlock()
507
508
509
510 if p.ticketState == notReceivingTickets {
511 <-p.callComplete
512 grpclog.Infof("Safe to close the connection because sending tickets to S2A is (already) complete.")
513 }
514 return p.Conn.Close()
515 }
516
517
518
519
520
521 func (p *conn) stripPaddingAndType() (recordType, error) {
522 if len(p.pendingApplicationData) == 0 {
523 return 0, errors.New("application data had length 0")
524 }
525 i := len(p.pendingApplicationData) - 1
526
527 for i > 0 {
528 if p.pendingApplicationData[i] != 0 {
529 break
530 }
531 i--
532 }
533 rt := recordType(p.pendingApplicationData[i])
534 p.pendingApplicationData = p.pendingApplicationData[:i]
535 return rt, nil
536 }
537
538
539
540 func (p *conn) readFullRecord() (fullRecord []byte, err error) {
541 fullRecord, p.nextRecord, err = parseReadBuffer(p.nextRecord, tlsRecordMaxPayloadSize)
542 if err != nil {
543 return nil, err
544 }
545
546
547 if len(fullRecord) == 0 {
548 copy(p.unusedBuf, p.nextRecord)
549 p.unusedBuf = p.unusedBuf[:len(p.nextRecord)]
550
551
552 p.nextRecord = p.unusedBuf
553 }
554
555 for len(fullRecord) == 0 {
556 if len(p.unusedBuf) == cap(p.unusedBuf) {
557 tmp := make([]byte, len(p.unusedBuf), cap(p.unusedBuf)+tlsRecordMaxSize)
558 copy(tmp, p.unusedBuf)
559 p.unusedBuf = tmp
560 }
561 n, err := p.Conn.Read(p.unusedBuf[len(p.unusedBuf):min(cap(p.unusedBuf), len(p.unusedBuf)+tlsRecordMaxSize)])
562 if err != nil {
563 return nil, err
564 }
565 p.unusedBuf = p.unusedBuf[:len(p.unusedBuf)+n]
566 fullRecord, p.nextRecord, err = parseReadBuffer(p.unusedBuf, tlsRecordMaxPayloadSize)
567 if err != nil {
568 return nil, err
569 }
570 }
571 return fullRecord, nil
572 }
573
574
575
576
577
578
579
580 func parseReadBuffer(b []byte, maxLen uint16) (fullRecord, remaining []byte, err error) {
581
582
583 if len(b) < tlsRecordHeaderSize {
584 return nil, b, nil
585 }
586 msgLenField := b[tlsRecordHeaderTypeSize+tlsRecordHeaderLegacyRecordVersionSize : tlsRecordHeaderSize]
587 length := binary.BigEndian.Uint16(msgLenField)
588 if length > maxLen {
589 return nil, nil, fmt.Errorf("record length larger than the limit %d", maxLen)
590 }
591 if len(b) < int(length)+tlsRecordHeaderSize {
592
593 return nil, b, nil
594 }
595 return b[:tlsRecordHeaderSize+length], b[tlsRecordHeaderSize+length:], nil
596 }
597
598
599
600
601
602 func splitAndValidateHeader(record []byte) (header, payload []byte, err error) {
603 if len(record) < tlsRecordHeaderSize {
604 return nil, nil, fmt.Errorf("record was smaller than the header size")
605 }
606 header = record[:tlsRecordHeaderSize]
607 payload = record[tlsRecordHeaderSize:]
608 if header[0] != tlsApplicationData {
609 return nil, nil, fmt.Errorf("incorrect type in the header")
610 }
611
612 if header[1] != 0x03 || header[2] != 0x03 {
613 return nil, nil, fmt.Errorf("incorrect legacy record version in the header")
614 }
615 return header, payload, nil
616 }
617
618
619 func (p *conn) handleAlertMessage() error {
620 if len(p.pendingApplicationData) != tlsAlertSize {
621 return errors.New("invalid alert message size")
622 }
623 alertType := p.pendingApplicationData[1]
624
625 p.pendingApplicationData = p.pendingApplicationData[:0]
626 if alertType == byte(closeNotify) {
627 return errors.New("received a close notify alert")
628 }
629
630 return fmt.Errorf("received an unrecognized alert type: %v", alertType)
631 }
632
633
634
635
636
637
638 func (p *conn) parseHandshakeMsg() (msgType byte, msgLen uint32, msg []byte, rawMsg []byte, ok bool) {
639
640 if len(p.handshakeBuf) < tlsHandshakePrefixSize {
641 return 0, 0, nil, nil, false
642 }
643 msgType = p.handshakeBuf[0]
644 msgLen = bigEndianInt24(p.handshakeBuf[tlsHandshakeMsgTypeSize : tlsHandshakeMsgTypeSize+tlsHandshakeLengthSize])
645 if msgLen > uint32(len(p.handshakeBuf)-tlsHandshakePrefixSize) {
646 return 0, 0, nil, nil, false
647 }
648 msg = p.handshakeBuf[tlsHandshakePrefixSize : tlsHandshakePrefixSize+msgLen]
649 rawMsg = p.handshakeBuf[:tlsHandshakeMsgTypeSize+tlsHandshakeLengthSize+msgLen]
650 p.handshakeBuf = p.handshakeBuf[tlsHandshakePrefixSize+msgLen:]
651 return msgType, msgLen, msg, rawMsg, true
652 }
653
654
655
656
657 func (p *conn) handleHandshakeMessage() error {
658
659
660
661 p.handshakeBuf = append(p.handshakeBuf, p.pendingApplicationData...)
662 p.pendingApplicationData = p.pendingApplicationData[:0]
663
664
665 for len(p.handshakeBuf) > 0 {
666 handshakeMsgType, msgLen, msg, rawMsg, ok := p.parseHandshakeMsg()
667 if !ok {
668
669
670 break
671 }
672 switch handshakeMsgType {
673 case tlsHandshakeKeyUpdateType:
674 if msgLen != tlsHandshakeKeyUpdateMsgSize {
675 return errors.New("invalid handshake key update message length")
676 }
677 if len(p.handshakeBuf) != 0 {
678 return errors.New("key update message must be the last message of a handshake record")
679 }
680 if err := p.handleKeyUpdateMsg(msg); err != nil {
681 return err
682 }
683 case tlsHandshakeNewSessionTicketType:
684
685
686 if p.ticketState == notReceivingTickets {
687 continue
688 }
689 if p.ticketState == ticketsNotYetReceived {
690 p.ticketState = receivingTickets
691 }
692 p.sessionTickets = append(p.sessionTickets, rawMsg)
693 if len(p.sessionTickets) == maxAllowedTickets {
694 p.ticketState = notReceivingTickets
695 grpclog.Infof("Sending session tickets to S2A.")
696 p.ticketSender.sendTicketsToS2A(p.sessionTickets, p.callComplete)
697 }
698 default:
699 return errors.New("unknown handshake message type")
700 }
701 }
702 return nil
703 }
704
705 func buildKeyUpdateRequest() []byte {
706 b := make([]byte, tlsHandshakePrefixSize+tlsHandshakeKeyUpdateMsgSize)
707 b[0] = tlsHandshakeKeyUpdateType
708 b[1] = 0
709 b[2] = 0
710 b[3] = tlsHandshakeKeyUpdateMsgSize
711 b[4] = byte(updateNotRequested)
712 return b
713 }
714
715
716 func (p *conn) handleKeyUpdateMsg(msg []byte) error {
717 keyUpdateRequest := msg[0]
718 if keyUpdateRequest != byte(updateNotRequested) &&
719 keyUpdateRequest != byte(updateRequested) {
720 return errors.New("invalid handshake key update message")
721 }
722 if err := p.inConn.UpdateKey(); err != nil {
723 return err
724 }
725
726 if keyUpdateRequest == byte(updateRequested) {
727 p.writeMutex.Lock()
728 defer p.writeMutex.Unlock()
729 n, err := p.writeTLSRecord(preConstructedKeyUpdateMsg, byte(handshake))
730 if err != nil {
731 return err
732 }
733 if n != tlsHandshakePrefixSize+tlsHandshakeKeyUpdateMsgSize {
734 return errors.New("key update request message wrote less bytes than expected")
735 }
736 if err = p.outConn.UpdateKey(); err != nil {
737 return err
738 }
739 }
740 return nil
741 }
742
743
744
745
746
747 func bigEndianInt24(b []byte) uint32 {
748 _ = b[2]
749 return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16
750 }
751
752 func min(a, b int) int {
753 if a < b {
754 return a
755 }
756 return b
757 }
758
View as plain text