...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/topology

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package topology
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"net"
    14  	"sync"
    15  	"sync/atomic"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/bson"
    19  	"go.mongodb.org/mongo-driver/bson/primitive"
    20  	"go.mongodb.org/mongo-driver/event"
    21  	"go.mongodb.org/mongo-driver/internal/driverutil"
    22  	"go.mongodb.org/mongo-driver/internal/logger"
    23  	"go.mongodb.org/mongo-driver/mongo/address"
    24  	"go.mongodb.org/mongo-driver/mongo/description"
    25  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    26  	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
    27  	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
    28  )
    29  
    30  const minHeartbeatInterval = 500 * time.Millisecond
    31  const wireVersion42 = 8 // Wire version for MongoDB 4.2
    32  
    33  // Server state constants.
    34  const (
    35  	serverDisconnected int64 = iota
    36  	serverDisconnecting
    37  	serverConnected
    38  )
    39  
    40  func serverStateString(state int64) string {
    41  	switch state {
    42  	case serverDisconnected:
    43  		return "Disconnected"
    44  	case serverDisconnecting:
    45  		return "Disconnecting"
    46  	case serverConnected:
    47  		return "Connected"
    48  	}
    49  
    50  	return ""
    51  }
    52  
    53  var (
    54  	// ErrServerClosed occurs when an attempt to Get a connection is made after
    55  	// the server has been closed.
    56  	ErrServerClosed = errors.New("server is closed")
    57  	// ErrServerConnected occurs when at attempt to Connect is made after a server
    58  	// has already been connected.
    59  	ErrServerConnected = errors.New("server is connected")
    60  
    61  	errCheckCancelled = errors.New("server check cancelled")
    62  	emptyDescription  = description.NewDefaultServer("")
    63  )
    64  
    65  // SelectedServer represents a specific server that was selected during server selection.
    66  // It contains the kind of the topology it was selected from.
    67  type SelectedServer struct {
    68  	*Server
    69  
    70  	Kind description.TopologyKind
    71  }
    72  
    73  // Description returns a description of the server as of the last heartbeat.
    74  func (ss *SelectedServer) Description() description.SelectedServer {
    75  	sdesc := ss.Server.Description()
    76  	return description.SelectedServer{
    77  		Server: sdesc,
    78  		Kind:   ss.Kind,
    79  	}
    80  }
    81  
    82  // Server is a single server within a topology.
    83  type Server struct {
    84  	// The following integer fields must be accessed using the atomic package and should be at the
    85  	// beginning of the struct.
    86  	// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
    87  	// - suggested layout: https://go101.org/article/memory-layout.html
    88  
    89  	state          int64
    90  	operationCount int64
    91  
    92  	cfg     *serverConfig
    93  	address address.Address
    94  
    95  	// connection related fields
    96  	pool *pool
    97  
    98  	// goroutine management fields
    99  	done          chan struct{}
   100  	checkNow      chan struct{}
   101  	disconnecting chan struct{}
   102  	closewg       sync.WaitGroup
   103  
   104  	// description related fields
   105  	desc                   atomic.Value // holds a description.Server
   106  	updateTopologyCallback atomic.Value
   107  	topologyID             primitive.ObjectID
   108  
   109  	// subscriber related fields
   110  	subLock             sync.Mutex
   111  	subscribers         map[uint64]chan description.Server
   112  	currentSubscriberID uint64
   113  	subscriptionsClosed bool
   114  
   115  	// heartbeat and cancellation related fields
   116  	// globalCtx should be created in NewServer and cancelled in Disconnect to signal that the server is shutting down.
   117  	// heartbeatCtx should be used for individual heartbeats and should be a child of globalCtx so that it will be
   118  	// cancelled automatically during shutdown.
   119  	heartbeatLock      sync.Mutex
   120  	conn               *connection
   121  	globalCtx          context.Context
   122  	globalCtxCancel    context.CancelFunc
   123  	heartbeatCtx       context.Context
   124  	heartbeatCtxCancel context.CancelFunc
   125  
   126  	processErrorLock sync.Mutex
   127  	rttMonitor       *rttMonitor
   128  }
   129  
   130  // updateTopologyCallback is a callback used to create a server that should be called when the parent Topology instance
   131  // should be updated based on a new server description. The callback must return the server description that should be
   132  // stored by the server.
   133  type updateTopologyCallback func(description.Server) description.Server
   134  
   135  // ConnectServer creates a new Server and then initializes it using the
   136  // Connect method.
   137  func ConnectServer(
   138  	addr address.Address,
   139  	updateCallback updateTopologyCallback,
   140  	topologyID primitive.ObjectID,
   141  	opts ...ServerOption,
   142  ) (*Server, error) {
   143  	srvr := NewServer(addr, topologyID, opts...)
   144  	err := srvr.Connect(updateCallback)
   145  	if err != nil {
   146  		return nil, err
   147  	}
   148  	return srvr, nil
   149  }
   150  
   151  // NewServer creates a new server. The mongodb server at the address will be monitored
   152  // on an internal monitoring goroutine.
   153  func NewServer(addr address.Address, topologyID primitive.ObjectID, opts ...ServerOption) *Server {
   154  	cfg := newServerConfig(opts...)
   155  	globalCtx, globalCtxCancel := context.WithCancel(context.Background())
   156  	s := &Server{
   157  		state: serverDisconnected,
   158  
   159  		cfg:     cfg,
   160  		address: addr,
   161  
   162  		done:          make(chan struct{}),
   163  		checkNow:      make(chan struct{}, 1),
   164  		disconnecting: make(chan struct{}),
   165  
   166  		topologyID: topologyID,
   167  
   168  		subscribers:     make(map[uint64]chan description.Server),
   169  		globalCtx:       globalCtx,
   170  		globalCtxCancel: globalCtxCancel,
   171  	}
   172  	s.desc.Store(description.NewDefaultServer(addr))
   173  	rttCfg := &rttConfig{
   174  		interval:           cfg.heartbeatInterval,
   175  		minRTTWindow:       5 * time.Minute,
   176  		createConnectionFn: s.createConnection,
   177  		createOperationFn:  s.createBaseOperation,
   178  	}
   179  	s.rttMonitor = newRTTMonitor(rttCfg)
   180  
   181  	pc := poolConfig{
   182  		Address:          addr,
   183  		MinPoolSize:      cfg.minConns,
   184  		MaxPoolSize:      cfg.maxConns,
   185  		MaxConnecting:    cfg.maxConnecting,
   186  		MaxIdleTime:      cfg.poolMaxIdleTime,
   187  		MaintainInterval: cfg.poolMaintainInterval,
   188  		LoadBalanced:     cfg.loadBalanced,
   189  		PoolMonitor:      cfg.poolMonitor,
   190  		Logger:           cfg.logger,
   191  		handshakeErrFn:   s.ProcessHandshakeError,
   192  	}
   193  
   194  	connectionOpts := copyConnectionOpts(cfg.connectionOpts)
   195  	s.pool = newPool(pc, connectionOpts...)
   196  	s.publishServerOpeningEvent(s.address)
   197  
   198  	return s
   199  }
   200  
   201  func mustLogServerMessage(srv *Server) bool {
   202  	return srv.cfg.logger != nil && srv.cfg.logger.LevelComponentEnabled(
   203  		logger.LevelDebug, logger.ComponentTopology)
   204  }
   205  
   206  func logServerMessage(srv *Server, msg string, keysAndValues ...interface{}) {
   207  	serverHost, serverPort, err := net.SplitHostPort(srv.address.String())
   208  	if err != nil {
   209  		serverHost = srv.address.String()
   210  		serverPort = ""
   211  	}
   212  
   213  	var driverConnectionID uint64
   214  	var serverConnectionID *int64
   215  
   216  	if srv.conn != nil {
   217  		driverConnectionID = srv.conn.driverConnectionID
   218  		serverConnectionID = srv.conn.serverConnectionID
   219  	}
   220  
   221  	srv.cfg.logger.Print(logger.LevelDebug,
   222  		logger.ComponentTopology,
   223  		msg,
   224  		logger.SerializeServer(logger.Server{
   225  			DriverConnectionID: driverConnectionID,
   226  			TopologyID:         srv.topologyID,
   227  			Message:            msg,
   228  			ServerConnectionID: serverConnectionID,
   229  			ServerHost:         serverHost,
   230  			ServerPort:         serverPort,
   231  		}, keysAndValues...)...)
   232  }
   233  
   234  // Connect initializes the Server by starting background monitoring goroutines.
   235  // This method must be called before a Server can be used.
   236  func (s *Server) Connect(updateCallback updateTopologyCallback) error {
   237  	if !atomic.CompareAndSwapInt64(&s.state, serverDisconnected, serverConnected) {
   238  		return ErrServerConnected
   239  	}
   240  
   241  	desc := description.NewDefaultServer(s.address)
   242  	if s.cfg.loadBalanced {
   243  		// LBs automatically start off with kind LoadBalancer because there is no monitoring routine for state changes.
   244  		desc.Kind = description.LoadBalancer
   245  	}
   246  	s.desc.Store(desc)
   247  	s.updateTopologyCallback.Store(updateCallback)
   248  
   249  	if !s.cfg.monitoringDisabled && !s.cfg.loadBalanced {
   250  		s.closewg.Add(1)
   251  		go s.update()
   252  	}
   253  
   254  	// The CMAP spec describes that pools should only be marked "ready" when the server description
   255  	// is updated to something other than "Unknown". However, we maintain the previous Server
   256  	// behavior here and immediately mark the pool as ready during Connect() to simplify and speed
   257  	// up the Client startup behavior. The risk of marking a pool as ready proactively during
   258  	// Connect() is that we could attempt to create connections to a server that was configured
   259  	// erroneously until the first server check or checkOut() failure occurs, when the SDAM error
   260  	// handler would transition the Server back to "Unknown" and set the pool to "paused".
   261  	return s.pool.ready()
   262  }
   263  
   264  // Disconnect closes sockets to the server referenced by this Server.
   265  // Subscriptions to this Server will be closed. Disconnect will shutdown
   266  // any monitoring goroutines, closeConnection the idle connection pool, and will
   267  // wait until all the in use connections have been returned to the connection
   268  // pool and are closed before returning. If the context expires via
   269  // cancellation, deadline, or timeout before the in use connections have been
   270  // returned, the in use connections will be closed, resulting in the failure of
   271  // any in flight read or write operations. If this method returns with no
   272  // errors, all connections associated with this Server have been closed.
   273  func (s *Server) Disconnect(ctx context.Context) error {
   274  	if !atomic.CompareAndSwapInt64(&s.state, serverConnected, serverDisconnecting) {
   275  		return ErrServerClosed
   276  	}
   277  
   278  	s.updateTopologyCallback.Store((updateTopologyCallback)(nil))
   279  
   280  	// Cancel the global context so any new contexts created from it will be automatically cancelled. Close the done
   281  	// channel so the update() routine will know that it can stop. Cancel any in-progress monitoring checks at the end.
   282  	// The done channel is closed before cancelling the check so the update routine() will immediately detect that it
   283  	// can stop rather than trying to create new connections until the read from done succeeds.
   284  	s.globalCtxCancel()
   285  	close(s.done)
   286  	s.cancelCheck()
   287  
   288  	s.rttMonitor.disconnect()
   289  	s.pool.close(ctx)
   290  
   291  	s.closewg.Wait()
   292  	atomic.StoreInt64(&s.state, serverDisconnected)
   293  
   294  	return nil
   295  }
   296  
   297  // Connection gets a connection to the server.
   298  func (s *Server) Connection(ctx context.Context) (driver.Connection, error) {
   299  	if atomic.LoadInt64(&s.state) != serverConnected {
   300  		return nil, ErrServerClosed
   301  	}
   302  
   303  	// Increment the operation count before calling checkOut to make sure that all connection
   304  	// requests are included in the operation count, including those in the wait queue. If we got an
   305  	// error instead of a connection, immediately decrement the operation count.
   306  	atomic.AddInt64(&s.operationCount, 1)
   307  	conn, err := s.pool.checkOut(ctx)
   308  	if err != nil {
   309  		atomic.AddInt64(&s.operationCount, -1)
   310  		return nil, err
   311  	}
   312  
   313  	return &Connection{
   314  		connection: conn,
   315  		cleanupServerFn: func() {
   316  			// Decrement the operation count whenever the caller is done with the connection. Note
   317  			// that cleanupServerFn() is not called while the connection is pinned to a cursor or
   318  			// transaction, so the operation count is not decremented until the cursor is closed or
   319  			// the transaction is committed or aborted. Use an int64 instead of a uint64 to mitigate
   320  			// the impact of any possible bugs that could cause the uint64 to underflow, which would
   321  			// make the server much less selectable.
   322  			atomic.AddInt64(&s.operationCount, -1)
   323  		},
   324  	}, nil
   325  }
   326  
   327  // ProcessHandshakeError implements SDAM error handling for errors that occur before a connection
   328  // finishes handshaking.
   329  func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint64, serviceID *primitive.ObjectID) {
   330  	// Ignore the error if the server is behind a load balancer but the service ID is unknown. This indicates that the
   331  	// error happened when dialing the connection or during the MongoDB handshake, so we don't know the service ID to
   332  	// use for clearing the pool.
   333  	if err == nil || s.cfg.loadBalanced && serviceID == nil {
   334  		return
   335  	}
   336  	// Ignore the error if the connection is stale.
   337  	if generation, _ := s.pool.generation.getGeneration(serviceID); startingGenerationNumber < generation {
   338  		return
   339  	}
   340  
   341  	// Unwrap any connection errors. If there is no wrapped connection error, then the error should
   342  	// not result in any Server state change (e.g. a command error from the database).
   343  	wrappedConnErr := unwrapConnectionError(err)
   344  	if wrappedConnErr == nil {
   345  		return
   346  	}
   347  
   348  	// Must hold the processErrorLock while updating the server description and clearing the pool.
   349  	// Not holding the lock leads to possible out-of-order processing of pool.clear() and
   350  	// pool.ready() calls from concurrent server description updates.
   351  	s.processErrorLock.Lock()
   352  	defer s.processErrorLock.Unlock()
   353  
   354  	// Since the only kind of ConnectionError we receive from pool.Get will be an initialization error, we should set
   355  	// the description.Server appropriately. The description should not have a TopologyVersion because the staleness
   356  	// checking logic above has already determined that this description is not stale.
   357  	s.updateDescription(description.NewServerFromError(s.address, wrappedConnErr, nil))
   358  	s.pool.clear(err, serviceID)
   359  	s.cancelCheck()
   360  }
   361  
   362  // Description returns a description of the server as of the last heartbeat.
   363  func (s *Server) Description() description.Server {
   364  	return s.desc.Load().(description.Server)
   365  }
   366  
   367  // SelectedDescription returns a description.SelectedServer with a Kind of
   368  // Single. This can be used when performing tasks like monitoring a batch
   369  // of servers and you want to run one off commands against those servers.
   370  func (s *Server) SelectedDescription() description.SelectedServer {
   371  	sdesc := s.Description()
   372  	return description.SelectedServer{
   373  		Server: sdesc,
   374  		Kind:   description.Single,
   375  	}
   376  }
   377  
   378  // Subscribe returns a ServerSubscription which has a channel on which all
   379  // updated server descriptions will be sent. The channel will have a buffer
   380  // size of one, and will be pre-populated with the current description.
   381  func (s *Server) Subscribe() (*ServerSubscription, error) {
   382  	if atomic.LoadInt64(&s.state) != serverConnected {
   383  		return nil, ErrSubscribeAfterClosed
   384  	}
   385  	ch := make(chan description.Server, 1)
   386  	ch <- s.desc.Load().(description.Server)
   387  
   388  	s.subLock.Lock()
   389  	defer s.subLock.Unlock()
   390  	if s.subscriptionsClosed {
   391  		return nil, ErrSubscribeAfterClosed
   392  	}
   393  	id := s.currentSubscriberID
   394  	s.subscribers[id] = ch
   395  	s.currentSubscriberID++
   396  
   397  	ss := &ServerSubscription{
   398  		C:  ch,
   399  		s:  s,
   400  		id: id,
   401  	}
   402  
   403  	return ss, nil
   404  }
   405  
   406  // RequestImmediateCheck will cause the server to send a heartbeat immediately
   407  // instead of waiting for the heartbeat timeout.
   408  func (s *Server) RequestImmediateCheck() {
   409  	select {
   410  	case s.checkNow <- struct{}{}:
   411  	default:
   412  	}
   413  }
   414  
   415  // getWriteConcernErrorForProcessing extracts a driver.WriteConcernError from the provided error. This function returns
   416  // (error, true) if the error is a WriteConcernError and the falls under the requirements for SDAM error
   417  // handling and (nil, false) otherwise.
   418  func getWriteConcernErrorForProcessing(err error) (*driver.WriteConcernError, bool) {
   419  	var writeCmdErr driver.WriteCommandError
   420  	if !errors.As(err, &writeCmdErr) {
   421  		return nil, false
   422  	}
   423  
   424  	wcerr := writeCmdErr.WriteConcernError
   425  	if wcerr != nil && (wcerr.NodeIsRecovering() || wcerr.NotPrimary()) {
   426  		return wcerr, true
   427  	}
   428  	return nil, false
   429  }
   430  
   431  // ProcessError handles SDAM error handling and implements driver.ErrorProcessor.
   432  func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult {
   433  	// Ignore nil errors.
   434  	if err == nil {
   435  		return driver.NoChange
   436  	}
   437  
   438  	// Ignore errors from stale connections because the error came from a previous generation of the
   439  	// connection pool. The root cause of the error has already been handled, which is what caused
   440  	// the pool generation to increment. Processing errors for stale connections could result in
   441  	// handling the same error root cause multiple times (e.g. a temporary network interrupt causing
   442  	// all connections to the same server to return errors).
   443  	if conn.Stale() {
   444  		return driver.NoChange
   445  	}
   446  
   447  	// Must hold the processErrorLock while updating the server description and clearing the pool.
   448  	// Not holding the lock leads to possible out-of-order processing of pool.clear() and
   449  	// pool.ready() calls from concurrent server description updates.
   450  	s.processErrorLock.Lock()
   451  	defer s.processErrorLock.Unlock()
   452  
   453  	// Get the wire version and service ID from the connection description because they will never
   454  	// change for the lifetime of a connection and can possibly be different between connections to
   455  	// the same server.
   456  	connDesc := conn.Description()
   457  	wireVersion := connDesc.WireVersion
   458  	serviceID := connDesc.ServiceID
   459  
   460  	// Get the topology version from the Server description because the Server description is
   461  	// updated by heartbeats and errors, so typically has a more up-to-date topology version.
   462  	serverDesc := s.desc.Load().(description.Server)
   463  	topologyVersion := serverDesc.TopologyVersion
   464  
   465  	// We don't currently update the Server topology version when we create new application
   466  	// connections, so it's possible for a connection's topology version to be newer than the
   467  	// Server's topology version. Pick the "newest" of the two topology versions.
   468  	// Technically a nil topology version on a new database response should be considered a new
   469  	// topology version and replace the Server's topology version. However, we don't know if the
   470  	// connection's topology version is based on a new or old database response, so we ignore a nil
   471  	// topology version on the connection for now.
   472  	//
   473  	// TODO(GODRIVER-2841): Remove this logic once we set the Server description when we create
   474  	// TODO application connections because then the Server's topology version will always be the
   475  	// TODO latest known.
   476  	if tv := connDesc.TopologyVersion; tv != nil && topologyVersion.CompareToIncoming(tv) < 0 {
   477  		topologyVersion = tv
   478  	}
   479  
   480  	// Invalidate server description if not primary or node recovering error occurs.
   481  	// These errors can be reported as a command error or a write concern error.
   482  	if cerr, ok := err.(driver.Error); ok && (cerr.NodeIsRecovering() || cerr.NotPrimary()) {
   483  		// Ignore errors that came from when the database was on a previous topology version.
   484  		if topologyVersion.CompareToIncoming(cerr.TopologyVersion) >= 0 {
   485  			return driver.NoChange
   486  		}
   487  
   488  		// updates description to unknown
   489  		s.updateDescription(description.NewServerFromError(s.address, err, cerr.TopologyVersion))
   490  		s.RequestImmediateCheck()
   491  
   492  		res := driver.ServerMarkedUnknown
   493  		// If the node is shutting down or is older than 4.2, we synchronously clear the pool
   494  		if cerr.NodeIsShuttingDown() || wireVersion == nil || wireVersion.Max < wireVersion42 {
   495  			res = driver.ConnectionPoolCleared
   496  			s.pool.clear(err, serviceID)
   497  		}
   498  
   499  		return res
   500  	}
   501  	if wcerr, ok := getWriteConcernErrorForProcessing(err); ok {
   502  		// Ignore errors that came from when the database was on a previous topology version.
   503  		if topologyVersion.CompareToIncoming(wcerr.TopologyVersion) >= 0 {
   504  			return driver.NoChange
   505  		}
   506  
   507  		// updates description to unknown
   508  		s.updateDescription(description.NewServerFromError(s.address, err, wcerr.TopologyVersion))
   509  		s.RequestImmediateCheck()
   510  
   511  		res := driver.ServerMarkedUnknown
   512  		// If the node is shutting down or is older than 4.2, we synchronously clear the pool
   513  		if wcerr.NodeIsShuttingDown() || wireVersion == nil || wireVersion.Max < wireVersion42 {
   514  			res = driver.ConnectionPoolCleared
   515  			s.pool.clear(err, serviceID)
   516  		}
   517  		return res
   518  	}
   519  
   520  	wrappedConnErr := unwrapConnectionError(err)
   521  	if wrappedConnErr == nil {
   522  		return driver.NoChange
   523  	}
   524  
   525  	// Ignore transient timeout errors.
   526  	if netErr, ok := wrappedConnErr.(net.Error); ok && netErr.Timeout() {
   527  		return driver.NoChange
   528  	}
   529  	if errors.Is(wrappedConnErr, context.Canceled) || errors.Is(wrappedConnErr, context.DeadlineExceeded) {
   530  		return driver.NoChange
   531  	}
   532  
   533  	// For a non-timeout network error, we clear the pool, set the description to Unknown, and cancel the in-progress
   534  	// monitoring check. The check is cancelled last to avoid a post-cancellation reconnect racing with
   535  	// updateDescription.
   536  	s.updateDescription(description.NewServerFromError(s.address, err, nil))
   537  	s.pool.clear(err, serviceID)
   538  	s.cancelCheck()
   539  	return driver.ConnectionPoolCleared
   540  }
   541  
   542  // update handle performing heartbeats and updating any subscribers of the
   543  // newest description.Server retrieved.
   544  func (s *Server) update() {
   545  	defer s.closewg.Done()
   546  	heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval)
   547  	rateLimiter := time.NewTicker(minHeartbeatInterval)
   548  	defer heartbeatTicker.Stop()
   549  	defer rateLimiter.Stop()
   550  	checkNow := s.checkNow
   551  	done := s.done
   552  
   553  	defer logUnexpectedFailure(s.cfg.logger, "Encountered unexpected failure updating server")
   554  
   555  	closeServer := func() {
   556  		s.subLock.Lock()
   557  		for id, c := range s.subscribers {
   558  			close(c)
   559  			delete(s.subscribers, id)
   560  		}
   561  		s.subscriptionsClosed = true
   562  		s.subLock.Unlock()
   563  
   564  		// We don't need to take s.heartbeatLock here because closeServer is called synchronously when the select checks
   565  		// below detect that the server is being closed, so we can be sure that the connection isn't being used.
   566  		if s.conn != nil {
   567  			_ = s.conn.close()
   568  		}
   569  	}
   570  
   571  	waitUntilNextCheck := func() {
   572  		// Wait until heartbeatFrequency elapses, an application operation requests an immediate check, or the server
   573  		// is disconnecting.
   574  		select {
   575  		case <-heartbeatTicker.C:
   576  		case <-checkNow:
   577  		case <-done:
   578  			// Return because the next update iteration will check the done channel again and clean up.
   579  			return
   580  		}
   581  
   582  		// Ensure we only return if minHeartbeatFrequency has elapsed or the server is disconnecting.
   583  		select {
   584  		case <-rateLimiter.C:
   585  		case <-done:
   586  			return
   587  		}
   588  	}
   589  
   590  	timeoutCnt := 0
   591  	for {
   592  		// Check if the server is disconnecting. Even if waitForNextCheck has already read from the done channel, we
   593  		// can safely read from it again because Disconnect closes the channel.
   594  		select {
   595  		case <-done:
   596  			closeServer()
   597  			return
   598  		default:
   599  		}
   600  
   601  		previousDescription := s.Description()
   602  
   603  		// Perform the next check.
   604  		desc, err := s.check()
   605  		if errors.Is(err, errCheckCancelled) {
   606  			if atomic.LoadInt64(&s.state) != serverConnected {
   607  				continue
   608  			}
   609  
   610  			// If the server is not disconnecting, the check was cancelled by an application operation after an error.
   611  			// Wait before running the next check.
   612  			waitUntilNextCheck()
   613  			continue
   614  		}
   615  
   616  		if isShortcut := func() bool {
   617  			// Must hold the processErrorLock while updating the server description and clearing the
   618  			// pool. Not holding the lock leads to possible out-of-order processing of pool.clear() and
   619  			// pool.ready() calls from concurrent server description updates.
   620  			s.processErrorLock.Lock()
   621  			defer s.processErrorLock.Unlock()
   622  
   623  			s.updateDescription(desc)
   624  			// Retry after the first timeout before clearing the pool in case of a FAAS pause as
   625  			// described in GODRIVER-2577.
   626  			if err := unwrapConnectionError(desc.LastError); err != nil && timeoutCnt < 1 {
   627  				if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
   628  					timeoutCnt++
   629  					// We want to immediately retry on timeout error. Continue to next loop.
   630  					return true
   631  				}
   632  				if err, ok := err.(net.Error); ok && err.Timeout() {
   633  					timeoutCnt++
   634  					// We want to immediately retry on timeout error. Continue to next loop.
   635  					return true
   636  				}
   637  			}
   638  			if err := desc.LastError; err != nil {
   639  				// Clear the pool once the description has been updated to Unknown. Pass in a nil service ID to clear
   640  				// because the monitoring routine only runs for non-load balanced deployments in which servers don't return
   641  				// IDs.
   642  				if timeoutCnt > 0 {
   643  					s.pool.clearAll(err, nil)
   644  				} else {
   645  					s.pool.clear(err, nil)
   646  				}
   647  			}
   648  			// We're either not handling a timeout error, or we just handled the 2nd consecutive
   649  			// timeout error. In either case, reset the timeout count to 0 and return false to
   650  			// continue the normal check process.
   651  			timeoutCnt = 0
   652  			return false
   653  		}(); isShortcut {
   654  			continue
   655  		}
   656  
   657  		// If the server supports streaming or we're already streaming, we want to move to streaming the next response
   658  		// without waiting. If the server has transitioned to Unknown from a network error, we want to do another
   659  		// check without waiting in case it was a transient error and the server isn't actually down.
   660  		connectionIsStreaming := s.conn != nil && s.conn.getCurrentlyStreaming()
   661  		transitionedFromNetworkError := desc.LastError != nil && unwrapConnectionError(desc.LastError) != nil &&
   662  			previousDescription.Kind != description.Unknown
   663  
   664  		if isStreamingEnabled(s) && isStreamable(s) && !s.rttMonitor.started {
   665  			s.rttMonitor.connect()
   666  		}
   667  
   668  		if isStreamable(s) || connectionIsStreaming || transitionedFromNetworkError {
   669  			continue
   670  		}
   671  
   672  		// The server either does not support the streamable protocol or is not in a healthy state, so we wait until
   673  		// the next check.
   674  		waitUntilNextCheck()
   675  	}
   676  }
   677  
   678  // updateDescription handles updating the description on the Server, notifying
   679  // subscribers, and potentially draining the connection pool. The initial
   680  // parameter is used to determine if this is the first description from the
   681  // server.
   682  func (s *Server) updateDescription(desc description.Server) {
   683  	if s.cfg.loadBalanced {
   684  		// In load balanced mode, there are no updates from the monitoring routine. For errors encountered in pooled
   685  		// connections, the server should not be marked Unknown to ensure that the LB remains selectable.
   686  		return
   687  	}
   688  
   689  	defer logUnexpectedFailure(s.cfg.logger, "Encountered unexpected failure updating server description")
   690  
   691  	// Anytime we update the server description to something other than "unknown", set the pool to
   692  	// "ready". Do this before updating the description so that connections can be checked out as
   693  	// soon as the server is selectable. If the pool is already ready, this operation is a no-op.
   694  	// Note that this behavior is roughly consistent with the current Go driver behavior (connects
   695  	// to all servers, even non-data-bearing nodes) but deviates slightly from CMAP spec, which
   696  	// specifies a more restricted set of server descriptions and topologies that should mark the
   697  	// pool ready. We don't have access to the topology here, so prefer the current Go driver
   698  	// behavior for simplicity.
   699  	if desc.Kind != description.Unknown {
   700  		_ = s.pool.ready()
   701  	}
   702  
   703  	// Use the updateTopologyCallback to update the parent Topology and get the description that should be stored.
   704  	callback, ok := s.updateTopologyCallback.Load().(updateTopologyCallback)
   705  	if ok && callback != nil {
   706  		desc = callback(desc)
   707  	}
   708  	s.desc.Store(desc)
   709  
   710  	s.subLock.Lock()
   711  	for _, c := range s.subscribers {
   712  		select {
   713  		// drain the channel if it isn't empty
   714  		case <-c:
   715  		default:
   716  		}
   717  		c <- desc
   718  	}
   719  	s.subLock.Unlock()
   720  }
   721  
   722  // createConnection creates a new connection instance but does not call connect on it. The caller must call connect
   723  // before the connection can be used for network operations.
   724  func (s *Server) createConnection() *connection {
   725  	opts := copyConnectionOpts(s.cfg.connectionOpts)
   726  	opts = append(opts,
   727  		WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
   728  		WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
   729  		WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
   730  		// We override whatever handshaker is currently attached to the options with a basic
   731  		// one because need to make sure we don't do auth.
   732  		WithHandshaker(func(h Handshaker) Handshaker {
   733  			return operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts).
   734  				ServerAPI(s.cfg.serverAPI)
   735  		}),
   736  		// Override any monitors specified in options with nil to avoid monitoring heartbeats.
   737  		WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return nil }),
   738  	)
   739  
   740  	return newConnection(s.address, opts...)
   741  }
   742  
   743  func copyConnectionOpts(opts []ConnectionOption) []ConnectionOption {
   744  	optsCopy := make([]ConnectionOption, len(opts))
   745  	copy(optsCopy, opts)
   746  	return optsCopy
   747  }
   748  
   749  func (s *Server) setupHeartbeatConnection() error {
   750  	conn := s.createConnection()
   751  
   752  	// Take the lock when assigning the context and connection because they're accessed by cancelCheck.
   753  	s.heartbeatLock.Lock()
   754  	if s.heartbeatCtxCancel != nil {
   755  		// Ensure the previous context is cancelled to avoid a leak.
   756  		s.heartbeatCtxCancel()
   757  	}
   758  	s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx)
   759  	s.conn = conn
   760  	s.heartbeatLock.Unlock()
   761  
   762  	return s.conn.connect(s.heartbeatCtx)
   763  }
   764  
   765  // cancelCheck cancels in-progress connection dials and reads. It does not set any fields on the server.
   766  func (s *Server) cancelCheck() {
   767  	var conn *connection
   768  
   769  	// Take heartbeatLock for mutual exclusion with the checks in the update function.
   770  	s.heartbeatLock.Lock()
   771  	if s.heartbeatCtx != nil {
   772  		s.heartbeatCtxCancel()
   773  	}
   774  	conn = s.conn
   775  	s.heartbeatLock.Unlock()
   776  
   777  	if conn == nil {
   778  		return
   779  	}
   780  
   781  	// If the connection exists, we need to wait for it to be connected because conn.connect() and
   782  	// conn.close() cannot be called concurrently. If the connection wasn't successfully opened, its
   783  	// state was set back to disconnected, so calling conn.close() will be a no-op.
   784  	conn.closeConnectContext()
   785  	conn.wait()
   786  	_ = conn.close()
   787  }
   788  
   789  func (s *Server) checkWasCancelled() bool {
   790  	return s.heartbeatCtx.Err() != nil
   791  }
   792  
   793  func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello {
   794  	return operation.
   795  		NewHello().
   796  		ClusterClock(s.cfg.clock).
   797  		Deployment(driver.SingleConnectionDeployment{C: conn}).
   798  		ServerAPI(s.cfg.serverAPI)
   799  }
   800  
   801  func isStreamingEnabled(srv *Server) bool {
   802  	switch srv.cfg.serverMonitoringMode {
   803  	case connstring.ServerMonitoringModeStream:
   804  		return true
   805  	case connstring.ServerMonitoringModePoll:
   806  		return false
   807  	default:
   808  		return driverutil.GetFaasEnvName() == ""
   809  	}
   810  }
   811  
   812  func isStreamable(srv *Server) bool {
   813  	return srv.Description().Kind != description.Unknown && srv.Description().TopologyVersion != nil
   814  }
   815  
   816  func (s *Server) check() (description.Server, error) {
   817  	var descPtr *description.Server
   818  	var err error
   819  	var duration time.Duration
   820  
   821  	start := time.Now()
   822  
   823  	// Create a new connection if this is the first check, the connection was closed after an error during the previous
   824  	// check, or the previous check was cancelled.
   825  	if s.conn == nil || s.conn.closed() || s.checkWasCancelled() {
   826  		connID := "0"
   827  		if s.conn != nil {
   828  			connID = s.conn.ID()
   829  		}
   830  		s.publishServerHeartbeatStartedEvent(connID, false)
   831  		// Create a new connection and add it's handshake RTT as a sample.
   832  		err = s.setupHeartbeatConnection()
   833  		duration = time.Since(start)
   834  		connID = "0"
   835  		if s.conn != nil {
   836  			connID = s.conn.ID()
   837  		}
   838  		if err == nil {
   839  			// Use the description from the connection handshake as the value for this check.
   840  			s.rttMonitor.addSample(s.conn.helloRTT)
   841  			descPtr = &s.conn.desc
   842  			s.publishServerHeartbeatSucceededEvent(connID, duration, s.conn.desc, false)
   843  		} else {
   844  			err = unwrapConnectionError(err)
   845  			s.publishServerHeartbeatFailedEvent(connID, duration, err, false)
   846  		}
   847  	} else {
   848  		// An existing connection is being used. Use the server description properties to execute the right heartbeat.
   849  
   850  		// Wrap conn in a type that implements driver.StreamerConnection.
   851  		heartbeatConn := initConnection{s.conn}
   852  		baseOperation := s.createBaseOperation(heartbeatConn)
   853  		previousDescription := s.Description()
   854  		streamable := isStreamingEnabled(s) && isStreamable(s)
   855  
   856  		s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable)
   857  
   858  		switch {
   859  		case s.conn.getCurrentlyStreaming():
   860  			// The connection is already in a streaming state, so we stream the next response.
   861  			err = baseOperation.StreamResponse(s.heartbeatCtx, heartbeatConn)
   862  		case streamable:
   863  			// The server supports the streamable protocol. Set the socket timeout to
   864  			// connectTimeoutMS+heartbeatFrequencyMS and execute an awaitable hello request. Set conn.canStream so
   865  			// the wire message will advertise streaming support to the server.
   866  
   867  			// Calculation for maxAwaitTimeMS is taken from time.Duration.Milliseconds (added in Go 1.13).
   868  			maxAwaitTimeMS := int64(s.cfg.heartbeatInterval) / 1e6
   869  			// If connectTimeoutMS=0, the socket timeout should be infinite. Otherwise, it is connectTimeoutMS +
   870  			// heartbeatFrequencyMS to account for the fact that the query will block for heartbeatFrequencyMS
   871  			// server-side.
   872  			socketTimeout := s.cfg.heartbeatTimeout
   873  			if socketTimeout != 0 {
   874  				socketTimeout += s.cfg.heartbeatInterval
   875  			}
   876  			s.conn.setSocketTimeout(socketTimeout)
   877  			baseOperation = baseOperation.TopologyVersion(previousDescription.TopologyVersion).
   878  				MaxAwaitTimeMS(maxAwaitTimeMS)
   879  			s.conn.setCanStream(true)
   880  			err = baseOperation.Execute(s.heartbeatCtx)
   881  		default:
   882  			// The server doesn't support the awaitable protocol. Set the socket timeout to connectTimeoutMS and
   883  			// execute a regular heartbeat without any additional parameters.
   884  
   885  			s.conn.setSocketTimeout(s.cfg.heartbeatTimeout)
   886  			err = baseOperation.Execute(s.heartbeatCtx)
   887  		}
   888  
   889  		duration = time.Since(start)
   890  
   891  		// We need to record an RTT sample in the polling case so that if the server
   892  		// is < 4.4, or if polling is specified by the user, then the
   893  		// RTT-short-circuit feature of CSOT is not disabled.
   894  		if !streamable {
   895  			s.rttMonitor.addSample(duration)
   896  		}
   897  
   898  		if err == nil {
   899  			tempDesc := baseOperation.Result(s.address)
   900  			descPtr = &tempDesc
   901  			s.publishServerHeartbeatSucceededEvent(s.conn.ID(), duration, tempDesc, s.conn.getCurrentlyStreaming() || streamable)
   902  		} else {
   903  			// Close the connection here rather than below so we ensure we're not closing a connection that wasn't
   904  			// successfully created.
   905  			if s.conn != nil {
   906  				_ = s.conn.close()
   907  			}
   908  			s.publishServerHeartbeatFailedEvent(s.conn.ID(), duration, err, s.conn.getCurrentlyStreaming() || streamable)
   909  		}
   910  	}
   911  
   912  	if descPtr != nil {
   913  		// The check was successful. Set the average RTT and the 90th percentile RTT and return.
   914  		desc := *descPtr
   915  		desc = desc.SetAverageRTT(s.rttMonitor.EWMA())
   916  		desc.HeartbeatInterval = s.cfg.heartbeatInterval
   917  		return desc, nil
   918  	}
   919  
   920  	if s.checkWasCancelled() {
   921  		// If the previous check was cancelled, we don't want to clear the pool. Return a sentinel error so the caller
   922  		// will know that an actual error didn't occur.
   923  		return emptyDescription, errCheckCancelled
   924  	}
   925  
   926  	// An error occurred. We reset the RTT monitor for all errors and return an Unknown description. The pool must also
   927  	// be cleared, but only after the description has already been updated, so that is handled by the caller.
   928  	topologyVersion := extractTopologyVersion(err)
   929  	s.rttMonitor.reset()
   930  	return description.NewServerFromError(s.address, err, topologyVersion), nil
   931  }
   932  
   933  func extractTopologyVersion(err error) *description.TopologyVersion {
   934  	if ce, ok := err.(ConnectionError); ok {
   935  		err = ce.Wrapped
   936  	}
   937  
   938  	switch converted := err.(type) {
   939  	case driver.Error:
   940  		return converted.TopologyVersion
   941  	case driver.WriteCommandError:
   942  		if converted.WriteConcernError != nil {
   943  			return converted.WriteConcernError.TopologyVersion
   944  		}
   945  	}
   946  
   947  	return nil
   948  }
   949  
   950  // RTTMonitor returns this server's round-trip-time monitor.
   951  func (s *Server) RTTMonitor() driver.RTTMonitor {
   952  	return s.rttMonitor
   953  }
   954  
   955  // OperationCount returns the current number of in-progress operations for this server.
   956  func (s *Server) OperationCount() int64 {
   957  	return atomic.LoadInt64(&s.operationCount)
   958  }
   959  
   960  // String implements the Stringer interface.
   961  func (s *Server) String() string {
   962  	desc := s.Description()
   963  	state := atomic.LoadInt64(&s.state)
   964  	str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
   965  		s.address, desc.Kind, serverStateString(state))
   966  	if len(desc.Tags) != 0 {
   967  		str += fmt.Sprintf(", Tag sets: %s", desc.Tags)
   968  	}
   969  	if state == serverConnected {
   970  		str += fmt.Sprintf(", Average RTT: %s, Min RTT: %s", desc.AverageRTT, s.RTTMonitor().Min())
   971  	}
   972  	if desc.LastError != nil {
   973  		str += fmt.Sprintf(", Last error: %s", desc.LastError)
   974  	}
   975  
   976  	return str
   977  }
   978  
   979  // ServerSubscription represents a subscription to the description.Server updates for
   980  // a specific server.
   981  type ServerSubscription struct {
   982  	C  <-chan description.Server
   983  	s  *Server
   984  	id uint64
   985  }
   986  
   987  // Unsubscribe unsubscribes this ServerSubscription from updates and closes the
   988  // subscription channel.
   989  func (ss *ServerSubscription) Unsubscribe() error {
   990  	ss.s.subLock.Lock()
   991  	defer ss.s.subLock.Unlock()
   992  	if ss.s.subscriptionsClosed {
   993  		return nil
   994  	}
   995  
   996  	ch, ok := ss.s.subscribers[ss.id]
   997  	if !ok {
   998  		return nil
   999  	}
  1000  
  1001  	close(ch)
  1002  	delete(ss.s.subscribers, ss.id)
  1003  
  1004  	return nil
  1005  }
  1006  
  1007  // publishes a ServerOpeningEvent to indicate the server is being initialized
  1008  func (s *Server) publishServerOpeningEvent(addr address.Address) {
  1009  	if s == nil {
  1010  		return
  1011  	}
  1012  
  1013  	serverOpening := &event.ServerOpeningEvent{
  1014  		Address:    addr,
  1015  		TopologyID: s.topologyID,
  1016  	}
  1017  
  1018  	if s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerOpening != nil {
  1019  		s.cfg.serverMonitor.ServerOpening(serverOpening)
  1020  	}
  1021  
  1022  	if mustLogServerMessage(s) {
  1023  		logServerMessage(s, logger.TopologyServerOpening)
  1024  	}
  1025  }
  1026  
  1027  // publishes a ServerHeartbeatStartedEvent to indicate a hello command has started
  1028  func (s *Server) publishServerHeartbeatStartedEvent(connectionID string, await bool) {
  1029  	serverHeartbeatStarted := &event.ServerHeartbeatStartedEvent{
  1030  		ConnectionID: connectionID,
  1031  		Awaited:      await,
  1032  	}
  1033  
  1034  	if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatStarted != nil {
  1035  		s.cfg.serverMonitor.ServerHeartbeatStarted(serverHeartbeatStarted)
  1036  	}
  1037  
  1038  	if mustLogServerMessage(s) {
  1039  		logServerMessage(s, logger.TopologyServerHeartbeatStarted,
  1040  			logger.KeyAwaited, await)
  1041  	}
  1042  }
  1043  
  1044  // publishes a ServerHeartbeatSucceededEvent to indicate hello has succeeded
  1045  func (s *Server) publishServerHeartbeatSucceededEvent(connectionID string,
  1046  	duration time.Duration,
  1047  	desc description.Server,
  1048  	await bool,
  1049  ) {
  1050  	serverHeartbeatSucceeded := &event.ServerHeartbeatSucceededEvent{
  1051  		DurationNanos: duration.Nanoseconds(),
  1052  		Duration:      duration,
  1053  		Reply:         desc,
  1054  		ConnectionID:  connectionID,
  1055  		Awaited:       await,
  1056  	}
  1057  
  1058  	if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatSucceeded != nil {
  1059  		s.cfg.serverMonitor.ServerHeartbeatSucceeded(serverHeartbeatSucceeded)
  1060  	}
  1061  
  1062  	if mustLogServerMessage(s) {
  1063  		descRaw, _ := bson.Marshal(struct {
  1064  			description.Server `bson:",inline"`
  1065  			Ok                 int32
  1066  		}{
  1067  			Server: desc,
  1068  			Ok: func() int32 {
  1069  				if desc.LastError != nil {
  1070  					return 0
  1071  				}
  1072  
  1073  				return 1
  1074  			}(),
  1075  		})
  1076  
  1077  		logServerMessage(s, logger.TopologyServerHeartbeatSucceeded,
  1078  			logger.KeyAwaited, await,
  1079  			logger.KeyDurationMS, duration.Milliseconds(),
  1080  			logger.KeyReply, bson.Raw(descRaw).String())
  1081  	}
  1082  }
  1083  
  1084  // publishes a ServerHeartbeatFailedEvent to indicate hello has failed
  1085  func (s *Server) publishServerHeartbeatFailedEvent(connectionID string,
  1086  	duration time.Duration,
  1087  	err error,
  1088  	await bool,
  1089  ) {
  1090  	serverHeartbeatFailed := &event.ServerHeartbeatFailedEvent{
  1091  		DurationNanos: duration.Nanoseconds(),
  1092  		Duration:      duration,
  1093  		Failure:       err,
  1094  		ConnectionID:  connectionID,
  1095  		Awaited:       await,
  1096  	}
  1097  
  1098  	if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatFailed != nil {
  1099  		s.cfg.serverMonitor.ServerHeartbeatFailed(serverHeartbeatFailed)
  1100  	}
  1101  
  1102  	if mustLogServerMessage(s) {
  1103  		logServerMessage(s, logger.TopologyServerHeartbeatFailed,
  1104  			logger.KeyAwaited, await,
  1105  			logger.KeyDurationMS, duration.Milliseconds(),
  1106  			logger.KeyFailure, err.Error())
  1107  	}
  1108  }
  1109  
  1110  // unwrapConnectionError returns the connection error wrapped by err, or nil if err does not wrap a connection error.
  1111  func unwrapConnectionError(err error) error {
  1112  	// This is essentially an implementation of errors.As to unwrap this error until we get a ConnectionError and then
  1113  	// return ConnectionError.Wrapped.
  1114  
  1115  	connErr, ok := err.(ConnectionError)
  1116  	if ok {
  1117  		return connErr.Wrapped
  1118  	}
  1119  
  1120  	driverErr, ok := err.(driver.Error)
  1121  	if !ok || !driverErr.NetworkError() {
  1122  		return nil
  1123  	}
  1124  
  1125  	connErr, ok = driverErr.Wrapped.(ConnectionError)
  1126  	if ok {
  1127  		return connErr.Wrapped
  1128  	}
  1129  
  1130  	return nil
  1131  }
  1132  

View as plain text