...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/topology/rtt_monitor.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/topology

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package topology
     8  
     9  import (
    10  	"context"
    11  	"fmt"
    12  	"math"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/montanaflynn/stats"
    17  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    18  	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
    19  )
    20  
    21  const (
    22  	rttAlphaValue = 0.2
    23  	minSamples    = 10
    24  	maxSamples    = 500
    25  )
    26  
    27  type rttConfig struct {
    28  	// The minimum interval between RTT measurements. The actual interval may be greater if running
    29  	// the operation takes longer than the interval.
    30  	interval time.Duration
    31  
    32  	// The timeout applied to running the "hello" operation. If the timeout is reached while running
    33  	// the operation, the RTT sample is discarded. The default is 1 minute.
    34  	timeout time.Duration
    35  
    36  	minRTTWindow       time.Duration
    37  	createConnectionFn func() *connection
    38  	createOperationFn  func(driver.Connection) *operation.Hello
    39  }
    40  
    41  type rttMonitor struct {
    42  	mu sync.RWMutex // mu guards samples, offset, minRTT, averageRTT, and averageRTTSet
    43  
    44  	// connMu guards connecting and disconnecting. This is necessary since
    45  	// disconnecting will await the cancellation of a started connection. The
    46  	// use case for rttMonitor.connect needs to be goroutine safe.
    47  	connMu        sync.Mutex
    48  	samples       []time.Duration
    49  	offset        int
    50  	minRTT        time.Duration
    51  	rtt90         time.Duration
    52  	averageRTT    time.Duration
    53  	averageRTTSet bool
    54  
    55  	closeWg  sync.WaitGroup
    56  	cfg      *rttConfig
    57  	ctx      context.Context
    58  	cancelFn context.CancelFunc
    59  	started  bool
    60  }
    61  
    62  var _ driver.RTTMonitor = &rttMonitor{}
    63  
    64  func newRTTMonitor(cfg *rttConfig) *rttMonitor {
    65  	if cfg.interval <= 0 {
    66  		panic("RTT monitor interval must be greater than 0")
    67  	}
    68  
    69  	ctx, cancel := context.WithCancel(context.Background())
    70  	// Determine the number of samples we need to keep to store the minWindow of RTT durations. The
    71  	// number of samples must be between [10, 500].
    72  	numSamples := int(math.Max(minSamples, math.Min(maxSamples, float64((cfg.minRTTWindow)/cfg.interval))))
    73  
    74  	return &rttMonitor{
    75  		samples:  make([]time.Duration, numSamples),
    76  		cfg:      cfg,
    77  		ctx:      ctx,
    78  		cancelFn: cancel,
    79  	}
    80  }
    81  
    82  func (r *rttMonitor) connect() {
    83  	r.connMu.Lock()
    84  	defer r.connMu.Unlock()
    85  
    86  	r.started = true
    87  	r.closeWg.Add(1)
    88  
    89  	go func() {
    90  		defer r.closeWg.Done()
    91  
    92  		r.start()
    93  	}()
    94  }
    95  
    96  func (r *rttMonitor) disconnect() {
    97  	r.connMu.Lock()
    98  	defer r.connMu.Unlock()
    99  
   100  	if !r.started {
   101  		return
   102  	}
   103  
   104  	r.cancelFn()
   105  
   106  	// Wait for the existing connection to complete.
   107  	r.closeWg.Wait()
   108  }
   109  
   110  func (r *rttMonitor) start() {
   111  	var conn *connection
   112  	defer func() {
   113  		if conn != nil {
   114  			// If the connection exists, we need to wait for it to be connected because
   115  			// conn.connect() and conn.close() cannot be called concurrently. If the connection
   116  			// wasn't successfully opened, its state was set back to disconnected, so calling
   117  			// conn.close() will be a no-op.
   118  			conn.closeConnectContext()
   119  			conn.wait()
   120  			_ = conn.close()
   121  		}
   122  	}()
   123  
   124  	ticker := time.NewTicker(r.cfg.interval)
   125  	defer ticker.Stop()
   126  
   127  	for {
   128  		conn := r.cfg.createConnectionFn()
   129  		err := conn.connect(r.ctx)
   130  
   131  		// Add an RTT sample from the new connection handshake and start a runHellos() loop if we
   132  		// successfully established the new connection. Otherwise, close the connection and try to
   133  		// create another new connection.
   134  		if err == nil {
   135  			r.addSample(conn.helloRTT)
   136  			r.runHellos(conn)
   137  		}
   138  
   139  		// Close any connection here because we're either about to try to create another new
   140  		// connection or we're about to exit the loop.
   141  		_ = conn.close()
   142  
   143  		// If a connection error happens quickly, always wait for the monitoring interval to try
   144  		// to create a new connection to prevent creating connections too quickly.
   145  		select {
   146  		case <-ticker.C:
   147  		case <-r.ctx.Done():
   148  			return
   149  		}
   150  	}
   151  }
   152  
   153  // runHellos runs "hello" operations in a loop using the provided connection, measuring and
   154  // recording the operation durations as RTT samples. If it encounters any errors, it returns.
   155  func (r *rttMonitor) runHellos(conn *connection) {
   156  	ticker := time.NewTicker(r.cfg.interval)
   157  	defer ticker.Stop()
   158  
   159  	for {
   160  		// Assume that the connection establishment recorded the first RTT sample, so wait for the
   161  		// first tick before trying to record another RTT sample.
   162  		select {
   163  		case <-ticker.C:
   164  		case <-r.ctx.Done():
   165  			return
   166  		}
   167  
   168  		// Create a Context with the operation timeout specified in the RTT monitor config. If a
   169  		// timeout is not set in the RTT monitor config, default to the connection's
   170  		// "connectTimeoutMS". The purpose of the timeout is to allow the RTT monitor to continue
   171  		// monitoring server RTTs after an operation gets stuck. An operation can get stuck if the
   172  		// server or a proxy stops responding to requests on the RTT connection but does not close
   173  		// the TCP socket, effectively creating an operation that will never complete. We expect
   174  		// that "connectTimeoutMS" provides at least enough time for a single round trip.
   175  		timeout := r.cfg.timeout
   176  		if timeout <= 0 {
   177  			timeout = conn.config.connectTimeout
   178  		}
   179  		ctx, cancel := context.WithTimeout(r.ctx, timeout)
   180  
   181  		start := time.Now()
   182  		err := r.cfg.createOperationFn(initConnection{conn}).Execute(ctx)
   183  		cancel()
   184  		if err != nil {
   185  			return
   186  		}
   187  		// Only record a sample if the "hello" operation was successful. If it was not successful,
   188  		// the operation may not have actually performed a complete round trip, so the duration may
   189  		// be artificially short.
   190  		r.addSample(time.Since(start))
   191  	}
   192  }
   193  
   194  // reset sets the average and min RTT to 0. This should only be called from the server monitor when an error
   195  // occurs during a server check. Errors in the RTT monitor should not reset the RTTs.
   196  func (r *rttMonitor) reset() {
   197  	r.mu.Lock()
   198  	defer r.mu.Unlock()
   199  
   200  	for i := range r.samples {
   201  		r.samples[i] = 0
   202  	}
   203  	r.offset = 0
   204  	r.minRTT = 0
   205  	r.rtt90 = 0
   206  	r.averageRTT = 0
   207  	r.averageRTTSet = false
   208  }
   209  
   210  func (r *rttMonitor) addSample(rtt time.Duration) {
   211  	// Lock for the duration of this method. We're doing compuationally inexpensive work very infrequently, so lock
   212  	// contention isn't expected.
   213  	r.mu.Lock()
   214  	defer r.mu.Unlock()
   215  
   216  	r.samples[r.offset] = rtt
   217  	r.offset = (r.offset + 1) % len(r.samples)
   218  	// Set the minRTT and 90th percentile RTT of all collected samples. Require at least 10 samples before
   219  	// setting these to prevent noisy samples on startup from artificially increasing RTT and to allow the
   220  	// calculation of a 90th percentile.
   221  	r.minRTT = min(r.samples, minSamples)
   222  	r.rtt90 = percentile(90.0, r.samples, minSamples)
   223  
   224  	if !r.averageRTTSet {
   225  		r.averageRTT = rtt
   226  		r.averageRTTSet = true
   227  		return
   228  	}
   229  
   230  	r.averageRTT = time.Duration(rttAlphaValue*float64(rtt) + (1-rttAlphaValue)*float64(r.averageRTT))
   231  }
   232  
   233  // min returns the minimum value of the slice of duration samples. Zero values are not considered
   234  // samples and are ignored. If no samples or fewer than minSamples are found in the slice, min
   235  // returns 0.
   236  func min(samples []time.Duration, minSamples int) time.Duration {
   237  	count := 0
   238  	min := time.Duration(math.MaxInt64)
   239  	for _, d := range samples {
   240  		if d > 0 {
   241  			count++
   242  		}
   243  		if d > 0 && d < min {
   244  			min = d
   245  		}
   246  	}
   247  	if count == 0 || count < minSamples {
   248  		return 0
   249  	}
   250  	return min
   251  }
   252  
   253  // percentile returns the specified percentile value of the slice of duration samples. Zero values
   254  // are not considered samples and are ignored. If no samples or fewer than minSamples are found
   255  // in the slice, percentile returns 0.
   256  func percentile(perc float64, samples []time.Duration, minSamples int) time.Duration {
   257  	// Convert Durations to float64s.
   258  	floatSamples := make([]float64, 0, len(samples))
   259  	for _, sample := range samples {
   260  		if sample > 0 {
   261  			floatSamples = append(floatSamples, float64(sample))
   262  		}
   263  	}
   264  	if len(floatSamples) == 0 || len(floatSamples) < minSamples {
   265  		return 0
   266  	}
   267  
   268  	p, err := stats.Percentile(floatSamples, perc)
   269  	if err != nil {
   270  		panic(fmt.Errorf("x/mongo/driver/topology: error calculating %f percentile RTT: %w for samples:\n%v", perc, err, floatSamples))
   271  	}
   272  	return time.Duration(p)
   273  }
   274  
   275  // EWMA returns the exponentially weighted moving average observed round-trip time.
   276  func (r *rttMonitor) EWMA() time.Duration {
   277  	r.mu.RLock()
   278  	defer r.mu.RUnlock()
   279  
   280  	return r.averageRTT
   281  }
   282  
   283  // Min returns the minimum observed round-trip time over the window period.
   284  func (r *rttMonitor) Min() time.Duration {
   285  	r.mu.RLock()
   286  	defer r.mu.RUnlock()
   287  
   288  	return r.minRTT
   289  }
   290  
   291  // P90 returns the 90th percentile observed round-trip time over the window period.
   292  func (r *rttMonitor) P90() time.Duration {
   293  	r.mu.RLock()
   294  	defer r.mu.RUnlock()
   295  
   296  	return r.rtt90
   297  }
   298  
   299  // Stats returns stringified stats of the current state of the monitor.
   300  func (r *rttMonitor) Stats() string {
   301  	r.mu.RLock()
   302  	defer r.mu.RUnlock()
   303  
   304  	// Calculate standard deviation and average (non-EWMA) of samples.
   305  	var sum float64
   306  	floatSamples := make([]float64, 0, len(r.samples))
   307  	for _, sample := range r.samples {
   308  		if sample > 0 {
   309  			floatSamples = append(floatSamples, float64(sample))
   310  			sum += float64(sample)
   311  		}
   312  	}
   313  
   314  	var avg, stdDev float64
   315  	if len(floatSamples) > 0 {
   316  		avg = sum / float64(len(floatSamples))
   317  
   318  		var err error
   319  		stdDev, err = stats.StandardDeviation(floatSamples)
   320  		if err != nil {
   321  			panic(fmt.Errorf("x/mongo/driver/topology: error calculating standard deviation RTT: %w for samples:\n%v", err, floatSamples))
   322  		}
   323  	}
   324  
   325  	return fmt.Sprintf(
   326  		"network round-trip time stats: avg: %v, min: %v, 90th pct: %v, stddev: %v",
   327  		time.Duration(avg),
   328  		r.minRTT,
   329  		r.rtt90,
   330  		time.Duration(stdDev))
   331  }
   332  

View as plain text