...
1
2
3
4
5
6
7 package topology
8
9 import (
10 "context"
11 "fmt"
12 "math"
13 "sync"
14 "time"
15
16 "github.com/montanaflynn/stats"
17 "go.mongodb.org/mongo-driver/x/mongo/driver"
18 "go.mongodb.org/mongo-driver/x/mongo/driver/operation"
19 )
20
21 const (
22 rttAlphaValue = 0.2
23 minSamples = 10
24 maxSamples = 500
25 )
26
27 type rttConfig struct {
28
29
30 interval time.Duration
31
32
33
34 timeout time.Duration
35
36 minRTTWindow time.Duration
37 createConnectionFn func() *connection
38 createOperationFn func(driver.Connection) *operation.Hello
39 }
40
41 type rttMonitor struct {
42 mu sync.RWMutex
43
44
45
46
47 connMu sync.Mutex
48 samples []time.Duration
49 offset int
50 minRTT time.Duration
51 rtt90 time.Duration
52 averageRTT time.Duration
53 averageRTTSet bool
54
55 closeWg sync.WaitGroup
56 cfg *rttConfig
57 ctx context.Context
58 cancelFn context.CancelFunc
59 started bool
60 }
61
62 var _ driver.RTTMonitor = &rttMonitor{}
63
64 func newRTTMonitor(cfg *rttConfig) *rttMonitor {
65 if cfg.interval <= 0 {
66 panic("RTT monitor interval must be greater than 0")
67 }
68
69 ctx, cancel := context.WithCancel(context.Background())
70
71
72 numSamples := int(math.Max(minSamples, math.Min(maxSamples, float64((cfg.minRTTWindow)/cfg.interval))))
73
74 return &rttMonitor{
75 samples: make([]time.Duration, numSamples),
76 cfg: cfg,
77 ctx: ctx,
78 cancelFn: cancel,
79 }
80 }
81
82 func (r *rttMonitor) connect() {
83 r.connMu.Lock()
84 defer r.connMu.Unlock()
85
86 r.started = true
87 r.closeWg.Add(1)
88
89 go func() {
90 defer r.closeWg.Done()
91
92 r.start()
93 }()
94 }
95
96 func (r *rttMonitor) disconnect() {
97 r.connMu.Lock()
98 defer r.connMu.Unlock()
99
100 if !r.started {
101 return
102 }
103
104 r.cancelFn()
105
106
107 r.closeWg.Wait()
108 }
109
110 func (r *rttMonitor) start() {
111 var conn *connection
112 defer func() {
113 if conn != nil {
114
115
116
117
118 conn.closeConnectContext()
119 conn.wait()
120 _ = conn.close()
121 }
122 }()
123
124 ticker := time.NewTicker(r.cfg.interval)
125 defer ticker.Stop()
126
127 for {
128 conn := r.cfg.createConnectionFn()
129 err := conn.connect(r.ctx)
130
131
132
133
134 if err == nil {
135 r.addSample(conn.helloRTT)
136 r.runHellos(conn)
137 }
138
139
140
141 _ = conn.close()
142
143
144
145 select {
146 case <-ticker.C:
147 case <-r.ctx.Done():
148 return
149 }
150 }
151 }
152
153
154
155 func (r *rttMonitor) runHellos(conn *connection) {
156 ticker := time.NewTicker(r.cfg.interval)
157 defer ticker.Stop()
158
159 for {
160
161
162 select {
163 case <-ticker.C:
164 case <-r.ctx.Done():
165 return
166 }
167
168
169
170
171
172
173
174
175 timeout := r.cfg.timeout
176 if timeout <= 0 {
177 timeout = conn.config.connectTimeout
178 }
179 ctx, cancel := context.WithTimeout(r.ctx, timeout)
180
181 start := time.Now()
182 err := r.cfg.createOperationFn(initConnection{conn}).Execute(ctx)
183 cancel()
184 if err != nil {
185 return
186 }
187
188
189
190 r.addSample(time.Since(start))
191 }
192 }
193
194
195
196 func (r *rttMonitor) reset() {
197 r.mu.Lock()
198 defer r.mu.Unlock()
199
200 for i := range r.samples {
201 r.samples[i] = 0
202 }
203 r.offset = 0
204 r.minRTT = 0
205 r.rtt90 = 0
206 r.averageRTT = 0
207 r.averageRTTSet = false
208 }
209
210 func (r *rttMonitor) addSample(rtt time.Duration) {
211
212
213 r.mu.Lock()
214 defer r.mu.Unlock()
215
216 r.samples[r.offset] = rtt
217 r.offset = (r.offset + 1) % len(r.samples)
218
219
220
221 r.minRTT = min(r.samples, minSamples)
222 r.rtt90 = percentile(90.0, r.samples, minSamples)
223
224 if !r.averageRTTSet {
225 r.averageRTT = rtt
226 r.averageRTTSet = true
227 return
228 }
229
230 r.averageRTT = time.Duration(rttAlphaValue*float64(rtt) + (1-rttAlphaValue)*float64(r.averageRTT))
231 }
232
233
234
235
236 func min(samples []time.Duration, minSamples int) time.Duration {
237 count := 0
238 min := time.Duration(math.MaxInt64)
239 for _, d := range samples {
240 if d > 0 {
241 count++
242 }
243 if d > 0 && d < min {
244 min = d
245 }
246 }
247 if count == 0 || count < minSamples {
248 return 0
249 }
250 return min
251 }
252
253
254
255
256 func percentile(perc float64, samples []time.Duration, minSamples int) time.Duration {
257
258 floatSamples := make([]float64, 0, len(samples))
259 for _, sample := range samples {
260 if sample > 0 {
261 floatSamples = append(floatSamples, float64(sample))
262 }
263 }
264 if len(floatSamples) == 0 || len(floatSamples) < minSamples {
265 return 0
266 }
267
268 p, err := stats.Percentile(floatSamples, perc)
269 if err != nil {
270 panic(fmt.Errorf("x/mongo/driver/topology: error calculating %f percentile RTT: %w for samples:\n%v", perc, err, floatSamples))
271 }
272 return time.Duration(p)
273 }
274
275
276 func (r *rttMonitor) EWMA() time.Duration {
277 r.mu.RLock()
278 defer r.mu.RUnlock()
279
280 return r.averageRTT
281 }
282
283
284 func (r *rttMonitor) Min() time.Duration {
285 r.mu.RLock()
286 defer r.mu.RUnlock()
287
288 return r.minRTT
289 }
290
291
292 func (r *rttMonitor) P90() time.Duration {
293 r.mu.RLock()
294 defer r.mu.RUnlock()
295
296 return r.rtt90
297 }
298
299
300 func (r *rttMonitor) Stats() string {
301 r.mu.RLock()
302 defer r.mu.RUnlock()
303
304
305 var sum float64
306 floatSamples := make([]float64, 0, len(r.samples))
307 for _, sample := range r.samples {
308 if sample > 0 {
309 floatSamples = append(floatSamples, float64(sample))
310 sum += float64(sample)
311 }
312 }
313
314 var avg, stdDev float64
315 if len(floatSamples) > 0 {
316 avg = sum / float64(len(floatSamples))
317
318 var err error
319 stdDev, err = stats.StandardDeviation(floatSamples)
320 if err != nil {
321 panic(fmt.Errorf("x/mongo/driver/topology: error calculating standard deviation RTT: %w for samples:\n%v", err, floatSamples))
322 }
323 }
324
325 return fmt.Sprintf(
326 "network round-trip time stats: avg: %v, min: %v, 90th pct: %v, stddev: %v",
327 time.Duration(avg),
328 r.minRTT,
329 r.rtt90,
330 time.Duration(stdDev))
331 }
332
View as plain text