...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/topology

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  //go:build go1.13
     8  // +build go1.13
     9  
    10  package topology
    11  
    12  import (
    13  	"context"
    14  	"crypto/tls"
    15  	"crypto/x509"
    16  	"errors"
    17  	"io/ioutil"
    18  	"net"
    19  	"os"
    20  	"runtime"
    21  	"sync"
    22  	"sync/atomic"
    23  	"testing"
    24  	"time"
    25  
    26  	"github.com/google/go-cmp/cmp"
    27  	"go.mongodb.org/mongo-driver/bson/primitive"
    28  	"go.mongodb.org/mongo-driver/event"
    29  	"go.mongodb.org/mongo-driver/internal/assert"
    30  	"go.mongodb.org/mongo-driver/internal/eventtest"
    31  	"go.mongodb.org/mongo-driver/internal/require"
    32  	"go.mongodb.org/mongo-driver/mongo/address"
    33  	"go.mongodb.org/mongo-driver/mongo/description"
    34  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    35  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    36  	"go.mongodb.org/mongo-driver/x/mongo/driver/auth"
    37  	"go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
    38  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    39  )
    40  
    41  type channelNetConnDialer struct{}
    42  
    43  func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (net.Conn, error) {
    44  	cnc := &drivertest.ChannelNetConn{
    45  		Written:  make(chan []byte, 1),
    46  		ReadResp: make(chan []byte, 2),
    47  		ReadErr:  make(chan error, 1),
    48  	}
    49  	if err := cnc.AddResponse(makeHelloReply()); err != nil {
    50  		return nil, err
    51  	}
    52  
    53  	return cnc, nil
    54  }
    55  
    56  type errorQueue struct {
    57  	errors []error
    58  	mutex  sync.Mutex
    59  }
    60  
    61  func (eq *errorQueue) head() error {
    62  	eq.mutex.Lock()
    63  	defer eq.mutex.Unlock()
    64  	if len(eq.errors) > 0 {
    65  		return eq.errors[0]
    66  	}
    67  	return nil
    68  }
    69  
    70  func (eq *errorQueue) dequeue() bool {
    71  	eq.mutex.Lock()
    72  	defer eq.mutex.Unlock()
    73  	if len(eq.errors) > 0 {
    74  		eq.errors = eq.errors[1:]
    75  		return true
    76  	}
    77  	return false
    78  }
    79  
    80  type timeoutConn struct {
    81  	net.Conn
    82  	errors *errorQueue
    83  }
    84  
    85  func (c *timeoutConn) Read(b []byte) (int, error) {
    86  	n, err := 0, c.errors.head()
    87  	if err == nil {
    88  		n, err = c.Conn.Read(b)
    89  	}
    90  	return n, err
    91  }
    92  
    93  func (c *timeoutConn) Write(b []byte) (int, error) {
    94  	n, err := 0, c.errors.head()
    95  	if err == nil {
    96  		n, err = c.Conn.Write(b)
    97  	}
    98  	return n, err
    99  }
   100  
   101  type timeoutDialer struct {
   102  	Dialer
   103  	errors *errorQueue
   104  }
   105  
   106  func (d *timeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
   107  	c, e := d.Dialer.DialContext(ctx, network, address)
   108  
   109  	if caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE"); len(caFile) > 0 {
   110  		pem, err := ioutil.ReadFile(caFile)
   111  		if err != nil {
   112  			return nil, err
   113  		}
   114  
   115  		ca := x509.NewCertPool()
   116  		if !ca.AppendCertsFromPEM(pem) {
   117  			return nil, errors.New("unable to load CA file")
   118  		}
   119  
   120  		config := &tls.Config{
   121  			InsecureSkipVerify: true,
   122  			RootCAs:            ca,
   123  		}
   124  		c = tls.Client(c, config)
   125  	}
   126  	return &timeoutConn{c, d.errors}, e
   127  }
   128  
   129  // TestServerHeartbeatTimeout tests timeout retry and preemptive canceling.
   130  func TestServerHeartbeatTimeout(t *testing.T) {
   131  	if os.Getenv("DOCKER_RUNNING") != "" {
   132  		t.Skip("Skipping this test in docker.")
   133  	}
   134  
   135  	networkTimeoutError := &net.DNSError{
   136  		IsTimeout: true,
   137  	}
   138  
   139  	testCases := []struct {
   140  		desc                string
   141  		ioErrors            []error
   142  		expectInterruptions int
   143  	}{
   144  		{
   145  			desc:                "one single timeout should not clear the pool",
   146  			ioErrors:            []error{nil, networkTimeoutError, nil, networkTimeoutError, nil},
   147  			expectInterruptions: 0,
   148  		},
   149  		{
   150  			desc:                "continuous timeouts should clear the pool with interruption",
   151  			ioErrors:            []error{nil, networkTimeoutError, networkTimeoutError, nil},
   152  			expectInterruptions: 1,
   153  		},
   154  	}
   155  	for _, tc := range testCases {
   156  		tc := tc
   157  		t.Run(tc.desc, func(t *testing.T) {
   158  			t.Parallel()
   159  
   160  			var wg sync.WaitGroup
   161  			wg.Add(1)
   162  
   163  			errors := &errorQueue{errors: tc.ioErrors}
   164  			tpm := eventtest.NewTestPoolMonitor()
   165  			server := NewServer(
   166  				address.Address("localhost:27017"),
   167  				primitive.NewObjectID(),
   168  				WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor {
   169  					return tpm.PoolMonitor
   170  				}),
   171  				WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption {
   172  					return append(opts,
   173  						WithDialer(func(d Dialer) Dialer {
   174  							var dialer net.Dialer
   175  							return &timeoutDialer{&dialer, errors}
   176  						}))
   177  				}),
   178  				WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor {
   179  					return &event.ServerMonitor{
   180  						ServerHeartbeatSucceeded: func(e *event.ServerHeartbeatSucceededEvent) {
   181  							if !errors.dequeue() {
   182  								wg.Done()
   183  							}
   184  						},
   185  						ServerHeartbeatFailed: func(e *event.ServerHeartbeatFailedEvent) {
   186  							if !errors.dequeue() {
   187  								wg.Done()
   188  							}
   189  						},
   190  					}
   191  				}),
   192  				WithHeartbeatInterval(func(time.Duration) time.Duration {
   193  					return 200 * time.Millisecond
   194  				}),
   195  			)
   196  			require.NoError(t, server.Connect(nil))
   197  			wg.Wait()
   198  			interruptions := tpm.Interruptions()
   199  			assert.Equal(t, tc.expectInterruptions, interruptions, "expected %d interruption but got %d", tc.expectInterruptions, interruptions)
   200  		})
   201  	}
   202  }
   203  
   204  // TestServerConnectionTimeout tests how different timeout errors are handled during connection
   205  // creation and server handshake.
   206  func TestServerConnectionTimeout(t *testing.T) {
   207  	testCases := []struct {
   208  		desc              string
   209  		dialer            func(Dialer) Dialer
   210  		handshaker        func(Handshaker) Handshaker
   211  		operationTimeout  time.Duration
   212  		connectTimeout    time.Duration
   213  		expectErr         bool
   214  		expectPoolCleared bool
   215  	}{
   216  		{
   217  			desc:              "successful connection should not clear the pool",
   218  			expectErr:         false,
   219  			expectPoolCleared: false,
   220  		},
   221  		{
   222  			desc: "timeout error during dialing should clear the pool",
   223  			dialer: func(Dialer) Dialer {
   224  				var d net.Dialer
   225  				return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
   226  					// Wait for the passed in context to time out. Expect the error returned by
   227  					// DialContext() to be treated as a timeout caused by reaching connectTimeoutMS.
   228  					<-ctx.Done()
   229  					return d.DialContext(ctx, network, addr)
   230  				})
   231  			},
   232  			operationTimeout:  1 * time.Minute,
   233  			connectTimeout:    100 * time.Millisecond,
   234  			expectErr:         true,
   235  			expectPoolCleared: true,
   236  		},
   237  		{
   238  			desc: "timeout error during dialing with no operation timeout should clear the pool",
   239  			dialer: func(Dialer) Dialer {
   240  				var d net.Dialer
   241  				return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
   242  					// Wait for the passed in context to time out. Expect the error returned by
   243  					// DialContext() to be treated as a timeout caused by reaching connectTimeoutMS.
   244  					<-ctx.Done()
   245  					return d.DialContext(ctx, network, addr)
   246  				})
   247  			},
   248  			operationTimeout:  0, // Uses a context.Background() with no timeout.
   249  			connectTimeout:    100 * time.Millisecond,
   250  			expectErr:         true,
   251  			expectPoolCleared: true,
   252  		},
   253  		{
   254  			desc: "dial errors unrelated to context timeouts should clear the pool",
   255  			dialer: func(Dialer) Dialer {
   256  				var d net.Dialer
   257  				return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
   258  					// Try to dial an invalid TCP address and expect an error.
   259  					return d.DialContext(ctx, "tcp", "300.0.0.0:nope")
   260  				})
   261  			},
   262  			expectErr:         true,
   263  			expectPoolCleared: true,
   264  		},
   265  		{
   266  			desc: "operation context timeout with unrelated dial errors should clear the pool",
   267  			dialer: func(Dialer) Dialer {
   268  				var d net.Dialer
   269  				return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
   270  					// Try to dial an invalid TCP address and expect an error.
   271  					c, err := d.DialContext(ctx, "tcp", "300.0.0.0:nope")
   272  					// Wait for the passed in context to time out. Expect that the context error is
   273  					// ignored because the dial error is not a timeout.
   274  					<-ctx.Done()
   275  					return c, err
   276  				})
   277  			},
   278  			operationTimeout:  1 * time.Millisecond,
   279  			connectTimeout:    100 * time.Millisecond,
   280  			expectErr:         true,
   281  			expectPoolCleared: true,
   282  		},
   283  	}
   284  
   285  	for _, tc := range testCases {
   286  		tc := tc
   287  		t.Run(tc.desc, func(t *testing.T) {
   288  			t.Parallel()
   289  
   290  			// Create a TCP listener on a random port. The listener will accept connections but not
   291  			// read or write to them.
   292  			l, err := net.Listen("tcp", "127.0.0.1:0")
   293  			require.NoError(t, err)
   294  			defer func() {
   295  				_ = l.Close()
   296  			}()
   297  
   298  			tpm := eventtest.NewTestPoolMonitor()
   299  			server := NewServer(
   300  				address.Address(l.Addr().String()),
   301  				primitive.NewObjectID(),
   302  				WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor {
   303  					return tpm.PoolMonitor
   304  				}),
   305  				// Replace the default dialer and handshaker with the test dialer and handshaker, if
   306  				// present.
   307  				WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption {
   308  					if tc.connectTimeout > 0 {
   309  						opts = append(opts, WithConnectTimeout(func(time.Duration) time.Duration { return tc.connectTimeout }))
   310  					}
   311  					if tc.dialer != nil {
   312  						opts = append(opts, WithDialer(tc.dialer))
   313  					}
   314  					if tc.handshaker != nil {
   315  						opts = append(opts, WithHandshaker(tc.handshaker))
   316  					}
   317  					return opts
   318  				}),
   319  				// Disable monitoring to prevent unrelated failures from the RTT monitor and
   320  				// heartbeats from unexpectedly clearing the connection pool.
   321  				withMonitoringDisabled(func(bool) bool { return true }),
   322  			)
   323  			require.NoError(t, server.Connect(nil))
   324  
   325  			// Create a context with the operation timeout if one is specified in the test case.
   326  			ctx := context.Background()
   327  			if tc.operationTimeout > 0 {
   328  				var cancel context.CancelFunc
   329  				ctx, cancel = context.WithTimeout(ctx, tc.operationTimeout)
   330  				defer cancel()
   331  			}
   332  			_, err = server.Connection(ctx)
   333  			if tc.expectErr {
   334  				assert.NotNil(t, err, "expected an error but got nil")
   335  			} else {
   336  				assert.Nil(t, err, "expected no error but got %s", err)
   337  			}
   338  
   339  			// If we expect the pool to be cleared, watch for all events until we get a
   340  			// "ConnectionPoolCleared" event or until we hit a 10s time limit.
   341  			if tc.expectPoolCleared {
   342  				assert.Eventually(t,
   343  					tpm.IsPoolCleared,
   344  					10*time.Second,
   345  					100*time.Millisecond,
   346  					"expected pool to be cleared within 10s but was not cleared")
   347  			}
   348  
   349  			// Disconnect the server then close the events channel and expect that no more events
   350  			// are sent on the channel.
   351  			_ = server.Disconnect(context.Background())
   352  
   353  			// If we don't expect the pool to be cleared, check all events after the server is
   354  			// disconnected and make sure none were "ConnectionPoolCleared".
   355  			if !tc.expectPoolCleared {
   356  				assert.False(t, tpm.IsPoolCleared(), "expected pool to not be cleared but was cleared")
   357  			}
   358  		})
   359  	}
   360  }
   361  
   362  func TestServer(t *testing.T) {
   363  	var serverTestTable = []struct {
   364  		name            string
   365  		connectionError bool
   366  		networkError    bool
   367  		hasDesc         bool
   368  	}{
   369  		{"auth_error", true, false, false},
   370  		{"no_error", false, false, false},
   371  		{"network_error_no_desc", false, true, false},
   372  		{"network_error_desc", false, true, true},
   373  	}
   374  
   375  	authErr := ConnectionError{Wrapped: &auth.Error{}, init: true}
   376  	netErr := ConnectionError{Wrapped: &net.AddrError{}, init: true}
   377  	for _, tt := range serverTestTable {
   378  		t.Run(tt.name, func(t *testing.T) {
   379  			var returnConnectionError bool
   380  			s := NewServer(
   381  				address.Address("localhost"),
   382  				primitive.NewObjectID(),
   383  				WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption {
   384  					return append(connOpts,
   385  						WithHandshaker(func(Handshaker) Handshaker {
   386  							return &testHandshaker{
   387  								finishHandshake: func(context.Context, driver.Connection) error {
   388  									var err error
   389  									if tt.connectionError && returnConnectionError {
   390  										err = authErr.Wrapped
   391  									}
   392  									return err
   393  								},
   394  							}
   395  						}),
   396  						WithDialer(func(Dialer) Dialer {
   397  							return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
   398  								var err error
   399  								if tt.networkError && returnConnectionError {
   400  									err = netErr.Wrapped
   401  								}
   402  								return &net.TCPConn{}, err
   403  							})
   404  						}),
   405  					)
   406  				}),
   407  			)
   408  
   409  			var desc *description.Server
   410  			descript := s.Description()
   411  			if tt.hasDesc {
   412  				desc = &descript
   413  				require.Nil(t, desc.LastError)
   414  			}
   415  			err := s.pool.ready()
   416  			require.NoError(t, err, "pool.ready() error")
   417  			defer s.pool.close(context.Background())
   418  
   419  			s.state = serverConnected
   420  
   421  			// The internal connection pool resets the generation number once the number of connections in a generation
   422  			// reaches zero, which will cause some of these tests to fail because they assert that the generation
   423  			// number after a connection failure is 1. To workaround this, we call Connection() twice: once to
   424  			// successfully establish a connection and once to actually do the behavior described in the test case.
   425  			_, err = s.Connection(context.Background())
   426  			assert.Nil(t, err, "error getting initial connection: %v", err)
   427  			returnConnectionError = true
   428  			_, err = s.Connection(context.Background())
   429  
   430  			switch {
   431  			case tt.connectionError && !cmp.Equal(err, authErr, cmp.Comparer(compareErrors)):
   432  				t.Errorf("Expected connection error. got %v; want %v", err, authErr)
   433  			case tt.networkError && !cmp.Equal(err, netErr, cmp.Comparer(compareErrors)):
   434  				t.Errorf("Expected network error. got %v; want %v", err, netErr)
   435  			case !tt.connectionError && !tt.networkError && err != nil:
   436  				t.Errorf("Expected error to be nil. got %v; want %v", err, "<nil>")
   437  			}
   438  
   439  			if tt.hasDesc {
   440  				require.Equal(t, s.Description().Kind, (description.ServerKind)(description.Unknown))
   441  				require.NotNil(t, s.Description().LastError)
   442  			}
   443  
   444  			generation, _ := s.pool.generation.getGeneration(nil)
   445  			if (tt.connectionError || tt.networkError) && generation != 1 {
   446  				t.Errorf("Expected pool to be drained once on connection or network error. got %d; want %d", generation, 1)
   447  			}
   448  		})
   449  	}
   450  
   451  	t.Run("multiple connection initialization errors are processed correctly", func(t *testing.T) {
   452  		assertGenerationStats := func(t *testing.T, server *Server, serviceID primitive.ObjectID, wantGeneration, wantNumConns uint64) {
   453  			t.Helper()
   454  
   455  			getGeneration := func(serviceIDPtr *primitive.ObjectID) uint64 {
   456  				generation, _ := server.pool.generation.getGeneration(serviceIDPtr)
   457  				return generation
   458  			}
   459  
   460  			// On connection failure, the connection is removed and closed after delivering the
   461  			// error to Connection(), so it may still count toward the generation connection count
   462  			// briefly. Wait up to 100ms for the generation connection count to reach the target.
   463  			assert.Eventuallyf(t,
   464  				func() bool {
   465  					generation, _ := server.pool.generation.getGeneration(&serviceID)
   466  					numConns := server.pool.generation.getNumConns(&serviceID)
   467  					return generation == wantGeneration && numConns == wantNumConns
   468  				},
   469  				100*time.Millisecond,
   470  				10*time.Millisecond,
   471  				"expected generation number %v, got %v; expected connection count %v, got %v",
   472  				wantGeneration,
   473  				getGeneration(&serviceID),
   474  				wantNumConns,
   475  				server.pool.generation.getNumConns(&serviceID))
   476  		}
   477  
   478  		testCases := []struct {
   479  			name               string
   480  			loadBalanced       bool
   481  			dialErr            error
   482  			getInfoErr         error
   483  			finishHandshakeErr error
   484  			finalGeneration    uint64
   485  			numNewConns        uint64
   486  		}{
   487  			// For LB clusters, errors for dialing and the initial handshake are ignored.
   488  			{"dial errors are ignored for load balancers", true, netErr.Wrapped, nil, nil, 0, 1},
   489  			{"initial handshake errors are ignored for load balancers", true, nil, netErr.Wrapped, nil, 0, 1},
   490  
   491  			// For LB clusters, post-handshake errors clear the pool, but do not update the server
   492  			// description or pause the pool.
   493  			{"post-handshake errors are not ignored for load balancers", true, nil, nil, netErr.Wrapped, 5, 1},
   494  
   495  			// For non-LB clusters, the first error sets the server to Unknown and clears and pauses
   496  			// the pool. All subsequent attempts to check out a connection without updating the
   497  			// server description return an error because the pool is paused.
   498  			{"dial errors are not ignored for non-lb clusters", false, netErr.Wrapped, nil, nil, 1, 1},
   499  			{"initial handshake errors are not ignored for non-lb clusters", false, nil, netErr.Wrapped, nil, 1, 1},
   500  			{"post-handshake errors are not ignored for non-lb clusters", false, nil, nil, netErr.Wrapped, 1, 1},
   501  		}
   502  		for _, tc := range testCases {
   503  			tc := tc // Capture range variable.
   504  
   505  			t.Run(tc.name, func(t *testing.T) {
   506  				var returnConnectionError bool
   507  				var serviceID primitive.ObjectID
   508  				if tc.loadBalanced {
   509  					serviceID = primitive.NewObjectID()
   510  				}
   511  
   512  				handshaker := &testHandshaker{
   513  					getHandshakeInformation: func(_ context.Context, addr address.Address, _ driver.Connection) (driver.HandshakeInformation, error) {
   514  						if tc.getInfoErr != nil && returnConnectionError {
   515  							return driver.HandshakeInformation{}, tc.getInfoErr
   516  						}
   517  
   518  						desc := description.NewDefaultServer(addr)
   519  						if tc.loadBalanced {
   520  							desc.ServiceID = &serviceID
   521  						}
   522  						return driver.HandshakeInformation{Description: desc}, nil
   523  					},
   524  					finishHandshake: func(context.Context, driver.Connection) error {
   525  						if tc.finishHandshakeErr != nil && returnConnectionError {
   526  							return tc.finishHandshakeErr
   527  						}
   528  						return nil
   529  					},
   530  				}
   531  				connOpts := []ConnectionOption{
   532  					WithDialer(func(Dialer) Dialer {
   533  						return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
   534  							var err error
   535  							if returnConnectionError && tc.dialErr != nil {
   536  								err = tc.dialErr
   537  							}
   538  							return &net.TCPConn{}, err
   539  						})
   540  					}),
   541  					WithHandshaker(func(Handshaker) Handshaker {
   542  						return handshaker
   543  					}),
   544  					WithConnectionLoadBalanced(func(bool) bool {
   545  						return tc.loadBalanced
   546  					}),
   547  				}
   548  				serverOpts := []ServerOption{
   549  					WithServerLoadBalanced(func(bool) bool {
   550  						return tc.loadBalanced
   551  					}),
   552  					WithConnectionOptions(func(...ConnectionOption) []ConnectionOption {
   553  						return connOpts
   554  					}),
   555  					// Disable the monitoring routine because we're only testing pooled connections and we don't want
   556  					// errors in monitoring to clear the pool and make this test flaky.
   557  					withMonitoringDisabled(func(bool) bool {
   558  						return true
   559  					}),
   560  					// With the default maxConnecting (2), there are multiple goroutines creating
   561  					// connections. Because handshake errors are processed after returning the error
   562  					// to checkOut(), it's possible for extra connection requests to be processed
   563  					// after a handshake error before the connection pool is cleared and paused. Set
   564  					// maxConnecting=1 to minimize the number of extra connection requests processed
   565  					// before the connection pool is cleared and paused.
   566  					WithMaxConnecting(func(uint64) uint64 { return 1 }),
   567  				}
   568  
   569  				server, err := ConnectServer(address.Address("localhost:27017"), nil, primitive.NewObjectID(), serverOpts...)
   570  				assert.Nil(t, err, "ConnectServer error: %v", err)
   571  				defer func() {
   572  					_ = server.Disconnect(context.Background())
   573  				}()
   574  
   575  				_, err = server.Connection(context.Background())
   576  				assert.Nil(t, err, "Connection error: %v", err)
   577  				assertGenerationStats(t, server, serviceID, 0, 1)
   578  
   579  				returnConnectionError = true
   580  				for i := 0; i < 5; i++ {
   581  					_, err = server.Connection(context.Background())
   582  					switch {
   583  					case tc.dialErr != nil || tc.getInfoErr != nil || tc.finishHandshakeErr != nil:
   584  						assert.NotNil(t, err, "expected Connection error at iteration %d, got nil", i)
   585  					default:
   586  						assert.Nil(t, err, "Connection error at iteration %d: %v", i, err)
   587  					}
   588  				}
   589  				assertGenerationStats(t, server, serviceID, tc.finalGeneration, tc.numNewConns)
   590  			})
   591  		}
   592  	})
   593  
   594  	t.Run("Cannot starve connection request", func(t *testing.T) {
   595  		cleanup := make(chan struct{})
   596  		addr := bootstrapConnections(t, 3, func(nc net.Conn) {
   597  			<-cleanup
   598  			_ = nc.Close()
   599  		})
   600  		d := newdialer(&net.Dialer{})
   601  		s := NewServer(address.Address(addr.String()),
   602  			primitive.NewObjectID(),
   603  			WithConnectionOptions(func(option ...ConnectionOption) []ConnectionOption {
   604  				return []ConnectionOption{WithDialer(func(_ Dialer) Dialer { return d })}
   605  			}),
   606  			WithMaxConnections(func(u uint64) uint64 {
   607  				return 1
   608  			}))
   609  		s.state = serverConnected
   610  		err := s.pool.ready()
   611  		noerr(t, err)
   612  		defer s.pool.close(context.Background())
   613  
   614  		conn, err := s.Connection(context.Background())
   615  		noerr(t, err)
   616  		if d.lenopened() != 1 {
   617  			t.Errorf("Should have opened 1 connections, but didn't. got %d; want %d", d.lenopened(), 1)
   618  		}
   619  
   620  		var wg sync.WaitGroup
   621  
   622  		wg.Add(1)
   623  		ch := make(chan struct{})
   624  		go func() {
   625  			ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   626  			defer cancel()
   627  			ch <- struct{}{}
   628  			_, err := s.Connection(ctx)
   629  			if err != nil {
   630  				t.Errorf("Should not be able to starve connection request, but got error: %v", err)
   631  			}
   632  			wg.Done()
   633  		}()
   634  		<-ch
   635  		runtime.Gosched()
   636  		err = conn.Close()
   637  		noerr(t, err)
   638  		wg.Wait()
   639  		close(cleanup)
   640  	})
   641  
   642  	t.Run("update topology", func(t *testing.T) {
   643  		var updated atomic.Value // bool
   644  		updated.Store(false)
   645  
   646  		updateCallback := func(desc description.Server) description.Server {
   647  			updated.Store(true)
   648  			return desc
   649  		}
   650  		s, err := ConnectServer(address.Address("localhost"), updateCallback, primitive.NewObjectID())
   651  		require.NoError(t, err)
   652  		s.updateDescription(description.Server{Addr: s.address})
   653  		require.True(t, updated.Load().(bool))
   654  	})
   655  	t.Run("heartbeat", func(t *testing.T) {
   656  		// test that client metadata is sent on handshakes but not heartbeats
   657  		dialer := &channelNetConnDialer{}
   658  		dialerOpt := WithDialer(func(Dialer) Dialer {
   659  			return dialer
   660  		})
   661  		serverOpt := WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption {
   662  			return append(connOpts, dialerOpt)
   663  		})
   664  
   665  		s := NewServer(address.Address("localhost:27017"), primitive.NewObjectID(), serverOpt)
   666  
   667  		// do a heartbeat with a nil connection so a new one will be dialed
   668  		_, err := s.check()
   669  		assert.Nil(t, err, "check error: %v", err)
   670  		assert.NotNil(t, s.conn, "no connection dialed in check")
   671  
   672  		channelConn := s.conn.nc.(*drivertest.ChannelNetConn)
   673  		wm := channelConn.GetWrittenMessage()
   674  		if wm == nil {
   675  			t.Fatal("no wire message written for handshake")
   676  		}
   677  		if !includesClientMetadata(t, wm) {
   678  			t.Fatal("client metadata expected in handshake but not found")
   679  		}
   680  
   681  		// do a heartbeat with a non-nil connection
   682  		if err = channelConn.AddResponse(makeHelloReply()); err != nil {
   683  			t.Fatalf("error adding response: %v", err)
   684  		}
   685  		_, err = s.check()
   686  		assert.Nil(t, err, "check error: %v", err)
   687  
   688  		wm = channelConn.GetWrittenMessage()
   689  		if wm == nil {
   690  			t.Fatal("no wire message written for heartbeat")
   691  		}
   692  		if includesClientMetadata(t, wm) {
   693  			t.Fatal("client metadata not expected in heartbeat but found")
   694  		}
   695  	})
   696  	t.Run("heartbeat monitoring", func(t *testing.T) {
   697  		var publishedEvents []interface{}
   698  
   699  		serverHeartbeatStarted := func(e *event.ServerHeartbeatStartedEvent) {
   700  			publishedEvents = append(publishedEvents, *e)
   701  		}
   702  
   703  		serverHeartbeatSucceeded := func(e *event.ServerHeartbeatSucceededEvent) {
   704  			publishedEvents = append(publishedEvents, *e)
   705  		}
   706  
   707  		serverHeartbeatFailed := func(e *event.ServerHeartbeatFailedEvent) {
   708  			publishedEvents = append(publishedEvents, *e)
   709  		}
   710  
   711  		sdam := &event.ServerMonitor{
   712  			ServerHeartbeatStarted:   serverHeartbeatStarted,
   713  			ServerHeartbeatSucceeded: serverHeartbeatSucceeded,
   714  			ServerHeartbeatFailed:    serverHeartbeatFailed,
   715  		}
   716  
   717  		dialer := &channelNetConnDialer{}
   718  		dialerOpt := WithDialer(func(Dialer) Dialer {
   719  			return dialer
   720  		})
   721  		serverOpts := []ServerOption{
   722  			WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption {
   723  				return append(connOpts, dialerOpt)
   724  			}),
   725  			withMonitoringDisabled(func(bool) bool { return true }),
   726  			WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return sdam }),
   727  		}
   728  
   729  		s := NewServer(address.Address("localhost:27017"), primitive.NewObjectID(), serverOpts...)
   730  
   731  		// set up heartbeat connection, which doesn't send events
   732  		_, err := s.check()
   733  		assert.Nil(t, err, "check error: %v", err)
   734  
   735  		channelConn := s.conn.nc.(*drivertest.ChannelNetConn)
   736  		_ = channelConn.GetWrittenMessage()
   737  
   738  		t.Run("success", func(t *testing.T) {
   739  			publishedEvents = nil
   740  			// do a heartbeat with a non-nil connection
   741  			if err = channelConn.AddResponse(makeHelloReply()); err != nil {
   742  				t.Fatalf("error adding response: %v", err)
   743  			}
   744  			_, err = s.check()
   745  			_ = channelConn.GetWrittenMessage()
   746  			assert.Nil(t, err, "check error: %v", err)
   747  
   748  			assert.Equal(t, len(publishedEvents), 2, "expected %v events, got %v", 2, len(publishedEvents))
   749  
   750  			started, ok := publishedEvents[0].(event.ServerHeartbeatStartedEvent)
   751  			assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatStartedEvent{}, publishedEvents[0])
   752  			assert.Equal(t, started.ConnectionID, s.conn.ID(), "expected connectionID to match")
   753  			assert.False(t, started.Awaited, "expected awaited to be false")
   754  
   755  			succeeded, ok := publishedEvents[1].(event.ServerHeartbeatSucceededEvent)
   756  			assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatSucceededEvent{}, publishedEvents[1])
   757  			assert.Equal(t, succeeded.ConnectionID, s.conn.ID(), "expected connectionID to match")
   758  			assert.Equal(t, succeeded.Reply.Addr, s.address, "expected address %v, got %v", s.address, succeeded.Reply.Addr)
   759  			assert.False(t, succeeded.Awaited, "expected awaited to be false")
   760  		})
   761  		t.Run("failure", func(t *testing.T) {
   762  			publishedEvents = nil
   763  			// do a heartbeat with a non-nil connection
   764  			readErr := errors.New("error")
   765  			channelConn.ReadErr <- readErr
   766  			_, err = s.check()
   767  			_ = channelConn.GetWrittenMessage()
   768  			assert.Nil(t, err, "check error: %v", err)
   769  
   770  			assert.Equal(t, len(publishedEvents), 2, "expected %v events, got %v", 2, len(publishedEvents))
   771  
   772  			started, ok := publishedEvents[0].(event.ServerHeartbeatStartedEvent)
   773  			assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatStartedEvent{}, publishedEvents[0])
   774  			assert.Equal(t, started.ConnectionID, s.conn.ID(), "expected connectionID to match")
   775  			assert.False(t, started.Awaited, "expected awaited to be false")
   776  
   777  			failed, ok := publishedEvents[1].(event.ServerHeartbeatFailedEvent)
   778  			assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatFailedEvent{}, publishedEvents[1])
   779  			assert.Equal(t, failed.ConnectionID, s.conn.ID(), "expected connectionID to match")
   780  			assert.False(t, failed.Awaited, "expected awaited to be false")
   781  			assert.True(t, errors.Is(failed.Failure, readErr), "expected Failure to be %v, got: %v", readErr, failed.Failure)
   782  		})
   783  	})
   784  	t.Run("WithServerAppName", func(t *testing.T) {
   785  		name := "test"
   786  
   787  		s := NewServer(address.Address("localhost"),
   788  			primitive.NewObjectID(),
   789  			WithServerAppName(func(string) string { return name }))
   790  		require.Equal(t, name, s.cfg.appname, "expected appname to be: %v, got: %v", name, s.cfg.appname)
   791  	})
   792  	t.Run("createConnection overwrites WithSocketTimeout", func(t *testing.T) {
   793  		socketTimeout := 40 * time.Second
   794  
   795  		s := NewServer(
   796  			address.Address("localhost"),
   797  			primitive.NewObjectID(),
   798  			WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption {
   799  				return append(
   800  					connOpts,
   801  					WithReadTimeout(func(time.Duration) time.Duration { return socketTimeout }),
   802  					WithWriteTimeout(func(time.Duration) time.Duration { return socketTimeout }),
   803  				)
   804  			}),
   805  		)
   806  
   807  		conn := s.createConnection()
   808  		assert.Equal(t, s.cfg.heartbeatTimeout, 10*time.Second, "expected heartbeatTimeout to be: %v, got: %v", 10*time.Second, s.cfg.heartbeatTimeout)
   809  		assert.Equal(t, s.cfg.heartbeatTimeout, conn.readTimeout, "expected readTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.readTimeout)
   810  		assert.Equal(t, s.cfg.heartbeatTimeout, conn.writeTimeout, "expected writeTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.writeTimeout)
   811  	})
   812  	t.Run("heartbeat contexts are not leaked", func(t *testing.T) {
   813  		// The context created for heartbeats should be cancelled when it is no longer needed to avoid leaks.
   814  
   815  		server, err := ConnectServer(
   816  			address.Address("invalid"),
   817  			nil,
   818  			primitive.NewObjectID(),
   819  			withMonitoringDisabled(func(bool) bool {
   820  				return true
   821  			}),
   822  		)
   823  		assert.Nil(t, err, "ConnectServer error: %v", err)
   824  
   825  		// Expect check to return an error in the server description because the server address doesn't exist. This is
   826  		// OK because we just want to ensure the heartbeat context is created.
   827  		desc, err := server.check()
   828  		assert.Nil(t, err, "check error: %v", err)
   829  		assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil")
   830  		assert.NotNil(t, server.heartbeatCtx, "expected heartbeatCtx to be non-nil, got nil")
   831  		assert.Nil(t, server.heartbeatCtx.Err(), "expected heartbeatCtx error to be nil, got %v", server.heartbeatCtx.Err())
   832  
   833  		// Override heartbeatCtxCancel with a wrapper that records whether or not it was called.
   834  		oldCancelFn := server.heartbeatCtxCancel
   835  		var previousCtxCancelled bool
   836  		server.heartbeatCtxCancel = func() {
   837  			previousCtxCancelled = true
   838  			oldCancelFn()
   839  		}
   840  
   841  		// The second check call should attempt to create a new heartbeat connection and should cancel the previous
   842  		// heartbeatCtx during the process.
   843  		desc, err = server.check()
   844  		assert.Nil(t, err, "check error: %v", err)
   845  		assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil")
   846  		assert.True(t, previousCtxCancelled, "expected check to cancel previous context but did not")
   847  	})
   848  }
   849  
   850  func TestServer_ProcessError(t *testing.T) {
   851  	t.Parallel()
   852  
   853  	processID := primitive.NewObjectID()
   854  	newProcessID := primitive.NewObjectID()
   855  
   856  	testCases := []struct {
   857  		name string
   858  
   859  		startDescription description.Server // Initial server description at the start of the test.
   860  
   861  		inputErr  error             // ProcessError error input.
   862  		inputConn driver.Connection // ProcessError conn input.
   863  
   864  		want            driver.ProcessErrorResult // Expected ProcessError return value.
   865  		wantGeneration  uint64                    // Expected resulting connection pool generation.
   866  		wantDescription description.Server        // Expected resulting server description.
   867  	}{
   868  		// Test that a nil error does not change the Server state.
   869  		{
   870  			name: "nil error",
   871  			startDescription: description.Server{
   872  				Kind: description.RSPrimary,
   873  			},
   874  			inputErr:       nil,
   875  			want:           driver.NoChange,
   876  			wantGeneration: 0,
   877  			wantDescription: description.Server{
   878  				Kind: description.RSPrimary,
   879  			},
   880  		},
   881  		// Test that errors that occur on stale connections are ignored.
   882  		{
   883  			name: "stale connection",
   884  			startDescription: description.Server{
   885  				Kind: description.RSPrimary,
   886  			},
   887  			inputErr: errors.New("foo"),
   888  			inputConn: newProcessErrorTestConn(
   889  				&description.VersionRange{
   890  					Max: 17,
   891  				},
   892  				true),
   893  			want:           driver.NoChange,
   894  			wantGeneration: 0,
   895  			wantDescription: description.Server{
   896  				Kind: description.RSPrimary,
   897  			},
   898  		},
   899  		// Test that errors that do not indicate a database state change or connection error are
   900  		// ignored.
   901  		{
   902  			name: "non state change error",
   903  			startDescription: description.Server{
   904  				Kind: description.RSPrimary,
   905  			},
   906  			inputErr: driver.Error{
   907  				Code: 1,
   908  			},
   909  			inputConn:      newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
   910  			want:           driver.NoChange,
   911  			wantGeneration: 0,
   912  			wantDescription: description.Server{
   913  				Kind: description.RSPrimary,
   914  			},
   915  		},
   916  		// Test that a "not writable primary" error with an old topology version is ignored.
   917  		{
   918  			name:             "stale not writable primary error",
   919  			startDescription: newServerDescription(description.RSPrimary, processID, 1, nil),
   920  			inputErr: driver.Error{
   921  				Code: 10107, // NotWritablePrimary
   922  				TopologyVersion: &description.TopologyVersion{
   923  					ProcessID: processID,
   924  					Counter:   0,
   925  				},
   926  			},
   927  			inputConn:       newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
   928  			want:            driver.NoChange,
   929  			wantGeneration:  0,
   930  			wantDescription: newServerDescription(description.RSPrimary, processID, 1, nil),
   931  		},
   932  		// Test that a "not writable primary" error with an newer topology version marks the Server
   933  		// as "unknown" and updates its topology version.
   934  		{
   935  			name:             "new not writable primary error",
   936  			startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
   937  			inputErr: driver.Error{
   938  				Code: 10107, // NotWritablePrimary
   939  				TopologyVersion: &description.TopologyVersion{
   940  					ProcessID: processID,
   941  					Counter:   1,
   942  				},
   943  			},
   944  			inputConn:      newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
   945  			want:           driver.ServerMarkedUnknown,
   946  			wantGeneration: 0,
   947  			wantDescription: newServerDescription(description.Unknown, processID, 1, driver.Error{
   948  				Code: 10107, // NotWritablePrimary
   949  				TopologyVersion: &description.TopologyVersion{
   950  					ProcessID: processID,
   951  					Counter:   1,
   952  				},
   953  			}),
   954  		},
   955  		// Test that a "not writable primary" error with an different topology process ID marks the Server as
   956  		// "unknown" and updates its topology version.
   957  		{
   958  			name:             "new process ID not writable primary error",
   959  			startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
   960  			inputErr: driver.Error{
   961  				Code: 10107, // NotWritablePrimary
   962  				TopologyVersion: &description.TopologyVersion{
   963  					ProcessID: newProcessID,
   964  					Counter:   0,
   965  				},
   966  			},
   967  			inputConn:      newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
   968  			want:           driver.ServerMarkedUnknown,
   969  			wantGeneration: 0,
   970  			wantDescription: newServerDescription(description.Unknown, newProcessID, 0, driver.Error{
   971  				Code: 10107, // NotWritablePrimary
   972  				TopologyVersion: &description.TopologyVersion{
   973  					ProcessID: newProcessID,
   974  					Counter:   0,
   975  				},
   976  			}),
   977  		},
   978  		// Test that a connection with a newer topology version overrides the server topology
   979  		// version and causes an error with the same topology version to be ignored.
   980  		// TODO(GODRIVER-2841): Remove this test case.
   981  		{
   982  			name:             "newer connection topology version",
   983  			startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
   984  			inputErr: driver.Error{
   985  				Code: 10107, // NotWritablePrimary
   986  				TopologyVersion: &description.TopologyVersion{
   987  					ProcessID: processID,
   988  					Counter:   1,
   989  				},
   990  			},
   991  			inputConn: &processErrorTestConn{
   992  				description: description.Server{
   993  					WireVersion: &description.VersionRange{Max: 17},
   994  					TopologyVersion: &description.TopologyVersion{
   995  						ProcessID: processID,
   996  						Counter:   1,
   997  					},
   998  				},
   999  				stale: false,
  1000  			},
  1001  			want:            driver.NoChange,
  1002  			wantGeneration:  0,
  1003  			wantDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
  1004  		},
  1005  		// Test that a "node is shutting down" error with a newer topology version clears the
  1006  		// connection pool, marks the Server as "unknown", and updates its topology version.
  1007  		{
  1008  			name:             "new shutdown error",
  1009  			startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
  1010  			inputErr: driver.Error{
  1011  				Code: 11600, // InterruptedAtShutdown
  1012  				TopologyVersion: &description.TopologyVersion{
  1013  					ProcessID: processID,
  1014  					Counter:   1,
  1015  				},
  1016  			},
  1017  			inputConn:      newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
  1018  			want:           driver.ConnectionPoolCleared,
  1019  			wantGeneration: 1,
  1020  			wantDescription: newServerDescription(description.Unknown, processID, 1, driver.Error{
  1021  				Code: 11600, // InterruptedAtShutdown
  1022  				TopologyVersion: &description.TopologyVersion{
  1023  					ProcessID: processID,
  1024  					Counter:   1,
  1025  				},
  1026  			}),
  1027  		},
  1028  		// Test that a "not writable primary" error with a stale topology version is ignored.
  1029  		{
  1030  			name:             "stale not writable primary write concern error",
  1031  			startDescription: newServerDescription(description.RSPrimary, processID, 1, nil),
  1032  			inputErr: driver.WriteCommandError{
  1033  				WriteConcernError: &driver.WriteConcernError{
  1034  					Code: 10107, // NotWritablePrimary
  1035  					TopologyVersion: &description.TopologyVersion{
  1036  						ProcessID: processID,
  1037  						Counter:   0,
  1038  					},
  1039  				},
  1040  			},
  1041  			inputConn:       newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
  1042  			want:            driver.NoChange,
  1043  			wantGeneration:  0,
  1044  			wantDescription: newServerDescription(description.RSPrimary, processID, 1, nil),
  1045  		},
  1046  		// Test that a "not writable primary" error with a newer topology version marks the Server
  1047  		// as "unknown" and updates its topology version.
  1048  		{
  1049  			name:             "new not writable primary write concern error",
  1050  			startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
  1051  			inputErr: driver.WriteCommandError{
  1052  				WriteConcernError: &driver.WriteConcernError{
  1053  					Code: 10107, // NotWritablePrimary
  1054  					TopologyVersion: &description.TopologyVersion{
  1055  						ProcessID: processID,
  1056  						Counter:   1,
  1057  					},
  1058  				},
  1059  			},
  1060  			inputConn:      newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
  1061  			want:           driver.ServerMarkedUnknown,
  1062  			wantGeneration: 0,
  1063  			wantDescription: newServerDescription(description.Unknown, processID, 1, driver.WriteCommandError{
  1064  				WriteConcernError: &driver.WriteConcernError{
  1065  					Code: 10107, // NotWritablePrimary
  1066  					TopologyVersion: &description.TopologyVersion{
  1067  						ProcessID: processID,
  1068  						Counter:   1,
  1069  					},
  1070  				},
  1071  			}),
  1072  		},
  1073  		// Test that "node is shutting down" errors that have a newer topology version than the
  1074  		// local Server topology version mark the Server as "unknown" and clear the connection pool.
  1075  		{
  1076  			name:             "new shutdown write concern error",
  1077  			startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
  1078  			inputErr: driver.WriteCommandError{
  1079  				WriteConcernError: &driver.WriteConcernError{
  1080  					Code: 11600, // InterruptedAtShutdown
  1081  					TopologyVersion: &description.TopologyVersion{
  1082  						ProcessID: processID,
  1083  						Counter:   1,
  1084  					},
  1085  				},
  1086  			},
  1087  			inputConn:      newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
  1088  			want:           driver.ConnectionPoolCleared,
  1089  			wantGeneration: 1,
  1090  			wantDescription: newServerDescription(description.Unknown, processID, 1, driver.WriteCommandError{
  1091  				WriteConcernError: &driver.WriteConcernError{
  1092  					Code: 11600, // InterruptedAtShutdown
  1093  					TopologyVersion: &description.TopologyVersion{
  1094  						ProcessID: processID,
  1095  						Counter:   1,
  1096  					},
  1097  				},
  1098  			}),
  1099  		},
  1100  		// Test that "node is recovering" or "not writable primary" errors that have a newer
  1101  		// topology version than the local Server topology version and appear to be from MongoDB
  1102  		// servers before 4.2 mark the Server as "unknown" and clear the connection pool.
  1103  		{
  1104  			name:             "older than 4.2 write concern error",
  1105  			startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
  1106  			inputErr: driver.WriteCommandError{
  1107  				WriteConcernError: &driver.WriteConcernError{
  1108  					Code: 10107, // NotWritablePrimary
  1109  					TopologyVersion: &description.TopologyVersion{
  1110  						ProcessID: processID,
  1111  						Counter:   1,
  1112  					},
  1113  				},
  1114  			},
  1115  			inputConn:      newProcessErrorTestConn(&description.VersionRange{Max: 7}, false),
  1116  			want:           driver.ConnectionPoolCleared,
  1117  			wantGeneration: 1,
  1118  			wantDescription: newServerDescription(description.Unknown, processID, 1, driver.WriteCommandError{
  1119  				WriteConcernError: &driver.WriteConcernError{
  1120  					Code: 10107, // NotWritablePrimary
  1121  					TopologyVersion: &description.TopologyVersion{
  1122  						ProcessID: processID,
  1123  						Counter:   1,
  1124  					},
  1125  				},
  1126  			}),
  1127  		},
  1128  		// Test that a network timeout error, such as a DNS lookup timeout error, is ignored.
  1129  		{
  1130  			name:             "network timeout error",
  1131  			startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
  1132  			inputErr: driver.Error{
  1133  				Labels: []string{driver.NetworkError},
  1134  				Wrapped: ConnectionError{
  1135  					// Use a net.Error implementation that can return true from its Timeout() function.
  1136  					Wrapped: &net.DNSError{
  1137  						IsTimeout: true,
  1138  					},
  1139  				},
  1140  			},
  1141  			inputConn:       newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
  1142  			want:            driver.NoChange,
  1143  			wantGeneration:  0,
  1144  			wantDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
  1145  		},
  1146  		// Test that a context canceled error is ignored.
  1147  		{
  1148  			name:             "context canceled error",
  1149  			startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
  1150  			inputErr: driver.Error{
  1151  				Labels: []string{driver.NetworkError},
  1152  				Wrapped: ConnectionError{
  1153  					Wrapped: context.Canceled,
  1154  				},
  1155  			},
  1156  			inputConn:       newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
  1157  			want:            driver.NoChange,
  1158  			wantGeneration:  0,
  1159  			wantDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
  1160  		},
  1161  		// Test that a non-timeout network error, such as an address lookup error, marks the server
  1162  		// as "unknown" and sets its topology version to nil.
  1163  		{
  1164  			name:             "non-timeout network error",
  1165  			startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
  1166  			inputErr: driver.Error{
  1167  				Labels: []string{driver.NetworkError},
  1168  				Wrapped: ConnectionError{
  1169  					// Use a net.Error implementation that always returns false from its Timeout() function.
  1170  					Wrapped: &net.AddrError{},
  1171  				},
  1172  			},
  1173  			inputConn:      newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
  1174  			want:           driver.ConnectionPoolCleared,
  1175  			wantGeneration: 1,
  1176  			wantDescription: description.Server{
  1177  				Kind: description.Unknown,
  1178  				LastError: driver.Error{
  1179  					Labels: []string{driver.NetworkError},
  1180  					Wrapped: ConnectionError{
  1181  						Wrapped: &net.AddrError{},
  1182  					},
  1183  				},
  1184  			},
  1185  		},
  1186  	}
  1187  
  1188  	for _, tc := range testCases {
  1189  		tc := tc // Capture range variable.
  1190  
  1191  		t.Run(tc.name, func(t *testing.T) {
  1192  			t.Parallel()
  1193  
  1194  			server := NewServer(address.Address(""), primitive.NewObjectID())
  1195  			server.state = serverConnected
  1196  			err := server.pool.ready()
  1197  			require.Nil(t, err, "pool.ready() error: %v", err)
  1198  
  1199  			server.desc.Store(tc.startDescription)
  1200  
  1201  			got := server.ProcessError(tc.inputErr, tc.inputConn)
  1202  			assert.Equal(t, tc.want, got, "expected and actual ProcessError result are different")
  1203  
  1204  			desc := server.Description()
  1205  			assert.Equal(t,
  1206  				tc.wantDescription,
  1207  				desc,
  1208  				"expected and actual server descriptions are different")
  1209  
  1210  			generation, _ := server.pool.generation.getGeneration(nil)
  1211  			assert.Equal(t,
  1212  				tc.wantGeneration,
  1213  				generation,
  1214  				"expected and actual pool generation are different")
  1215  		})
  1216  	}
  1217  }
  1218  
  1219  // includesClientMetadata will return true if the wire message includes the
  1220  // "client" field.
  1221  func includesClientMetadata(t *testing.T, wm []byte) bool {
  1222  	t.Helper()
  1223  
  1224  	var ok bool
  1225  	_, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
  1226  	if !ok {
  1227  		t.Fatal("could not read header")
  1228  	}
  1229  	_, wm, ok = wiremessage.ReadQueryFlags(wm)
  1230  	if !ok {
  1231  		t.Fatal("could not read flags")
  1232  	}
  1233  	_, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
  1234  	if !ok {
  1235  		t.Fatal("could not read fullCollectionName")
  1236  	}
  1237  	_, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
  1238  	if !ok {
  1239  		t.Fatal("could not read numberToSkip")
  1240  	}
  1241  	_, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
  1242  	if !ok {
  1243  		t.Fatal("could not read numberToReturn")
  1244  	}
  1245  	var query bsoncore.Document
  1246  	query, wm, ok = wiremessage.ReadQueryQuery(wm)
  1247  	if !ok {
  1248  		t.Fatal("could not read query")
  1249  	}
  1250  
  1251  	if _, err := query.LookupErr("client"); err == nil {
  1252  		return true
  1253  	}
  1254  	if _, err := query.LookupErr("$query", "client"); err == nil {
  1255  		return true
  1256  	}
  1257  
  1258  	return false
  1259  }
  1260  
  1261  // processErrorTestConn is a driver.Connection implementation used by tests
  1262  // for Server.ProcessError. This type should not be used for other tests
  1263  // because it does not implement all of the functions of the interface.
  1264  type processErrorTestConn struct {
  1265  	// Embed a driver.Connection to quickly implement the interface without
  1266  	// implementing all methods.
  1267  	driver.Connection
  1268  	description description.Server
  1269  	stale       bool
  1270  }
  1271  
  1272  func newProcessErrorTestConn(wireVersion *description.VersionRange, stale bool) *processErrorTestConn {
  1273  	return &processErrorTestConn{
  1274  		description: description.Server{
  1275  			WireVersion: wireVersion,
  1276  		},
  1277  		stale: stale,
  1278  	}
  1279  }
  1280  
  1281  func (p *processErrorTestConn) Stale() bool {
  1282  	return p.stale
  1283  }
  1284  
  1285  func (p *processErrorTestConn) Description() description.Server {
  1286  	return p.description
  1287  }
  1288  
  1289  // newServerDescription is a convenience function for creating a server description with a specified
  1290  // kind, topology version process ID and counter, and last error.
  1291  func newServerDescription(
  1292  	kind description.ServerKind,
  1293  	processID primitive.ObjectID,
  1294  	counter int64,
  1295  	lastError error,
  1296  ) description.Server {
  1297  	return description.Server{
  1298  		Kind: kind,
  1299  		TopologyVersion: &description.TopologyVersion{
  1300  			ProcessID: processID,
  1301  			Counter:   counter,
  1302  		},
  1303  		LastError: lastError,
  1304  	}
  1305  }
  1306  

View as plain text