...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/topology/connection_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/topology

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package topology
     8  
     9  import (
    10  	"context"
    11  	"crypto/tls"
    12  	"errors"
    13  	"math/rand"
    14  	"net"
    15  	"sync"
    16  	"sync/atomic"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/google/go-cmp/cmp"
    21  	"go.mongodb.org/mongo-driver/internal/assert"
    22  	"go.mongodb.org/mongo-driver/mongo/address"
    23  	"go.mongodb.org/mongo-driver/mongo/description"
    24  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    25  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    26  )
    27  
    28  type testHandshaker struct {
    29  	getHandshakeInformation func(context.Context, address.Address, driver.Connection) (driver.HandshakeInformation, error)
    30  	finishHandshake         func(context.Context, driver.Connection) error
    31  }
    32  
    33  // GetHandshakeInformation implements the Handshaker interface.
    34  func (th *testHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) {
    35  	if th.getHandshakeInformation != nil {
    36  		return th.getHandshakeInformation(ctx, addr, conn)
    37  	}
    38  	return driver.HandshakeInformation{}, nil
    39  }
    40  
    41  // FinishHandshake implements the Handshaker interface.
    42  func (th *testHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error {
    43  	if th.finishHandshake != nil {
    44  		return th.finishHandshake(ctx, conn)
    45  	}
    46  	return nil
    47  }
    48  
    49  var _ driver.Handshaker = &testHandshaker{}
    50  
    51  func TestConnection(t *testing.T) {
    52  	t.Run("connection", func(t *testing.T) {
    53  		t.Run("newConnection", func(t *testing.T) {
    54  			t.Run("no default idle timeout", func(t *testing.T) {
    55  				conn := newConnection(address.Address(""))
    56  				wantTimeout := time.Duration(0)
    57  				assert.Equal(t, wantTimeout, conn.idleTimeout, "expected idle timeout %v, got %v", wantTimeout,
    58  					conn.idleTimeout)
    59  			})
    60  		})
    61  		t.Run("connect", func(t *testing.T) {
    62  			t.Run("dialer error", func(t *testing.T) {
    63  				err := errors.New("dialer error")
    64  				var want error = ConnectionError{Wrapped: err, init: true}
    65  				conn := newConnection(address.Address(""), WithDialer(func(Dialer) Dialer {
    66  					return DialerFunc(func(context.Context, string, string) (net.Conn, error) { return nil, err })
    67  				}))
    68  				got := conn.connect(context.Background())
    69  				if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
    70  					t.Errorf("errors do not match. got %v; want %v", got, want)
    71  				}
    72  				connState := atomic.LoadInt64(&conn.state)
    73  				assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
    74  			})
    75  			t.Run("handshaker error", func(t *testing.T) {
    76  				err := errors.New("handshaker error")
    77  				var want error = ConnectionError{Wrapped: err, init: true}
    78  				conn := newConnection(address.Address(""),
    79  					WithHandshaker(func(Handshaker) Handshaker {
    80  						return &testHandshaker{
    81  							finishHandshake: func(context.Context, driver.Connection) error {
    82  								return err
    83  							},
    84  						}
    85  					}),
    86  					WithDialer(func(Dialer) Dialer {
    87  						return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
    88  							return &net.TCPConn{}, nil
    89  						})
    90  					}),
    91  				)
    92  				got := conn.connect(context.Background())
    93  				if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
    94  					t.Errorf("errors do not match. got %v; want %v", got, want)
    95  				}
    96  				connState := atomic.LoadInt64(&conn.state)
    97  				assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
    98  			})
    99  			t.Run("context is not pinned by connect", func(t *testing.T) {
   100  				// connect creates a cancel-able version of the context passed to it and stores the CancelFunc on the
   101  				// connection. The CancelFunc must be set to nil once the connection has been established so the driver
   102  				// does not pin the memory associated with the context for the connection's lifetime.
   103  
   104  				t.Run("connect succeeds", func(t *testing.T) {
   105  					// In the case where connect finishes successfully, it unpins the CancelFunc.
   106  
   107  					conn := newConnection(address.Address(""),
   108  						WithDialer(func(Dialer) Dialer {
   109  							return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
   110  								return &net.TCPConn{}, nil
   111  							})
   112  						}),
   113  						WithHandshaker(func(Handshaker) Handshaker {
   114  							return &testHandshaker{}
   115  						}),
   116  					)
   117  
   118  					err := conn.connect(context.Background())
   119  					assert.Nil(t, err, "error establishing connection: %v", err)
   120  					assert.Nil(t, conn.cancelConnectContext, "cancellation function was not cleared")
   121  				})
   122  				t.Run("connect cancelled", func(t *testing.T) {
   123  					// In the case where connection establishment is cancelled, the closeConnectContext function
   124  					// unpins the CancelFunc.
   125  
   126  					// Create a connection that will block in connect until doneChan is closed. This prevents
   127  					// connect from succeeding and unpinning the CancelFunc.
   128  					doneChan := make(chan struct{})
   129  					conn := newConnection(address.Address(""),
   130  						WithDialer(func(Dialer) Dialer {
   131  							return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
   132  								<-doneChan
   133  								return &net.TCPConn{}, nil
   134  							})
   135  						}),
   136  						WithHandshaker(func(Handshaker) Handshaker {
   137  							return &testHandshaker{}
   138  						}),
   139  					)
   140  
   141  					// Call connect in a goroutine because it will block.
   142  					var wg sync.WaitGroup
   143  					wg.Add(1)
   144  					go func() {
   145  						defer wg.Done()
   146  						_ = conn.connect(context.Background())
   147  					}()
   148  
   149  					// Simulate cancelling connection establishment and assert that this clears the CancelFunc.
   150  					conn.closeConnectContext()
   151  					assert.Nil(t, conn.cancelConnectContext, "cancellation function was not cleared")
   152  					close(doneChan)
   153  					wg.Wait()
   154  				})
   155  			})
   156  			t.Run("tls", func(t *testing.T) {
   157  				t.Run("connection source is set to default if unspecified", func(t *testing.T) {
   158  					conn := newConnection(address.Address(""))
   159  					assert.NotNil(t, conn.config.tlsConnectionSource, "expected tlsConnectionSource to be set but was not")
   160  				})
   161  				t.Run("server name", func(t *testing.T) {
   162  					testCases := []struct {
   163  						name               string
   164  						addr               address.Address
   165  						cfg                *tls.Config
   166  						expectedServerName string
   167  					}{
   168  						{"set to connection address if empty", "localhost:27017", &tls.Config{}, "localhost"},
   169  						{"left alone if non-empty", "localhost:27017", &tls.Config{ServerName: "other"}, "other"},
   170  					}
   171  					for _, tc := range testCases {
   172  						t.Run(tc.name, func(t *testing.T) {
   173  							var sentCfg *tls.Config
   174  							var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
   175  								sentCfg = cfg
   176  								return tls.Client(nc, cfg)
   177  							}
   178  
   179  							connOpts := []ConnectionOption{
   180  								WithDialer(func(Dialer) Dialer {
   181  									return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
   182  										return &net.TCPConn{}, nil
   183  									})
   184  								}),
   185  								WithHandshaker(func(Handshaker) Handshaker {
   186  									return &testHandshaker{}
   187  								}),
   188  								WithTLSConfig(func(*tls.Config) *tls.Config {
   189  									return tc.cfg
   190  								}),
   191  								withTLSConnectionSource(func(tlsConnectionSource) tlsConnectionSource {
   192  									return testTLSConnectionSource
   193  								}),
   194  							}
   195  							conn := newConnection(tc.addr, connOpts...)
   196  
   197  							_ = conn.connect(context.Background())
   198  							assert.NotNil(t, sentCfg, "expected TLS config to be set, but was not")
   199  							assert.Equal(t, tc.expectedServerName, sentCfg.ServerName, "expected ServerName %s, got %s",
   200  								tc.expectedServerName, sentCfg.ServerName)
   201  						})
   202  					}
   203  				})
   204  			})
   205  			t.Run("connectTimeout is applied correctly", func(t *testing.T) {
   206  				testCases := []struct {
   207  					name           string
   208  					contextTimeout time.Duration
   209  					connectTimeout time.Duration
   210  					maxConnectTime time.Duration
   211  				}{
   212  					// The timeout to dial a connection should be min(context timeout, connectTimeoutMS), so 1ms for
   213  					// both of the tests declared below. Both tests also specify a 50ms max connect time to provide
   214  					// a large buffer for lag and avoid test flakiness.
   215  
   216  					{"context timeout is lower", 1 * time.Millisecond, 100 * time.Millisecond, 50 * time.Millisecond},
   217  					{"connect timeout is lower", 100 * time.Millisecond, 1 * time.Millisecond, 50 * time.Millisecond},
   218  				}
   219  
   220  				for _, tc := range testCases {
   221  					t.Run("timeout applied to socket establishment: "+tc.name, func(t *testing.T) {
   222  						// Ensure the initial connection dial can be timed out and the connection propagates the error
   223  						// from the dialer in this case.
   224  
   225  						connOpts := []ConnectionOption{
   226  							WithDialer(func(Dialer) Dialer {
   227  								return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
   228  									<-ctx.Done()
   229  									return nil, ctx.Err()
   230  								})
   231  							}),
   232  							WithConnectTimeout(func(time.Duration) time.Duration {
   233  								return tc.connectTimeout
   234  							}),
   235  						}
   236  						conn := newConnection("", connOpts...)
   237  
   238  						var connectErr error
   239  						callback := func(ctx context.Context) {
   240  							connectCtx, cancel := context.WithTimeout(ctx, tc.contextTimeout)
   241  							defer cancel()
   242  
   243  							connectErr = conn.connect(connectCtx)
   244  						}
   245  						assert.Soon(t, callback, tc.maxConnectTime)
   246  
   247  						ce, ok := connectErr.(ConnectionError)
   248  						assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
   249  						assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v",
   250  							context.DeadlineExceeded, ce.Unwrap())
   251  					})
   252  					t.Run("timeout applied to TLS handshake: "+tc.name, func(t *testing.T) {
   253  						// Ensure the TLS handshake can be timed out and the connection propagates the error from the
   254  						// tlsConn in this case.
   255  
   256  						// Start a TCP listener on a random port and use the listener address as the
   257  						// target for connections. The listener will act as a source of connections
   258  						// that never respond, allowing the timeout logic to always trigger.
   259  						l, err := net.Listen("tcp", "localhost:0")
   260  						assert.Nil(t, err, "net.Listen() error: %q", err)
   261  						defer l.Close()
   262  
   263  						connOpts := []ConnectionOption{
   264  							WithConnectTimeout(func(time.Duration) time.Duration {
   265  								return tc.connectTimeout
   266  							}),
   267  							WithTLSConfig(func(*tls.Config) *tls.Config {
   268  								return &tls.Config{ServerName: "test"}
   269  							}),
   270  						}
   271  						conn := newConnection(address.Address(l.Addr().String()), connOpts...)
   272  
   273  						var connectErr error
   274  						callback := func(ctx context.Context) {
   275  							connectCtx, cancel := context.WithTimeout(ctx, tc.contextTimeout)
   276  							defer cancel()
   277  
   278  							connectErr = conn.connect(connectCtx)
   279  						}
   280  						assert.Soon(t, callback, tc.maxConnectTime)
   281  
   282  						ce, ok := connectErr.(ConnectionError)
   283  						assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
   284  
   285  						isTimeout := func(err error) bool {
   286  							if errors.Is(err, context.DeadlineExceeded) {
   287  								return true
   288  							}
   289  							if ne, ok := err.(net.Error); ok {
   290  								return ne.Timeout()
   291  							}
   292  							return false
   293  						}
   294  						assert.True(t,
   295  							isTimeout(ce.Unwrap()),
   296  							"expected wrapped error to be a timeout error, but got %q",
   297  							ce.Unwrap())
   298  					})
   299  					t.Run("timeout is not applied to handshaker: "+tc.name, func(t *testing.T) {
   300  						// Ensure that no additional timeout is applied to the handshake after the connection has been
   301  						// established.
   302  
   303  						var getInfoCtx, finishCtx context.Context
   304  						handshaker := &testHandshaker{
   305  							getHandshakeInformation: func(ctx context.Context, _ address.Address, _ driver.Connection) (driver.HandshakeInformation, error) {
   306  								getInfoCtx = ctx
   307  								return driver.HandshakeInformation{}, nil
   308  							},
   309  							finishHandshake: func(ctx context.Context, _ driver.Connection) error {
   310  								finishCtx = ctx
   311  								return nil
   312  							},
   313  						}
   314  
   315  						connOpts := []ConnectionOption{
   316  							WithConnectTimeout(func(time.Duration) time.Duration {
   317  								return tc.connectTimeout
   318  							}),
   319  							WithDialer(func(Dialer) Dialer {
   320  								return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
   321  									return &net.TCPConn{}, nil
   322  								})
   323  							}),
   324  							WithHandshaker(func(Handshaker) Handshaker {
   325  								return handshaker
   326  							}),
   327  						}
   328  						conn := newConnection("", connOpts...)
   329  
   330  						err := conn.connect(context.Background())
   331  						assert.Nil(t, err, "connect error: %v", err)
   332  
   333  						assertNoContextTimeout := func(t *testing.T, ctx context.Context) {
   334  							t.Helper()
   335  							dl, ok := ctx.Deadline()
   336  							assert.False(t, ok, "expected context to have no deadline, but got deadline %v", dl)
   337  						}
   338  						assertNoContextTimeout(t, getInfoCtx)
   339  						assertNoContextTimeout(t, finishCtx)
   340  					})
   341  				}
   342  			})
   343  		})
   344  		t.Run("writeWireMessage", func(t *testing.T) {
   345  			t.Run("closed connection", func(t *testing.T) {
   346  				conn := &connection{id: "foobar"}
   347  				want := ConnectionError{ConnectionID: "foobar", message: "connection is closed"}
   348  				got := conn.writeWireMessage(context.Background(), []byte{})
   349  				if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   350  					t.Errorf("errors do not match. got %v; want %v", got, want)
   351  				}
   352  			})
   353  			t.Run("deadlines", func(t *testing.T) {
   354  				testCases := []struct {
   355  					name        string
   356  					ctxDeadline time.Duration
   357  					timeout     time.Duration
   358  					deadline    time.Time
   359  				}{
   360  					{"no deadline", 0, 0, time.Now().Add(1 * time.Second)},
   361  					{"ctx deadline", 5 * time.Second, 0, time.Now().Add(6 * time.Second)},
   362  					{"timeout", 0, 10 * time.Second, time.Now().Add(11 * time.Second)},
   363  					{"both (ctx wins)", 15 * time.Second, 20 * time.Second, time.Now().Add(16 * time.Second)},
   364  					{"both (timeout wins)", 30 * time.Second, 25 * time.Second, time.Now().Add(26 * time.Second)},
   365  				}
   366  
   367  				for _, tc := range testCases {
   368  					t.Run(tc.name, func(t *testing.T) {
   369  						ctx := context.Background()
   370  						if tc.ctxDeadline > 0 {
   371  							var cancel context.CancelFunc
   372  							ctx, cancel = context.WithTimeout(ctx, tc.ctxDeadline)
   373  							defer cancel()
   374  						}
   375  						want := ConnectionError{
   376  							ConnectionID: "foobar",
   377  							Wrapped:      errors.New("set writeDeadline error"),
   378  							message:      "failed to set write deadline",
   379  						}
   380  						tnc := &testNetConn{deadlineerr: errors.New("set writeDeadline error")}
   381  						conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout, state: connConnected}
   382  						got := conn.writeWireMessage(ctx, []byte{})
   383  						if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   384  							t.Errorf("errors do not match. got %v; want %v", got, want)
   385  						}
   386  						if !tc.deadline.After(tnc.writeDeadline) {
   387  							t.Errorf("write deadline not properly set. got %v; want %v", tnc.writeDeadline, tc.deadline)
   388  						}
   389  					})
   390  				}
   391  			})
   392  			t.Run("Write", func(t *testing.T) {
   393  				writeErrMsg := "unable to write wire message to network"
   394  
   395  				t.Run("error", func(t *testing.T) {
   396  					err := errors.New("Write error")
   397  					tnc := &testNetConn{writeerr: err}
   398  					conn := &connection{id: "foobar", nc: tnc, state: connConnected}
   399  					listener := newTestCancellationListener(false)
   400  					conn.cancellationListener = listener
   401  
   402  					want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: writeErrMsg}
   403  					got := conn.writeWireMessage(context.Background(), []byte{})
   404  					if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   405  						t.Errorf("errors do not match. got %v; want %v", got, want)
   406  					}
   407  					if !tnc.closed {
   408  						t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
   409  					}
   410  					listener.assertCalledOnce(t)
   411  				})
   412  				t.Run("success", func(t *testing.T) {
   413  					tnc := &testNetConn{}
   414  					conn := &connection{id: "foobar", nc: tnc, state: connConnected}
   415  					listener := newTestCancellationListener(false)
   416  					conn.cancellationListener = listener
   417  
   418  					want := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}
   419  					err := conn.writeWireMessage(context.Background(), want)
   420  					noerr(t, err)
   421  					got := tnc.buf
   422  					if !cmp.Equal(got, want) {
   423  						t.Errorf("writeWireMessage did not write the proper bytes. got %v; want %v", got, want)
   424  					}
   425  					listener.assertCalledOnce(t)
   426  				})
   427  				t.Run("cancel in-progress write", func(t *testing.T) {
   428  					// Simulate context cancellation during a network write.
   429  
   430  					nc := newCancellationWriteConn(&testNetConn{}, 0)
   431  					conn := &connection{id: "foobar", nc: nc, state: connConnected}
   432  					listener := newTestCancellationListener(false)
   433  					conn.cancellationListener = listener
   434  
   435  					ctx, cancel := context.WithCancel(context.Background())
   436  					var err error
   437  
   438  					var wg sync.WaitGroup
   439  					wg.Add(1)
   440  					go func() {
   441  						defer wg.Done()
   442  						err = conn.writeWireMessage(ctx, []byte("foobar"))
   443  					}()
   444  
   445  					<-nc.operationStartedChan
   446  					cancel()
   447  					nc.continueChan <- struct{}{}
   448  
   449  					wg.Wait()
   450  					want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: writeErrMsg}
   451  					assert.Equal(t, want, err, "expected error %v, got %v", want, err)
   452  					assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
   453  						conn.state)
   454  				})
   455  				t.Run("connection is closed if context is cancelled even if network write succeeds", func(t *testing.T) {
   456  					// Test the race condition between Write and the cancellation listener. The socket write will
   457  					// succeed, but we set the abortedForCancellation flag to true to simulate the context being
   458  					// cancelled immediately after the Write finishes.
   459  
   460  					tnc := &testNetConn{}
   461  					conn := &connection{id: "foobar", nc: tnc, state: connConnected}
   462  					listener := newTestCancellationListener(true)
   463  					conn.cancellationListener = listener
   464  
   465  					want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: writeErrMsg}
   466  					err := conn.writeWireMessage(context.Background(), []byte("foobar"))
   467  					assert.Equal(t, want, err, "expected error %v, got %v", want, err)
   468  					assert.Equal(t, conn.state, connDisconnected, "expected connection state %v, got %v", connDisconnected,
   469  						conn.state)
   470  				})
   471  			})
   472  		})
   473  		t.Run("readWireMessage", func(t *testing.T) {
   474  			t.Run("closed connection", func(t *testing.T) {
   475  				conn := &connection{id: "foobar"}
   476  				want := ConnectionError{ConnectionID: "foobar", message: "connection is closed"}
   477  				_, got := conn.readWireMessage(context.Background())
   478  				if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   479  					t.Errorf("errors do not match. got %v; want %v", got, want)
   480  				}
   481  			})
   482  			t.Run("deadlines", func(t *testing.T) {
   483  				testCases := []struct {
   484  					name        string
   485  					ctxDeadline time.Duration
   486  					timeout     time.Duration
   487  					deadline    time.Time
   488  				}{
   489  					{"no deadline", 0, 0, time.Now().Add(1 * time.Second)},
   490  					{"ctx deadline", 5 * time.Second, 0, time.Now().Add(6 * time.Second)},
   491  					{"timeout", 0, 10 * time.Second, time.Now().Add(11 * time.Second)},
   492  					{"both (ctx wins)", 15 * time.Second, 20 * time.Second, time.Now().Add(16 * time.Second)},
   493  					{"both (timeout wins)", 30 * time.Second, 25 * time.Second, time.Now().Add(26 * time.Second)},
   494  				}
   495  
   496  				for _, tc := range testCases {
   497  					t.Run(tc.name, func(t *testing.T) {
   498  						ctx := context.Background()
   499  						if tc.ctxDeadline > 0 {
   500  							var cancel context.CancelFunc
   501  							ctx, cancel = context.WithTimeout(ctx, tc.ctxDeadline)
   502  							defer cancel()
   503  						}
   504  						want := ConnectionError{
   505  							ConnectionID: "foobar",
   506  							Wrapped:      errors.New("set readDeadline error"),
   507  							message:      "failed to set read deadline",
   508  						}
   509  						tnc := &testNetConn{deadlineerr: errors.New("set readDeadline error")}
   510  						conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout, state: connConnected}
   511  						_, got := conn.readWireMessage(ctx)
   512  						if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   513  							t.Errorf("errors do not match. got %v; want %v", got, want)
   514  						}
   515  						if !tc.deadline.After(tnc.readDeadline) {
   516  							t.Errorf("read deadline not properly set. got %v; want %v", tnc.readDeadline, tc.deadline)
   517  						}
   518  					})
   519  				}
   520  			})
   521  			t.Run("Read", func(t *testing.T) {
   522  				t.Run("size read errors", func(t *testing.T) {
   523  					err := errors.New("Read error")
   524  					tnc := &testNetConn{readerr: err}
   525  					conn := &connection{id: "foobar", nc: tnc, state: connConnected}
   526  					listener := newTestCancellationListener(false)
   527  					conn.cancellationListener = listener
   528  
   529  					want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: "incomplete read of message header"}
   530  					_, got := conn.readWireMessage(context.Background())
   531  					if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   532  						t.Errorf("errors do not match. got %v; want %v", got, want)
   533  					}
   534  					if !tnc.closed {
   535  						t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
   536  					}
   537  					listener.assertCalledOnce(t)
   538  				})
   539  				t.Run("full message read errors", func(t *testing.T) {
   540  					err := errors.New("Read error")
   541  					tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}
   542  					conn := &connection{id: "foobar", nc: tnc, state: connConnected}
   543  					listener := newTestCancellationListener(false)
   544  					conn.cancellationListener = listener
   545  
   546  					want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: "incomplete read of full message"}
   547  					_, got := conn.readWireMessage(context.Background())
   548  					if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   549  						t.Errorf("errors do not match. got %v; want %v", got, want)
   550  					}
   551  					if !tnc.closed {
   552  						t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
   553  					}
   554  					listener.assertCalledOnce(t)
   555  				})
   556  				t.Run("message too large errors", func(t *testing.T) {
   557  					testCases := []struct {
   558  						name   string
   559  						buffer []byte
   560  						desc   description.Server
   561  					}{
   562  						{
   563  							"message too large errors with small max message size",
   564  							[]byte{0x0A, 0x00, 0x00, 0x00}, // defines a message size of 10 in hex with the first four bytes.
   565  							description.Server{MaxMessageSize: 9},
   566  						},
   567  						{
   568  							"message too large errors with default max message size",
   569  							[]byte{0x01, 0x6C, 0xDC, 0x02}, // defines a message size of 48000001 in hex with the first four bytes.
   570  							description.Server{},
   571  						},
   572  					}
   573  					for _, tc := range testCases {
   574  						t.Run(tc.name, func(t *testing.T) {
   575  							err := errors.New("length of read message too large")
   576  							tnc := &testNetConn{buf: make([]byte, len(tc.buffer))}
   577  							copy(tnc.buf, tc.buffer)
   578  							conn := &connection{id: "foobar", nc: tnc, state: connConnected, desc: tc.desc}
   579  							listener := newTestCancellationListener(false)
   580  							conn.cancellationListener = listener
   581  
   582  							want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: err.Error()}
   583  							_, got := conn.readWireMessage(context.Background())
   584  							if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   585  								t.Errorf("errors do not match. got %v; want %v", got, want)
   586  							}
   587  							listener.assertCalledOnce(t)
   588  						})
   589  					}
   590  				})
   591  				t.Run("success", func(t *testing.T) {
   592  					want := []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}
   593  					tnc := &testNetConn{buf: make([]byte, len(want))}
   594  					copy(tnc.buf, want)
   595  					conn := &connection{id: "foobar", nc: tnc, state: connConnected}
   596  					listener := newTestCancellationListener(false)
   597  					conn.cancellationListener = listener
   598  
   599  					got, err := conn.readWireMessage(context.Background())
   600  					noerr(t, err)
   601  					if !cmp.Equal(got, want) {
   602  						t.Errorf("did not read full wire message. got %v; want %v", got, want)
   603  					}
   604  					listener.assertCalledOnce(t)
   605  				})
   606  				t.Run("cancel in-progress read", func(t *testing.T) {
   607  					// Simulate context cancellation during a network read. This has two sub-tests to test cancellation
   608  					// when reading the msg size and when reading the rest of the msg.
   609  
   610  					testCases := []struct {
   611  						name   string
   612  						skip   int
   613  						errmsg string
   614  					}{
   615  						{"cancel size read", 0, "incomplete read of message header"},
   616  						{"cancel full message read", 1, "incomplete read of full message"},
   617  					}
   618  					for _, tc := range testCases {
   619  						t.Run(tc.name, func(t *testing.T) {
   620  							// In the full message case, the size read needs to succeed and return a non-zero size, so
   621  							// we set readBuf to indicate that the full message will have 10 bytes.
   622  							readBuf := []byte{10, 0, 0, 0}
   623  							nc := newCancellationReadConn(&testNetConn{}, tc.skip, readBuf)
   624  
   625  							conn := &connection{id: "foobar", nc: nc, state: connConnected}
   626  							listener := newTestCancellationListener(false)
   627  							conn.cancellationListener = listener
   628  
   629  							ctx, cancel := context.WithCancel(context.Background())
   630  							var err error
   631  
   632  							var wg sync.WaitGroup
   633  							wg.Add(1)
   634  							go func() {
   635  								defer wg.Done()
   636  								_, err = conn.readWireMessage(ctx)
   637  							}()
   638  
   639  							<-nc.operationStartedChan
   640  							cancel()
   641  							nc.continueChan <- struct{}{}
   642  
   643  							wg.Wait()
   644  							want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: tc.errmsg}
   645  							assert.Equal(t, want, err, "expected error %v, got %v", want, err)
   646  							assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
   647  								conn.state)
   648  						})
   649  					}
   650  				})
   651  				t.Run("closes connection if context is cancelled even if the socket read succeeds", func(t *testing.T) {
   652  					tnc := &testNetConn{buf: []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}}
   653  					conn := &connection{id: "foobar", nc: tnc, state: connConnected}
   654  					listener := newTestCancellationListener(true)
   655  					conn.cancellationListener = listener
   656  
   657  					want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: "unable to read server response"}
   658  					_, err := conn.readWireMessage(context.Background())
   659  					assert.Equal(t, want, err, "expected error %v, got %v", want, err)
   660  					assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
   661  						conn.state)
   662  				})
   663  			})
   664  		})
   665  		t.Run("close", func(t *testing.T) {
   666  			t.Run("can close a connection that failed handshaking", func(t *testing.T) {
   667  				conn := newConnection(address.Address(""),
   668  					WithHandshaker(func(Handshaker) Handshaker {
   669  						return &testHandshaker{
   670  							finishHandshake: func(context.Context, driver.Connection) error {
   671  								return errors.New("handshake err")
   672  							},
   673  						}
   674  					}),
   675  					WithDialer(func(Dialer) Dialer {
   676  						return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
   677  							return &net.TCPConn{}, nil
   678  						})
   679  					}),
   680  				)
   681  
   682  				err := conn.connect(context.Background())
   683  				assert.NotNil(t, err, "expected handshake error from connect, got nil")
   684  				connState := atomic.LoadInt64(&conn.state)
   685  				assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
   686  
   687  				err = conn.close()
   688  				assert.Nil(t, err, "close error: %v", err)
   689  			})
   690  		})
   691  		t.Run("cancellation listener callback", func(t *testing.T) {
   692  			t.Run("closes connection", func(t *testing.T) {
   693  				tnc := &testNetConn{}
   694  				conn := &connection{state: connConnected, nc: tnc}
   695  
   696  				conn.cancellationListenerCallback()
   697  				assert.True(t, conn.state == connDisconnected, "expected connection state %v, got %v", connDisconnected,
   698  					conn.state)
   699  				assert.True(t, tnc.closed, "expected net.Conn to be closed but was not")
   700  			})
   701  		})
   702  	})
   703  	t.Run("Connection", func(t *testing.T) {
   704  		t.Run("nil connection does not panic", func(t *testing.T) {
   705  			conn := &Connection{}
   706  			defer func() {
   707  				if r := recover(); r != nil {
   708  					t.Fatalf("Methods on a Connection with a nil *connection should not panic, but panicked with %v", r)
   709  				}
   710  			}()
   711  
   712  			var want, got interface{}
   713  
   714  			want = ErrConnectionClosed
   715  			got = conn.WriteWireMessage(context.Background(), nil)
   716  			if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   717  				t.Errorf("errors do not match. got %v; want %v", got, want)
   718  			}
   719  			_, got = conn.ReadWireMessage(context.Background())
   720  			if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   721  				t.Errorf("errors do not match. got %v; want %v", got, want)
   722  			}
   723  
   724  			want = description.Server{}
   725  			got = conn.Description()
   726  			if !cmp.Equal(got, want) {
   727  				t.Errorf("descriptions do not match. got %v; want %v", got, want)
   728  			}
   729  
   730  			want = nil
   731  			got = conn.Close()
   732  			if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   733  				t.Errorf("errors do not match. got %v; want %v", got, want)
   734  			}
   735  
   736  			got = conn.Expire()
   737  			if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   738  				t.Errorf("errors do not match. got %v; want %v", got, want)
   739  			}
   740  
   741  			want = false
   742  			got = conn.Alive()
   743  			if !cmp.Equal(got, want) {
   744  				t.Errorf("Alive does not match. got %v; want %v", got, want)
   745  			}
   746  
   747  			want = "<closed>"
   748  			got = conn.ID()
   749  			if !cmp.Equal(got, want) {
   750  				t.Errorf("IDs do not match. got %v; want %v", got, want)
   751  			}
   752  
   753  			want = address.Address("0.0.0.0")
   754  			got = conn.Address()
   755  			if !cmp.Equal(got, want) {
   756  				t.Errorf("Addresses do not match. got %v; want %v", got, want)
   757  			}
   758  
   759  			want = address.Address("0.0.0.0")
   760  			got = conn.LocalAddress()
   761  			if !cmp.Equal(got, want) {
   762  				t.Errorf("LocalAddresses do not match. got %v; want %v", got, want)
   763  			}
   764  
   765  			want = (*int64)(nil)
   766  			got = conn.ServerConnectionID()
   767  			if !cmp.Equal(got, want) {
   768  				t.Errorf("ServerConnectionIDs do not match. got %v; want %v", got, want)
   769  			}
   770  		})
   771  
   772  		t.Run("pinning", func(t *testing.T) {
   773  			makeMultipleConnections := func(t *testing.T, numConns int) (*pool, []*Connection, func()) {
   774  				t.Helper()
   775  
   776  				addr := bootstrapConnections(t, numConns, func(nc net.Conn) {})
   777  				pool := newPool(poolConfig{
   778  					Address: address.Address(addr.String()),
   779  				})
   780  				err := pool.ready()
   781  				assert.Nil(t, err, "pool.connect() error: %v", err)
   782  
   783  				conns := make([]*Connection, 0, numConns)
   784  				for i := 0; i < numConns; i++ {
   785  					conn, err := pool.checkOut(context.Background())
   786  					assert.Nil(t, err, "checkOut error: %v", err)
   787  					conns = append(conns, &Connection{connection: conn})
   788  				}
   789  				disconnect := func() {
   790  					pool.close(context.Background())
   791  				}
   792  				return pool, conns, disconnect
   793  			}
   794  			makeOneConnection := func(t *testing.T) (*pool, *Connection, func()) {
   795  				t.Helper()
   796  
   797  				pool, conns, disconnect := makeMultipleConnections(t, 1)
   798  				return pool, conns[0], disconnect
   799  			}
   800  
   801  			assertPoolPinnedStats := func(t *testing.T, p *pool, cursorConns, txnConns uint64) {
   802  				t.Helper()
   803  
   804  				assert.Equal(t, cursorConns, p.pinnedCursorConnections, "expected %d connections to be pinned to cursors, got %d",
   805  					cursorConns, p.pinnedCursorConnections)
   806  				assert.Equal(t, txnConns, p.pinnedTransactionConnections, "expected %d connections to be pinned to transactions, got %d",
   807  					txnConns, p.pinnedTransactionConnections)
   808  			}
   809  
   810  			t.Run("cursors", func(t *testing.T) {
   811  				pool, conn, disconnect := makeOneConnection(t)
   812  				defer disconnect()
   813  
   814  				err := conn.PinToCursor()
   815  				assert.Nil(t, err, "PinToCursor error: %v", err)
   816  				assertPoolPinnedStats(t, pool, 1, 0)
   817  
   818  				err = conn.UnpinFromCursor()
   819  				assert.Nil(t, err, "UnpinFromCursor error: %v", err)
   820  
   821  				err = conn.Close()
   822  				assert.Nil(t, err, "Close error: %v", err)
   823  				assertPoolPinnedStats(t, pool, 0, 0)
   824  			})
   825  			t.Run("transactions", func(t *testing.T) {
   826  				pool, conn, disconnect := makeOneConnection(t)
   827  				defer disconnect()
   828  
   829  				err := conn.PinToTransaction()
   830  				assert.Nil(t, err, "PinToTransaction error: %v", err)
   831  				assertPoolPinnedStats(t, pool, 0, 1)
   832  
   833  				err = conn.UnpinFromTransaction()
   834  				assert.Nil(t, err, "UnpinFromTransaction error: %v", err)
   835  
   836  				err = conn.Close()
   837  				assert.Nil(t, err, "Close error: %v", err)
   838  				assertPoolPinnedStats(t, pool, 0, 0)
   839  			})
   840  			t.Run("pool is only updated for first reference", func(t *testing.T) {
   841  				pool, conn, disconnect := makeOneConnection(t)
   842  				defer disconnect()
   843  
   844  				err := conn.PinToTransaction()
   845  				assert.Nil(t, err, "PinToTransaction error: %v", err)
   846  				assertPoolPinnedStats(t, pool, 0, 1)
   847  
   848  				err = conn.PinToCursor()
   849  				assert.Nil(t, err, "PinToCursor error: %v", err)
   850  				assertPoolPinnedStats(t, pool, 0, 1)
   851  
   852  				err = conn.UnpinFromCursor()
   853  				assert.Nil(t, err, "UnpinFromCursor error: %v", err)
   854  				assertPoolPinnedStats(t, pool, 0, 1)
   855  
   856  				err = conn.UnpinFromTransaction()
   857  				assert.Nil(t, err, "UnpinFromTransaction error: %v", err)
   858  				assertPoolPinnedStats(t, pool, 0, 1)
   859  
   860  				err = conn.Close()
   861  				assert.Nil(t, err, "Close error: %v", err)
   862  				assertPoolPinnedStats(t, pool, 0, 0)
   863  			})
   864  			t.Run("multiple connections from a pool", func(t *testing.T) {
   865  				pool, conns, disconnect := makeMultipleConnections(t, 2)
   866  				defer disconnect()
   867  
   868  				first, second := conns[0], conns[1]
   869  
   870  				err := first.PinToTransaction()
   871  				assert.Nil(t, err, "PinToTransaction error: %v", err)
   872  				err = second.PinToCursor()
   873  				assert.Nil(t, err, "PinToCursor error: %v", err)
   874  				assertPoolPinnedStats(t, pool, 1, 1)
   875  
   876  				err = first.UnpinFromTransaction()
   877  				assert.Nil(t, err, "UnpinFromTransaction error: %v", err)
   878  				err = first.Close()
   879  				assert.Nil(t, err, "Close error: %v", err)
   880  				assertPoolPinnedStats(t, pool, 1, 0)
   881  
   882  				err = second.UnpinFromCursor()
   883  				assert.Nil(t, err, "UnpinFromCursor error: %v", err)
   884  				err = second.Close()
   885  				assert.Nil(t, err, "Close error: %v", err)
   886  				assertPoolPinnedStats(t, pool, 0, 0)
   887  			})
   888  			t.Run("close is ignored if connection is pinned", func(t *testing.T) {
   889  				pool, conn, disconnect := makeOneConnection(t)
   890  				defer disconnect()
   891  
   892  				err := conn.PinToCursor()
   893  				assert.Nil(t, err, "PinToCursor error: %v", err)
   894  
   895  				err = conn.Close()
   896  				assert.Nil(t, err, "Close error")
   897  				assert.NotNil(t, conn.connection, "expected connection to be pinned but it was released to the pool")
   898  				assertPoolPinnedStats(t, pool, 1, 0)
   899  			})
   900  			t.Run("expire forcefully returns connection to pool", func(t *testing.T) {
   901  				pool, conn, disconnect := makeOneConnection(t)
   902  				defer disconnect()
   903  
   904  				err := conn.PinToCursor()
   905  				assert.Nil(t, err, "PinToCursor error: %v", err)
   906  
   907  				err = conn.Expire()
   908  				assert.Nil(t, err, "Expire error")
   909  				assert.Nil(t, conn.connection, "expected connection to be released to the pool but was not")
   910  				assertPoolPinnedStats(t, pool, 0, 0)
   911  			})
   912  		})
   913  	})
   914  }
   915  
   916  func BenchmarkConnection(b *testing.B) {
   917  	b.Run("CompressWireMessage CompressorNoOp", func(b *testing.B) {
   918  		buf := make([]byte, 256)
   919  		_, err := rand.Read(buf)
   920  		if err != nil {
   921  			b.Log(err)
   922  			b.FailNow()
   923  		}
   924  		conn := Connection{connection: &connection{compressor: wiremessage.CompressorNoOp}}
   925  		for i := 0; i < b.N; i++ {
   926  			_, err := conn.CompressWireMessage(buf, nil)
   927  			if err != nil {
   928  				b.Error(err)
   929  			}
   930  		}
   931  	})
   932  }
   933  
   934  // cancellationTestNetConn is a net.Conn implementation that is used to test context.Cancellation during an in-progress
   935  // network read or write. This type has two unbuffered channels: operationStartedChan and continueChan. When Read/Write
   936  // starts, the type will write to operationStartedChan, which will block until the test reads from it. This signals to
   937  // the test that the connection has entered the net.Conn read/write. After that unblocks, the type will then read from
   938  // continueChan, which blocks until the test writes to it. This allows the test to perform operations with the guarantee
   939  // that they will complete before the read/write functions exit. Sample usage:
   940  //
   941  // nc := newCancellationWriteConn(&testNetConn{}, 0)
   942  // conn := &connection{nc}
   943  // go func() { _ = conn.writeWireMessage(ctx, []byte{"hello world"})}()
   944  // <-nc.operationStartedChan
   945  // log.Println("This print will happen inside net.Conn.Write")
   946  // nc.continueChan <- struct{}{}
   947  //
   948  // By default, the read/write methods will error after they can read from continueChan to simulate a connection being
   949  // closed after context cancellation. This type also supports skipping to allow a number of successful read/write calls
   950  // before one fails.
   951  type cancellationTestNetConn struct {
   952  	net.Conn
   953  
   954  	shouldSkip           int
   955  	skipCount            int
   956  	readBuf              []byte
   957  	operationStartedChan chan struct{}
   958  	continueChan         chan struct{}
   959  }
   960  
   961  // create a cancellationTestNetConn to test cancelling net.Conn.Write().
   962  // skip specifies the number of writes that should succeed. Successful writes will return len(writeBuffer), nil.
   963  func newCancellationWriteConn(nc net.Conn, skip int) *cancellationTestNetConn {
   964  	return &cancellationTestNetConn{
   965  		Conn:                 nc,
   966  		shouldSkip:           skip,
   967  		operationStartedChan: make(chan struct{}),
   968  		continueChan:         make(chan struct{}),
   969  	}
   970  }
   971  
   972  // create a cancellationTestNetConn to test cancelling net.Conn.Read().
   973  // skip specifies the number of reads that should succeed. Successful reads will copy the contents of readBuf into the
   974  // buffer provided to Read and will return len(readBuf), nil.
   975  func newCancellationReadConn(nc net.Conn, skip int, readBuf []byte) *cancellationTestNetConn {
   976  	return &cancellationTestNetConn{
   977  		Conn:                 nc,
   978  		shouldSkip:           skip,
   979  		readBuf:              readBuf,
   980  		operationStartedChan: make(chan struct{}),
   981  		continueChan:         make(chan struct{}),
   982  	}
   983  }
   984  
   985  func (c *cancellationTestNetConn) Read(b []byte) (int, error) {
   986  	if c.skipCount < c.shouldSkip {
   987  		c.skipCount++
   988  		copy(b, c.readBuf)
   989  		return len(c.readBuf), nil
   990  	}
   991  
   992  	c.operationStartedChan <- struct{}{}
   993  	<-c.continueChan
   994  	return 0, errors.New("cancelled read")
   995  }
   996  
   997  func (c *cancellationTestNetConn) Write(b []byte) (n int, err error) {
   998  	if c.skipCount < c.shouldSkip {
   999  		c.skipCount++
  1000  		return len(b), nil
  1001  	}
  1002  
  1003  	c.operationStartedChan <- struct{}{}
  1004  	<-c.continueChan
  1005  	return 0, errors.New("cancelled write")
  1006  }
  1007  
  1008  type testNetConn struct {
  1009  	nc  net.Conn
  1010  	buf []byte
  1011  
  1012  	deadlineerr error
  1013  	writeerr    error
  1014  	readerr     error
  1015  	closed      bool
  1016  
  1017  	deadline      time.Time
  1018  	readDeadline  time.Time
  1019  	writeDeadline time.Time
  1020  }
  1021  
  1022  func (tnc *testNetConn) Read(b []byte) (n int, err error) {
  1023  	if len(tnc.buf) > 0 {
  1024  		n := copy(b, tnc.buf)
  1025  		tnc.buf = tnc.buf[n:]
  1026  		return n, nil
  1027  	}
  1028  	if tnc.readerr != nil {
  1029  		return 0, tnc.readerr
  1030  	}
  1031  	if tnc.nc == nil {
  1032  		return 0, nil
  1033  	}
  1034  	return tnc.nc.Read(b)
  1035  }
  1036  
  1037  func (tnc *testNetConn) Write(b []byte) (n int, err error) {
  1038  	if tnc.writeerr != nil {
  1039  		return 0, tnc.writeerr
  1040  	}
  1041  	if tnc.nc == nil {
  1042  		idx := len(tnc.buf)
  1043  		tnc.buf = append(tnc.buf, make([]byte, len(b))...)
  1044  		copy(tnc.buf[idx:], b)
  1045  		return len(b), nil
  1046  	}
  1047  	return tnc.nc.Write(b)
  1048  }
  1049  
  1050  func (tnc *testNetConn) Close() error {
  1051  	tnc.closed = true
  1052  	if tnc.nc == nil {
  1053  		return nil
  1054  	}
  1055  	return tnc.nc.Close()
  1056  }
  1057  
  1058  func (tnc *testNetConn) LocalAddr() net.Addr {
  1059  	if tnc.nc == nil {
  1060  		return nil
  1061  	}
  1062  	return tnc.nc.LocalAddr()
  1063  }
  1064  
  1065  func (tnc *testNetConn) RemoteAddr() net.Addr {
  1066  	if tnc.nc == nil {
  1067  		return nil
  1068  	}
  1069  	return tnc.nc.RemoteAddr()
  1070  }
  1071  
  1072  func (tnc *testNetConn) SetDeadline(t time.Time) error {
  1073  	tnc.deadline = t
  1074  	if tnc.deadlineerr != nil {
  1075  		return tnc.deadlineerr
  1076  	}
  1077  	if tnc.nc == nil {
  1078  		return nil
  1079  	}
  1080  	return tnc.nc.SetDeadline(t)
  1081  }
  1082  
  1083  func (tnc *testNetConn) SetReadDeadline(t time.Time) error {
  1084  	tnc.readDeadline = t
  1085  	if tnc.deadlineerr != nil {
  1086  		return tnc.deadlineerr
  1087  	}
  1088  	if tnc.nc == nil {
  1089  		return nil
  1090  	}
  1091  	return tnc.nc.SetReadDeadline(t)
  1092  }
  1093  
  1094  func (tnc *testNetConn) SetWriteDeadline(t time.Time) error {
  1095  	tnc.writeDeadline = t
  1096  	if tnc.deadlineerr != nil {
  1097  		return tnc.deadlineerr
  1098  	}
  1099  	if tnc.nc == nil {
  1100  		return nil
  1101  	}
  1102  	return tnc.nc.SetWriteDeadline(t)
  1103  }
  1104  
  1105  // bootstrapConnection creates a listener that will listen for a single connection
  1106  // on the return address. The user provided run function will be called with the accepted
  1107  // connection. The user is responsible for closing the connection.
  1108  func bootstrapConnections(t *testing.T, num int, run func(net.Conn)) net.Addr {
  1109  	l, err := net.Listen("tcp", "localhost:0")
  1110  	if err != nil {
  1111  		t.Errorf("Could not set up a listener: %v", err)
  1112  		t.FailNow()
  1113  	}
  1114  	go func() {
  1115  		for i := 0; i < num; i++ {
  1116  			c, err := l.Accept()
  1117  			if err != nil {
  1118  				t.Errorf("Could not accept a connection: %v", err)
  1119  			}
  1120  			go run(c)
  1121  		}
  1122  		_ = l.Close()
  1123  	}()
  1124  	return l.Addr()
  1125  }
  1126  
  1127  type netconn struct {
  1128  	net.Conn
  1129  	closed chan struct{}
  1130  	d      *dialer
  1131  }
  1132  
  1133  func (nc *netconn) Close() error {
  1134  	nc.closed <- struct{}{}
  1135  	nc.d.connclosed(nc)
  1136  	return nc.Conn.Close()
  1137  }
  1138  
  1139  type writeFailConn struct {
  1140  	net.Conn
  1141  }
  1142  
  1143  func (wfc *writeFailConn) Write([]byte) (int, error) {
  1144  	return 0, errors.New("Write error")
  1145  }
  1146  
  1147  func (wfc *writeFailConn) SetWriteDeadline(time.Time) error {
  1148  	return nil
  1149  }
  1150  
  1151  type dialer struct {
  1152  	Dialer
  1153  	opened        map[*netconn]struct{}
  1154  	closed        map[*netconn]struct{}
  1155  	closeCallBack func()
  1156  	sync.Mutex
  1157  }
  1158  
  1159  func newdialer(d Dialer) *dialer {
  1160  	return &dialer{Dialer: d, opened: make(map[*netconn]struct{}), closed: make(map[*netconn]struct{})}
  1161  }
  1162  
  1163  func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  1164  	d.Lock()
  1165  	defer d.Unlock()
  1166  	c, err := d.Dialer.DialContext(ctx, network, address)
  1167  	if err != nil {
  1168  		return nil, err
  1169  	}
  1170  	nc := &netconn{Conn: c, closed: make(chan struct{}, 1), d: d}
  1171  	d.opened[nc] = struct{}{}
  1172  	return nc, nil
  1173  }
  1174  
  1175  func (d *dialer) connclosed(nc *netconn) {
  1176  	d.Lock()
  1177  	defer d.Unlock()
  1178  	d.closed[nc] = struct{}{}
  1179  	if d.closeCallBack != nil {
  1180  		d.closeCallBack()
  1181  	}
  1182  }
  1183  
  1184  func (d *dialer) lenopened() int {
  1185  	d.Lock()
  1186  	defer d.Unlock()
  1187  	return len(d.opened)
  1188  }
  1189  
  1190  func (d *dialer) lenclosed() int {
  1191  	d.Lock()
  1192  	defer d.Unlock()
  1193  	return len(d.closed)
  1194  }
  1195  
  1196  type testCancellationListener struct {
  1197  	listener         *cancellListener
  1198  	numListen        int
  1199  	numStopListening int
  1200  	aborted          bool
  1201  }
  1202  
  1203  // This function creates a new testCancellationListener. The aborted parameter specifies the value that should be
  1204  // returned by the StopListening method.
  1205  func newTestCancellationListener(aborted bool) *testCancellationListener {
  1206  	return &testCancellationListener{
  1207  		listener: newCancellListener(),
  1208  		aborted:  aborted,
  1209  	}
  1210  }
  1211  
  1212  func (tcl *testCancellationListener) Listen(ctx context.Context, abortFn func()) {
  1213  	tcl.numListen++
  1214  	tcl.listener.Listen(ctx, abortFn)
  1215  }
  1216  
  1217  func (tcl *testCancellationListener) StopListening() bool {
  1218  	tcl.numStopListening++
  1219  	tcl.listener.StopListening()
  1220  	return tcl.aborted
  1221  }
  1222  
  1223  func (tcl *testCancellationListener) assertCalledOnce(t *testing.T) {
  1224  	assert.Equal(t, 1, tcl.numListen, "expected Listen to be called once, got %d", tcl.numListen)
  1225  	assert.Equal(t, 1, tcl.numStopListening, "expected StopListening to be called once, got %d", tcl.numListen)
  1226  }
  1227  

View as plain text