...

Source file src/github.com/google/s2a-go/internal/record/record.go

Documentation: github.com/google/s2a-go/internal/record

     1  /*
     2   *
     3   * Copyright 2021 Google LLC
     4   *
     5   * Licensed under the Apache License, Version 2.0 (the "License");
     6   * you may not use this file except in compliance with the License.
     7   * You may obtain a copy of the License at
     8   *
     9   *     https://www.apache.org/licenses/LICENSE-2.0
    10   *
    11   * Unless required by applicable law or agreed to in writing, software
    12   * distributed under the License is distributed on an "AS IS" BASIS,
    13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    14   * See the License for the specific language governing permissions and
    15   * limitations under the License.
    16   *
    17   */
    18  
    19  // Package record implements the TLS 1.3 record protocol used by the S2A
    20  // transport credentials.
    21  package record
    22  
    23  import (
    24  	"encoding/binary"
    25  	"errors"
    26  	"fmt"
    27  	"math"
    28  	"net"
    29  	"sync"
    30  
    31  	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
    32  	"github.com/google/s2a-go/internal/record/internal/halfconn"
    33  	"github.com/google/s2a-go/internal/tokenmanager"
    34  	"google.golang.org/grpc/grpclog"
    35  )
    36  
    37  // recordType is the `ContentType` as described in
    38  // https://tools.ietf.org/html/rfc8446#section-5.1.
    39  type recordType byte
    40  
    41  const (
    42  	alert           recordType = 21
    43  	handshake       recordType = 22
    44  	applicationData recordType = 23
    45  )
    46  
    47  // keyUpdateRequest is the `KeyUpdateRequest` as described in
    48  // https://tools.ietf.org/html/rfc8446#section-4.6.3.
    49  type keyUpdateRequest byte
    50  
    51  const (
    52  	updateNotRequested keyUpdateRequest = 0
    53  	updateRequested    keyUpdateRequest = 1
    54  )
    55  
    56  // alertDescription is the `AlertDescription` as described in
    57  // https://tools.ietf.org/html/rfc8446#section-6.
    58  type alertDescription byte
    59  
    60  const (
    61  	closeNotify alertDescription = 0
    62  )
    63  
    64  // sessionTicketState is used to determine whether session tickets have not yet
    65  // been received, are in the process of being received, or have finished
    66  // receiving.
    67  type sessionTicketState byte
    68  
    69  const (
    70  	ticketsNotYetReceived sessionTicketState = 0
    71  	receivingTickets      sessionTicketState = 1
    72  	notReceivingTickets   sessionTicketState = 2
    73  )
    74  
    75  const (
    76  	// The TLS 1.3-specific constants below (tlsRecordMaxPlaintextSize,
    77  	// tlsRecordHeaderSize, tlsRecordTypeSize) were taken from
    78  	// https://tools.ietf.org/html/rfc8446#section-5.1.
    79  
    80  	// tlsRecordMaxPlaintextSize is the maximum size in bytes of the plaintext
    81  	// in a single TLS 1.3 record.
    82  	tlsRecordMaxPlaintextSize = 16384 // 2^14
    83  	// tlsRecordTypeSize is the size in bytes of the TLS 1.3 record type.
    84  	tlsRecordTypeSize = 1
    85  	// tlsTagSize is the size in bytes of the tag of the following three
    86  	// ciphersuites: AES-128-GCM-SHA256, AES-256-GCM-SHA384,
    87  	// CHACHA20-POLY1305-SHA256.
    88  	tlsTagSize = 16
    89  	// tlsRecordMaxPayloadSize is the maximum size in bytes of the payload in a
    90  	// single TLS 1.3 record. This is the maximum size of the plaintext plus the
    91  	// record type byte and 16 bytes of the tag.
    92  	tlsRecordMaxPayloadSize = tlsRecordMaxPlaintextSize + tlsRecordTypeSize + tlsTagSize
    93  	// tlsRecordHeaderTypeSize is the size in bytes of the TLS 1.3 record
    94  	// header type.
    95  	tlsRecordHeaderTypeSize = 1
    96  	// tlsRecordHeaderLegacyRecordVersionSize is the size in bytes of the TLS
    97  	// 1.3 record header legacy record version.
    98  	tlsRecordHeaderLegacyRecordVersionSize = 2
    99  	// tlsRecordHeaderPayloadLengthSize is the size in bytes of the TLS 1.3
   100  	// record header payload length.
   101  	tlsRecordHeaderPayloadLengthSize = 2
   102  	// tlsRecordHeaderSize is the size in bytes of the TLS 1.3 record header.
   103  	tlsRecordHeaderSize = tlsRecordHeaderTypeSize + tlsRecordHeaderLegacyRecordVersionSize + tlsRecordHeaderPayloadLengthSize
   104  	// tlsRecordMaxSize
   105  	tlsRecordMaxSize = tlsRecordMaxPayloadSize + tlsRecordHeaderSize
   106  	// tlsApplicationData is the application data type of the TLS 1.3 record
   107  	// header.
   108  	tlsApplicationData = 23
   109  	// tlsLegacyRecordVersion is the legacy record version of the TLS record.
   110  	tlsLegacyRecordVersion = 3
   111  	// tlsAlertSize is the size in bytes of an alert of TLS 1.3.
   112  	tlsAlertSize = 2
   113  )
   114  
   115  const (
   116  	// These are TLS 1.3 handshake-specific constants.
   117  
   118  	// tlsHandshakeNewSessionTicketType is the prefix of a handshake new session
   119  	// ticket message of TLS 1.3.
   120  	tlsHandshakeNewSessionTicketType = 4
   121  	// tlsHandshakeKeyUpdateType is the prefix of a handshake key update message
   122  	// of TLS 1.3.
   123  	tlsHandshakeKeyUpdateType = 24
   124  	// tlsHandshakeMsgTypeSize is the size in bytes of the TLS 1.3 handshake
   125  	// message type field.
   126  	tlsHandshakeMsgTypeSize = 1
   127  	// tlsHandshakeLengthSize is the size in bytes of the TLS 1.3 handshake
   128  	// message length field.
   129  	tlsHandshakeLengthSize = 3
   130  	// tlsHandshakeKeyUpdateMsgSize is the size in bytes of the TLS 1.3
   131  	// handshake key update message.
   132  	tlsHandshakeKeyUpdateMsgSize = 1
   133  	// tlsHandshakePrefixSize is the size in bytes of the prefix of the TLS 1.3
   134  	// handshake message.
   135  	tlsHandshakePrefixSize = 4
   136  	// tlsMaxSessionTicketSize is the maximum size of a NewSessionTicket message
   137  	// in TLS 1.3. This is the sum of the max sizes of all the fields in the
   138  	// NewSessionTicket struct specified in
   139  	// https://tools.ietf.org/html/rfc8446#section-4.6.1.
   140  	tlsMaxSessionTicketSize = 131338
   141  )
   142  
   143  const (
   144  	// outBufMaxRecords is the maximum number of records that can fit in the
   145  	// ourRecordsBuf buffer.
   146  	outBufMaxRecords = 16
   147  	// outBufMaxSize is the maximum size (in bytes) of the outRecordsBuf buffer.
   148  	outBufMaxSize = outBufMaxRecords * tlsRecordMaxSize
   149  	// maxAllowedTickets is the maximum number of session tickets that are
   150  	// allowed. The number of tickets are limited to ensure that the size of the
   151  	// ticket queue does not grow indefinitely. S2A also keeps a limit on the
   152  	// number of tickets that it caches.
   153  	maxAllowedTickets = 5
   154  )
   155  
   156  // preConstructedKeyUpdateMsg holds the key update message. This is needed as an
   157  // optimization so that the same message does not need to be constructed every
   158  // time a key update message is sent.
   159  var preConstructedKeyUpdateMsg = buildKeyUpdateRequest()
   160  
   161  // conn represents a secured TLS connection. It implements the net.Conn
   162  // interface.
   163  type conn struct {
   164  	net.Conn
   165  	// inConn is the half connection responsible for decrypting incoming bytes.
   166  	inConn *halfconn.S2AHalfConnection
   167  	// outConn is the half connection responsible for encrypting outgoing bytes.
   168  	outConn *halfconn.S2AHalfConnection
   169  	// pendingApplicationData holds data that has been read from the connection
   170  	// and decrypted, but has not yet been returned by Read.
   171  	pendingApplicationData []byte
   172  	// unusedBuf holds data read from the network that has not yet been
   173  	// decrypted. This data might not consist of a complete record. It may
   174  	// consist of several records, the last of which could be incomplete.
   175  	unusedBuf []byte
   176  	// outRecordsBuf is a buffer used to store outgoing TLS records before
   177  	// they are written to the network.
   178  	outRecordsBuf []byte
   179  	// nextRecord stores the next record info in the unusedBuf buffer.
   180  	nextRecord []byte
   181  	// overheadSize is the overhead size in bytes of each TLS 1.3 record, which
   182  	// is computed as overheadSize = header size + record type byte + tag size.
   183  	// Note that there is no padding by zeros in the overhead calculation.
   184  	overheadSize int
   185  	// readMutex guards against concurrent calls to Read. This is required since
   186  	// Close may be called during a Read.
   187  	readMutex sync.Mutex
   188  	// writeMutex guards against concurrent calls to Write. This is required
   189  	// since Close may be called during a Write, and also because a key update
   190  	// message may be written during a Read.
   191  	writeMutex sync.Mutex
   192  	// handshakeBuf holds handshake messages while they are being processed.
   193  	handshakeBuf []byte
   194  	// ticketState is the current processing state of the session tickets.
   195  	ticketState sessionTicketState
   196  	// sessionTickets holds the completed session tickets until they are sent to
   197  	// the handshaker service for processing.
   198  	sessionTickets [][]byte
   199  	// ticketSender sends session tickets to the S2A handshaker service.
   200  	ticketSender s2aTicketSender
   201  	// callComplete is a channel that blocks closing the record protocol until a
   202  	// pending call to the S2A completes.
   203  	callComplete chan bool
   204  }
   205  
   206  // ConnParameters holds the parameters used for creating a new conn object.
   207  type ConnParameters struct {
   208  	// NetConn is the TCP connection to the peer. This parameter is required.
   209  	NetConn net.Conn
   210  	// Ciphersuite is the TLS ciphersuite negotiated by the S2A handshaker
   211  	// service. This parameter is required.
   212  	Ciphersuite commonpb.Ciphersuite
   213  	// TLSVersion is the TLS version number negotiated by the S2A handshaker
   214  	// service. This parameter is required.
   215  	TLSVersion commonpb.TLSVersion
   216  	// InTrafficSecret is the traffic secret used to derive the session key for
   217  	// the inbound direction. This parameter is required.
   218  	InTrafficSecret []byte
   219  	// OutTrafficSecret is the traffic secret used to derive the session key
   220  	// for the outbound direction. This parameter is required.
   221  	OutTrafficSecret []byte
   222  	// UnusedBuf is the data read from the network that has not yet been
   223  	// decrypted. This parameter is optional. If not provided, then no
   224  	// application data was sent in the same flight of messages as the final
   225  	// handshake message.
   226  	UnusedBuf []byte
   227  	// InSequence is the sequence number of the next, incoming, TLS record.
   228  	// This parameter is required.
   229  	InSequence uint64
   230  	// OutSequence is the sequence number of the next, outgoing, TLS record.
   231  	// This parameter is required.
   232  	OutSequence uint64
   233  	// HSAddr stores the address of the S2A handshaker service. This parameter
   234  	// is optional. If not provided, then TLS resumption is disabled.
   235  	HSAddr string
   236  	// ConnectionId is the connection identifier that was created and sent by
   237  	// S2A at the end of a handshake.
   238  	ConnectionID uint64
   239  	// LocalIdentity is the local identity that was used by S2A during session
   240  	// setup and included in the session result.
   241  	LocalIdentity *commonpb.Identity
   242  	// EnsureProcessSessionTickets allows users to wait and ensure that all
   243  	// available session tickets are sent to S2A before a process completes.
   244  	EnsureProcessSessionTickets *sync.WaitGroup
   245  }
   246  
   247  // NewConn creates a TLS record protocol that wraps the TCP connection.
   248  func NewConn(o *ConnParameters) (net.Conn, error) {
   249  	if o == nil {
   250  		return nil, errors.New("conn options must not be nil")
   251  	}
   252  	if o.TLSVersion != commonpb.TLSVersion_TLS1_3 {
   253  		return nil, errors.New("TLS version must be TLS 1.3")
   254  	}
   255  
   256  	inConn, err := halfconn.New(o.Ciphersuite, o.InTrafficSecret, o.InSequence)
   257  	if err != nil {
   258  		return nil, fmt.Errorf("failed to create inbound half connection: %v", err)
   259  	}
   260  	outConn, err := halfconn.New(o.Ciphersuite, o.OutTrafficSecret, o.OutSequence)
   261  	if err != nil {
   262  		return nil, fmt.Errorf("failed to create outbound half connection: %v", err)
   263  	}
   264  
   265  	// The tag size for the in/out connections should be the same.
   266  	overheadSize := tlsRecordHeaderSize + tlsRecordTypeSize + inConn.TagSize()
   267  	var unusedBuf []byte
   268  	if o.UnusedBuf == nil {
   269  		// We pre-allocate unusedBuf to be of size
   270  		// 2*tlsRecordMaxSize-1 during initialization. We only read from the
   271  		// network into unusedBuf when unusedBuf does not contain a complete
   272  		// record and the incomplete record is at most tlsRecordMaxSize-1
   273  		// (bytes). And we read at most tlsRecordMaxSize bytes of data from the
   274  		// network into unusedBuf at one time. Therefore, 2*tlsRecordMaxSize-1
   275  		// is large enough to buffer data read from the network.
   276  		unusedBuf = make([]byte, 0, 2*tlsRecordMaxSize-1)
   277  	} else {
   278  		unusedBuf = make([]byte, len(o.UnusedBuf))
   279  		copy(unusedBuf, o.UnusedBuf)
   280  	}
   281  
   282  	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
   283  	if err != nil {
   284  		grpclog.Infof("failed to create single token access token manager: %v", err)
   285  	}
   286  
   287  	s2aConn := &conn{
   288  		Conn:          o.NetConn,
   289  		inConn:        inConn,
   290  		outConn:       outConn,
   291  		unusedBuf:     unusedBuf,
   292  		outRecordsBuf: make([]byte, tlsRecordMaxSize),
   293  		nextRecord:    unusedBuf,
   294  		overheadSize:  overheadSize,
   295  		ticketState:   ticketsNotYetReceived,
   296  		// Pre-allocate the buffer for one session ticket message and the max
   297  		// plaintext size. This is the largest size that handshakeBuf will need
   298  		// to hold. The largest incomplete handshake message is the
   299  		// [handshake header size] + [max session ticket size] - 1.
   300  		// Then, tlsRecordMaxPlaintextSize is the maximum size that will be
   301  		// appended to the handshakeBuf before the handshake message is
   302  		// completed. Therefore, the buffer size below should be large enough to
   303  		// buffer any handshake messages.
   304  		handshakeBuf: make([]byte, 0, tlsHandshakePrefixSize+tlsMaxSessionTicketSize+tlsRecordMaxPlaintextSize-1),
   305  		ticketSender: &ticketSender{
   306  			hsAddr:                      o.HSAddr,
   307  			connectionID:                o.ConnectionID,
   308  			localIdentity:               o.LocalIdentity,
   309  			tokenManager:                tokenManager,
   310  			ensureProcessSessionTickets: o.EnsureProcessSessionTickets,
   311  		},
   312  		callComplete: make(chan bool),
   313  	}
   314  	return s2aConn, nil
   315  }
   316  
   317  // Read reads and decrypts a TLS 1.3 record from the underlying connection, and
   318  // copies any application data received from the peer into b. If the size of the
   319  // payload is greater than len(b), Read retains the remaining bytes in an
   320  // internal buffer, and subsequent calls to Read will read from this buffer
   321  // until it is exhausted. At most 1 TLS record worth of application data is
   322  // written to b for each call to Read.
   323  //
   324  // Note that for the user to efficiently call this method, the user should
   325  // ensure that the buffer b is allocated such that the buffer does not have any
   326  // unused segments. This can be done by calling Read via io.ReadFull, which
   327  // continually calls Read until the specified buffer has been filled. Also note
   328  // that the user should close the connection via Close() if an error is thrown
   329  // by a call to Read.
   330  func (p *conn) Read(b []byte) (n int, err error) {
   331  	p.readMutex.Lock()
   332  	defer p.readMutex.Unlock()
   333  	// Check if p.pendingApplication data has leftover application data from
   334  	// the previous call to Read.
   335  	if len(p.pendingApplicationData) == 0 {
   336  		// Read a full record from the wire.
   337  		record, err := p.readFullRecord()
   338  		if err != nil {
   339  			return 0, err
   340  		}
   341  		// Now we have a complete record, so split the header and validate it
   342  		// The TLS record is split into 2 pieces: the record header and the
   343  		// payload. The payload has the following form:
   344  		// [payload] = [ciphertext of application data]
   345  		//           + [ciphertext of record type byte]
   346  		//           + [(optionally) ciphertext of padding by zeros]
   347  		//           + [tag]
   348  		header, payload, err := splitAndValidateHeader(record)
   349  		if err != nil {
   350  			return 0, err
   351  		}
   352  		// Decrypt the ciphertext.
   353  		p.pendingApplicationData, err = p.inConn.Decrypt(payload[:0], payload, header)
   354  		if err != nil {
   355  			return 0, err
   356  		}
   357  		// Remove the padding by zeros and the record type byte from the
   358  		// p.pendingApplicationData buffer.
   359  		msgType, err := p.stripPaddingAndType()
   360  		if err != nil {
   361  			return 0, err
   362  		}
   363  		// Check that the length of the plaintext after stripping the padding
   364  		// and record type byte is under the maximum plaintext size.
   365  		if len(p.pendingApplicationData) > tlsRecordMaxPlaintextSize {
   366  			return 0, errors.New("plaintext size larger than maximum")
   367  		}
   368  		// The expected message types are application data, alert, and
   369  		// handshake. For application data, the bytes are directly copied into
   370  		// b. For an alert, the type of the alert is checked and the connection
   371  		// is closed on a close notify alert. For a handshake message, the
   372  		// handshake message type is checked. The handshake message type can be
   373  		// a key update type, for which we advance the traffic secret, and a
   374  		// new session ticket type, for which we send the received ticket to S2A
   375  		// for processing.
   376  		switch msgType {
   377  		case applicationData:
   378  			if len(p.handshakeBuf) > 0 {
   379  				return 0, errors.New("application data received while processing fragmented handshake messages")
   380  			}
   381  			if p.ticketState == receivingTickets {
   382  				p.ticketState = notReceivingTickets
   383  				grpclog.Infof("Sending session tickets to S2A.")
   384  				p.ticketSender.sendTicketsToS2A(p.sessionTickets, p.callComplete)
   385  			}
   386  		case alert:
   387  			return 0, p.handleAlertMessage()
   388  		case handshake:
   389  			if err = p.handleHandshakeMessage(); err != nil {
   390  				return 0, err
   391  			}
   392  			return 0, nil
   393  		default:
   394  			return 0, errors.New("unknown record type")
   395  		}
   396  	}
   397  	// Write as much application data as possible to b, the output buffer.
   398  	n = copy(b, p.pendingApplicationData)
   399  	p.pendingApplicationData = p.pendingApplicationData[n:]
   400  	return n, nil
   401  }
   402  
   403  // Write divides b into segments of size tlsRecordMaxPlaintextSize, builds a
   404  // TLS 1.3 record (of type "application data") from each segment, and sends
   405  // the record to the peer. It returns the number of plaintext bytes that were
   406  // successfully sent to the peer.
   407  func (p *conn) Write(b []byte) (n int, err error) {
   408  	p.writeMutex.Lock()
   409  	defer p.writeMutex.Unlock()
   410  	return p.writeTLSRecord(b, tlsApplicationData)
   411  }
   412  
   413  // writeTLSRecord divides b into segments of size maxPlaintextBytesPerRecord,
   414  // builds a TLS 1.3 record (of type recordType) from each segment, and sends
   415  // the record to the peer. It returns the number of plaintext bytes that were
   416  // successfully sent to the peer.
   417  func (p *conn) writeTLSRecord(b []byte, recordType byte) (n int, err error) {
   418  	// Create a record of only header, record type, and tag if given empty
   419  	// byte array.
   420  	if len(b) == 0 {
   421  		recordEndIndex, _, err := p.buildRecord(b, recordType, 0)
   422  		if err != nil {
   423  			return 0, err
   424  		}
   425  
   426  		// Write the bytes stored in outRecordsBuf to p.Conn. Since we return
   427  		// the number of plaintext bytes written without overhead, we will
   428  		// always return 0 while p.Conn.Write returns the entire record length.
   429  		_, err = p.Conn.Write(p.outRecordsBuf[:recordEndIndex])
   430  		return 0, err
   431  	}
   432  
   433  	numRecords := int(math.Ceil(float64(len(b)) / float64(tlsRecordMaxPlaintextSize)))
   434  	totalRecordsSize := len(b) + numRecords*p.overheadSize
   435  	partialBSize := len(b)
   436  	if totalRecordsSize > outBufMaxSize {
   437  		totalRecordsSize = outBufMaxSize
   438  		partialBSize = outBufMaxRecords * tlsRecordMaxPlaintextSize
   439  	}
   440  	if len(p.outRecordsBuf) < totalRecordsSize {
   441  		p.outRecordsBuf = make([]byte, totalRecordsSize)
   442  	}
   443  	for bStart := 0; bStart < len(b); bStart += partialBSize {
   444  		bEnd := bStart + partialBSize
   445  		if bEnd > len(b) {
   446  			bEnd = len(b)
   447  		}
   448  		partialB := b[bStart:bEnd]
   449  		recordEndIndex := 0
   450  		for len(partialB) > 0 {
   451  			recordEndIndex, partialB, err = p.buildRecord(partialB, recordType, recordEndIndex)
   452  			if err != nil {
   453  				// Return the amount of bytes written prior to the error.
   454  				return bStart, err
   455  			}
   456  		}
   457  		// Write the bytes stored in outRecordsBuf to p.Conn. If there is an
   458  		// error, calculate the total number of plaintext bytes of complete
   459  		// records successfully written to the peer and return it.
   460  		nn, err := p.Conn.Write(p.outRecordsBuf[:recordEndIndex])
   461  		if err != nil {
   462  			numberOfCompletedRecords := int(math.Floor(float64(nn) / float64(tlsRecordMaxSize)))
   463  			return bStart + numberOfCompletedRecords*tlsRecordMaxPlaintextSize, err
   464  		}
   465  	}
   466  	return len(b), nil
   467  }
   468  
   469  // buildRecord builds a TLS 1.3 record of type recordType from plaintext,
   470  // and writes the record to outRecordsBuf at recordStartIndex. The record will
   471  // have at most tlsRecordMaxPlaintextSize bytes of payload. It returns the
   472  // index of outRecordsBuf where the current record ends, as well as any
   473  // remaining plaintext bytes.
   474  func (p *conn) buildRecord(plaintext []byte, recordType byte, recordStartIndex int) (n int, remainingPlaintext []byte, err error) {
   475  	// Construct the payload, which consists of application data and record type.
   476  	dataLen := len(plaintext)
   477  	if dataLen > tlsRecordMaxPlaintextSize {
   478  		dataLen = tlsRecordMaxPlaintextSize
   479  	}
   480  	remainingPlaintext = plaintext[dataLen:]
   481  	newRecordBuf := p.outRecordsBuf[recordStartIndex:]
   482  
   483  	copy(newRecordBuf[tlsRecordHeaderSize:], plaintext[:dataLen])
   484  	newRecordBuf[tlsRecordHeaderSize+dataLen] = recordType
   485  	payload := newRecordBuf[tlsRecordHeaderSize : tlsRecordHeaderSize+dataLen+1] // 1 is for the recordType.
   486  	// Construct the header.
   487  	newRecordBuf[0] = tlsApplicationData
   488  	newRecordBuf[1] = tlsLegacyRecordVersion
   489  	newRecordBuf[2] = tlsLegacyRecordVersion
   490  	binary.BigEndian.PutUint16(newRecordBuf[3:], uint16(len(payload)+tlsTagSize))
   491  	header := newRecordBuf[:tlsRecordHeaderSize]
   492  
   493  	// Encrypt the payload using header as aad.
   494  	encryptedPayload, err := p.outConn.Encrypt(newRecordBuf[tlsRecordHeaderSize:][:0], payload, header)
   495  	if err != nil {
   496  		return 0, plaintext, err
   497  	}
   498  	recordStartIndex += len(header) + len(encryptedPayload)
   499  	return recordStartIndex, remainingPlaintext, nil
   500  }
   501  
   502  func (p *conn) Close() error {
   503  	p.readMutex.Lock()
   504  	defer p.readMutex.Unlock()
   505  	p.writeMutex.Lock()
   506  	defer p.writeMutex.Unlock()
   507  	// If p.ticketState is equal to notReceivingTickets, then S2A has
   508  	// been sent a flight of session tickets, and we must wait for the
   509  	// call to S2A to complete before closing the record protocol.
   510  	if p.ticketState == notReceivingTickets {
   511  		<-p.callComplete
   512  		grpclog.Infof("Safe to close the connection because sending tickets to S2A is (already) complete.")
   513  	}
   514  	return p.Conn.Close()
   515  }
   516  
   517  // stripPaddingAndType strips the padding by zeros and record type from
   518  // p.pendingApplicationData and returns the record type. Note that
   519  // p.pendingApplicationData should be of the form:
   520  // [application data] + [record type byte] + [trailing zeros]
   521  func (p *conn) stripPaddingAndType() (recordType, error) {
   522  	if len(p.pendingApplicationData) == 0 {
   523  		return 0, errors.New("application data had length 0")
   524  	}
   525  	i := len(p.pendingApplicationData) - 1
   526  	// Search for the index of the record type byte.
   527  	for i > 0 {
   528  		if p.pendingApplicationData[i] != 0 {
   529  			break
   530  		}
   531  		i--
   532  	}
   533  	rt := recordType(p.pendingApplicationData[i])
   534  	p.pendingApplicationData = p.pendingApplicationData[:i]
   535  	return rt, nil
   536  }
   537  
   538  // readFullRecord reads from the wire until a record is completed and returns
   539  // the full record.
   540  func (p *conn) readFullRecord() (fullRecord []byte, err error) {
   541  	fullRecord, p.nextRecord, err = parseReadBuffer(p.nextRecord, tlsRecordMaxPayloadSize)
   542  	if err != nil {
   543  		return nil, err
   544  	}
   545  	// Check whether the next record to be decrypted has been completely
   546  	// received.
   547  	if len(fullRecord) == 0 {
   548  		copy(p.unusedBuf, p.nextRecord)
   549  		p.unusedBuf = p.unusedBuf[:len(p.nextRecord)]
   550  		// Always copy next incomplete record to the beginning of the
   551  		// unusedBuf buffer and reset nextRecord to it.
   552  		p.nextRecord = p.unusedBuf
   553  	}
   554  	// Keep reading from the wire until we have a complete record.
   555  	for len(fullRecord) == 0 {
   556  		if len(p.unusedBuf) == cap(p.unusedBuf) {
   557  			tmp := make([]byte, len(p.unusedBuf), cap(p.unusedBuf)+tlsRecordMaxSize)
   558  			copy(tmp, p.unusedBuf)
   559  			p.unusedBuf = tmp
   560  		}
   561  		n, err := p.Conn.Read(p.unusedBuf[len(p.unusedBuf):min(cap(p.unusedBuf), len(p.unusedBuf)+tlsRecordMaxSize)])
   562  		if err != nil {
   563  			return nil, err
   564  		}
   565  		p.unusedBuf = p.unusedBuf[:len(p.unusedBuf)+n]
   566  		fullRecord, p.nextRecord, err = parseReadBuffer(p.unusedBuf, tlsRecordMaxPayloadSize)
   567  		if err != nil {
   568  			return nil, err
   569  		}
   570  	}
   571  	return fullRecord, nil
   572  }
   573  
   574  // parseReadBuffer parses the provided buffer and returns a full record and any
   575  // remaining bytes in that buffer. If the record is incomplete, nil is returned
   576  // for the first return value and the given byte buffer is returned for the
   577  // second return value. The length of the payload specified by the header should
   578  // not be greater than maxLen, otherwise an error is returned. Note that this
   579  // function does not allocate or copy any buffers.
   580  func parseReadBuffer(b []byte, maxLen uint16) (fullRecord, remaining []byte, err error) {
   581  	// If the header is not complete, return the provided buffer as remaining
   582  	// buffer.
   583  	if len(b) < tlsRecordHeaderSize {
   584  		return nil, b, nil
   585  	}
   586  	msgLenField := b[tlsRecordHeaderTypeSize+tlsRecordHeaderLegacyRecordVersionSize : tlsRecordHeaderSize]
   587  	length := binary.BigEndian.Uint16(msgLenField)
   588  	if length > maxLen {
   589  		return nil, nil, fmt.Errorf("record length larger than the limit %d", maxLen)
   590  	}
   591  	if len(b) < int(length)+tlsRecordHeaderSize {
   592  		// Record is not complete yet.
   593  		return nil, b, nil
   594  	}
   595  	return b[:tlsRecordHeaderSize+length], b[tlsRecordHeaderSize+length:], nil
   596  }
   597  
   598  // splitAndValidateHeader splits the header from the payload in the TLS 1.3
   599  // record and returns them. Note that the header is checked for validity, and an
   600  // error is returned when an invalid header is parsed. Also note that this
   601  // function does not allocate or copy any buffers.
   602  func splitAndValidateHeader(record []byte) (header, payload []byte, err error) {
   603  	if len(record) < tlsRecordHeaderSize {
   604  		return nil, nil, fmt.Errorf("record was smaller than the header size")
   605  	}
   606  	header = record[:tlsRecordHeaderSize]
   607  	payload = record[tlsRecordHeaderSize:]
   608  	if header[0] != tlsApplicationData {
   609  		return nil, nil, fmt.Errorf("incorrect type in the header")
   610  	}
   611  	// Check the legacy record version, which should be 0x03, 0x03.
   612  	if header[1] != 0x03 || header[2] != 0x03 {
   613  		return nil, nil, fmt.Errorf("incorrect legacy record version in the header")
   614  	}
   615  	return header, payload, nil
   616  }
   617  
   618  // handleAlertMessage handles an alert message.
   619  func (p *conn) handleAlertMessage() error {
   620  	if len(p.pendingApplicationData) != tlsAlertSize {
   621  		return errors.New("invalid alert message size")
   622  	}
   623  	alertType := p.pendingApplicationData[1]
   624  	// Clear the body of the alert message.
   625  	p.pendingApplicationData = p.pendingApplicationData[:0]
   626  	if alertType == byte(closeNotify) {
   627  		return errors.New("received a close notify alert")
   628  	}
   629  	// TODO(matthewstevenson88): Add support for more alert types.
   630  	return fmt.Errorf("received an unrecognized alert type: %v", alertType)
   631  }
   632  
   633  // parseHandshakeHeader parses a handshake message from the handshake buffer.
   634  // It returns the message type, the message length, the message, the raw message
   635  // that includes the type and length bytes and a flag indicating whether the
   636  // handshake message has been fully parsed. i.e. whether the entire handshake
   637  // message was in the handshake buffer.
   638  func (p *conn) parseHandshakeMsg() (msgType byte, msgLen uint32, msg []byte, rawMsg []byte, ok bool) {
   639  	// Handle the case where the 4 byte handshake header is fragmented.
   640  	if len(p.handshakeBuf) < tlsHandshakePrefixSize {
   641  		return 0, 0, nil, nil, false
   642  	}
   643  	msgType = p.handshakeBuf[0]
   644  	msgLen = bigEndianInt24(p.handshakeBuf[tlsHandshakeMsgTypeSize : tlsHandshakeMsgTypeSize+tlsHandshakeLengthSize])
   645  	if msgLen > uint32(len(p.handshakeBuf)-tlsHandshakePrefixSize) {
   646  		return 0, 0, nil, nil, false
   647  	}
   648  	msg = p.handshakeBuf[tlsHandshakePrefixSize : tlsHandshakePrefixSize+msgLen]
   649  	rawMsg = p.handshakeBuf[:tlsHandshakeMsgTypeSize+tlsHandshakeLengthSize+msgLen]
   650  	p.handshakeBuf = p.handshakeBuf[tlsHandshakePrefixSize+msgLen:]
   651  	return msgType, msgLen, msg, rawMsg, true
   652  }
   653  
   654  // handleHandshakeMessage handles a handshake message. Note that the first
   655  // complete handshake message from the handshake buffer is removed, if it
   656  // exists.
   657  func (p *conn) handleHandshakeMessage() error {
   658  	// Copy the pending application data to the handshake buffer. At this point,
   659  	// we are guaranteed that the pending application data contains only parts
   660  	// of a handshake message.
   661  	p.handshakeBuf = append(p.handshakeBuf, p.pendingApplicationData...)
   662  	p.pendingApplicationData = p.pendingApplicationData[:0]
   663  	// Several handshake messages may be coalesced into a single record.
   664  	// Continue reading them until the handshake buffer is empty.
   665  	for len(p.handshakeBuf) > 0 {
   666  		handshakeMsgType, msgLen, msg, rawMsg, ok := p.parseHandshakeMsg()
   667  		if !ok {
   668  			// The handshake could not be fully parsed, so read in another
   669  			// record and try again later.
   670  			break
   671  		}
   672  		switch handshakeMsgType {
   673  		case tlsHandshakeKeyUpdateType:
   674  			if msgLen != tlsHandshakeKeyUpdateMsgSize {
   675  				return errors.New("invalid handshake key update message length")
   676  			}
   677  			if len(p.handshakeBuf) != 0 {
   678  				return errors.New("key update message must be the last message of a handshake record")
   679  			}
   680  			if err := p.handleKeyUpdateMsg(msg); err != nil {
   681  				return err
   682  			}
   683  		case tlsHandshakeNewSessionTicketType:
   684  			// Ignore tickets that are received after a batch of tickets has
   685  			// been sent to S2A.
   686  			if p.ticketState == notReceivingTickets {
   687  				continue
   688  			}
   689  			if p.ticketState == ticketsNotYetReceived {
   690  				p.ticketState = receivingTickets
   691  			}
   692  			p.sessionTickets = append(p.sessionTickets, rawMsg)
   693  			if len(p.sessionTickets) == maxAllowedTickets {
   694  				p.ticketState = notReceivingTickets
   695  				grpclog.Infof("Sending session tickets to S2A.")
   696  				p.ticketSender.sendTicketsToS2A(p.sessionTickets, p.callComplete)
   697  			}
   698  		default:
   699  			return errors.New("unknown handshake message type")
   700  		}
   701  	}
   702  	return nil
   703  }
   704  
   705  func buildKeyUpdateRequest() []byte {
   706  	b := make([]byte, tlsHandshakePrefixSize+tlsHandshakeKeyUpdateMsgSize)
   707  	b[0] = tlsHandshakeKeyUpdateType
   708  	b[1] = 0
   709  	b[2] = 0
   710  	b[3] = tlsHandshakeKeyUpdateMsgSize
   711  	b[4] = byte(updateNotRequested)
   712  	return b
   713  }
   714  
   715  // handleKeyUpdateMsg handles a key update message.
   716  func (p *conn) handleKeyUpdateMsg(msg []byte) error {
   717  	keyUpdateRequest := msg[0]
   718  	if keyUpdateRequest != byte(updateNotRequested) &&
   719  		keyUpdateRequest != byte(updateRequested) {
   720  		return errors.New("invalid handshake key update message")
   721  	}
   722  	if err := p.inConn.UpdateKey(); err != nil {
   723  		return err
   724  	}
   725  	// Send a key update message back to the peer if requested.
   726  	if keyUpdateRequest == byte(updateRequested) {
   727  		p.writeMutex.Lock()
   728  		defer p.writeMutex.Unlock()
   729  		n, err := p.writeTLSRecord(preConstructedKeyUpdateMsg, byte(handshake))
   730  		if err != nil {
   731  			return err
   732  		}
   733  		if n != tlsHandshakePrefixSize+tlsHandshakeKeyUpdateMsgSize {
   734  			return errors.New("key update request message wrote less bytes than expected")
   735  		}
   736  		if err = p.outConn.UpdateKey(); err != nil {
   737  			return err
   738  		}
   739  	}
   740  	return nil
   741  }
   742  
   743  // bidEndianInt24 converts the given byte buffer of at least size 3 and
   744  // outputs the resulting 24 bit integer as a uint32. This is needed because
   745  // TLS 1.3 requires 3 byte integers, and the binary.BigEndian package does
   746  // not provide a way to transform a byte buffer into a 3 byte integer.
   747  func bigEndianInt24(b []byte) uint32 {
   748  	_ = b[2] // bounds check hint to compiler; see golang.org/issue/14808
   749  	return uint32(b[2]) | uint32(b[1])<<8 | uint32(b[0])<<16
   750  }
   751  
   752  func min(a, b int) int {
   753  	if a < b {
   754  		return a
   755  	}
   756  	return b
   757  }
   758  

View as plain text