...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/topology

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package topology
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"net"
    16  	"strings"
    17  	"sync"
    18  	"sync/atomic"
    19  	"time"
    20  
    21  	"go.mongodb.org/mongo-driver/internal/csot"
    22  	"go.mongodb.org/mongo-driver/mongo/address"
    23  	"go.mongodb.org/mongo-driver/mongo/description"
    24  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    25  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    26  	"go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
    27  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    28  )
    29  
    30  // Connection state constants.
    31  const (
    32  	connDisconnected int64 = iota
    33  	connConnected
    34  	connInitialized
    35  )
    36  
    37  var globalConnectionID uint64 = 1
    38  
    39  var (
    40  	defaultMaxMessageSize        uint32 = 48000000
    41  	errResponseTooLarge                 = errors.New("length of read message too large")
    42  	errLoadBalancedStateMismatch        = errors.New("driver attempted to initialize in load balancing mode, but the server does not support this mode")
    43  )
    44  
    45  func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
    46  
    47  type connection struct {
    48  	// state must be accessed using the atomic package and should be at the beginning of the struct.
    49  	// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
    50  	// - suggested layout: https://go101.org/article/memory-layout.html
    51  	state int64
    52  
    53  	id                   string
    54  	nc                   net.Conn // When nil, the connection is closed.
    55  	addr                 address.Address
    56  	idleTimeout          time.Duration
    57  	idleDeadline         atomic.Value // Stores a time.Time
    58  	readTimeout          time.Duration
    59  	writeTimeout         time.Duration
    60  	desc                 description.Server
    61  	helloRTT             time.Duration
    62  	compressor           wiremessage.CompressorID
    63  	zliblevel            int
    64  	zstdLevel            int
    65  	connectDone          chan struct{}
    66  	config               *connectionConfig
    67  	cancelConnectContext context.CancelFunc
    68  	connectContextMade   chan struct{}
    69  	canStream            bool
    70  	currentlyStreaming   bool
    71  	connectContextMutex  sync.Mutex
    72  	cancellationListener cancellationListener
    73  	serverConnectionID   *int64 // the server's ID for this client's connection
    74  
    75  	// pool related fields
    76  	pool *pool
    77  
    78  	// TODO(GODRIVER-2824): change driverConnectionID type to int64.
    79  	driverConnectionID uint64
    80  	generation         uint64
    81  
    82  	// awaitingResponse indicates that the server response was not completely
    83  	// read before returning the connection to the pool.
    84  	awaitingResponse bool
    85  }
    86  
    87  // newConnection handles the creation of a connection. It does not connect the connection.
    88  func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
    89  	cfg := newConnectionConfig(opts...)
    90  
    91  	id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())
    92  
    93  	c := &connection{
    94  		id:                   id,
    95  		addr:                 addr,
    96  		idleTimeout:          cfg.idleTimeout,
    97  		readTimeout:          cfg.readTimeout,
    98  		writeTimeout:         cfg.writeTimeout,
    99  		connectDone:          make(chan struct{}),
   100  		config:               cfg,
   101  		connectContextMade:   make(chan struct{}),
   102  		cancellationListener: newCancellListener(),
   103  	}
   104  	// Connections to non-load balanced deployments should eagerly set the generation numbers so errors encountered
   105  	// at any point during connection establishment can be processed without the connection being considered stale.
   106  	if !c.config.loadBalanced {
   107  		c.setGenerationNumber()
   108  	}
   109  	atomic.StoreInt64(&c.state, connInitialized)
   110  
   111  	return c
   112  }
   113  
   114  // DriverConnectionID returns the driver connection ID.
   115  // TODO(GODRIVER-2824): change return type to int64.
   116  func (c *connection) DriverConnectionID() uint64 {
   117  	return c.driverConnectionID
   118  }
   119  
   120  // setGenerationNumber sets the connection's generation number if a callback has been provided to do so in connection
   121  // configuration.
   122  func (c *connection) setGenerationNumber() {
   123  	if c.config.getGenerationFn != nil {
   124  		c.generation = c.config.getGenerationFn(c.desc.ServiceID)
   125  	}
   126  }
   127  
   128  // hasGenerationNumber returns true if the connection has set its generation number. If so, this indicates that the
   129  // generationNumberFn provided via the connection options has been called exactly once.
   130  func (c *connection) hasGenerationNumber() bool {
   131  	if !c.config.loadBalanced {
   132  		// The generation is known for all non-LB clusters once the connection object has been created.
   133  		return true
   134  	}
   135  
   136  	// For LB clusters, we set the generation after the initial handshake, so we know it's set if the connection
   137  	// description has been updated to reflect that it's behind an LB.
   138  	return c.desc.LoadBalanced()
   139  }
   140  
   141  // connect handles the I/O for a connection. It will dial, configure TLS, and perform initialization
   142  // handshakes. All errors returned by connect are considered "before the handshake completes" and
   143  // must be handled by calling the appropriate SDAM handshake error handler.
   144  func (c *connection) connect(ctx context.Context) (err error) {
   145  	if !atomic.CompareAndSwapInt64(&c.state, connInitialized, connConnected) {
   146  		return nil
   147  	}
   148  
   149  	defer close(c.connectDone)
   150  
   151  	// If connect returns an error, set the connection status as disconnected and close the
   152  	// underlying net.Conn if it was created.
   153  	defer func() {
   154  		if err != nil {
   155  			atomic.StoreInt64(&c.state, connDisconnected)
   156  
   157  			if c.nc != nil {
   158  				_ = c.nc.Close()
   159  			}
   160  		}
   161  	}()
   162  
   163  	// Create separate contexts for dialing a connection and doing the MongoDB/auth handshakes.
   164  	//
   165  	// handshakeCtx is simply a cancellable version of ctx because there's no default timeout that needs to be applied
   166  	// to the full handshake. The cancellation allows consumers to bail out early when dialing a connection if it's no
   167  	// longer required. This is done in lock because it accesses the shared cancelConnectContext field.
   168  	//
   169  	// dialCtx is equal to handshakeCtx if connectTimeoutMS=0. Otherwise, it is derived from handshakeCtx so the
   170  	// cancellation still applies but with an added timeout to ensure the connectTimeoutMS option is applied to socket
   171  	// establishment and the TLS handshake as a whole. This is created outside of the connectContextMutex lock to avoid
   172  	// holding the lock longer than necessary.
   173  	c.connectContextMutex.Lock()
   174  	var handshakeCtx context.Context
   175  	handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx)
   176  	c.connectContextMutex.Unlock()
   177  
   178  	dialCtx := handshakeCtx
   179  	var dialCancel context.CancelFunc
   180  	if c.config.connectTimeout != 0 {
   181  		dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout)
   182  		defer dialCancel()
   183  	}
   184  
   185  	defer func() {
   186  		var cancelFn context.CancelFunc
   187  
   188  		c.connectContextMutex.Lock()
   189  		cancelFn = c.cancelConnectContext
   190  		c.cancelConnectContext = nil
   191  		c.connectContextMutex.Unlock()
   192  
   193  		if cancelFn != nil {
   194  			cancelFn()
   195  		}
   196  	}()
   197  
   198  	close(c.connectContextMade)
   199  
   200  	// Assign the result of DialContext to a temporary net.Conn to ensure that c.nc is not set in an error case.
   201  	tempNc, err := c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
   202  	if err != nil {
   203  		return ConnectionError{Wrapped: err, init: true}
   204  	}
   205  	c.nc = tempNc
   206  
   207  	if c.config.tlsConfig != nil {
   208  		tlsConfig := c.config.tlsConfig.Clone()
   209  
   210  		// store the result of configureTLS in a separate variable than c.nc to avoid overwriting c.nc with nil in
   211  		// error cases.
   212  		ocspOpts := &ocsp.VerifyOptions{
   213  			Cache:                   c.config.ocspCache,
   214  			DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
   215  			HTTPClient:              c.config.httpClient,
   216  		}
   217  		tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
   218  		if err != nil {
   219  			return ConnectionError{Wrapped: err, init: true}
   220  		}
   221  		c.nc = tlsNc
   222  	}
   223  
   224  	// running hello and authentication is handled by a handshaker on the configuration instance.
   225  	handshaker := c.config.handshaker
   226  	if handshaker == nil {
   227  		return nil
   228  	}
   229  
   230  	var handshakeInfo driver.HandshakeInformation
   231  	handshakeStartTime := time.Now()
   232  	handshakeConn := initConnection{c}
   233  	handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn)
   234  	if err == nil {
   235  		// We only need to retain the Description field as the connection's description. The authentication-related
   236  		// fields in handshakeInfo are tracked by the handshaker if necessary.
   237  		c.desc = handshakeInfo.Description
   238  		c.serverConnectionID = handshakeInfo.ServerConnectionID
   239  		c.helloRTT = time.Since(handshakeStartTime)
   240  
   241  		// If the application has indicated that the cluster is load balanced, ensure the server has included serviceId
   242  		// in its handshake response to signal that it knows it's behind an LB as well.
   243  		if c.config.loadBalanced && c.desc.ServiceID == nil {
   244  			err = errLoadBalancedStateMismatch
   245  		}
   246  	}
   247  	if err == nil {
   248  		// For load-balanced connections, the generation number depends on the service ID, which isn't known until the
   249  		// initial MongoDB handshake is done. To account for this, we don't attempt to set the connection's generation
   250  		// number unless GetHandshakeInformation succeeds.
   251  		if c.config.loadBalanced {
   252  			c.setGenerationNumber()
   253  		}
   254  
   255  		// If we successfully finished the first part of the handshake and verified LB state, continue with the rest of
   256  		// the handshake.
   257  		err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
   258  	}
   259  
   260  	// We have a failed handshake here
   261  	if err != nil {
   262  		return ConnectionError{Wrapped: err, init: true}
   263  	}
   264  
   265  	if len(c.desc.Compression) > 0 {
   266  	clientMethodLoop:
   267  		for _, method := range c.config.compressors {
   268  			for _, serverMethod := range c.desc.Compression {
   269  				if method != serverMethod {
   270  					continue
   271  				}
   272  
   273  				switch strings.ToLower(method) {
   274  				case "snappy":
   275  					c.compressor = wiremessage.CompressorSnappy
   276  				case "zlib":
   277  					c.compressor = wiremessage.CompressorZLib
   278  					c.zliblevel = wiremessage.DefaultZlibLevel
   279  					if c.config.zlibLevel != nil {
   280  						c.zliblevel = *c.config.zlibLevel
   281  					}
   282  				case "zstd":
   283  					c.compressor = wiremessage.CompressorZstd
   284  					c.zstdLevel = wiremessage.DefaultZstdLevel
   285  					if c.config.zstdLevel != nil {
   286  						c.zstdLevel = *c.config.zstdLevel
   287  					}
   288  				}
   289  				break clientMethodLoop
   290  			}
   291  		}
   292  	}
   293  	return nil
   294  }
   295  
   296  func (c *connection) wait() {
   297  	if c.connectDone != nil {
   298  		<-c.connectDone
   299  	}
   300  }
   301  
   302  func (c *connection) closeConnectContext() {
   303  	<-c.connectContextMade
   304  	var cancelFn context.CancelFunc
   305  
   306  	c.connectContextMutex.Lock()
   307  	cancelFn = c.cancelConnectContext
   308  	c.cancelConnectContext = nil
   309  	c.connectContextMutex.Unlock()
   310  
   311  	if cancelFn != nil {
   312  		cancelFn()
   313  	}
   314  }
   315  
   316  func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
   317  	if originalError == nil {
   318  		return nil
   319  	}
   320  
   321  	// If there was an error and the context was cancelled, we assume it happened due to the cancellation.
   322  	if errors.Is(ctx.Err(), context.Canceled) {
   323  		return context.Canceled
   324  	}
   325  
   326  	// If there was a timeout error and the context deadline was used, we convert the error into
   327  	// context.DeadlineExceeded.
   328  	if !contextDeadlineUsed {
   329  		return originalError
   330  	}
   331  	if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() {
   332  		return context.DeadlineExceeded
   333  	}
   334  
   335  	return originalError
   336  }
   337  
   338  func (c *connection) cancellationListenerCallback() {
   339  	_ = c.close()
   340  }
   341  
   342  func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
   343  	var err error
   344  	if atomic.LoadInt64(&c.state) != connConnected {
   345  		return ConnectionError{
   346  			ConnectionID: c.id,
   347  			message:      "connection is closed",
   348  		}
   349  	}
   350  
   351  	var deadline time.Time
   352  	if c.writeTimeout != 0 {
   353  		deadline = time.Now().Add(c.writeTimeout)
   354  	}
   355  
   356  	var contextDeadlineUsed bool
   357  	if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
   358  		contextDeadlineUsed = true
   359  		deadline = dl
   360  	}
   361  
   362  	if err := c.nc.SetWriteDeadline(deadline); err != nil {
   363  		return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"}
   364  	}
   365  
   366  	err = c.write(ctx, wm)
   367  	if err != nil {
   368  		c.close()
   369  		return ConnectionError{
   370  			ConnectionID: c.id,
   371  			Wrapped:      transformNetworkError(ctx, err, contextDeadlineUsed),
   372  			message:      "unable to write wire message to network",
   373  		}
   374  	}
   375  
   376  	return nil
   377  }
   378  
   379  func (c *connection) write(ctx context.Context, wm []byte) (err error) {
   380  	go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
   381  	defer func() {
   382  		// There is a race condition between Write and StopListening. If the context is cancelled after c.nc.Write
   383  		// succeeds, the cancellation listener could fire and close the connection. In this case, the connection has
   384  		// been invalidated but the error is nil. To account for this, overwrite the error to context.Cancelled if
   385  		// the abortedForCancellation flag was set.
   386  
   387  		if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
   388  			err = context.Canceled
   389  		}
   390  	}()
   391  
   392  	_, err = c.nc.Write(wm)
   393  	return err
   394  }
   395  
   396  // readWireMessage reads a wiremessage from the connection. The dst parameter will be overwritten.
   397  func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
   398  	if atomic.LoadInt64(&c.state) != connConnected {
   399  		return nil, ConnectionError{
   400  			ConnectionID: c.id,
   401  			message:      "connection is closed",
   402  		}
   403  	}
   404  
   405  	var deadline time.Time
   406  	if c.readTimeout != 0 {
   407  		deadline = time.Now().Add(c.readTimeout)
   408  	}
   409  
   410  	var contextDeadlineUsed bool
   411  	if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
   412  		contextDeadlineUsed = true
   413  		deadline = dl
   414  	}
   415  
   416  	if err := c.nc.SetReadDeadline(deadline); err != nil {
   417  		return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"}
   418  	}
   419  
   420  	dst, errMsg, err := c.read(ctx)
   421  	if err != nil {
   422  		if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) {
   423  			// If the error was a timeout error and CSOT is enabled, instead of
   424  			// closing the connection mark it as awaiting response so the pool
   425  			// can read the response before making it available to other
   426  			// operations.
   427  			c.awaitingResponse = true
   428  		} else {
   429  			// Otherwise, use the pre-CSOT behavior and close the connection
   430  			// because we don't know if there are other bytes left to read.
   431  			c.close()
   432  		}
   433  		message := errMsg
   434  		if errors.Is(err, io.EOF) {
   435  			message = "socket was unexpectedly closed"
   436  		}
   437  		return nil, ConnectionError{
   438  			ConnectionID: c.id,
   439  			Wrapped:      transformNetworkError(ctx, err, contextDeadlineUsed),
   440  			message:      message,
   441  		}
   442  	}
   443  
   444  	return dst, nil
   445  }
   446  
   447  func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
   448  	go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
   449  	defer func() {
   450  		// If the context is cancelled after we finish reading the server response, the cancellation listener could fire
   451  		// even though the socket reads succeed. To account for this, we overwrite err to be context.Canceled if the
   452  		// abortedForCancellation flag is set.
   453  
   454  		if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
   455  			errMsg = "unable to read server response"
   456  			err = context.Canceled
   457  		}
   458  	}()
   459  
   460  	// We use an array here because it only costs 4 bytes on the stack and means we'll only need to
   461  	// reslice dst once instead of twice.
   462  	var sizeBuf [4]byte
   463  
   464  	// We do a ReadFull into an array here instead of doing an opportunistic ReadAtLeast into dst
   465  	// because there might be more than one wire message waiting to be read, for example when
   466  	// reading messages from an exhaust cursor.
   467  	_, err = io.ReadFull(c.nc, sizeBuf[:])
   468  	if err != nil {
   469  		return nil, "incomplete read of message header", err
   470  	}
   471  
   472  	// read the length as an int32
   473  	size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
   474  
   475  	// In the case of a hello response where MaxMessageSize has not yet been set, use the hard-coded
   476  	// defaultMaxMessageSize instead.
   477  	maxMessageSize := c.desc.MaxMessageSize
   478  	if maxMessageSize == 0 {
   479  		maxMessageSize = defaultMaxMessageSize
   480  	}
   481  	if uint32(size) > maxMessageSize {
   482  		return nil, errResponseTooLarge.Error(), errResponseTooLarge
   483  	}
   484  
   485  	dst := make([]byte, size)
   486  	copy(dst, sizeBuf[:])
   487  
   488  	_, err = io.ReadFull(c.nc, dst[4:])
   489  	if err != nil {
   490  		return dst, "incomplete read of full message", err
   491  	}
   492  
   493  	return dst, "", nil
   494  }
   495  
   496  func (c *connection) close() error {
   497  	// Overwrite the connection state as the first step so only the first close call will execute.
   498  	if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) {
   499  		return nil
   500  	}
   501  
   502  	var err error
   503  	if c.nc != nil {
   504  		err = c.nc.Close()
   505  	}
   506  
   507  	return err
   508  }
   509  
   510  func (c *connection) closed() bool {
   511  	return atomic.LoadInt64(&c.state) == connDisconnected
   512  }
   513  
   514  func (c *connection) idleTimeoutExpired() bool {
   515  	now := time.Now()
   516  	if c.idleTimeout > 0 {
   517  		idleDeadline, ok := c.idleDeadline.Load().(time.Time)
   518  		if ok && now.After(idleDeadline) {
   519  			return true
   520  		}
   521  	}
   522  
   523  	return false
   524  }
   525  
   526  func (c *connection) bumpIdleDeadline() {
   527  	if c.idleTimeout > 0 {
   528  		c.idleDeadline.Store(time.Now().Add(c.idleTimeout))
   529  	}
   530  }
   531  
   532  func (c *connection) setCanStream(canStream bool) {
   533  	c.canStream = canStream
   534  }
   535  
   536  func (c initConnection) supportsStreaming() bool {
   537  	return c.canStream
   538  }
   539  
   540  func (c *connection) setStreaming(streaming bool) {
   541  	c.currentlyStreaming = streaming
   542  }
   543  
   544  func (c *connection) getCurrentlyStreaming() bool {
   545  	return c.currentlyStreaming
   546  }
   547  
   548  func (c *connection) setSocketTimeout(timeout time.Duration) {
   549  	c.readTimeout = timeout
   550  	c.writeTimeout = timeout
   551  }
   552  
   553  func (c *connection) ID() string {
   554  	return c.id
   555  }
   556  
   557  func (c *connection) ServerConnectionID() *int64 {
   558  	return c.serverConnectionID
   559  }
   560  
   561  // initConnection is an adapter used during connection initialization. It has the minimum
   562  // functionality necessary to implement the driver.Connection interface, which is required to pass a
   563  // *connection to a Handshaker.
   564  type initConnection struct{ *connection }
   565  
   566  var _ driver.Connection = initConnection{}
   567  var _ driver.StreamerConnection = initConnection{}
   568  
   569  func (c initConnection) Description() description.Server {
   570  	if c.connection == nil {
   571  		return description.Server{}
   572  	}
   573  	return c.connection.desc
   574  }
   575  func (c initConnection) Close() error             { return nil }
   576  func (c initConnection) ID() string               { return c.id }
   577  func (c initConnection) Address() address.Address { return c.addr }
   578  func (c initConnection) Stale() bool              { return false }
   579  func (c initConnection) LocalAddress() address.Address {
   580  	if c.connection == nil || c.nc == nil {
   581  		return address.Address("0.0.0.0")
   582  	}
   583  	return address.Address(c.nc.LocalAddr().String())
   584  }
   585  func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error {
   586  	return c.writeWireMessage(ctx, wm)
   587  }
   588  func (c initConnection) ReadWireMessage(ctx context.Context) ([]byte, error) {
   589  	return c.readWireMessage(ctx)
   590  }
   591  func (c initConnection) SetStreaming(streaming bool) {
   592  	c.setStreaming(streaming)
   593  }
   594  func (c initConnection) CurrentlyStreaming() bool {
   595  	return c.getCurrentlyStreaming()
   596  }
   597  func (c initConnection) SupportsStreaming() bool {
   598  	return c.supportsStreaming()
   599  }
   600  
   601  // Connection implements the driver.Connection interface to allow reading and writing wire
   602  // messages and the driver.Expirable interface to allow expiring. It wraps an underlying
   603  // topology.connection to make it more goroutine-safe and nil-safe.
   604  type Connection struct {
   605  	connection    *connection
   606  	refCount      int
   607  	cleanupPoolFn func()
   608  
   609  	// cleanupServerFn resets the server state when a connection is returned to the connection pool
   610  	// via Close() or expired via Expire().
   611  	cleanupServerFn func()
   612  
   613  	mu sync.RWMutex
   614  }
   615  
   616  var _ driver.Connection = (*Connection)(nil)
   617  var _ driver.Expirable = (*Connection)(nil)
   618  var _ driver.PinnedConnection = (*Connection)(nil)
   619  
   620  // WriteWireMessage handles writing a wire message to the underlying connection.
   621  func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
   622  	c.mu.RLock()
   623  	defer c.mu.RUnlock()
   624  	if c.connection == nil {
   625  		return ErrConnectionClosed
   626  	}
   627  	return c.connection.writeWireMessage(ctx, wm)
   628  }
   629  
   630  // ReadWireMessage handles reading a wire message from the underlying connection. The dst parameter
   631  // will be overwritten with the new wire message.
   632  func (c *Connection) ReadWireMessage(ctx context.Context) ([]byte, error) {
   633  	c.mu.RLock()
   634  	defer c.mu.RUnlock()
   635  	if c.connection == nil {
   636  		return nil, ErrConnectionClosed
   637  	}
   638  	return c.connection.readWireMessage(ctx)
   639  }
   640  
   641  // CompressWireMessage handles compressing the provided wire message using the underlying
   642  // connection's compressor. The dst parameter will be overwritten with the new wire message. If
   643  // there is no compressor set on the underlying connection, then no compression will be performed.
   644  func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
   645  	c.mu.RLock()
   646  	defer c.mu.RUnlock()
   647  	if c.connection == nil {
   648  		return dst, ErrConnectionClosed
   649  	}
   650  	if c.connection.compressor == wiremessage.CompressorNoOp {
   651  		return append(dst, src...), nil
   652  	}
   653  	_, reqid, respto, origcode, rem, ok := wiremessage.ReadHeader(src)
   654  	if !ok {
   655  		return dst, errors.New("wiremessage is too short to compress, less than 16 bytes")
   656  	}
   657  	idx, dst := wiremessage.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed)
   658  	dst = wiremessage.AppendCompressedOriginalOpCode(dst, origcode)
   659  	dst = wiremessage.AppendCompressedUncompressedSize(dst, int32(len(rem)))
   660  	dst = wiremessage.AppendCompressedCompressorID(dst, c.connection.compressor)
   661  	opts := driver.CompressionOpts{
   662  		Compressor: c.connection.compressor,
   663  		ZlibLevel:  c.connection.zliblevel,
   664  		ZstdLevel:  c.connection.zstdLevel,
   665  	}
   666  	compressed, err := driver.CompressPayload(rem, opts)
   667  	if err != nil {
   668  		return nil, err
   669  	}
   670  	dst = wiremessage.AppendCompressedCompressedMessage(dst, compressed)
   671  	return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil
   672  }
   673  
   674  // Description returns the server description of the server this connection is connected to.
   675  func (c *Connection) Description() description.Server {
   676  	c.mu.RLock()
   677  	defer c.mu.RUnlock()
   678  	if c.connection == nil {
   679  		return description.Server{}
   680  	}
   681  	return c.connection.desc
   682  }
   683  
   684  // Close returns this connection to the connection pool. This method may not closeConnection the underlying
   685  // socket.
   686  func (c *Connection) Close() error {
   687  	c.mu.Lock()
   688  	defer c.mu.Unlock()
   689  	if c.connection == nil || c.refCount > 0 {
   690  		return nil
   691  	}
   692  
   693  	return c.cleanupReferences()
   694  }
   695  
   696  // Expire closes this connection and will closeConnection the underlying socket.
   697  func (c *Connection) Expire() error {
   698  	c.mu.Lock()
   699  	defer c.mu.Unlock()
   700  	if c.connection == nil {
   701  		return nil
   702  	}
   703  
   704  	_ = c.connection.close()
   705  	return c.cleanupReferences()
   706  }
   707  
   708  func (c *Connection) cleanupReferences() error {
   709  	err := c.connection.pool.checkIn(c.connection)
   710  	if c.cleanupPoolFn != nil {
   711  		c.cleanupPoolFn()
   712  		c.cleanupPoolFn = nil
   713  	}
   714  	if c.cleanupServerFn != nil {
   715  		c.cleanupServerFn()
   716  		c.cleanupServerFn = nil
   717  	}
   718  	c.connection = nil
   719  	return err
   720  }
   721  
   722  // Alive returns if the connection is still alive.
   723  func (c *Connection) Alive() bool {
   724  	return c.connection != nil
   725  }
   726  
   727  // ID returns the ID of this connection.
   728  func (c *Connection) ID() string {
   729  	c.mu.RLock()
   730  	defer c.mu.RUnlock()
   731  	if c.connection == nil {
   732  		return "<closed>"
   733  	}
   734  	return c.connection.id
   735  }
   736  
   737  // ServerConnectionID returns the server connection ID of this connection.
   738  func (c *Connection) ServerConnectionID() *int64 {
   739  	if c.connection == nil {
   740  		return nil
   741  	}
   742  	return c.connection.serverConnectionID
   743  }
   744  
   745  // Stale returns if the connection is stale.
   746  func (c *Connection) Stale() bool {
   747  	c.mu.RLock()
   748  	defer c.mu.RUnlock()
   749  	return c.connection.pool.stale(c.connection)
   750  }
   751  
   752  // Address returns the address of this connection.
   753  func (c *Connection) Address() address.Address {
   754  	c.mu.RLock()
   755  	defer c.mu.RUnlock()
   756  	if c.connection == nil {
   757  		return address.Address("0.0.0.0")
   758  	}
   759  	return c.connection.addr
   760  }
   761  
   762  // LocalAddress returns the local address of the connection
   763  func (c *Connection) LocalAddress() address.Address {
   764  	c.mu.RLock()
   765  	defer c.mu.RUnlock()
   766  	if c.connection == nil || c.connection.nc == nil {
   767  		return address.Address("0.0.0.0")
   768  	}
   769  	return address.Address(c.connection.nc.LocalAddr().String())
   770  }
   771  
   772  // PinToCursor updates this connection to reflect that it is pinned to a cursor.
   773  func (c *Connection) PinToCursor() error {
   774  	return c.pin("cursor", c.connection.pool.pinConnectionToCursor, c.connection.pool.unpinConnectionFromCursor)
   775  }
   776  
   777  // PinToTransaction updates this connection to reflect that it is pinned to a transaction.
   778  func (c *Connection) PinToTransaction() error {
   779  	return c.pin("transaction", c.connection.pool.pinConnectionToTransaction, c.connection.pool.unpinConnectionFromTransaction)
   780  }
   781  
   782  func (c *Connection) pin(reason string, updatePoolFn, cleanupPoolFn func()) error {
   783  	c.mu.Lock()
   784  	defer c.mu.Unlock()
   785  	if c.connection == nil {
   786  		return fmt.Errorf("attempted to pin a connection for a %s, but the connection has already been returned to the pool", reason)
   787  	}
   788  
   789  	// Only use the provided callbacks for the first reference to avoid double-counting pinned connection statistics
   790  	// in the pool.
   791  	if c.refCount == 0 {
   792  		updatePoolFn()
   793  		c.cleanupPoolFn = cleanupPoolFn
   794  	}
   795  	c.refCount++
   796  	return nil
   797  }
   798  
   799  // UnpinFromCursor updates this connection to reflect that it is no longer pinned to a cursor.
   800  func (c *Connection) UnpinFromCursor() error {
   801  	return c.unpin("cursor")
   802  }
   803  
   804  // UnpinFromTransaction updates this connection to reflect that it is no longer pinned to a transaction.
   805  func (c *Connection) UnpinFromTransaction() error {
   806  	return c.unpin("transaction")
   807  }
   808  
   809  func (c *Connection) unpin(reason string) error {
   810  	c.mu.Lock()
   811  	defer c.mu.Unlock()
   812  	if c.connection == nil {
   813  		// We don't error here because the resource could have been forcefully closed via Expire.
   814  		return nil
   815  	}
   816  	if c.refCount == 0 {
   817  		return fmt.Errorf("attempted to unpin a connection from a %s, but the connection is not pinned by any resources", reason)
   818  	}
   819  
   820  	c.refCount--
   821  	return nil
   822  }
   823  
   824  // DriverConnectionID returns the driver connection ID.
   825  // TODO(GODRIVER-2824): change return type to int64.
   826  func (c *Connection) DriverConnectionID() uint64 {
   827  	return c.connection.DriverConnectionID()
   828  }
   829  
   830  func configureTLS(ctx context.Context,
   831  	tlsConnSource tlsConnectionSource,
   832  	nc net.Conn,
   833  	addr address.Address,
   834  	config *tls.Config,
   835  	ocspOpts *ocsp.VerifyOptions,
   836  ) (net.Conn, error) {
   837  	// Ensure config.ServerName is always set for SNI.
   838  	if config.ServerName == "" {
   839  		hostname := addr.String()
   840  		colonPos := strings.LastIndex(hostname, ":")
   841  		if colonPos == -1 {
   842  			colonPos = len(hostname)
   843  		}
   844  
   845  		hostname = hostname[:colonPos]
   846  		config.ServerName = hostname
   847  	}
   848  
   849  	client := tlsConnSource.Client(nc, config)
   850  	if err := clientHandshake(ctx, client); err != nil {
   851  		return nil, err
   852  	}
   853  
   854  	// Only do OCSP verification if TLS verification is requested.
   855  	if !config.InsecureSkipVerify {
   856  		if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
   857  			return nil, ocspErr
   858  		}
   859  	}
   860  	return client, nil
   861  }
   862  
   863  // TODO: Naming?
   864  
   865  // cancellListener listens for context cancellation and notifies listeners via a
   866  // callback function.
   867  type cancellListener struct {
   868  	aborted bool
   869  	done    chan struct{}
   870  }
   871  
   872  // newCancellListener constructs a cancellListener.
   873  func newCancellListener() *cancellListener {
   874  	return &cancellListener{
   875  		done: make(chan struct{}),
   876  	}
   877  }
   878  
   879  // Listen blocks until the provided context is cancelled or listening is aborted
   880  // via the StopListening function. If this detects that the context has been
   881  // cancelled (i.e. errors.Is(ctx.Err(), context.Canceled), the provided callback is
   882  // called to abort in-progress work. Even if the context expires, this function
   883  // will block until StopListening is called.
   884  func (c *cancellListener) Listen(ctx context.Context, abortFn func()) {
   885  	c.aborted = false
   886  
   887  	select {
   888  	case <-ctx.Done():
   889  		if errors.Is(ctx.Err(), context.Canceled) {
   890  			c.aborted = true
   891  			abortFn()
   892  		}
   893  
   894  		<-c.done
   895  	case <-c.done:
   896  	}
   897  }
   898  
   899  // StopListening stops the in-progress Listen call. This blocks if there is no
   900  // in-progress Listen call. This function will return true if the provided abort
   901  // callback was called when listening for cancellation on the previous context.
   902  func (c *cancellListener) StopListening() bool {
   903  	c.done <- struct{}{}
   904  	return c.aborted
   905  }
   906  

View as plain text