...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/topology/topology.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/topology

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  // Package topology contains types that handles the discovery, monitoring, and selection
     8  // of servers. This package is designed to expose enough inner workings of service discovery
     9  // and monitoring to allow low level applications to have fine grained control, while hiding
    10  // most of the detailed implementation of the algorithms.
    11  package topology // import "go.mongodb.org/mongo-driver/x/mongo/driver/topology"
    12  
    13  import (
    14  	"context"
    15  	"errors"
    16  	"fmt"
    17  	"net"
    18  	"strconv"
    19  	"strings"
    20  	"sync"
    21  	"sync/atomic"
    22  	"time"
    23  
    24  	"go.mongodb.org/mongo-driver/bson/primitive"
    25  	"go.mongodb.org/mongo-driver/event"
    26  	"go.mongodb.org/mongo-driver/internal/logger"
    27  	"go.mongodb.org/mongo-driver/internal/randutil"
    28  	"go.mongodb.org/mongo-driver/mongo/address"
    29  	"go.mongodb.org/mongo-driver/mongo/description"
    30  	"go.mongodb.org/mongo-driver/mongo/options"
    31  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    32  	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
    33  	"go.mongodb.org/mongo-driver/x/mongo/driver/dns"
    34  )
    35  
    36  // Topology state constants.
    37  const (
    38  	topologyDisconnected int64 = iota
    39  	topologyDisconnecting
    40  	topologyConnected
    41  	topologyConnecting
    42  )
    43  
    44  // ErrSubscribeAfterClosed is returned when a user attempts to subscribe to a
    45  // closed Server or Topology.
    46  var ErrSubscribeAfterClosed = errors.New("cannot subscribe after closeConnection")
    47  
    48  // ErrTopologyClosed is returned when a user attempts to call a method on a
    49  // closed Topology.
    50  var ErrTopologyClosed = errors.New("topology is closed")
    51  
    52  // ErrTopologyConnected is returned whena  user attempts to Connect to an
    53  // already connected Topology.
    54  var ErrTopologyConnected = errors.New("topology is connected or connecting")
    55  
    56  // ErrServerSelectionTimeout is returned from server selection when the server
    57  // selection process took longer than allowed by the timeout.
    58  var ErrServerSelectionTimeout = errors.New("server selection timeout")
    59  
    60  // MonitorMode represents the way in which a server is monitored.
    61  type MonitorMode uint8
    62  
    63  // random is a package-global pseudo-random number generator.
    64  var random = randutil.NewLockedRand()
    65  
    66  // These constants are the available monitoring modes.
    67  const (
    68  	AutomaticMode MonitorMode = iota
    69  	SingleMode
    70  )
    71  
    72  // Topology represents a MongoDB deployment.
    73  type Topology struct {
    74  	state int64
    75  
    76  	cfg *Config
    77  
    78  	desc atomic.Value // holds a description.Topology
    79  
    80  	dnsResolver *dns.Resolver
    81  
    82  	done chan struct{}
    83  
    84  	pollingRequired   bool
    85  	pollingDone       chan struct{}
    86  	pollingwg         sync.WaitGroup
    87  	rescanSRVInterval time.Duration
    88  	pollHeartbeatTime atomic.Value // holds a bool
    89  
    90  	hosts []string
    91  
    92  	updateCallback updateTopologyCallback
    93  	fsm            *fsm
    94  
    95  	// This should really be encapsulated into it's own type. This will likely
    96  	// require a redesign so we can share a minimum of data between the
    97  	// subscribers and the topology.
    98  	subscribers         map[uint64]chan description.Topology
    99  	currentSubscriberID uint64
   100  	subscriptionsClosed bool
   101  	subLock             sync.Mutex
   102  
   103  	// We should redesign how we Connect and handle individual servers. This is
   104  	// too difficult to maintain and it's rather easy to accidentally access
   105  	// the servers without acquiring the lock or checking if the servers are
   106  	// closed. This lock should also be an RWMutex.
   107  	serversLock   sync.Mutex
   108  	serversClosed bool
   109  	servers       map[address.Address]*Server
   110  
   111  	id primitive.ObjectID
   112  }
   113  
   114  var (
   115  	_ driver.Deployment = &Topology{}
   116  	_ driver.Subscriber = &Topology{}
   117  )
   118  
   119  type serverSelectionState struct {
   120  	selector    description.ServerSelector
   121  	timeoutChan <-chan time.Time
   122  }
   123  
   124  func newServerSelectionState(selector description.ServerSelector, timeoutChan <-chan time.Time) serverSelectionState {
   125  	return serverSelectionState{
   126  		selector:    selector,
   127  		timeoutChan: timeoutChan,
   128  	}
   129  }
   130  
   131  // New creates a new topology. A "nil" config is interpreted as the default configuration.
   132  func New(cfg *Config) (*Topology, error) {
   133  	if cfg == nil {
   134  		var err error
   135  		cfg, err = NewConfig(options.Client(), nil)
   136  		if err != nil {
   137  			return nil, err
   138  		}
   139  	}
   140  
   141  	t := &Topology{
   142  		cfg:               cfg,
   143  		done:              make(chan struct{}),
   144  		pollingDone:       make(chan struct{}),
   145  		rescanSRVInterval: 60 * time.Second,
   146  		fsm:               newFSM(),
   147  		subscribers:       make(map[uint64]chan description.Topology),
   148  		servers:           make(map[address.Address]*Server),
   149  		dnsResolver:       dns.DefaultResolver,
   150  		id:                primitive.NewObjectID(),
   151  	}
   152  	t.desc.Store(description.Topology{})
   153  	t.updateCallback = func(desc description.Server) description.Server {
   154  		return t.apply(context.TODO(), desc)
   155  	}
   156  
   157  	if t.cfg.URI != "" {
   158  		connStr, err := connstring.Parse(t.cfg.URI)
   159  		if err != nil {
   160  			return nil, err
   161  		}
   162  		t.pollingRequired = (connStr.Scheme == connstring.SchemeMongoDBSRV) && !t.cfg.LoadBalanced
   163  		t.hosts = connStr.RawHosts
   164  	}
   165  
   166  	t.publishTopologyOpeningEvent()
   167  
   168  	return t, nil
   169  }
   170  
   171  func mustLogTopologyMessage(topo *Topology, level logger.Level) bool {
   172  	return topo.cfg.logger != nil && topo.cfg.logger.LevelComponentEnabled(
   173  		level, logger.ComponentTopology)
   174  }
   175  
   176  func logTopologyMessage(topo *Topology, level logger.Level, msg string, keysAndValues ...interface{}) {
   177  	topo.cfg.logger.Print(level,
   178  		logger.ComponentTopology,
   179  		msg,
   180  		logger.SerializeTopology(logger.Topology{
   181  			ID:      topo.id,
   182  			Message: msg,
   183  		}, keysAndValues...)...)
   184  }
   185  
   186  func logTopologyThirdPartyUsage(topo *Topology, parsedHosts []string) {
   187  	thirdPartyMessages := [2]string{
   188  		`You appear to be connected to a CosmosDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/cosmosdb`,
   189  		`You appear to be connected to a DocumentDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/documentdb`,
   190  	}
   191  
   192  	thirdPartySuffixes := map[string]int{
   193  		".cosmos.azure.com":            0,
   194  		".docdb.amazonaws.com":         1,
   195  		".docdb-elastic.amazonaws.com": 1,
   196  	}
   197  
   198  	hostSet := make([]bool, len(thirdPartyMessages))
   199  	for _, host := range parsedHosts {
   200  		if h, _, err := net.SplitHostPort(host); err == nil {
   201  			host = h
   202  		}
   203  		for suffix, env := range thirdPartySuffixes {
   204  			if !strings.HasSuffix(host, suffix) {
   205  				continue
   206  			}
   207  			if hostSet[env] {
   208  				break
   209  			}
   210  			hostSet[env] = true
   211  			logTopologyMessage(topo, logger.LevelInfo, thirdPartyMessages[env])
   212  		}
   213  	}
   214  }
   215  
   216  func mustLogServerSelection(topo *Topology, level logger.Level) bool {
   217  	return topo.cfg.logger != nil && topo.cfg.logger.LevelComponentEnabled(
   218  		level, logger.ComponentServerSelection)
   219  }
   220  
   221  func logServerSelection(
   222  	ctx context.Context,
   223  	topo *Topology,
   224  	level logger.Level,
   225  	msg string,
   226  	srvSelector description.ServerSelector,
   227  	keysAndValues ...interface{},
   228  ) {
   229  	var srvSelectorString string
   230  
   231  	selectorStringer, ok := srvSelector.(fmt.Stringer)
   232  	if ok {
   233  		srvSelectorString = selectorStringer.String()
   234  	}
   235  
   236  	operationName, _ := logger.OperationName(ctx)
   237  	operationID, _ := logger.OperationID(ctx)
   238  
   239  	topo.cfg.logger.Print(level,
   240  		logger.ComponentServerSelection,
   241  		msg,
   242  		logger.SerializeServerSelection(logger.ServerSelection{
   243  			Selector:            srvSelectorString,
   244  			Operation:           operationName,
   245  			OperationID:         &operationID,
   246  			TopologyDescription: topo.String(),
   247  		}, keysAndValues...)...)
   248  }
   249  
   250  func logServerSelectionSucceeded(
   251  	ctx context.Context,
   252  	topo *Topology,
   253  	srvSelector description.ServerSelector,
   254  	server *SelectedServer,
   255  ) {
   256  	host, port, err := net.SplitHostPort(server.address.String())
   257  	if err != nil {
   258  		host = server.address.String()
   259  		port = ""
   260  	}
   261  
   262  	portInt64, _ := strconv.ParseInt(port, 10, 32)
   263  
   264  	logServerSelection(ctx, topo, logger.LevelDebug, logger.ServerSelectionSucceeded, srvSelector,
   265  		logger.KeyServerHost, host,
   266  		logger.KeyServerPort, portInt64)
   267  }
   268  
   269  func logServerSelectionFailed(
   270  	ctx context.Context,
   271  	topo *Topology,
   272  	srvSelector description.ServerSelector,
   273  	err error,
   274  ) {
   275  	logServerSelection(ctx, topo, logger.LevelDebug, logger.ServerSelectionFailed, srvSelector,
   276  		logger.KeyFailure, err.Error())
   277  }
   278  
   279  // logUnexpectedFailure is a defer-recover function for logging unexpected
   280  // failures encountered while maintaining a topology.
   281  //
   282  // Most topology maintenance actions, such as updating a server, should not take
   283  // down a client's application. This function provides a best-effort to log
   284  // unexpected failures. If the logger passed to this function is nil, then the
   285  // recovery will be silent.
   286  func logUnexpectedFailure(log *logger.Logger, msg string, callbacks ...func()) {
   287  	r := recover()
   288  	if r == nil {
   289  		return
   290  	}
   291  
   292  	defer func() {
   293  		for _, clbk := range callbacks {
   294  			clbk()
   295  		}
   296  	}()
   297  
   298  	if log == nil {
   299  		return
   300  	}
   301  
   302  	log.Print(logger.LevelInfo, logger.ComponentTopology, fmt.Sprintf("%s: %v", msg, r))
   303  }
   304  
   305  // Connect initializes a Topology and starts the monitoring process. This function
   306  // must be called to properly monitor the topology.
   307  func (t *Topology) Connect() error {
   308  	if !atomic.CompareAndSwapInt64(&t.state, topologyDisconnected, topologyConnecting) {
   309  		return ErrTopologyConnected
   310  	}
   311  
   312  	t.desc.Store(description.Topology{})
   313  	var err error
   314  	t.serversLock.Lock()
   315  
   316  	// A replica set name sets the initial topology type to ReplicaSetNoPrimary unless a direct connection is also
   317  	// specified, in which case the initial type is Single.
   318  	if t.cfg.ReplicaSetName != "" {
   319  		t.fsm.SetName = t.cfg.ReplicaSetName
   320  		t.fsm.Kind = description.ReplicaSetNoPrimary
   321  	}
   322  
   323  	// A direct connection unconditionally sets the topology type to Single.
   324  	if t.cfg.Mode == SingleMode {
   325  		t.fsm.Kind = description.Single
   326  	}
   327  
   328  	for _, a := range t.cfg.SeedList {
   329  		addr := address.Address(a).Canonicalize()
   330  		t.fsm.Servers = append(t.fsm.Servers, description.NewDefaultServer(addr))
   331  	}
   332  
   333  	switch {
   334  	case t.cfg.LoadBalanced:
   335  		// In LoadBalanced mode, we mock a series of events: TopologyDescriptionChanged from Unknown to LoadBalanced,
   336  		// ServerDescriptionChanged from Unknown to LoadBalancer, and then TopologyDescriptionChanged to reflect the
   337  		// previous ServerDescriptionChanged event. We publish all of these events here because we don't start server
   338  		// monitoring routines in this mode, so we have to mock state changes.
   339  
   340  		// Transition from Unknown with no servers to LoadBalanced with a single Unknown server.
   341  		t.fsm.Kind = description.LoadBalanced
   342  		t.publishTopologyDescriptionChangedEvent(description.Topology{}, t.fsm.Topology)
   343  
   344  		addr := address.Address(t.cfg.SeedList[0]).Canonicalize()
   345  		if err := t.addServer(addr); err != nil {
   346  			t.serversLock.Unlock()
   347  			return err
   348  		}
   349  
   350  		// Transition the server from Unknown to LoadBalancer.
   351  		newServerDesc := t.servers[addr].Description()
   352  		t.publishServerDescriptionChangedEvent(t.fsm.Servers[0], newServerDesc)
   353  
   354  		// Transition from LoadBalanced with an Unknown server to LoadBalanced with a LoadBalancer.
   355  		oldDesc := t.fsm.Topology
   356  		t.fsm.Servers = []description.Server{newServerDesc}
   357  		t.desc.Store(t.fsm.Topology)
   358  		t.publishTopologyDescriptionChangedEvent(oldDesc, t.fsm.Topology)
   359  	default:
   360  		// In non-LB mode, we only publish an initial TopologyDescriptionChanged event from Unknown with no servers to
   361  		// the current state (e.g. Unknown with one or more servers if we're discovering or Single with one server if
   362  		// we're connecting directly). Other events are published when state changes occur due to responses in the
   363  		// server monitoring goroutines.
   364  
   365  		newDesc := description.Topology{
   366  			Kind:                     t.fsm.Kind,
   367  			Servers:                  t.fsm.Servers,
   368  			SessionTimeoutMinutesPtr: t.fsm.SessionTimeoutMinutesPtr,
   369  
   370  			// TODO(GODRIVER-2885): This field can be removed once
   371  			// legacy SessionTimeoutMinutes is removed.
   372  			SessionTimeoutMinutes: t.fsm.SessionTimeoutMinutes,
   373  		}
   374  		t.desc.Store(newDesc)
   375  		t.publishTopologyDescriptionChangedEvent(description.Topology{}, t.fsm.Topology)
   376  		for _, a := range t.cfg.SeedList {
   377  			addr := address.Address(a).Canonicalize()
   378  			err = t.addServer(addr)
   379  			if err != nil {
   380  				t.serversLock.Unlock()
   381  				return err
   382  			}
   383  		}
   384  	}
   385  
   386  	t.serversLock.Unlock()
   387  	if mustLogTopologyMessage(t, logger.LevelInfo) {
   388  		logTopologyThirdPartyUsage(t, t.hosts)
   389  	}
   390  	if t.pollingRequired {
   391  		// sanity check before passing the hostname to resolver
   392  		if len(t.hosts) != 1 {
   393  			return fmt.Errorf("URI with SRV must include one and only one hostname")
   394  		}
   395  		_, _, err = net.SplitHostPort(t.hosts[0])
   396  		if err == nil {
   397  			// we were able to successfully extract a port from the host,
   398  			// but should not be able to when using SRV
   399  			return fmt.Errorf("URI with srv must not include a port number")
   400  		}
   401  		go t.pollSRVRecords(t.hosts[0])
   402  		t.pollingwg.Add(1)
   403  	}
   404  
   405  	t.subscriptionsClosed = false // explicitly set in case topology was disconnected and then reconnected
   406  
   407  	atomic.StoreInt64(&t.state, topologyConnected)
   408  	return nil
   409  }
   410  
   411  // Disconnect closes the topology. It stops the monitoring thread and
   412  // closes all open subscriptions.
   413  func (t *Topology) Disconnect(ctx context.Context) error {
   414  	if !atomic.CompareAndSwapInt64(&t.state, topologyConnected, topologyDisconnecting) {
   415  		return ErrTopologyClosed
   416  	}
   417  
   418  	servers := make(map[address.Address]*Server)
   419  	t.serversLock.Lock()
   420  	t.serversClosed = true
   421  	for addr, server := range t.servers {
   422  		servers[addr] = server
   423  	}
   424  	t.serversLock.Unlock()
   425  
   426  	for _, server := range servers {
   427  		_ = server.Disconnect(ctx)
   428  		t.publishServerClosedEvent(server.address)
   429  	}
   430  
   431  	t.subLock.Lock()
   432  	for id, ch := range t.subscribers {
   433  		close(ch)
   434  		delete(t.subscribers, id)
   435  	}
   436  	t.subscriptionsClosed = true
   437  	t.subLock.Unlock()
   438  
   439  	if t.pollingRequired {
   440  		t.pollingDone <- struct{}{}
   441  		t.pollingwg.Wait()
   442  	}
   443  
   444  	t.desc.Store(description.Topology{})
   445  
   446  	atomic.StoreInt64(&t.state, topologyDisconnected)
   447  	t.publishTopologyClosedEvent()
   448  	return nil
   449  }
   450  
   451  // Description returns a description of the topology.
   452  func (t *Topology) Description() description.Topology {
   453  	td, ok := t.desc.Load().(description.Topology)
   454  	if !ok {
   455  		td = description.Topology{}
   456  	}
   457  	return td
   458  }
   459  
   460  // Kind returns the topology kind of this Topology.
   461  func (t *Topology) Kind() description.TopologyKind { return t.Description().Kind }
   462  
   463  // Subscribe returns a Subscription on which all updated description.Topologys
   464  // will be sent. The channel of the subscription will have a buffer size of one,
   465  // and will be pre-populated with the current description.Topology.
   466  // Subscribe implements the driver.Subscriber interface.
   467  func (t *Topology) Subscribe() (*driver.Subscription, error) {
   468  	if atomic.LoadInt64(&t.state) != topologyConnected {
   469  		return nil, errors.New("cannot subscribe to Topology that is not connected")
   470  	}
   471  	ch := make(chan description.Topology, 1)
   472  	td, ok := t.desc.Load().(description.Topology)
   473  	if !ok {
   474  		td = description.Topology{}
   475  	}
   476  	ch <- td
   477  
   478  	t.subLock.Lock()
   479  	defer t.subLock.Unlock()
   480  	if t.subscriptionsClosed {
   481  		return nil, ErrSubscribeAfterClosed
   482  	}
   483  	id := t.currentSubscriberID
   484  	t.subscribers[id] = ch
   485  	t.currentSubscriberID++
   486  
   487  	return &driver.Subscription{
   488  		Updates: ch,
   489  		ID:      id,
   490  	}, nil
   491  }
   492  
   493  // Unsubscribe unsubscribes the given subscription from the topology and closes the subscription channel.
   494  // Unsubscribe implements the driver.Subscriber interface.
   495  func (t *Topology) Unsubscribe(sub *driver.Subscription) error {
   496  	t.subLock.Lock()
   497  	defer t.subLock.Unlock()
   498  
   499  	if t.subscriptionsClosed {
   500  		return nil
   501  	}
   502  
   503  	ch, ok := t.subscribers[sub.ID]
   504  	if !ok {
   505  		return nil
   506  	}
   507  
   508  	close(ch)
   509  	delete(t.subscribers, sub.ID)
   510  	return nil
   511  }
   512  
   513  // RequestImmediateCheck will send heartbeats to all the servers in the
   514  // topology right away, instead of waiting for the heartbeat timeout.
   515  func (t *Topology) RequestImmediateCheck() {
   516  	if atomic.LoadInt64(&t.state) != topologyConnected {
   517  		return
   518  	}
   519  	t.serversLock.Lock()
   520  	for _, server := range t.servers {
   521  		server.RequestImmediateCheck()
   522  	}
   523  	t.serversLock.Unlock()
   524  }
   525  
   526  // SelectServer selects a server with given a selector. SelectServer complies with the
   527  // server selection spec, and will time out after serverSelectionTimeout or when the
   528  // parent context is done.
   529  func (t *Topology) SelectServer(ctx context.Context, ss description.ServerSelector) (driver.Server, error) {
   530  	if atomic.LoadInt64(&t.state) != topologyConnected {
   531  		if mustLogServerSelection(t, logger.LevelDebug) {
   532  			logServerSelectionFailed(ctx, t, ss, ErrTopologyClosed)
   533  		}
   534  
   535  		return nil, ErrTopologyClosed
   536  	}
   537  	var ssTimeoutCh <-chan time.Time
   538  
   539  	if t.cfg.ServerSelectionTimeout > 0 {
   540  		ssTimeout := time.NewTimer(t.cfg.ServerSelectionTimeout)
   541  		ssTimeoutCh = ssTimeout.C
   542  		defer ssTimeout.Stop()
   543  	}
   544  
   545  	var doneOnce bool
   546  	var sub *driver.Subscription
   547  	selectionState := newServerSelectionState(ss, ssTimeoutCh)
   548  
   549  	// Record the start time.
   550  	startTime := time.Now()
   551  	for {
   552  		var suitable []description.Server
   553  		var selectErr error
   554  
   555  		if !doneOnce {
   556  			if mustLogServerSelection(t, logger.LevelDebug) {
   557  				logServerSelection(ctx, t, logger.LevelDebug, logger.ServerSelectionStarted, ss)
   558  			}
   559  
   560  			// for the first pass, select a server from the current description.
   561  			// this improves selection speed for up-to-date topology descriptions.
   562  			suitable, selectErr = t.selectServerFromDescription(t.Description(), selectionState)
   563  			doneOnce = true
   564  		} else {
   565  			// if the first pass didn't select a server, the previous description did not contain a suitable server, so
   566  			// we subscribe to the topology and attempt to obtain a server from that subscription
   567  			if sub == nil {
   568  				var err error
   569  				sub, err = t.Subscribe()
   570  				if err != nil {
   571  					if mustLogServerSelection(t, logger.LevelDebug) {
   572  						logServerSelectionFailed(ctx, t, ss, err)
   573  					}
   574  
   575  					return nil, err
   576  				}
   577  				defer func() { _ = t.Unsubscribe(sub) }()
   578  			}
   579  
   580  			suitable, selectErr = t.selectServerFromSubscription(ctx, sub.Updates, selectionState)
   581  		}
   582  		if selectErr != nil {
   583  			if mustLogServerSelection(t, logger.LevelDebug) {
   584  				logServerSelectionFailed(ctx, t, ss, selectErr)
   585  			}
   586  
   587  			return nil, selectErr
   588  		}
   589  
   590  		if len(suitable) == 0 {
   591  			// try again if there are no servers available
   592  			if mustLogServerSelection(t, logger.LevelInfo) {
   593  				elapsed := time.Since(startTime)
   594  				remainingTimeMS := t.cfg.ServerSelectionTimeout - elapsed
   595  
   596  				logServerSelection(ctx, t, logger.LevelInfo, logger.ServerSelectionWaiting, ss,
   597  					logger.KeyRemainingTimeMS, remainingTimeMS.Milliseconds())
   598  			}
   599  
   600  			continue
   601  		}
   602  
   603  		// If there's only one suitable server description, try to find the associated server and
   604  		// return it. This is an optimization primarily for standalone and load-balanced deployments.
   605  		if len(suitable) == 1 {
   606  			server, err := t.FindServer(suitable[0])
   607  			if err != nil {
   608  				if mustLogServerSelection(t, logger.LevelDebug) {
   609  					logServerSelectionFailed(ctx, t, ss, err)
   610  				}
   611  
   612  				return nil, err
   613  			}
   614  			if server == nil {
   615  				continue
   616  			}
   617  
   618  			if mustLogServerSelection(t, logger.LevelDebug) {
   619  				logServerSelectionSucceeded(ctx, t, ss, server)
   620  			}
   621  
   622  			return server, nil
   623  		}
   624  
   625  		// Randomly select 2 suitable server descriptions and find servers for them. We select two
   626  		// so we can pick the one with the one with fewer in-progress operations below.
   627  		desc1, desc2 := pick2(suitable)
   628  		server1, err := t.FindServer(desc1)
   629  		if err != nil {
   630  			if mustLogServerSelection(t, logger.LevelDebug) {
   631  				logServerSelectionFailed(ctx, t, ss, err)
   632  			}
   633  
   634  			return nil, err
   635  		}
   636  		server2, err := t.FindServer(desc2)
   637  		if err != nil {
   638  			if mustLogServerSelection(t, logger.LevelDebug) {
   639  				logServerSelectionFailed(ctx, t, ss, err)
   640  			}
   641  
   642  			return nil, err
   643  		}
   644  
   645  		// If we don't have an actual server for one or both of the provided descriptions, either
   646  		// return the one server we have, or try again if they're both nil. This could happen for a
   647  		// number of reasons, including that the server has since stopped being a part of this
   648  		// topology.
   649  		if server1 == nil || server2 == nil {
   650  			if server1 == nil && server2 == nil {
   651  				continue
   652  			}
   653  
   654  			if server1 != nil {
   655  				if mustLogServerSelection(t, logger.LevelDebug) {
   656  					logServerSelectionSucceeded(ctx, t, ss, server1)
   657  				}
   658  				return server1, nil
   659  			}
   660  
   661  			if mustLogServerSelection(t, logger.LevelDebug) {
   662  				logServerSelectionSucceeded(ctx, t, ss, server2)
   663  			}
   664  
   665  			return server2, nil
   666  		}
   667  
   668  		// Of the two randomly selected suitable servers, pick the one with fewer in-use connections.
   669  		// We use in-use connections as an analog for in-progress operations because they are almost
   670  		// always the same value for a given server.
   671  		if server1.OperationCount() < server2.OperationCount() {
   672  			if mustLogServerSelection(t, logger.LevelDebug) {
   673  				logServerSelectionSucceeded(ctx, t, ss, server1)
   674  			}
   675  
   676  			return server1, nil
   677  		}
   678  
   679  		if mustLogServerSelection(t, logger.LevelDebug) {
   680  			logServerSelectionSucceeded(ctx, t, ss, server2)
   681  		}
   682  		return server2, nil
   683  	}
   684  }
   685  
   686  // pick2 returns 2 random server descriptions from the input slice of server descriptions,
   687  // guaranteeing that the same element from the slice is not picked twice. The order of server
   688  // descriptions in the input slice may be modified. If fewer than 2 server descriptions are
   689  // provided, pick2 will panic.
   690  func pick2(ds []description.Server) (description.Server, description.Server) {
   691  	// Select a random index from the input slice and keep the server description from that index.
   692  	idx := random.Intn(len(ds))
   693  	s1 := ds[idx]
   694  
   695  	// Swap the selected index to the end and reslice to remove it so we don't pick the same server
   696  	// description twice.
   697  	ds[idx], ds[len(ds)-1] = ds[len(ds)-1], ds[idx]
   698  	ds = ds[:len(ds)-1]
   699  
   700  	// Select another random index from the input slice and return both selected server descriptions.
   701  	return s1, ds[random.Intn(len(ds))]
   702  }
   703  
   704  // FindServer will attempt to find a server that fits the given server description.
   705  // This method will return nil, nil if a matching server could not be found.
   706  func (t *Topology) FindServer(selected description.Server) (*SelectedServer, error) {
   707  	if atomic.LoadInt64(&t.state) != topologyConnected {
   708  		return nil, ErrTopologyClosed
   709  	}
   710  	t.serversLock.Lock()
   711  	defer t.serversLock.Unlock()
   712  	server, ok := t.servers[selected.Addr]
   713  	if !ok {
   714  		return nil, nil
   715  	}
   716  
   717  	desc := t.Description()
   718  	return &SelectedServer{
   719  		Server: server,
   720  		Kind:   desc.Kind,
   721  	}, nil
   722  }
   723  
   724  // selectServerFromSubscription loops until a topology description is available for server selection. It returns
   725  // when the given context expires, server selection timeout is reached, or a description containing a selectable
   726  // server is available.
   727  func (t *Topology) selectServerFromSubscription(ctx context.Context, subscriptionCh <-chan description.Topology,
   728  	selectionState serverSelectionState) ([]description.Server, error) {
   729  
   730  	current := t.Description()
   731  	for {
   732  		select {
   733  		case <-ctx.Done():
   734  			return nil, ServerSelectionError{Wrapped: ctx.Err(), Desc: current}
   735  		case <-selectionState.timeoutChan:
   736  			return nil, ServerSelectionError{Wrapped: ErrServerSelectionTimeout, Desc: current}
   737  		case current = <-subscriptionCh:
   738  		}
   739  
   740  		suitable, err := t.selectServerFromDescription(current, selectionState)
   741  		if err != nil {
   742  			return nil, err
   743  		}
   744  
   745  		if len(suitable) > 0 {
   746  			return suitable, nil
   747  		}
   748  		t.RequestImmediateCheck()
   749  	}
   750  }
   751  
   752  // selectServerFromDescription process the given topology description and returns a slice of suitable servers.
   753  func (t *Topology) selectServerFromDescription(desc description.Topology,
   754  	selectionState serverSelectionState) ([]description.Server, error) {
   755  
   756  	// Unlike selectServerFromSubscription, this code path does not check ctx.Done or selectionState.timeoutChan because
   757  	// selecting a server from a description is not a blocking operation.
   758  
   759  	if desc.CompatibilityErr != nil {
   760  		return nil, desc.CompatibilityErr
   761  	}
   762  
   763  	// If the topology kind is LoadBalanced, the LB is the only server and it is always considered selectable. The
   764  	// selectors exported by the driver should already return the LB as a candidate, so this but this check ensures that
   765  	// the LB is always selectable even if a user of the low-level driver provides a custom selector.
   766  	if desc.Kind == description.LoadBalanced {
   767  		return desc.Servers, nil
   768  	}
   769  
   770  	allowedIndexes := make([]int, 0, len(desc.Servers))
   771  	for i, s := range desc.Servers {
   772  		if s.Kind != description.Unknown {
   773  			allowedIndexes = append(allowedIndexes, i)
   774  		}
   775  	}
   776  
   777  	allowed := make([]description.Server, len(allowedIndexes))
   778  	for i, idx := range allowedIndexes {
   779  		allowed[i] = desc.Servers[idx]
   780  	}
   781  
   782  	suitable, err := selectionState.selector.SelectServer(desc, allowed)
   783  	if err != nil {
   784  		return nil, ServerSelectionError{Wrapped: err, Desc: desc}
   785  	}
   786  	return suitable, nil
   787  }
   788  
   789  func (t *Topology) pollSRVRecords(hosts string) {
   790  	defer t.pollingwg.Done()
   791  
   792  	serverConfig := newServerConfig(t.cfg.ServerOpts...)
   793  	heartbeatInterval := serverConfig.heartbeatInterval
   794  
   795  	pollTicker := time.NewTicker(t.rescanSRVInterval)
   796  	defer pollTicker.Stop()
   797  	t.pollHeartbeatTime.Store(false)
   798  	var doneOnce bool
   799  	defer logUnexpectedFailure(t.cfg.logger, "Encountered unexpected failure polling SRV records", func() {
   800  		if !doneOnce {
   801  			<-t.pollingDone
   802  		}
   803  	})
   804  
   805  	for {
   806  		select {
   807  		case <-pollTicker.C:
   808  		case <-t.pollingDone:
   809  			doneOnce = true
   810  			return
   811  		}
   812  		topoKind := t.Description().Kind
   813  		if !(topoKind == description.Unknown || topoKind == description.Sharded) {
   814  			break
   815  		}
   816  
   817  		parsedHosts, err := t.dnsResolver.ParseHosts(hosts, t.cfg.SRVServiceName, false)
   818  		// DNS problem or no verified hosts returned
   819  		if err != nil || len(parsedHosts) == 0 {
   820  			if !t.pollHeartbeatTime.Load().(bool) {
   821  				pollTicker.Stop()
   822  				pollTicker = time.NewTicker(heartbeatInterval)
   823  				t.pollHeartbeatTime.Store(true)
   824  			}
   825  			continue
   826  		}
   827  		if t.pollHeartbeatTime.Load().(bool) {
   828  			pollTicker.Stop()
   829  			pollTicker = time.NewTicker(t.rescanSRVInterval)
   830  			t.pollHeartbeatTime.Store(false)
   831  		}
   832  
   833  		cont := t.processSRVResults(parsedHosts)
   834  		if !cont {
   835  			break
   836  		}
   837  	}
   838  	<-t.pollingDone
   839  	doneOnce = true
   840  }
   841  
   842  func (t *Topology) processSRVResults(parsedHosts []string) bool {
   843  	t.serversLock.Lock()
   844  	defer t.serversLock.Unlock()
   845  
   846  	if t.serversClosed {
   847  		return false
   848  	}
   849  	prev := t.fsm.Topology
   850  	diff := diffHostList(t.fsm.Topology, parsedHosts)
   851  
   852  	if len(diff.Added) == 0 && len(diff.Removed) == 0 {
   853  		return true
   854  	}
   855  
   856  	for _, r := range diff.Removed {
   857  		addr := address.Address(r).Canonicalize()
   858  		s, ok := t.servers[addr]
   859  		if !ok {
   860  			continue
   861  		}
   862  		go func() {
   863  			cancelCtx, cancel := context.WithCancel(context.Background())
   864  			cancel()
   865  			_ = s.Disconnect(cancelCtx)
   866  		}()
   867  		delete(t.servers, addr)
   868  		t.fsm.removeServerByAddr(addr)
   869  		t.publishServerClosedEvent(s.address)
   870  	}
   871  
   872  	// Now that we've removed all the hosts that disappeared from the SRV record, we need to add any
   873  	// new hosts added to the SRV record. If adding all of the new hosts would increase the number
   874  	// of servers past srvMaxHosts, shuffle the list of added hosts.
   875  	if t.cfg.SRVMaxHosts > 0 && len(t.servers)+len(diff.Added) > t.cfg.SRVMaxHosts {
   876  		random.Shuffle(len(diff.Added), func(i, j int) {
   877  			diff.Added[i], diff.Added[j] = diff.Added[j], diff.Added[i]
   878  		})
   879  	}
   880  	// Add all added hosts until the number of servers reaches srvMaxHosts.
   881  	for _, a := range diff.Added {
   882  		if t.cfg.SRVMaxHosts > 0 && len(t.servers) >= t.cfg.SRVMaxHosts {
   883  			break
   884  		}
   885  		addr := address.Address(a).Canonicalize()
   886  		_ = t.addServer(addr)
   887  		t.fsm.addServer(addr)
   888  	}
   889  
   890  	// store new description
   891  	newDesc := description.Topology{
   892  		Kind:                     t.fsm.Kind,
   893  		Servers:                  t.fsm.Servers,
   894  		SessionTimeoutMinutesPtr: t.fsm.SessionTimeoutMinutesPtr,
   895  
   896  		// TODO(GODRIVER-2885): This field can be removed once legacy
   897  		// SessionTimeoutMinutes is removed.
   898  		SessionTimeoutMinutes: t.fsm.SessionTimeoutMinutes,
   899  	}
   900  	t.desc.Store(newDesc)
   901  
   902  	if !prev.Equal(newDesc) {
   903  		t.publishTopologyDescriptionChangedEvent(prev, newDesc)
   904  	}
   905  
   906  	t.subLock.Lock()
   907  	for _, ch := range t.subscribers {
   908  		// We drain the description if there's one in the channel
   909  		select {
   910  		case <-ch:
   911  		default:
   912  		}
   913  		ch <- newDesc
   914  	}
   915  	t.subLock.Unlock()
   916  
   917  	return true
   918  }
   919  
   920  // apply updates the Topology and its underlying FSM based on the provided server description and returns the server
   921  // description that should be stored.
   922  func (t *Topology) apply(ctx context.Context, desc description.Server) description.Server {
   923  	t.serversLock.Lock()
   924  	defer t.serversLock.Unlock()
   925  
   926  	ind, ok := t.fsm.findServer(desc.Addr)
   927  	if t.serversClosed || !ok {
   928  		return desc
   929  	}
   930  
   931  	prev := t.fsm.Topology
   932  	oldDesc := t.fsm.Servers[ind]
   933  	if oldDesc.TopologyVersion.CompareToIncoming(desc.TopologyVersion) > 0 {
   934  		return oldDesc
   935  	}
   936  
   937  	var current description.Topology
   938  	current, desc = t.fsm.apply(desc)
   939  
   940  	if !oldDesc.Equal(desc) {
   941  		t.publishServerDescriptionChangedEvent(oldDesc, desc)
   942  	}
   943  
   944  	diff := diffTopology(prev, current)
   945  
   946  	for _, removed := range diff.Removed {
   947  		if s, ok := t.servers[removed.Addr]; ok {
   948  			go func() {
   949  				cancelCtx, cancel := context.WithCancel(ctx)
   950  				cancel()
   951  				_ = s.Disconnect(cancelCtx)
   952  			}()
   953  			delete(t.servers, removed.Addr)
   954  			t.publishServerClosedEvent(s.address)
   955  		}
   956  	}
   957  
   958  	for _, added := range diff.Added {
   959  		_ = t.addServer(added.Addr)
   960  	}
   961  
   962  	t.desc.Store(current)
   963  	if !prev.Equal(current) {
   964  		t.publishTopologyDescriptionChangedEvent(prev, current)
   965  	}
   966  
   967  	t.subLock.Lock()
   968  	for _, ch := range t.subscribers {
   969  		// We drain the description if there's one in the channel
   970  		select {
   971  		case <-ch:
   972  		default:
   973  		}
   974  		ch <- current
   975  	}
   976  	t.subLock.Unlock()
   977  
   978  	return desc
   979  }
   980  
   981  func (t *Topology) addServer(addr address.Address) error {
   982  	if _, ok := t.servers[addr]; ok {
   983  		return nil
   984  	}
   985  
   986  	svr, err := ConnectServer(addr, t.updateCallback, t.id, t.cfg.ServerOpts...)
   987  	if err != nil {
   988  		return err
   989  	}
   990  
   991  	t.servers[addr] = svr
   992  
   993  	return nil
   994  }
   995  
   996  // String implements the Stringer interface
   997  func (t *Topology) String() string {
   998  	desc := t.Description()
   999  
  1000  	serversStr := ""
  1001  	t.serversLock.Lock()
  1002  	defer t.serversLock.Unlock()
  1003  	for _, s := range t.servers {
  1004  		serversStr += "{ " + s.String() + " }, "
  1005  	}
  1006  	return fmt.Sprintf("Type: %s, Servers: [%s]", desc.Kind, serversStr)
  1007  }
  1008  
  1009  // publishes a ServerDescriptionChangedEvent to indicate the server description has changed
  1010  func (t *Topology) publishServerDescriptionChangedEvent(prev description.Server, current description.Server) {
  1011  	serverDescriptionChanged := &event.ServerDescriptionChangedEvent{
  1012  		Address:             current.Addr,
  1013  		TopologyID:          t.id,
  1014  		PreviousDescription: prev,
  1015  		NewDescription:      current,
  1016  	}
  1017  
  1018  	if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.ServerDescriptionChanged != nil {
  1019  		t.cfg.ServerMonitor.ServerDescriptionChanged(serverDescriptionChanged)
  1020  	}
  1021  }
  1022  
  1023  // publishes a ServerClosedEvent to indicate the server has closed
  1024  func (t *Topology) publishServerClosedEvent(addr address.Address) {
  1025  	serverClosed := &event.ServerClosedEvent{
  1026  		Address:    addr,
  1027  		TopologyID: t.id,
  1028  	}
  1029  
  1030  	if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.ServerClosed != nil {
  1031  		t.cfg.ServerMonitor.ServerClosed(serverClosed)
  1032  	}
  1033  
  1034  	if mustLogTopologyMessage(t, logger.LevelDebug) {
  1035  		serverHost, serverPort, err := net.SplitHostPort(addr.String())
  1036  		if err != nil {
  1037  			serverHost = addr.String()
  1038  			serverPort = ""
  1039  		}
  1040  
  1041  		portInt64, _ := strconv.ParseInt(serverPort, 10, 32)
  1042  
  1043  		logTopologyMessage(t, logger.LevelDebug, logger.TopologyServerClosed,
  1044  			logger.KeyServerHost, serverHost,
  1045  			logger.KeyServerPort, portInt64)
  1046  	}
  1047  }
  1048  
  1049  // publishes a TopologyDescriptionChangedEvent to indicate the topology description has changed
  1050  func (t *Topology) publishTopologyDescriptionChangedEvent(prev description.Topology, current description.Topology) {
  1051  	topologyDescriptionChanged := &event.TopologyDescriptionChangedEvent{
  1052  		TopologyID:          t.id,
  1053  		PreviousDescription: prev,
  1054  		NewDescription:      current,
  1055  	}
  1056  
  1057  	if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.TopologyDescriptionChanged != nil {
  1058  		t.cfg.ServerMonitor.TopologyDescriptionChanged(topologyDescriptionChanged)
  1059  	}
  1060  
  1061  	if mustLogTopologyMessage(t, logger.LevelDebug) {
  1062  		logTopologyMessage(t, logger.LevelDebug, logger.TopologyDescriptionChanged,
  1063  			logger.KeyPreviousDescription, prev.String(),
  1064  			logger.KeyNewDescription, current.String())
  1065  	}
  1066  }
  1067  
  1068  // publishes a TopologyOpeningEvent to indicate the topology is being initialized
  1069  func (t *Topology) publishTopologyOpeningEvent() {
  1070  	topologyOpening := &event.TopologyOpeningEvent{
  1071  		TopologyID: t.id,
  1072  	}
  1073  
  1074  	if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.TopologyOpening != nil {
  1075  		t.cfg.ServerMonitor.TopologyOpening(topologyOpening)
  1076  	}
  1077  
  1078  	if mustLogTopologyMessage(t, logger.LevelDebug) {
  1079  		logTopologyMessage(t, logger.LevelDebug, logger.TopologyOpening)
  1080  	}
  1081  }
  1082  
  1083  // publishes a TopologyClosedEvent to indicate the topology has been closed
  1084  func (t *Topology) publishTopologyClosedEvent() {
  1085  	topologyClosed := &event.TopologyClosedEvent{
  1086  		TopologyID: t.id,
  1087  	}
  1088  
  1089  	if t.cfg.ServerMonitor != nil && t.cfg.ServerMonitor.TopologyClosed != nil {
  1090  		t.cfg.ServerMonitor.TopologyClosed(topologyClosed)
  1091  	}
  1092  
  1093  	if mustLogTopologyMessage(t, logger.LevelDebug) {
  1094  		logTopologyMessage(t, logger.LevelDebug, logger.TopologyClosed)
  1095  	}
  1096  }
  1097  

View as plain text