...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/topology

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package topology
     8  
     9  import (
    10  	"bufio"
    11  	"bytes"
    12  	"context"
    13  	"encoding/json"
    14  	"errors"
    15  	"fmt"
    16  	"io/ioutil"
    17  	"path"
    18  	"sync/atomic"
    19  	"testing"
    20  	"time"
    21  
    22  	"go.mongodb.org/mongo-driver/bson/primitive"
    23  	"go.mongodb.org/mongo-driver/internal/assert"
    24  	"go.mongodb.org/mongo-driver/internal/logger"
    25  	"go.mongodb.org/mongo-driver/internal/require"
    26  	"go.mongodb.org/mongo-driver/internal/spectest"
    27  	"go.mongodb.org/mongo-driver/mongo/address"
    28  	"go.mongodb.org/mongo-driver/mongo/description"
    29  	"go.mongodb.org/mongo-driver/mongo/options"
    30  	"go.mongodb.org/mongo-driver/mongo/readpref"
    31  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    32  )
    33  
    34  const testTimeout = 2 * time.Second
    35  
    36  func noerr(t *testing.T, err error) {
    37  	t.Helper()
    38  	if err != nil {
    39  		t.Errorf("Unexpected error: %v", err)
    40  		t.FailNow()
    41  	}
    42  }
    43  
    44  func compareErrors(err1, err2 error) bool {
    45  	if err1 == nil && err2 == nil {
    46  		return true
    47  	}
    48  
    49  	if err1 == nil || err2 == nil {
    50  		return false
    51  	}
    52  
    53  	if err1.Error() != err2.Error() {
    54  		return false
    55  	}
    56  
    57  	return true
    58  }
    59  
    60  func TestServerSelection(t *testing.T) {
    61  	var selectFirst description.ServerSelectorFunc = func(_ description.Topology, candidates []description.Server) ([]description.Server, error) {
    62  		if len(candidates) == 0 {
    63  			return []description.Server{}, nil
    64  		}
    65  		return candidates[0:1], nil
    66  	}
    67  	var selectNone description.ServerSelectorFunc = func(description.Topology, []description.Server) ([]description.Server, error) {
    68  		return []description.Server{}, nil
    69  	}
    70  	var errSelectionError = errors.New("encountered an error in the selector")
    71  	var selectError description.ServerSelectorFunc = func(description.Topology, []description.Server) ([]description.Server, error) {
    72  		return nil, errSelectionError
    73  	}
    74  
    75  	t.Run("Success", func(t *testing.T) {
    76  		topo, err := New(nil)
    77  		noerr(t, err)
    78  		desc := description.Topology{
    79  			Servers: []description.Server{
    80  				{Addr: address.Address("one"), Kind: description.Standalone},
    81  				{Addr: address.Address("two"), Kind: description.Standalone},
    82  				{Addr: address.Address("three"), Kind: description.Standalone},
    83  			},
    84  		}
    85  		subCh := make(chan description.Topology, 1)
    86  		subCh <- desc
    87  
    88  		state := newServerSelectionState(selectFirst, nil)
    89  		srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state)
    90  		noerr(t, err)
    91  		if len(srvs) != 1 {
    92  			t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1)
    93  		}
    94  		if srvs[0].Addr != desc.Servers[0].Addr {
    95  			t.Errorf("Incorrect sever selected. got %s; want %s", srvs[0].Addr, desc.Servers[0].Addr)
    96  		}
    97  	})
    98  	t.Run("Compatibility Error Min Version Too High", func(t *testing.T) {
    99  		topo, err := New(nil)
   100  		noerr(t, err)
   101  		desc := description.Topology{
   102  			Kind: description.Single,
   103  			Servers: []description.Server{
   104  				{Addr: address.Address("one:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 11, Min: 11}},
   105  				{Addr: address.Address("two:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 6}},
   106  				{Addr: address.Address("three:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 6}},
   107  			},
   108  		}
   109  		want := fmt.Errorf(
   110  			"server at %s requires wire version %d, but this version of the Go driver only supports up to %d",
   111  			desc.Servers[0].Addr.String(),
   112  			desc.Servers[0].WireVersion.Min,
   113  			SupportedWireVersions.Max,
   114  		)
   115  		desc.CompatibilityErr = want
   116  		atomic.StoreInt64(&topo.state, topologyConnected)
   117  		topo.desc.Store(desc)
   118  		_, err = topo.SelectServer(context.Background(), selectFirst)
   119  		assert.Equal(t, err, want, "expected %v, got %v", want, err)
   120  	})
   121  	t.Run("Compatibility Error Max Version Too Low", func(t *testing.T) {
   122  		topo, err := New(nil)
   123  		noerr(t, err)
   124  		desc := description.Topology{
   125  			Kind: description.Single,
   126  			Servers: []description.Server{
   127  				{Addr: address.Address("one:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 21, Min: 6}},
   128  				{Addr: address.Address("two:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 2}},
   129  				{Addr: address.Address("three:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 2}},
   130  			},
   131  		}
   132  		want := fmt.Errorf(
   133  			"server at %s reports wire version %d, but this version of the Go driver requires "+
   134  				"at least 6 (MongoDB 3.6)",
   135  			desc.Servers[0].Addr.String(),
   136  			desc.Servers[0].WireVersion.Max,
   137  		)
   138  		desc.CompatibilityErr = want
   139  		atomic.StoreInt64(&topo.state, topologyConnected)
   140  		topo.desc.Store(desc)
   141  		_, err = topo.SelectServer(context.Background(), selectFirst)
   142  		assert.Equal(t, err, want, "expected %v, got %v", want, err)
   143  	})
   144  	t.Run("Updated", func(t *testing.T) {
   145  		topo, err := New(nil)
   146  		noerr(t, err)
   147  		desc := description.Topology{Servers: []description.Server{}}
   148  		subCh := make(chan description.Topology, 1)
   149  		subCh <- desc
   150  
   151  		resp := make(chan []description.Server)
   152  		go func() {
   153  			state := newServerSelectionState(selectFirst, nil)
   154  			srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state)
   155  			noerr(t, err)
   156  			resp <- srvs
   157  		}()
   158  
   159  		desc = description.Topology{
   160  			Servers: []description.Server{
   161  				{Addr: address.Address("one"), Kind: description.Standalone},
   162  				{Addr: address.Address("two"), Kind: description.Standalone},
   163  				{Addr: address.Address("three"), Kind: description.Standalone},
   164  			},
   165  		}
   166  		select {
   167  		case subCh <- desc:
   168  		case <-time.After(100 * time.Millisecond):
   169  			t.Error("Timed out while trying to send topology description")
   170  		}
   171  
   172  		var srvs []description.Server
   173  		select {
   174  		case srvs = <-resp:
   175  		case <-time.After(100 * time.Millisecond):
   176  			t.Errorf("Timed out while trying to retrieve selected servers")
   177  		}
   178  
   179  		if len(srvs) != 1 {
   180  			t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1)
   181  		}
   182  		if srvs[0].Addr != desc.Servers[0].Addr {
   183  			t.Errorf("Incorrect sever selected. got %s; want %s", srvs[0].Addr, desc.Servers[0].Addr)
   184  		}
   185  	})
   186  	t.Run("Cancel", func(t *testing.T) {
   187  		desc := description.Topology{
   188  			Servers: []description.Server{
   189  				{Addr: address.Address("one"), Kind: description.Standalone},
   190  				{Addr: address.Address("two"), Kind: description.Standalone},
   191  				{Addr: address.Address("three"), Kind: description.Standalone},
   192  			},
   193  		}
   194  		topo, err := New(nil)
   195  		noerr(t, err)
   196  		subCh := make(chan description.Topology, 1)
   197  		subCh <- desc
   198  		resp := make(chan error)
   199  		ctx, cancel := context.WithCancel(context.Background())
   200  		go func() {
   201  			state := newServerSelectionState(selectNone, nil)
   202  			_, err := topo.selectServerFromSubscription(ctx, subCh, state)
   203  			resp <- err
   204  		}()
   205  
   206  		select {
   207  		case err := <-resp:
   208  			t.Errorf("Received error from server selection too soon: %v", err)
   209  		case <-time.After(100 * time.Millisecond):
   210  		}
   211  
   212  		cancel()
   213  
   214  		select {
   215  		case err = <-resp:
   216  		case <-time.After(100 * time.Millisecond):
   217  			t.Errorf("Timed out while trying to retrieve selected servers")
   218  		}
   219  
   220  		want := ServerSelectionError{Wrapped: context.Canceled, Desc: desc}
   221  		assert.Equal(t, err, want, "Incorrect error received. got %v; want %v", err, want)
   222  	})
   223  	t.Run("Timeout", func(t *testing.T) {
   224  		desc := description.Topology{
   225  			Servers: []description.Server{
   226  				{Addr: address.Address("one"), Kind: description.Standalone},
   227  				{Addr: address.Address("two"), Kind: description.Standalone},
   228  				{Addr: address.Address("three"), Kind: description.Standalone},
   229  			},
   230  		}
   231  		topo, err := New(nil)
   232  		noerr(t, err)
   233  		subCh := make(chan description.Topology, 1)
   234  		subCh <- desc
   235  		resp := make(chan error)
   236  		timeout := make(chan time.Time)
   237  		go func() {
   238  			state := newServerSelectionState(selectNone, timeout)
   239  			_, err := topo.selectServerFromSubscription(context.Background(), subCh, state)
   240  			resp <- err
   241  		}()
   242  
   243  		select {
   244  		case err := <-resp:
   245  			t.Errorf("Received error from server selection too soon: %v", err)
   246  		case timeout <- time.Now():
   247  		}
   248  
   249  		select {
   250  		case err = <-resp:
   251  		case <-time.After(100 * time.Millisecond):
   252  			t.Errorf("Timed out while trying to retrieve selected servers")
   253  		}
   254  
   255  		if err == nil {
   256  			t.Fatalf("did not receive error from server selection")
   257  		}
   258  	})
   259  	t.Run("Error", func(t *testing.T) {
   260  		desc := description.Topology{
   261  			Servers: []description.Server{
   262  				{Addr: address.Address("one"), Kind: description.Standalone},
   263  				{Addr: address.Address("two"), Kind: description.Standalone},
   264  				{Addr: address.Address("three"), Kind: description.Standalone},
   265  			},
   266  		}
   267  		topo, err := New(nil)
   268  		noerr(t, err)
   269  		subCh := make(chan description.Topology, 1)
   270  		subCh <- desc
   271  		resp := make(chan error)
   272  		timeout := make(chan time.Time)
   273  		go func() {
   274  			state := newServerSelectionState(selectError, timeout)
   275  			_, err := topo.selectServerFromSubscription(context.Background(), subCh, state)
   276  			resp <- err
   277  		}()
   278  
   279  		select {
   280  		case err = <-resp:
   281  		case <-time.After(100 * time.Millisecond):
   282  			t.Errorf("Timed out while trying to retrieve selected servers")
   283  		}
   284  
   285  		if err == nil {
   286  			t.Fatalf("did not receive error from server selection")
   287  		}
   288  	})
   289  	t.Run("findServer returns topology kind", func(t *testing.T) {
   290  		topo, err := New(nil)
   291  		noerr(t, err)
   292  		atomic.StoreInt64(&topo.state, topologyConnected)
   293  		srvr, err := ConnectServer(address.Address("one"), topo.updateCallback, topo.id)
   294  		noerr(t, err)
   295  		topo.servers[address.Address("one")] = srvr
   296  		desc := topo.desc.Load().(description.Topology)
   297  		desc.Kind = description.Single
   298  		topo.desc.Store(desc)
   299  
   300  		selected := description.Server{Addr: address.Address("one")}
   301  
   302  		ss, err := topo.FindServer(selected)
   303  		noerr(t, err)
   304  		if ss.Kind != description.Single {
   305  			t.Errorf("findServer does not properly set the topology description kind. got %v; want %v", ss.Kind, description.Single)
   306  		}
   307  	})
   308  	t.Run("Update on not primary error", func(t *testing.T) {
   309  		topo, err := New(nil)
   310  		noerr(t, err)
   311  		atomic.StoreInt64(&topo.state, topologyConnected)
   312  
   313  		addr1 := address.Address("one")
   314  		addr2 := address.Address("two")
   315  		addr3 := address.Address("three")
   316  		desc := description.Topology{
   317  			Servers: []description.Server{
   318  				{Addr: addr1, Kind: description.RSPrimary},
   319  				{Addr: addr2, Kind: description.RSSecondary},
   320  				{Addr: addr3, Kind: description.RSSecondary},
   321  			},
   322  		}
   323  
   324  		// manually add the servers to the topology
   325  		for _, srv := range desc.Servers {
   326  			s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id)
   327  			noerr(t, err)
   328  			topo.servers[srv.Addr] = s
   329  		}
   330  
   331  		// Send updated description
   332  		desc = description.Topology{
   333  			Servers: []description.Server{
   334  				{Addr: addr1, Kind: description.RSSecondary},
   335  				{Addr: addr2, Kind: description.RSPrimary},
   336  				{Addr: addr3, Kind: description.RSSecondary},
   337  			},
   338  		}
   339  
   340  		subCh := make(chan description.Topology, 1)
   341  		subCh <- desc
   342  
   343  		// send a not primary error to the server forcing an update
   344  		serv, err := topo.FindServer(desc.Servers[0])
   345  		noerr(t, err)
   346  		atomic.StoreInt64(&serv.state, serverConnected)
   347  		_ = serv.ProcessError(driver.Error{Message: driver.LegacyNotPrimaryErrMsg}, initConnection{})
   348  
   349  		resp := make(chan []description.Server)
   350  
   351  		go func() {
   352  			// server selection should discover the new topology
   353  			state := newServerSelectionState(description.WriteSelector(), nil)
   354  			srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state)
   355  			noerr(t, err)
   356  			resp <- srvs
   357  		}()
   358  
   359  		var srvs []description.Server
   360  		select {
   361  		case srvs = <-resp:
   362  		case <-time.After(100 * time.Millisecond):
   363  			t.Errorf("Timed out while trying to retrieve selected servers")
   364  		}
   365  
   366  		if len(srvs) != 1 {
   367  			t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1)
   368  		}
   369  		if srvs[0].Addr != desc.Servers[1].Addr {
   370  			t.Errorf("Incorrect sever selected. got %s; want %s", srvs[0].Addr, desc.Servers[1].Addr)
   371  		}
   372  	})
   373  	t.Run("fast path does not subscribe or check timeouts", func(t *testing.T) {
   374  		// Assert that the server selection fast path does not create a Subscription or check for timeout errors.
   375  		topo, err := New(nil)
   376  		noerr(t, err)
   377  		atomic.StoreInt64(&topo.state, topologyConnected)
   378  
   379  		primaryAddr := address.Address("one")
   380  		desc := description.Topology{
   381  			Servers: []description.Server{
   382  				{Addr: primaryAddr, Kind: description.RSPrimary},
   383  			},
   384  		}
   385  		topo.desc.Store(desc)
   386  		for _, srv := range desc.Servers {
   387  			s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id)
   388  			noerr(t, err)
   389  			topo.servers[srv.Addr] = s
   390  		}
   391  
   392  		// Manually close subscriptions so calls to Subscribe will error and pass in a cancelled context to ensure the
   393  		// fast path ignores timeout errors.
   394  		topo.subscriptionsClosed = true
   395  		ctx, cancel := context.WithCancel(context.Background())
   396  		cancel()
   397  		selectedServer, err := topo.SelectServer(ctx, description.WriteSelector())
   398  		noerr(t, err)
   399  		selectedAddr := selectedServer.(*SelectedServer).address
   400  		assert.Equal(t, primaryAddr, selectedAddr, "expected address %v, got %v", primaryAddr, selectedAddr)
   401  	})
   402  	t.Run("default to selecting from subscription if fast path fails", func(t *testing.T) {
   403  		topo, err := New(nil)
   404  		noerr(t, err)
   405  
   406  		atomic.StoreInt64(&topo.state, topologyConnected)
   407  		desc := description.Topology{
   408  			Servers: []description.Server{},
   409  		}
   410  		topo.desc.Store(desc)
   411  
   412  		topo.subscriptionsClosed = true
   413  		_, err = topo.SelectServer(context.Background(), description.WriteSelector())
   414  		assert.Equal(t, ErrSubscribeAfterClosed, err, "expected error %v, got %v", ErrSubscribeAfterClosed, err)
   415  	})
   416  }
   417  
   418  func TestSessionTimeout(t *testing.T) {
   419  	int64ToPtr := func(i64 int64) *int64 { return &i64 }
   420  
   421  	t.Run("UpdateSessionTimeout", func(t *testing.T) {
   422  		topo, err := New(nil)
   423  		noerr(t, err)
   424  		topo.servers["foo"] = nil
   425  		topo.fsm.Servers = []description.Server{
   426  			{
   427  				Addr:                     address.Address("foo").Canonicalize(),
   428  				Kind:                     description.RSPrimary,
   429  				SessionTimeoutMinutes:    60,
   430  				SessionTimeoutMinutesPtr: int64ToPtr(60),
   431  			},
   432  		}
   433  
   434  		ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   435  		defer cancel()
   436  
   437  		desc := description.Server{
   438  			Addr:                     "foo",
   439  			Kind:                     description.RSPrimary,
   440  			SessionTimeoutMinutes:    30,
   441  			SessionTimeoutMinutesPtr: int64ToPtr(30),
   442  		}
   443  		topo.apply(ctx, desc)
   444  
   445  		currDesc := topo.desc.Load().(description.Topology)
   446  		want := int64(30)
   447  		require.Equal(t, &want, currDesc.SessionTimeoutMinutesPtr,
   448  			"session timeout minutes mismatch")
   449  	})
   450  	t.Run("MultipleUpdates", func(t *testing.T) {
   451  		topo, err := New(nil)
   452  		noerr(t, err)
   453  		topo.fsm.Kind = description.ReplicaSetWithPrimary
   454  		topo.servers["foo"] = nil
   455  		topo.servers["bar"] = nil
   456  		topo.fsm.Servers = []description.Server{
   457  			{
   458  				Addr:                     address.Address("foo").Canonicalize(),
   459  				Kind:                     description.RSPrimary,
   460  				SessionTimeoutMinutes:    60,
   461  				SessionTimeoutMinutesPtr: int64ToPtr(60),
   462  			},
   463  			{
   464  				Addr:                     address.Address("bar").Canonicalize(),
   465  				Kind:                     description.RSSecondary,
   466  				SessionTimeoutMinutes:    60,
   467  				SessionTimeoutMinutesPtr: int64ToPtr(60),
   468  			},
   469  		}
   470  
   471  		ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   472  		defer cancel()
   473  
   474  		desc1 := description.Server{
   475  			Addr:                     "foo",
   476  			Kind:                     description.RSPrimary,
   477  			SessionTimeoutMinutes:    30,
   478  			SessionTimeoutMinutesPtr: int64ToPtr(30),
   479  			Members:                  []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
   480  		}
   481  		// should update because new timeout is lower
   482  		desc2 := description.Server{
   483  			Addr:                     "bar",
   484  			Kind:                     description.RSPrimary,
   485  			SessionTimeoutMinutes:    20,
   486  			SessionTimeoutMinutesPtr: int64ToPtr(20),
   487  			Members:                  []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
   488  		}
   489  		topo.apply(ctx, desc1)
   490  		topo.apply(ctx, desc2)
   491  
   492  		currDesc := topo.Description()
   493  		want := int64(20)
   494  		require.Equal(t, &want, currDesc.SessionTimeoutMinutesPtr,
   495  			"session timeout minutes mismatch")
   496  	})
   497  	t.Run("NoUpdate", func(t *testing.T) {
   498  		topo, err := New(nil)
   499  		noerr(t, err)
   500  		topo.servers["foo"] = nil
   501  		topo.servers["bar"] = nil
   502  		topo.fsm.Servers = []description.Server{
   503  			{
   504  				Addr:                     address.Address("foo").Canonicalize(),
   505  				Kind:                     description.RSPrimary,
   506  				SessionTimeoutMinutes:    60,
   507  				SessionTimeoutMinutesPtr: int64ToPtr(60),
   508  			},
   509  			{
   510  				Addr:                     address.Address("bar").Canonicalize(),
   511  				Kind:                     description.RSSecondary,
   512  				SessionTimeoutMinutes:    60,
   513  				SessionTimeoutMinutesPtr: int64ToPtr(60),
   514  			},
   515  		}
   516  
   517  		ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   518  		defer cancel()
   519  
   520  		desc1 := description.Server{
   521  			Addr:                     "foo",
   522  			Kind:                     description.RSPrimary,
   523  			SessionTimeoutMinutes:    20,
   524  			SessionTimeoutMinutesPtr: int64ToPtr(20),
   525  			Members:                  []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
   526  		}
   527  		// should not update because new timeout is higher
   528  		desc2 := description.Server{
   529  			Addr:                     "bar",
   530  			Kind:                     description.RSPrimary,
   531  			SessionTimeoutMinutes:    30,
   532  			SessionTimeoutMinutesPtr: int64ToPtr(30),
   533  			Members:                  []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
   534  		}
   535  		topo.apply(ctx, desc1)
   536  		topo.apply(ctx, desc2)
   537  
   538  		currDesc := topo.desc.Load().(description.Topology)
   539  		want := int64(20)
   540  		require.Equal(t, &want, currDesc.SessionTimeoutMinutesPtr,
   541  			"session timeout minutes mismatch")
   542  	})
   543  	t.Run("TimeoutDataBearing", func(t *testing.T) {
   544  		topo, err := New(nil)
   545  		noerr(t, err)
   546  		topo.servers["foo"] = nil
   547  		topo.servers["bar"] = nil
   548  		topo.fsm.Servers = []description.Server{
   549  			{
   550  				Addr:                     address.Address("foo").Canonicalize(),
   551  				Kind:                     description.RSPrimary,
   552  				SessionTimeoutMinutes:    60,
   553  				SessionTimeoutMinutesPtr: int64ToPtr(60),
   554  			},
   555  			{
   556  				Addr:                     address.Address("bar").Canonicalize(),
   557  				Kind:                     description.RSSecondary,
   558  				SessionTimeoutMinutes:    60,
   559  				SessionTimeoutMinutesPtr: int64ToPtr(60),
   560  			},
   561  		}
   562  
   563  		ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   564  		defer cancel()
   565  
   566  		desc1 := description.Server{
   567  			Addr:                     "foo",
   568  			Kind:                     description.RSPrimary,
   569  			SessionTimeoutMinutes:    20,
   570  			SessionTimeoutMinutesPtr: int64ToPtr(20),
   571  			Members:                  []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
   572  		}
   573  		// should not update because not a data bearing server
   574  		desc2 := description.Server{
   575  			Addr:                     "bar",
   576  			Kind:                     description.Unknown,
   577  			SessionTimeoutMinutes:    10,
   578  			SessionTimeoutMinutesPtr: int64ToPtr(10),
   579  			Members:                  []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
   580  		}
   581  		topo.apply(ctx, desc1)
   582  		topo.apply(ctx, desc2)
   583  
   584  		currDesc := topo.desc.Load().(description.Topology)
   585  		want := int64(20)
   586  		assert.Equal(t, &want, currDesc.SessionTimeoutMinutesPtr,
   587  			"session timeout minutes mismatch")
   588  	})
   589  	t.Run("MixedSessionSupport", func(t *testing.T) {
   590  		topo, err := New(nil)
   591  		noerr(t, err)
   592  		topo.fsm.Kind = description.ReplicaSetWithPrimary
   593  		topo.servers["one"] = nil
   594  		topo.servers["two"] = nil
   595  		topo.servers["three"] = nil
   596  		topo.fsm.Servers = []description.Server{
   597  			{
   598  				Addr:                     address.Address("one").Canonicalize(),
   599  				Kind:                     description.RSPrimary,
   600  				SessionTimeoutMinutes:    20,
   601  				SessionTimeoutMinutesPtr: int64ToPtr(20),
   602  			},
   603  			{
   604  				// does not support sessions
   605  				Addr: address.Address("two").Canonicalize(),
   606  				Kind: description.RSSecondary,
   607  			},
   608  			{
   609  				Addr:                     address.Address("three").Canonicalize(),
   610  				Kind:                     description.RSPrimary,
   611  				SessionTimeoutMinutes:    60,
   612  				SessionTimeoutMinutesPtr: int64ToPtr(60),
   613  			},
   614  		}
   615  
   616  		ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
   617  		defer cancel()
   618  
   619  		desc := description.Server{
   620  			Addr:                     address.Address("three"),
   621  			Kind:                     description.RSSecondary,
   622  			SessionTimeoutMinutes:    30,
   623  			SessionTimeoutMinutesPtr: int64ToPtr(30),
   624  		}
   625  
   626  		topo.apply(ctx, desc)
   627  
   628  		currDesc := topo.desc.Load().(description.Topology)
   629  		require.Nil(t, currDesc.SessionTimeoutMinutesPtr,
   630  			"session timeout minutes mismatch. got: %d. expected: nil", currDesc.SessionTimeoutMinutes)
   631  	})
   632  }
   633  
   634  func TestMinPoolSize(t *testing.T) {
   635  	cfg, err := NewConfig(options.Client().SetHosts([]string{"localhost:27017"}).SetMinPoolSize(10), nil)
   636  	if err != nil {
   637  		t.Errorf("error constructing topology config: %v", err)
   638  	}
   639  
   640  	topo, err := New(cfg)
   641  	if err != nil {
   642  		t.Errorf("topology.New shouldn't error. got: %v", err)
   643  	}
   644  	err = topo.Connect()
   645  	if err != nil {
   646  		t.Errorf("topology.Connect shouldn't error. got: %v", err)
   647  	}
   648  }
   649  
   650  func TestTopology_String_Race(_ *testing.T) {
   651  	ch := make(chan bool)
   652  	topo := &Topology{
   653  		servers: make(map[address.Address]*Server),
   654  	}
   655  
   656  	go func() {
   657  		topo.serversLock.Lock()
   658  		srv := &Server{}
   659  		srv.desc.Store(description.Server{})
   660  		topo.servers[address.Address("127.0.0.1:27017")] = srv
   661  		topo.serversLock.Unlock()
   662  		ch <- true
   663  	}()
   664  
   665  	go func() {
   666  		_ = topo.String()
   667  		ch <- true
   668  	}()
   669  
   670  	<-ch
   671  	<-ch
   672  }
   673  
   674  func TestTopologyConstruction(t *testing.T) {
   675  	t.Run("construct with URI", func(t *testing.T) {
   676  		testCases := []struct {
   677  			name            string
   678  			uri             string
   679  			pollingRequired bool
   680  		}{
   681  			{
   682  				name:            "normal",
   683  				uri:             "mongodb://localhost:27017",
   684  				pollingRequired: false,
   685  			},
   686  		}
   687  		for _, tc := range testCases {
   688  			t.Run(tc.name, func(t *testing.T) {
   689  				cfg, err := NewConfig(options.Client().ApplyURI(tc.uri), nil)
   690  				assert.Nil(t, err, "error constructing topology config: %v", err)
   691  
   692  				topo, err := New(cfg)
   693  				assert.Nil(t, err, "topology.New error: %v", err)
   694  
   695  				assert.Equal(t, tc.uri, topo.cfg.URI, "expected topology URI to be %v, got %v", tc.uri, topo.cfg.URI)
   696  				assert.Equal(t, tc.pollingRequired, topo.pollingRequired,
   697  					"expected topo.pollingRequired to be %v, got %v", tc.pollingRequired, topo.pollingRequired)
   698  			})
   699  		}
   700  	})
   701  }
   702  
   703  type mockLogSink struct {
   704  	msgs []string
   705  }
   706  
   707  func (s *mockLogSink) Info(_ int, msg string, _ ...interface{}) {
   708  	s.msgs = append(s.msgs, msg)
   709  }
   710  func (*mockLogSink) Error(error, string, ...interface{}) {
   711  	// Do nothing.
   712  }
   713  
   714  // Note: SRV connection strings are intentionally untested, since initial
   715  // lookup responses cannot be easily mocked.
   716  func TestTopologyConstructionLogging(t *testing.T) {
   717  	const (
   718  		cosmosDBMsg   = `You appear to be connected to a CosmosDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/cosmosdb`
   719  		documentDBMsg = `You appear to be connected to a DocumentDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/documentdb`
   720  	)
   721  
   722  	newLoggerOptions := func(sink options.LogSink) *options.LoggerOptions {
   723  		return options.
   724  			Logger().
   725  			SetSink(sink).
   726  			SetComponentLevel(options.LogComponentTopology, options.LogLevelInfo)
   727  	}
   728  
   729  	t.Run("CosmosDB URIs", func(t *testing.T) {
   730  		t.Parallel()
   731  
   732  		testCases := []struct {
   733  			name string
   734  			uri  string
   735  			msgs []string
   736  		}{
   737  			{
   738  				name: "normal",
   739  				uri:  "mongodb://a.mongo.cosmos.azure.com:19555/",
   740  				msgs: []string{cosmosDBMsg},
   741  			},
   742  			{
   743  				name: "multiple hosts",
   744  				uri:  "mongodb://a.mongo.cosmos.azure.com:1955,b.mongo.cosmos.azure.com:19555/",
   745  				msgs: []string{cosmosDBMsg},
   746  			},
   747  			{
   748  				name: "case-insensitive matching",
   749  				uri:  "mongodb://a.MONGO.COSMOS.AZURE.COM:19555/",
   750  				msgs: []string{},
   751  			},
   752  			{
   753  				name: "Mixing genuine and nongenuine hosts (unlikely in practice)",
   754  				uri:  "mongodb://a.example.com:27017,b.mongo.cosmos.azure.com:19555/",
   755  				msgs: []string{cosmosDBMsg},
   756  			},
   757  		}
   758  		for _, tc := range testCases {
   759  			tc := tc
   760  
   761  			t.Run(tc.name, func(t *testing.T) {
   762  				t.Parallel()
   763  
   764  				sink := &mockLogSink{}
   765  				cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil)
   766  				require.Nil(t, err, "error constructing topology config: %v", err)
   767  
   768  				topo, err := New(cfg)
   769  				require.Nil(t, err, "topology.New error: %v", err)
   770  
   771  				err = topo.Connect()
   772  				assert.Nil(t, err, "Connect error: %v", err)
   773  
   774  				assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs)
   775  			})
   776  		}
   777  	})
   778  	t.Run("DocumentDB URIs", func(t *testing.T) {
   779  		t.Parallel()
   780  
   781  		testCases := []struct {
   782  			name string
   783  			uri  string
   784  			msgs []string
   785  		}{
   786  			{
   787  				name: "normal",
   788  				uri:  "mongodb://a.docdb.amazonaws.com:27017/",
   789  				msgs: []string{documentDBMsg},
   790  			},
   791  			{
   792  				name: "normal",
   793  				uri:  "mongodb://a.docdb-elastic.amazonaws.com:27017/",
   794  				msgs: []string{documentDBMsg},
   795  			},
   796  			{
   797  				name: "multiple hosts",
   798  				uri:  "mongodb://a.docdb.amazonaws.com:27017,a.docdb-elastic.amazonaws.com:27017/",
   799  				msgs: []string{documentDBMsg},
   800  			},
   801  			{
   802  				name: "case-insensitive matching",
   803  				uri:  "mongodb://a.DOCDB.AMAZONAWS.COM:27017/",
   804  				msgs: []string{},
   805  			},
   806  			{
   807  				name: "case-insensitive matching",
   808  				uri:  "mongodb://a.DOCDB-ELASTIC.AMAZONAWS.COM:27017/",
   809  				msgs: []string{},
   810  			},
   811  			{
   812  				name: "Mixing genuine and nongenuine hosts (unlikely in practice)",
   813  				uri:  "mongodb://a.example.com:27017,b.docdb.amazonaws.com:27017/",
   814  				msgs: []string{documentDBMsg},
   815  			},
   816  			{
   817  				name: "Mixing genuine and nongenuine hosts (unlikely in practice)",
   818  				uri:  "mongodb://a.example.com:27017,b.docdb-elastic.amazonaws.com:27017/",
   819  				msgs: []string{documentDBMsg},
   820  			},
   821  		}
   822  		for _, tc := range testCases {
   823  			tc := tc
   824  
   825  			t.Run(tc.name, func(t *testing.T) {
   826  				t.Parallel()
   827  
   828  				sink := &mockLogSink{}
   829  				cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil)
   830  				require.Nil(t, err, "error constructing topology config: %v", err)
   831  
   832  				topo, err := New(cfg)
   833  				require.Nil(t, err, "topology.New error: %v", err)
   834  
   835  				err = topo.Connect()
   836  				assert.Nil(t, err, "Connect error: %v", err)
   837  
   838  				assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs)
   839  			})
   840  		}
   841  	})
   842  	t.Run("Mixing CosmosDB and DocumentDB URIs", func(t *testing.T) {
   843  		t.Parallel()
   844  
   845  		testCases := []struct {
   846  			name string
   847  			uri  string
   848  			msgs []string
   849  		}{
   850  			{
   851  				name: "Mixing hosts",
   852  				uri:  "mongodb://a.mongo.cosmos.azure.com:19555,a.docdb.amazonaws.com:27017/",
   853  				msgs: []string{cosmosDBMsg, documentDBMsg},
   854  			},
   855  		}
   856  		for _, tc := range testCases {
   857  			tc := tc
   858  
   859  			t.Run(tc.name, func(t *testing.T) {
   860  				t.Parallel()
   861  
   862  				sink := &mockLogSink{}
   863  				cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil)
   864  				require.Nil(t, err, "error constructing topology config: %v", err)
   865  
   866  				topo, err := New(cfg)
   867  				require.Nil(t, err, "topology.New error: %v", err)
   868  
   869  				err = topo.Connect()
   870  				assert.Nil(t, err, "Connect error: %v", err)
   871  
   872  				assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs)
   873  			})
   874  		}
   875  	})
   876  	t.Run("genuine URIs", func(t *testing.T) {
   877  		t.Parallel()
   878  
   879  		testCases := []struct {
   880  			name string
   881  			uri  string
   882  			msgs []string
   883  		}{
   884  			{
   885  				name: "normal",
   886  				uri:  "mongodb://a.example.com:27017/",
   887  				msgs: []string{},
   888  			},
   889  			{
   890  				name: "socket",
   891  				uri:  "mongodb://%2Ftmp%2Fmongodb-27017.sock/",
   892  				msgs: []string{},
   893  			},
   894  			{
   895  				name: "srv",
   896  				uri:  "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname",
   897  				msgs: []string{},
   898  			},
   899  			{
   900  				name: "multiple hosts",
   901  				uri:  "mongodb://a.example.com:27017,b.example.com:27017/",
   902  				msgs: []string{},
   903  			},
   904  			{
   905  				name: "unexpected suffix",
   906  				uri:  "mongodb://a.mongo.cosmos.azure.com.tld:19555/",
   907  				msgs: []string{},
   908  			},
   909  			{
   910  				name: "unexpected suffix",
   911  				uri:  "mongodb://a.docdb.amazonaws.com.tld:27017/",
   912  				msgs: []string{},
   913  			},
   914  			{
   915  				name: "unexpected suffix",
   916  				uri:  "mongodb://a.docdb-elastic.amazonaws.com.tld:27017/",
   917  				msgs: []string{},
   918  			},
   919  		}
   920  		for _, tc := range testCases {
   921  			tc := tc
   922  
   923  			t.Run(tc.name, func(t *testing.T) {
   924  				t.Parallel()
   925  
   926  				sink := &mockLogSink{}
   927  				cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil)
   928  				require.Nil(t, err, "error constructing topology config: %v", err)
   929  
   930  				topo, err := New(cfg)
   931  				require.Nil(t, err, "topology.New error: %v", err)
   932  
   933  				err = topo.Connect()
   934  				assert.Nil(t, err, "Connect error: %v", err)
   935  
   936  				assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs)
   937  			})
   938  		}
   939  	})
   940  }
   941  
   942  type inWindowServer struct {
   943  	Address  string `json:"address"`
   944  	Type     string `json:"type"`
   945  	AvgRTTMS int64  `json:"avg_rtt_ms"`
   946  }
   947  
   948  type inWindowTopology struct {
   949  	Type    string           `json:"type"`
   950  	Servers []inWindowServer `json:"servers"`
   951  }
   952  
   953  type inWindowOutcome struct {
   954  	Tolerance           float64            `json:"tolerance"`
   955  	ExpectedFrequencies map[string]float64 `json:"expected_frequencies"`
   956  }
   957  
   958  type inWindowTopologyState struct {
   959  	Address        string `json:"address"`
   960  	OperationCount int64  `json:"operation_count"`
   961  }
   962  
   963  type inWindowTestCase struct {
   964  	TopologyDescription inWindowTopology        `json:"topology_description"`
   965  	MockedTopologyState []inWindowTopologyState `json:"mocked_topology_state"`
   966  	Iterations          int                     `json:"iterations"`
   967  	Outcome             inWindowOutcome         `json:"outcome"`
   968  }
   969  
   970  // TestServerSelectionSpecInWindow runs the "in_window" server selection spec tests. This test is
   971  // in the "topology" package instead of the "description" package (where the rest of the server
   972  // selection spec tests are) because it primarily tests load-based server selection. Load-based
   973  // server selection is implemented in Topology.SelectServer() because it requires knowledge of the
   974  // current "operation count" (the number of currently running operations) for each server, so it
   975  // can't be effectively accomplished just with server descriptions like most other server selection
   976  // algorithms.
   977  func TestServerSelectionSpecInWindow(t *testing.T) {
   978  	const testsDir = "../../../../testdata/server-selection/in_window"
   979  
   980  	files := spectest.FindJSONFilesInDir(t, testsDir)
   981  
   982  	for _, file := range files {
   983  		t.Run(file, func(t *testing.T) {
   984  			runInWindowTest(t, testsDir, file)
   985  		})
   986  	}
   987  }
   988  
   989  func runInWindowTest(t *testing.T, directory string, filename string) {
   990  	filepath := path.Join(directory, filename)
   991  	content, err := ioutil.ReadFile(filepath)
   992  	require.NoError(t, err)
   993  
   994  	var test inWindowTestCase
   995  	require.NoError(t, json.Unmarshal(content, &test))
   996  
   997  	// For each server described in the test's "topology_description", create both a *Server and
   998  	// description.Server, which are both required to run Topology.SelectServer().
   999  	servers := make(map[string]*Server, len(test.TopologyDescription.Servers))
  1000  	descriptions := make([]description.Server, 0, len(test.TopologyDescription.Servers))
  1001  	for _, testDesc := range test.TopologyDescription.Servers {
  1002  		server := NewServer(
  1003  			address.Address(testDesc.Address),
  1004  			primitive.NilObjectID,
  1005  			withMonitoringDisabled(func(bool) bool { return true }))
  1006  		servers[testDesc.Address] = server
  1007  
  1008  		desc := description.Server{
  1009  			Kind:          serverKindFromString(t, testDesc.Type),
  1010  			Addr:          address.Address(testDesc.Address),
  1011  			AverageRTT:    time.Duration(testDesc.AvgRTTMS) * time.Millisecond,
  1012  			AverageRTTSet: true,
  1013  		}
  1014  
  1015  		if testDesc.AvgRTTMS > 0 {
  1016  			desc.AverageRTT = time.Duration(testDesc.AvgRTTMS) * time.Millisecond
  1017  			desc.AverageRTTSet = true
  1018  		}
  1019  
  1020  		descriptions = append(descriptions, desc)
  1021  	}
  1022  
  1023  	// For each server state in the test's "mocked_topology_state", set the connection pool's
  1024  	// in-use connections count to the test operation count value.
  1025  	for _, state := range test.MockedTopologyState {
  1026  		servers[state.Address].operationCount = state.OperationCount
  1027  	}
  1028  
  1029  	// Create a new Topology, set the state to "connected", store a topology description
  1030  	// containing all server descriptions created from the test server descriptions, and copy
  1031  	// all *Server instances to the Topology's servers list.
  1032  	topology, err := New(nil)
  1033  	require.NoError(t, err, "error creating new Topology")
  1034  	topology.state = topologyConnected
  1035  	topology.desc.Store(description.Topology{
  1036  		Kind:    topologyKindFromString(t, test.TopologyDescription.Type),
  1037  		Servers: descriptions,
  1038  	})
  1039  	for addr, server := range servers {
  1040  		topology.servers[address.Address(addr)] = server
  1041  	}
  1042  
  1043  	// Run server selection the required number of times and record how many times each server
  1044  	// address was selected.
  1045  	counts := make(map[string]int, len(test.TopologyDescription.Servers))
  1046  	for i := 0; i < test.Iterations; i++ {
  1047  		selected, err := topology.SelectServer(
  1048  			context.Background(),
  1049  			description.ReadPrefSelector(readpref.Nearest()))
  1050  		require.NoError(t, err, "error selecting server")
  1051  		counts[string(selected.(*SelectedServer).address)]++
  1052  	}
  1053  
  1054  	// Convert the server selection counts to selection frequencies by dividing the counts by
  1055  	// the total number of server selection attempts.
  1056  	frequencies := make(map[string]float64, len(counts))
  1057  	for addr, count := range counts {
  1058  		frequencies[addr] = float64(count) / float64(test.Iterations)
  1059  	}
  1060  
  1061  	// Assert that the observed server selection frequency for each server address matches the
  1062  	// expected server selection frequency.
  1063  	for addr, expected := range test.Outcome.ExpectedFrequencies {
  1064  		actual := frequencies[addr]
  1065  
  1066  		// If the expected frequency for a given server is 1 or 0, then the observed frequency
  1067  		// MUST be exactly equal to the expected one.
  1068  		if expected == 1 || expected == 0 {
  1069  			assert.Equal(
  1070  				t,
  1071  				expected,
  1072  				actual,
  1073  				"expected frequency of %q to be equal to %f, but is %f",
  1074  				addr, expected, actual)
  1075  			continue
  1076  		}
  1077  
  1078  		// Otherwise, check if the expected frequency is within the given tolerance range.
  1079  		// TODO(GODRIVER-2179): Use assert.Deltaf() when we migrate all test code to the "testify/assert" or an
  1080  		// TODO API-compatible library for assertions.
  1081  		low := expected - test.Outcome.Tolerance
  1082  		high := expected + test.Outcome.Tolerance
  1083  		assert.True(
  1084  			t,
  1085  			actual >= low && actual <= high,
  1086  			"expected frequency of %q to be in range [%f, %f], but is %f",
  1087  			addr, low, high, actual)
  1088  	}
  1089  }
  1090  
  1091  func topologyKindFromString(t *testing.T, s string) description.TopologyKind {
  1092  	t.Helper()
  1093  
  1094  	switch s {
  1095  	case "Single":
  1096  		return description.Single
  1097  	case "ReplicaSet":
  1098  		return description.ReplicaSet
  1099  	case "ReplicaSetNoPrimary":
  1100  		return description.ReplicaSetNoPrimary
  1101  	case "ReplicaSetWithPrimary":
  1102  		return description.ReplicaSetWithPrimary
  1103  	case "Sharded":
  1104  		return description.Sharded
  1105  	case "LoadBalanced":
  1106  		return description.LoadBalanced
  1107  	case "Unknown":
  1108  		return description.Unknown
  1109  	default:
  1110  		t.Fatalf("unrecognized topology kind: %q", s)
  1111  	}
  1112  
  1113  	return description.Unknown
  1114  }
  1115  
  1116  func serverKindFromString(t *testing.T, s string) description.ServerKind {
  1117  	t.Helper()
  1118  
  1119  	switch s {
  1120  	case "Standalone":
  1121  		return description.Standalone
  1122  	case "RSOther":
  1123  		return description.RSMember
  1124  	case "RSPrimary":
  1125  		return description.RSPrimary
  1126  	case "RSSecondary":
  1127  		return description.RSSecondary
  1128  	case "RSArbiter":
  1129  		return description.RSArbiter
  1130  	case "RSGhost":
  1131  		return description.RSGhost
  1132  	case "Mongos":
  1133  		return description.Mongos
  1134  	case "LoadBalancer":
  1135  		return description.LoadBalancer
  1136  	case "PossiblePrimary", "Unknown":
  1137  		// Go does not have a PossiblePrimary server type and per the SDAM spec, this type is synonymous with Unknown.
  1138  		return description.Unknown
  1139  	default:
  1140  		t.Fatalf("unrecognized server kind: %q", s)
  1141  	}
  1142  
  1143  	return description.Unknown
  1144  }
  1145  
  1146  func BenchmarkSelectServerFromDescription(b *testing.B) {
  1147  	for _, bcase := range []struct {
  1148  		name        string
  1149  		serversHook func(servers []description.Server)
  1150  	}{
  1151  		{
  1152  			name:        "AllFit",
  1153  			serversHook: func(servers []description.Server) {},
  1154  		},
  1155  		{
  1156  			name: "AllButOneFit",
  1157  			serversHook: func(servers []description.Server) {
  1158  				servers[0].Kind = description.Unknown
  1159  			},
  1160  		},
  1161  		{
  1162  			name: "HalfFit",
  1163  			serversHook: func(servers []description.Server) {
  1164  				for i := 0; i < len(servers); i += 2 {
  1165  					servers[i].Kind = description.Unknown
  1166  				}
  1167  			},
  1168  		},
  1169  		{
  1170  			name: "OneFit",
  1171  			serversHook: func(servers []description.Server) {
  1172  				for i := 1; i < len(servers); i++ {
  1173  					servers[i].Kind = description.Unknown
  1174  				}
  1175  			},
  1176  		},
  1177  	} {
  1178  		bcase := bcase
  1179  
  1180  		b.Run(bcase.name, func(b *testing.B) {
  1181  			s := description.Server{
  1182  				Addr:              address.Address("localhost:27017"),
  1183  				HeartbeatInterval: time.Duration(10) * time.Second,
  1184  				LastWriteTime:     time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC),
  1185  				LastUpdateTime:    time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC),
  1186  				Kind:              description.Mongos,
  1187  				WireVersion:       &description.VersionRange{Min: 6, Max: 21},
  1188  			}
  1189  			servers := make([]description.Server, 100)
  1190  			for i := 0; i < len(servers); i++ {
  1191  				servers[i] = s
  1192  			}
  1193  			bcase.serversHook(servers)
  1194  			desc := description.Topology{
  1195  				Servers: servers,
  1196  			}
  1197  
  1198  			timeout := make(chan time.Time)
  1199  			b.ResetTimer()
  1200  			b.RunParallel(func(p *testing.PB) {
  1201  				b.ReportAllocs()
  1202  				for p.Next() {
  1203  					var c Topology
  1204  					_, _ = c.selectServerFromDescription(desc, newServerSelectionState(selectNone, timeout))
  1205  				}
  1206  			})
  1207  		})
  1208  	}
  1209  }
  1210  
  1211  func TestLogUnexpectedFailure(t *testing.T) {
  1212  	t.Parallel()
  1213  
  1214  	// newIOLogger will log data using an io sink.
  1215  	newIOLogger := func() (*logger.Logger, *bytes.Buffer, *bufio.Writer) {
  1216  		buf := bytes.NewBuffer(nil)
  1217  		w := bufio.NewWriter(buf)
  1218  
  1219  		ioSink := logger.NewIOSink(w)
  1220  
  1221  		ioLogger, err := logger.New(ioSink, logger.DefaultMaxDocumentLength, map[logger.Component]logger.Level{
  1222  			logger.ComponentTopology: logger.LevelDebug,
  1223  		})
  1224  
  1225  		assert.NoError(t, err)
  1226  
  1227  		return ioLogger, buf, w
  1228  	}
  1229  
  1230  	// newNilLogger will return a nil logger with empty buffer and writer.
  1231  	newNilLogger := func() (*logger.Logger, *bytes.Buffer, *bufio.Writer) {
  1232  		return nil, &bytes.Buffer{}, &bufio.Writer{}
  1233  	}
  1234  
  1235  	tests := []struct {
  1236  		name       string
  1237  		msg        string
  1238  		newLogger  func() (*logger.Logger, *bytes.Buffer, *bufio.Writer)
  1239  		panicValue interface{}
  1240  		want       interface{} // Either a string or nil
  1241  	}{
  1242  		{
  1243  			name:       "nil logger",
  1244  			msg:        "",
  1245  			newLogger:  newNilLogger,
  1246  			panicValue: 1,
  1247  			want:       nil,
  1248  		},
  1249  		{
  1250  			name:       "valid logger",
  1251  			msg:        "test",
  1252  			newLogger:  newIOLogger,
  1253  			panicValue: 1,
  1254  			want:       "test: 1",
  1255  		},
  1256  		{
  1257  			name:       "valid logger with error panic",
  1258  			msg:        "test",
  1259  			newLogger:  newIOLogger,
  1260  			panicValue: errors.New("err"),
  1261  			want:       "test: err",
  1262  		},
  1263  	}
  1264  
  1265  	for _, test := range tests {
  1266  		test := test
  1267  
  1268  		t.Run(test.name, func(t *testing.T) {
  1269  			t.Parallel()
  1270  
  1271  			log, buf, w := test.newLogger()
  1272  
  1273  			func() {
  1274  				defer logUnexpectedFailure(log, test.msg)
  1275  
  1276  				panic(test.panicValue)
  1277  			}()
  1278  
  1279  			assert.NoError(t, w.Flush())
  1280  
  1281  			got := map[string]interface{}{}
  1282  			_ = json.Unmarshal(buf.Bytes(), &got)
  1283  
  1284  			assert.Equal(t, test.want, got[logger.KeyMessage])
  1285  		})
  1286  	}
  1287  }
  1288  

View as plain text