1
2
3
4
5
6
7 package topology
8
9 import (
10 "bufio"
11 "bytes"
12 "context"
13 "encoding/json"
14 "errors"
15 "fmt"
16 "io/ioutil"
17 "path"
18 "sync/atomic"
19 "testing"
20 "time"
21
22 "go.mongodb.org/mongo-driver/bson/primitive"
23 "go.mongodb.org/mongo-driver/internal/assert"
24 "go.mongodb.org/mongo-driver/internal/logger"
25 "go.mongodb.org/mongo-driver/internal/require"
26 "go.mongodb.org/mongo-driver/internal/spectest"
27 "go.mongodb.org/mongo-driver/mongo/address"
28 "go.mongodb.org/mongo-driver/mongo/description"
29 "go.mongodb.org/mongo-driver/mongo/options"
30 "go.mongodb.org/mongo-driver/mongo/readpref"
31 "go.mongodb.org/mongo-driver/x/mongo/driver"
32 )
33
34 const testTimeout = 2 * time.Second
35
36 func noerr(t *testing.T, err error) {
37 t.Helper()
38 if err != nil {
39 t.Errorf("Unexpected error: %v", err)
40 t.FailNow()
41 }
42 }
43
44 func compareErrors(err1, err2 error) bool {
45 if err1 == nil && err2 == nil {
46 return true
47 }
48
49 if err1 == nil || err2 == nil {
50 return false
51 }
52
53 if err1.Error() != err2.Error() {
54 return false
55 }
56
57 return true
58 }
59
60 func TestServerSelection(t *testing.T) {
61 var selectFirst description.ServerSelectorFunc = func(_ description.Topology, candidates []description.Server) ([]description.Server, error) {
62 if len(candidates) == 0 {
63 return []description.Server{}, nil
64 }
65 return candidates[0:1], nil
66 }
67 var selectNone description.ServerSelectorFunc = func(description.Topology, []description.Server) ([]description.Server, error) {
68 return []description.Server{}, nil
69 }
70 var errSelectionError = errors.New("encountered an error in the selector")
71 var selectError description.ServerSelectorFunc = func(description.Topology, []description.Server) ([]description.Server, error) {
72 return nil, errSelectionError
73 }
74
75 t.Run("Success", func(t *testing.T) {
76 topo, err := New(nil)
77 noerr(t, err)
78 desc := description.Topology{
79 Servers: []description.Server{
80 {Addr: address.Address("one"), Kind: description.Standalone},
81 {Addr: address.Address("two"), Kind: description.Standalone},
82 {Addr: address.Address("three"), Kind: description.Standalone},
83 },
84 }
85 subCh := make(chan description.Topology, 1)
86 subCh <- desc
87
88 state := newServerSelectionState(selectFirst, nil)
89 srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state)
90 noerr(t, err)
91 if len(srvs) != 1 {
92 t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1)
93 }
94 if srvs[0].Addr != desc.Servers[0].Addr {
95 t.Errorf("Incorrect sever selected. got %s; want %s", srvs[0].Addr, desc.Servers[0].Addr)
96 }
97 })
98 t.Run("Compatibility Error Min Version Too High", func(t *testing.T) {
99 topo, err := New(nil)
100 noerr(t, err)
101 desc := description.Topology{
102 Kind: description.Single,
103 Servers: []description.Server{
104 {Addr: address.Address("one:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 11, Min: 11}},
105 {Addr: address.Address("two:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 6}},
106 {Addr: address.Address("three:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 6}},
107 },
108 }
109 want := fmt.Errorf(
110 "server at %s requires wire version %d, but this version of the Go driver only supports up to %d",
111 desc.Servers[0].Addr.String(),
112 desc.Servers[0].WireVersion.Min,
113 SupportedWireVersions.Max,
114 )
115 desc.CompatibilityErr = want
116 atomic.StoreInt64(&topo.state, topologyConnected)
117 topo.desc.Store(desc)
118 _, err = topo.SelectServer(context.Background(), selectFirst)
119 assert.Equal(t, err, want, "expected %v, got %v", want, err)
120 })
121 t.Run("Compatibility Error Max Version Too Low", func(t *testing.T) {
122 topo, err := New(nil)
123 noerr(t, err)
124 desc := description.Topology{
125 Kind: description.Single,
126 Servers: []description.Server{
127 {Addr: address.Address("one:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 21, Min: 6}},
128 {Addr: address.Address("two:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 2}},
129 {Addr: address.Address("three:27017"), Kind: description.Standalone, WireVersion: &description.VersionRange{Max: 9, Min: 2}},
130 },
131 }
132 want := fmt.Errorf(
133 "server at %s reports wire version %d, but this version of the Go driver requires "+
134 "at least 6 (MongoDB 3.6)",
135 desc.Servers[0].Addr.String(),
136 desc.Servers[0].WireVersion.Max,
137 )
138 desc.CompatibilityErr = want
139 atomic.StoreInt64(&topo.state, topologyConnected)
140 topo.desc.Store(desc)
141 _, err = topo.SelectServer(context.Background(), selectFirst)
142 assert.Equal(t, err, want, "expected %v, got %v", want, err)
143 })
144 t.Run("Updated", func(t *testing.T) {
145 topo, err := New(nil)
146 noerr(t, err)
147 desc := description.Topology{Servers: []description.Server{}}
148 subCh := make(chan description.Topology, 1)
149 subCh <- desc
150
151 resp := make(chan []description.Server)
152 go func() {
153 state := newServerSelectionState(selectFirst, nil)
154 srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state)
155 noerr(t, err)
156 resp <- srvs
157 }()
158
159 desc = description.Topology{
160 Servers: []description.Server{
161 {Addr: address.Address("one"), Kind: description.Standalone},
162 {Addr: address.Address("two"), Kind: description.Standalone},
163 {Addr: address.Address("three"), Kind: description.Standalone},
164 },
165 }
166 select {
167 case subCh <- desc:
168 case <-time.After(100 * time.Millisecond):
169 t.Error("Timed out while trying to send topology description")
170 }
171
172 var srvs []description.Server
173 select {
174 case srvs = <-resp:
175 case <-time.After(100 * time.Millisecond):
176 t.Errorf("Timed out while trying to retrieve selected servers")
177 }
178
179 if len(srvs) != 1 {
180 t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1)
181 }
182 if srvs[0].Addr != desc.Servers[0].Addr {
183 t.Errorf("Incorrect sever selected. got %s; want %s", srvs[0].Addr, desc.Servers[0].Addr)
184 }
185 })
186 t.Run("Cancel", func(t *testing.T) {
187 desc := description.Topology{
188 Servers: []description.Server{
189 {Addr: address.Address("one"), Kind: description.Standalone},
190 {Addr: address.Address("two"), Kind: description.Standalone},
191 {Addr: address.Address("three"), Kind: description.Standalone},
192 },
193 }
194 topo, err := New(nil)
195 noerr(t, err)
196 subCh := make(chan description.Topology, 1)
197 subCh <- desc
198 resp := make(chan error)
199 ctx, cancel := context.WithCancel(context.Background())
200 go func() {
201 state := newServerSelectionState(selectNone, nil)
202 _, err := topo.selectServerFromSubscription(ctx, subCh, state)
203 resp <- err
204 }()
205
206 select {
207 case err := <-resp:
208 t.Errorf("Received error from server selection too soon: %v", err)
209 case <-time.After(100 * time.Millisecond):
210 }
211
212 cancel()
213
214 select {
215 case err = <-resp:
216 case <-time.After(100 * time.Millisecond):
217 t.Errorf("Timed out while trying to retrieve selected servers")
218 }
219
220 want := ServerSelectionError{Wrapped: context.Canceled, Desc: desc}
221 assert.Equal(t, err, want, "Incorrect error received. got %v; want %v", err, want)
222 })
223 t.Run("Timeout", func(t *testing.T) {
224 desc := description.Topology{
225 Servers: []description.Server{
226 {Addr: address.Address("one"), Kind: description.Standalone},
227 {Addr: address.Address("two"), Kind: description.Standalone},
228 {Addr: address.Address("three"), Kind: description.Standalone},
229 },
230 }
231 topo, err := New(nil)
232 noerr(t, err)
233 subCh := make(chan description.Topology, 1)
234 subCh <- desc
235 resp := make(chan error)
236 timeout := make(chan time.Time)
237 go func() {
238 state := newServerSelectionState(selectNone, timeout)
239 _, err := topo.selectServerFromSubscription(context.Background(), subCh, state)
240 resp <- err
241 }()
242
243 select {
244 case err := <-resp:
245 t.Errorf("Received error from server selection too soon: %v", err)
246 case timeout <- time.Now():
247 }
248
249 select {
250 case err = <-resp:
251 case <-time.After(100 * time.Millisecond):
252 t.Errorf("Timed out while trying to retrieve selected servers")
253 }
254
255 if err == nil {
256 t.Fatalf("did not receive error from server selection")
257 }
258 })
259 t.Run("Error", func(t *testing.T) {
260 desc := description.Topology{
261 Servers: []description.Server{
262 {Addr: address.Address("one"), Kind: description.Standalone},
263 {Addr: address.Address("two"), Kind: description.Standalone},
264 {Addr: address.Address("three"), Kind: description.Standalone},
265 },
266 }
267 topo, err := New(nil)
268 noerr(t, err)
269 subCh := make(chan description.Topology, 1)
270 subCh <- desc
271 resp := make(chan error)
272 timeout := make(chan time.Time)
273 go func() {
274 state := newServerSelectionState(selectError, timeout)
275 _, err := topo.selectServerFromSubscription(context.Background(), subCh, state)
276 resp <- err
277 }()
278
279 select {
280 case err = <-resp:
281 case <-time.After(100 * time.Millisecond):
282 t.Errorf("Timed out while trying to retrieve selected servers")
283 }
284
285 if err == nil {
286 t.Fatalf("did not receive error from server selection")
287 }
288 })
289 t.Run("findServer returns topology kind", func(t *testing.T) {
290 topo, err := New(nil)
291 noerr(t, err)
292 atomic.StoreInt64(&topo.state, topologyConnected)
293 srvr, err := ConnectServer(address.Address("one"), topo.updateCallback, topo.id)
294 noerr(t, err)
295 topo.servers[address.Address("one")] = srvr
296 desc := topo.desc.Load().(description.Topology)
297 desc.Kind = description.Single
298 topo.desc.Store(desc)
299
300 selected := description.Server{Addr: address.Address("one")}
301
302 ss, err := topo.FindServer(selected)
303 noerr(t, err)
304 if ss.Kind != description.Single {
305 t.Errorf("findServer does not properly set the topology description kind. got %v; want %v", ss.Kind, description.Single)
306 }
307 })
308 t.Run("Update on not primary error", func(t *testing.T) {
309 topo, err := New(nil)
310 noerr(t, err)
311 atomic.StoreInt64(&topo.state, topologyConnected)
312
313 addr1 := address.Address("one")
314 addr2 := address.Address("two")
315 addr3 := address.Address("three")
316 desc := description.Topology{
317 Servers: []description.Server{
318 {Addr: addr1, Kind: description.RSPrimary},
319 {Addr: addr2, Kind: description.RSSecondary},
320 {Addr: addr3, Kind: description.RSSecondary},
321 },
322 }
323
324
325 for _, srv := range desc.Servers {
326 s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id)
327 noerr(t, err)
328 topo.servers[srv.Addr] = s
329 }
330
331
332 desc = description.Topology{
333 Servers: []description.Server{
334 {Addr: addr1, Kind: description.RSSecondary},
335 {Addr: addr2, Kind: description.RSPrimary},
336 {Addr: addr3, Kind: description.RSSecondary},
337 },
338 }
339
340 subCh := make(chan description.Topology, 1)
341 subCh <- desc
342
343
344 serv, err := topo.FindServer(desc.Servers[0])
345 noerr(t, err)
346 atomic.StoreInt64(&serv.state, serverConnected)
347 _ = serv.ProcessError(driver.Error{Message: driver.LegacyNotPrimaryErrMsg}, initConnection{})
348
349 resp := make(chan []description.Server)
350
351 go func() {
352
353 state := newServerSelectionState(description.WriteSelector(), nil)
354 srvs, err := topo.selectServerFromSubscription(context.Background(), subCh, state)
355 noerr(t, err)
356 resp <- srvs
357 }()
358
359 var srvs []description.Server
360 select {
361 case srvs = <-resp:
362 case <-time.After(100 * time.Millisecond):
363 t.Errorf("Timed out while trying to retrieve selected servers")
364 }
365
366 if len(srvs) != 1 {
367 t.Errorf("Incorrect number of descriptions returned. got %d; want %d", len(srvs), 1)
368 }
369 if srvs[0].Addr != desc.Servers[1].Addr {
370 t.Errorf("Incorrect sever selected. got %s; want %s", srvs[0].Addr, desc.Servers[1].Addr)
371 }
372 })
373 t.Run("fast path does not subscribe or check timeouts", func(t *testing.T) {
374
375 topo, err := New(nil)
376 noerr(t, err)
377 atomic.StoreInt64(&topo.state, topologyConnected)
378
379 primaryAddr := address.Address("one")
380 desc := description.Topology{
381 Servers: []description.Server{
382 {Addr: primaryAddr, Kind: description.RSPrimary},
383 },
384 }
385 topo.desc.Store(desc)
386 for _, srv := range desc.Servers {
387 s, err := ConnectServer(srv.Addr, topo.updateCallback, topo.id)
388 noerr(t, err)
389 topo.servers[srv.Addr] = s
390 }
391
392
393
394 topo.subscriptionsClosed = true
395 ctx, cancel := context.WithCancel(context.Background())
396 cancel()
397 selectedServer, err := topo.SelectServer(ctx, description.WriteSelector())
398 noerr(t, err)
399 selectedAddr := selectedServer.(*SelectedServer).address
400 assert.Equal(t, primaryAddr, selectedAddr, "expected address %v, got %v", primaryAddr, selectedAddr)
401 })
402 t.Run("default to selecting from subscription if fast path fails", func(t *testing.T) {
403 topo, err := New(nil)
404 noerr(t, err)
405
406 atomic.StoreInt64(&topo.state, topologyConnected)
407 desc := description.Topology{
408 Servers: []description.Server{},
409 }
410 topo.desc.Store(desc)
411
412 topo.subscriptionsClosed = true
413 _, err = topo.SelectServer(context.Background(), description.WriteSelector())
414 assert.Equal(t, ErrSubscribeAfterClosed, err, "expected error %v, got %v", ErrSubscribeAfterClosed, err)
415 })
416 }
417
418 func TestSessionTimeout(t *testing.T) {
419 int64ToPtr := func(i64 int64) *int64 { return &i64 }
420
421 t.Run("UpdateSessionTimeout", func(t *testing.T) {
422 topo, err := New(nil)
423 noerr(t, err)
424 topo.servers["foo"] = nil
425 topo.fsm.Servers = []description.Server{
426 {
427 Addr: address.Address("foo").Canonicalize(),
428 Kind: description.RSPrimary,
429 SessionTimeoutMinutes: 60,
430 SessionTimeoutMinutesPtr: int64ToPtr(60),
431 },
432 }
433
434 ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
435 defer cancel()
436
437 desc := description.Server{
438 Addr: "foo",
439 Kind: description.RSPrimary,
440 SessionTimeoutMinutes: 30,
441 SessionTimeoutMinutesPtr: int64ToPtr(30),
442 }
443 topo.apply(ctx, desc)
444
445 currDesc := topo.desc.Load().(description.Topology)
446 want := int64(30)
447 require.Equal(t, &want, currDesc.SessionTimeoutMinutesPtr,
448 "session timeout minutes mismatch")
449 })
450 t.Run("MultipleUpdates", func(t *testing.T) {
451 topo, err := New(nil)
452 noerr(t, err)
453 topo.fsm.Kind = description.ReplicaSetWithPrimary
454 topo.servers["foo"] = nil
455 topo.servers["bar"] = nil
456 topo.fsm.Servers = []description.Server{
457 {
458 Addr: address.Address("foo").Canonicalize(),
459 Kind: description.RSPrimary,
460 SessionTimeoutMinutes: 60,
461 SessionTimeoutMinutesPtr: int64ToPtr(60),
462 },
463 {
464 Addr: address.Address("bar").Canonicalize(),
465 Kind: description.RSSecondary,
466 SessionTimeoutMinutes: 60,
467 SessionTimeoutMinutesPtr: int64ToPtr(60),
468 },
469 }
470
471 ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
472 defer cancel()
473
474 desc1 := description.Server{
475 Addr: "foo",
476 Kind: description.RSPrimary,
477 SessionTimeoutMinutes: 30,
478 SessionTimeoutMinutesPtr: int64ToPtr(30),
479 Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
480 }
481
482 desc2 := description.Server{
483 Addr: "bar",
484 Kind: description.RSPrimary,
485 SessionTimeoutMinutes: 20,
486 SessionTimeoutMinutesPtr: int64ToPtr(20),
487 Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
488 }
489 topo.apply(ctx, desc1)
490 topo.apply(ctx, desc2)
491
492 currDesc := topo.Description()
493 want := int64(20)
494 require.Equal(t, &want, currDesc.SessionTimeoutMinutesPtr,
495 "session timeout minutes mismatch")
496 })
497 t.Run("NoUpdate", func(t *testing.T) {
498 topo, err := New(nil)
499 noerr(t, err)
500 topo.servers["foo"] = nil
501 topo.servers["bar"] = nil
502 topo.fsm.Servers = []description.Server{
503 {
504 Addr: address.Address("foo").Canonicalize(),
505 Kind: description.RSPrimary,
506 SessionTimeoutMinutes: 60,
507 SessionTimeoutMinutesPtr: int64ToPtr(60),
508 },
509 {
510 Addr: address.Address("bar").Canonicalize(),
511 Kind: description.RSSecondary,
512 SessionTimeoutMinutes: 60,
513 SessionTimeoutMinutesPtr: int64ToPtr(60),
514 },
515 }
516
517 ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
518 defer cancel()
519
520 desc1 := description.Server{
521 Addr: "foo",
522 Kind: description.RSPrimary,
523 SessionTimeoutMinutes: 20,
524 SessionTimeoutMinutesPtr: int64ToPtr(20),
525 Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
526 }
527
528 desc2 := description.Server{
529 Addr: "bar",
530 Kind: description.RSPrimary,
531 SessionTimeoutMinutes: 30,
532 SessionTimeoutMinutesPtr: int64ToPtr(30),
533 Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
534 }
535 topo.apply(ctx, desc1)
536 topo.apply(ctx, desc2)
537
538 currDesc := topo.desc.Load().(description.Topology)
539 want := int64(20)
540 require.Equal(t, &want, currDesc.SessionTimeoutMinutesPtr,
541 "session timeout minutes mismatch")
542 })
543 t.Run("TimeoutDataBearing", func(t *testing.T) {
544 topo, err := New(nil)
545 noerr(t, err)
546 topo.servers["foo"] = nil
547 topo.servers["bar"] = nil
548 topo.fsm.Servers = []description.Server{
549 {
550 Addr: address.Address("foo").Canonicalize(),
551 Kind: description.RSPrimary,
552 SessionTimeoutMinutes: 60,
553 SessionTimeoutMinutesPtr: int64ToPtr(60),
554 },
555 {
556 Addr: address.Address("bar").Canonicalize(),
557 Kind: description.RSSecondary,
558 SessionTimeoutMinutes: 60,
559 SessionTimeoutMinutesPtr: int64ToPtr(60),
560 },
561 }
562
563 ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
564 defer cancel()
565
566 desc1 := description.Server{
567 Addr: "foo",
568 Kind: description.RSPrimary,
569 SessionTimeoutMinutes: 20,
570 SessionTimeoutMinutesPtr: int64ToPtr(20),
571 Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
572 }
573
574 desc2 := description.Server{
575 Addr: "bar",
576 Kind: description.Unknown,
577 SessionTimeoutMinutes: 10,
578 SessionTimeoutMinutesPtr: int64ToPtr(10),
579 Members: []address.Address{address.Address("foo").Canonicalize(), address.Address("bar").Canonicalize()},
580 }
581 topo.apply(ctx, desc1)
582 topo.apply(ctx, desc2)
583
584 currDesc := topo.desc.Load().(description.Topology)
585 want := int64(20)
586 assert.Equal(t, &want, currDesc.SessionTimeoutMinutesPtr,
587 "session timeout minutes mismatch")
588 })
589 t.Run("MixedSessionSupport", func(t *testing.T) {
590 topo, err := New(nil)
591 noerr(t, err)
592 topo.fsm.Kind = description.ReplicaSetWithPrimary
593 topo.servers["one"] = nil
594 topo.servers["two"] = nil
595 topo.servers["three"] = nil
596 topo.fsm.Servers = []description.Server{
597 {
598 Addr: address.Address("one").Canonicalize(),
599 Kind: description.RSPrimary,
600 SessionTimeoutMinutes: 20,
601 SessionTimeoutMinutesPtr: int64ToPtr(20),
602 },
603 {
604
605 Addr: address.Address("two").Canonicalize(),
606 Kind: description.RSSecondary,
607 },
608 {
609 Addr: address.Address("three").Canonicalize(),
610 Kind: description.RSPrimary,
611 SessionTimeoutMinutes: 60,
612 SessionTimeoutMinutesPtr: int64ToPtr(60),
613 },
614 }
615
616 ctx, cancel := context.WithTimeout(context.Background(), testTimeout)
617 defer cancel()
618
619 desc := description.Server{
620 Addr: address.Address("three"),
621 Kind: description.RSSecondary,
622 SessionTimeoutMinutes: 30,
623 SessionTimeoutMinutesPtr: int64ToPtr(30),
624 }
625
626 topo.apply(ctx, desc)
627
628 currDesc := topo.desc.Load().(description.Topology)
629 require.Nil(t, currDesc.SessionTimeoutMinutesPtr,
630 "session timeout minutes mismatch. got: %d. expected: nil", currDesc.SessionTimeoutMinutes)
631 })
632 }
633
634 func TestMinPoolSize(t *testing.T) {
635 cfg, err := NewConfig(options.Client().SetHosts([]string{"localhost:27017"}).SetMinPoolSize(10), nil)
636 if err != nil {
637 t.Errorf("error constructing topology config: %v", err)
638 }
639
640 topo, err := New(cfg)
641 if err != nil {
642 t.Errorf("topology.New shouldn't error. got: %v", err)
643 }
644 err = topo.Connect()
645 if err != nil {
646 t.Errorf("topology.Connect shouldn't error. got: %v", err)
647 }
648 }
649
650 func TestTopology_String_Race(_ *testing.T) {
651 ch := make(chan bool)
652 topo := &Topology{
653 servers: make(map[address.Address]*Server),
654 }
655
656 go func() {
657 topo.serversLock.Lock()
658 srv := &Server{}
659 srv.desc.Store(description.Server{})
660 topo.servers[address.Address("127.0.0.1:27017")] = srv
661 topo.serversLock.Unlock()
662 ch <- true
663 }()
664
665 go func() {
666 _ = topo.String()
667 ch <- true
668 }()
669
670 <-ch
671 <-ch
672 }
673
674 func TestTopologyConstruction(t *testing.T) {
675 t.Run("construct with URI", func(t *testing.T) {
676 testCases := []struct {
677 name string
678 uri string
679 pollingRequired bool
680 }{
681 {
682 name: "normal",
683 uri: "mongodb://localhost:27017",
684 pollingRequired: false,
685 },
686 }
687 for _, tc := range testCases {
688 t.Run(tc.name, func(t *testing.T) {
689 cfg, err := NewConfig(options.Client().ApplyURI(tc.uri), nil)
690 assert.Nil(t, err, "error constructing topology config: %v", err)
691
692 topo, err := New(cfg)
693 assert.Nil(t, err, "topology.New error: %v", err)
694
695 assert.Equal(t, tc.uri, topo.cfg.URI, "expected topology URI to be %v, got %v", tc.uri, topo.cfg.URI)
696 assert.Equal(t, tc.pollingRequired, topo.pollingRequired,
697 "expected topo.pollingRequired to be %v, got %v", tc.pollingRequired, topo.pollingRequired)
698 })
699 }
700 })
701 }
702
703 type mockLogSink struct {
704 msgs []string
705 }
706
707 func (s *mockLogSink) Info(_ int, msg string, _ ...interface{}) {
708 s.msgs = append(s.msgs, msg)
709 }
710 func (*mockLogSink) Error(error, string, ...interface{}) {
711
712 }
713
714
715
716 func TestTopologyConstructionLogging(t *testing.T) {
717 const (
718 cosmosDBMsg = `You appear to be connected to a CosmosDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/cosmosdb`
719 documentDBMsg = `You appear to be connected to a DocumentDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/documentdb`
720 )
721
722 newLoggerOptions := func(sink options.LogSink) *options.LoggerOptions {
723 return options.
724 Logger().
725 SetSink(sink).
726 SetComponentLevel(options.LogComponentTopology, options.LogLevelInfo)
727 }
728
729 t.Run("CosmosDB URIs", func(t *testing.T) {
730 t.Parallel()
731
732 testCases := []struct {
733 name string
734 uri string
735 msgs []string
736 }{
737 {
738 name: "normal",
739 uri: "mongodb://a.mongo.cosmos.azure.com:19555/",
740 msgs: []string{cosmosDBMsg},
741 },
742 {
743 name: "multiple hosts",
744 uri: "mongodb://a.mongo.cosmos.azure.com:1955,b.mongo.cosmos.azure.com:19555/",
745 msgs: []string{cosmosDBMsg},
746 },
747 {
748 name: "case-insensitive matching",
749 uri: "mongodb://a.MONGO.COSMOS.AZURE.COM:19555/",
750 msgs: []string{},
751 },
752 {
753 name: "Mixing genuine and nongenuine hosts (unlikely in practice)",
754 uri: "mongodb://a.example.com:27017,b.mongo.cosmos.azure.com:19555/",
755 msgs: []string{cosmosDBMsg},
756 },
757 }
758 for _, tc := range testCases {
759 tc := tc
760
761 t.Run(tc.name, func(t *testing.T) {
762 t.Parallel()
763
764 sink := &mockLogSink{}
765 cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil)
766 require.Nil(t, err, "error constructing topology config: %v", err)
767
768 topo, err := New(cfg)
769 require.Nil(t, err, "topology.New error: %v", err)
770
771 err = topo.Connect()
772 assert.Nil(t, err, "Connect error: %v", err)
773
774 assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs)
775 })
776 }
777 })
778 t.Run("DocumentDB URIs", func(t *testing.T) {
779 t.Parallel()
780
781 testCases := []struct {
782 name string
783 uri string
784 msgs []string
785 }{
786 {
787 name: "normal",
788 uri: "mongodb://a.docdb.amazonaws.com:27017/",
789 msgs: []string{documentDBMsg},
790 },
791 {
792 name: "normal",
793 uri: "mongodb://a.docdb-elastic.amazonaws.com:27017/",
794 msgs: []string{documentDBMsg},
795 },
796 {
797 name: "multiple hosts",
798 uri: "mongodb://a.docdb.amazonaws.com:27017,a.docdb-elastic.amazonaws.com:27017/",
799 msgs: []string{documentDBMsg},
800 },
801 {
802 name: "case-insensitive matching",
803 uri: "mongodb://a.DOCDB.AMAZONAWS.COM:27017/",
804 msgs: []string{},
805 },
806 {
807 name: "case-insensitive matching",
808 uri: "mongodb://a.DOCDB-ELASTIC.AMAZONAWS.COM:27017/",
809 msgs: []string{},
810 },
811 {
812 name: "Mixing genuine and nongenuine hosts (unlikely in practice)",
813 uri: "mongodb://a.example.com:27017,b.docdb.amazonaws.com:27017/",
814 msgs: []string{documentDBMsg},
815 },
816 {
817 name: "Mixing genuine and nongenuine hosts (unlikely in practice)",
818 uri: "mongodb://a.example.com:27017,b.docdb-elastic.amazonaws.com:27017/",
819 msgs: []string{documentDBMsg},
820 },
821 }
822 for _, tc := range testCases {
823 tc := tc
824
825 t.Run(tc.name, func(t *testing.T) {
826 t.Parallel()
827
828 sink := &mockLogSink{}
829 cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil)
830 require.Nil(t, err, "error constructing topology config: %v", err)
831
832 topo, err := New(cfg)
833 require.Nil(t, err, "topology.New error: %v", err)
834
835 err = topo.Connect()
836 assert.Nil(t, err, "Connect error: %v", err)
837
838 assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs)
839 })
840 }
841 })
842 t.Run("Mixing CosmosDB and DocumentDB URIs", func(t *testing.T) {
843 t.Parallel()
844
845 testCases := []struct {
846 name string
847 uri string
848 msgs []string
849 }{
850 {
851 name: "Mixing hosts",
852 uri: "mongodb://a.mongo.cosmos.azure.com:19555,a.docdb.amazonaws.com:27017/",
853 msgs: []string{cosmosDBMsg, documentDBMsg},
854 },
855 }
856 for _, tc := range testCases {
857 tc := tc
858
859 t.Run(tc.name, func(t *testing.T) {
860 t.Parallel()
861
862 sink := &mockLogSink{}
863 cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil)
864 require.Nil(t, err, "error constructing topology config: %v", err)
865
866 topo, err := New(cfg)
867 require.Nil(t, err, "topology.New error: %v", err)
868
869 err = topo.Connect()
870 assert.Nil(t, err, "Connect error: %v", err)
871
872 assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs)
873 })
874 }
875 })
876 t.Run("genuine URIs", func(t *testing.T) {
877 t.Parallel()
878
879 testCases := []struct {
880 name string
881 uri string
882 msgs []string
883 }{
884 {
885 name: "normal",
886 uri: "mongodb://a.example.com:27017/",
887 msgs: []string{},
888 },
889 {
890 name: "socket",
891 uri: "mongodb://%2Ftmp%2Fmongodb-27017.sock/",
892 msgs: []string{},
893 },
894 {
895 name: "srv",
896 uri: "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname",
897 msgs: []string{},
898 },
899 {
900 name: "multiple hosts",
901 uri: "mongodb://a.example.com:27017,b.example.com:27017/",
902 msgs: []string{},
903 },
904 {
905 name: "unexpected suffix",
906 uri: "mongodb://a.mongo.cosmos.azure.com.tld:19555/",
907 msgs: []string{},
908 },
909 {
910 name: "unexpected suffix",
911 uri: "mongodb://a.docdb.amazonaws.com.tld:27017/",
912 msgs: []string{},
913 },
914 {
915 name: "unexpected suffix",
916 uri: "mongodb://a.docdb-elastic.amazonaws.com.tld:27017/",
917 msgs: []string{},
918 },
919 }
920 for _, tc := range testCases {
921 tc := tc
922
923 t.Run(tc.name, func(t *testing.T) {
924 t.Parallel()
925
926 sink := &mockLogSink{}
927 cfg, err := NewConfig(options.Client().ApplyURI(tc.uri).SetLoggerOptions(newLoggerOptions(sink)), nil)
928 require.Nil(t, err, "error constructing topology config: %v", err)
929
930 topo, err := New(cfg)
931 require.Nil(t, err, "topology.New error: %v", err)
932
933 err = topo.Connect()
934 assert.Nil(t, err, "Connect error: %v", err)
935
936 assert.ElementsMatch(t, tc.msgs, sink.msgs, "expected messages to be %v, got %v", tc.msgs, sink.msgs)
937 })
938 }
939 })
940 }
941
942 type inWindowServer struct {
943 Address string `json:"address"`
944 Type string `json:"type"`
945 AvgRTTMS int64 `json:"avg_rtt_ms"`
946 }
947
948 type inWindowTopology struct {
949 Type string `json:"type"`
950 Servers []inWindowServer `json:"servers"`
951 }
952
953 type inWindowOutcome struct {
954 Tolerance float64 `json:"tolerance"`
955 ExpectedFrequencies map[string]float64 `json:"expected_frequencies"`
956 }
957
958 type inWindowTopologyState struct {
959 Address string `json:"address"`
960 OperationCount int64 `json:"operation_count"`
961 }
962
963 type inWindowTestCase struct {
964 TopologyDescription inWindowTopology `json:"topology_description"`
965 MockedTopologyState []inWindowTopologyState `json:"mocked_topology_state"`
966 Iterations int `json:"iterations"`
967 Outcome inWindowOutcome `json:"outcome"`
968 }
969
970
971
972
973
974
975
976
977 func TestServerSelectionSpecInWindow(t *testing.T) {
978 const testsDir = "../../../../testdata/server-selection/in_window"
979
980 files := spectest.FindJSONFilesInDir(t, testsDir)
981
982 for _, file := range files {
983 t.Run(file, func(t *testing.T) {
984 runInWindowTest(t, testsDir, file)
985 })
986 }
987 }
988
989 func runInWindowTest(t *testing.T, directory string, filename string) {
990 filepath := path.Join(directory, filename)
991 content, err := ioutil.ReadFile(filepath)
992 require.NoError(t, err)
993
994 var test inWindowTestCase
995 require.NoError(t, json.Unmarshal(content, &test))
996
997
998
999 servers := make(map[string]*Server, len(test.TopologyDescription.Servers))
1000 descriptions := make([]description.Server, 0, len(test.TopologyDescription.Servers))
1001 for _, testDesc := range test.TopologyDescription.Servers {
1002 server := NewServer(
1003 address.Address(testDesc.Address),
1004 primitive.NilObjectID,
1005 withMonitoringDisabled(func(bool) bool { return true }))
1006 servers[testDesc.Address] = server
1007
1008 desc := description.Server{
1009 Kind: serverKindFromString(t, testDesc.Type),
1010 Addr: address.Address(testDesc.Address),
1011 AverageRTT: time.Duration(testDesc.AvgRTTMS) * time.Millisecond,
1012 AverageRTTSet: true,
1013 }
1014
1015 if testDesc.AvgRTTMS > 0 {
1016 desc.AverageRTT = time.Duration(testDesc.AvgRTTMS) * time.Millisecond
1017 desc.AverageRTTSet = true
1018 }
1019
1020 descriptions = append(descriptions, desc)
1021 }
1022
1023
1024
1025 for _, state := range test.MockedTopologyState {
1026 servers[state.Address].operationCount = state.OperationCount
1027 }
1028
1029
1030
1031
1032 topology, err := New(nil)
1033 require.NoError(t, err, "error creating new Topology")
1034 topology.state = topologyConnected
1035 topology.desc.Store(description.Topology{
1036 Kind: topologyKindFromString(t, test.TopologyDescription.Type),
1037 Servers: descriptions,
1038 })
1039 for addr, server := range servers {
1040 topology.servers[address.Address(addr)] = server
1041 }
1042
1043
1044
1045 counts := make(map[string]int, len(test.TopologyDescription.Servers))
1046 for i := 0; i < test.Iterations; i++ {
1047 selected, err := topology.SelectServer(
1048 context.Background(),
1049 description.ReadPrefSelector(readpref.Nearest()))
1050 require.NoError(t, err, "error selecting server")
1051 counts[string(selected.(*SelectedServer).address)]++
1052 }
1053
1054
1055
1056 frequencies := make(map[string]float64, len(counts))
1057 for addr, count := range counts {
1058 frequencies[addr] = float64(count) / float64(test.Iterations)
1059 }
1060
1061
1062
1063 for addr, expected := range test.Outcome.ExpectedFrequencies {
1064 actual := frequencies[addr]
1065
1066
1067
1068 if expected == 1 || expected == 0 {
1069 assert.Equal(
1070 t,
1071 expected,
1072 actual,
1073 "expected frequency of %q to be equal to %f, but is %f",
1074 addr, expected, actual)
1075 continue
1076 }
1077
1078
1079
1080
1081 low := expected - test.Outcome.Tolerance
1082 high := expected + test.Outcome.Tolerance
1083 assert.True(
1084 t,
1085 actual >= low && actual <= high,
1086 "expected frequency of %q to be in range [%f, %f], but is %f",
1087 addr, low, high, actual)
1088 }
1089 }
1090
1091 func topologyKindFromString(t *testing.T, s string) description.TopologyKind {
1092 t.Helper()
1093
1094 switch s {
1095 case "Single":
1096 return description.Single
1097 case "ReplicaSet":
1098 return description.ReplicaSet
1099 case "ReplicaSetNoPrimary":
1100 return description.ReplicaSetNoPrimary
1101 case "ReplicaSetWithPrimary":
1102 return description.ReplicaSetWithPrimary
1103 case "Sharded":
1104 return description.Sharded
1105 case "LoadBalanced":
1106 return description.LoadBalanced
1107 case "Unknown":
1108 return description.Unknown
1109 default:
1110 t.Fatalf("unrecognized topology kind: %q", s)
1111 }
1112
1113 return description.Unknown
1114 }
1115
1116 func serverKindFromString(t *testing.T, s string) description.ServerKind {
1117 t.Helper()
1118
1119 switch s {
1120 case "Standalone":
1121 return description.Standalone
1122 case "RSOther":
1123 return description.RSMember
1124 case "RSPrimary":
1125 return description.RSPrimary
1126 case "RSSecondary":
1127 return description.RSSecondary
1128 case "RSArbiter":
1129 return description.RSArbiter
1130 case "RSGhost":
1131 return description.RSGhost
1132 case "Mongos":
1133 return description.Mongos
1134 case "LoadBalancer":
1135 return description.LoadBalancer
1136 case "PossiblePrimary", "Unknown":
1137
1138 return description.Unknown
1139 default:
1140 t.Fatalf("unrecognized server kind: %q", s)
1141 }
1142
1143 return description.Unknown
1144 }
1145
1146 func BenchmarkSelectServerFromDescription(b *testing.B) {
1147 for _, bcase := range []struct {
1148 name string
1149 serversHook func(servers []description.Server)
1150 }{
1151 {
1152 name: "AllFit",
1153 serversHook: func(servers []description.Server) {},
1154 },
1155 {
1156 name: "AllButOneFit",
1157 serversHook: func(servers []description.Server) {
1158 servers[0].Kind = description.Unknown
1159 },
1160 },
1161 {
1162 name: "HalfFit",
1163 serversHook: func(servers []description.Server) {
1164 for i := 0; i < len(servers); i += 2 {
1165 servers[i].Kind = description.Unknown
1166 }
1167 },
1168 },
1169 {
1170 name: "OneFit",
1171 serversHook: func(servers []description.Server) {
1172 for i := 1; i < len(servers); i++ {
1173 servers[i].Kind = description.Unknown
1174 }
1175 },
1176 },
1177 } {
1178 bcase := bcase
1179
1180 b.Run(bcase.name, func(b *testing.B) {
1181 s := description.Server{
1182 Addr: address.Address("localhost:27017"),
1183 HeartbeatInterval: time.Duration(10) * time.Second,
1184 LastWriteTime: time.Date(2017, 2, 11, 14, 0, 0, 0, time.UTC),
1185 LastUpdateTime: time.Date(2017, 2, 11, 14, 0, 2, 0, time.UTC),
1186 Kind: description.Mongos,
1187 WireVersion: &description.VersionRange{Min: 6, Max: 21},
1188 }
1189 servers := make([]description.Server, 100)
1190 for i := 0; i < len(servers); i++ {
1191 servers[i] = s
1192 }
1193 bcase.serversHook(servers)
1194 desc := description.Topology{
1195 Servers: servers,
1196 }
1197
1198 timeout := make(chan time.Time)
1199 b.ResetTimer()
1200 b.RunParallel(func(p *testing.PB) {
1201 b.ReportAllocs()
1202 for p.Next() {
1203 var c Topology
1204 _, _ = c.selectServerFromDescription(desc, newServerSelectionState(selectNone, timeout))
1205 }
1206 })
1207 })
1208 }
1209 }
1210
1211 func TestLogUnexpectedFailure(t *testing.T) {
1212 t.Parallel()
1213
1214
1215 newIOLogger := func() (*logger.Logger, *bytes.Buffer, *bufio.Writer) {
1216 buf := bytes.NewBuffer(nil)
1217 w := bufio.NewWriter(buf)
1218
1219 ioSink := logger.NewIOSink(w)
1220
1221 ioLogger, err := logger.New(ioSink, logger.DefaultMaxDocumentLength, map[logger.Component]logger.Level{
1222 logger.ComponentTopology: logger.LevelDebug,
1223 })
1224
1225 assert.NoError(t, err)
1226
1227 return ioLogger, buf, w
1228 }
1229
1230
1231 newNilLogger := func() (*logger.Logger, *bytes.Buffer, *bufio.Writer) {
1232 return nil, &bytes.Buffer{}, &bufio.Writer{}
1233 }
1234
1235 tests := []struct {
1236 name string
1237 msg string
1238 newLogger func() (*logger.Logger, *bytes.Buffer, *bufio.Writer)
1239 panicValue interface{}
1240 want interface{}
1241 }{
1242 {
1243 name: "nil logger",
1244 msg: "",
1245 newLogger: newNilLogger,
1246 panicValue: 1,
1247 want: nil,
1248 },
1249 {
1250 name: "valid logger",
1251 msg: "test",
1252 newLogger: newIOLogger,
1253 panicValue: 1,
1254 want: "test: 1",
1255 },
1256 {
1257 name: "valid logger with error panic",
1258 msg: "test",
1259 newLogger: newIOLogger,
1260 panicValue: errors.New("err"),
1261 want: "test: err",
1262 },
1263 }
1264
1265 for _, test := range tests {
1266 test := test
1267
1268 t.Run(test.name, func(t *testing.T) {
1269 t.Parallel()
1270
1271 log, buf, w := test.newLogger()
1272
1273 func() {
1274 defer logUnexpectedFailure(log, test.msg)
1275
1276 panic(test.panicValue)
1277 }()
1278
1279 assert.NoError(t, w.Flush())
1280
1281 got := map[string]interface{}{}
1282 _ = json.Unmarshal(buf.Bytes(), &got)
1283
1284 assert.Equal(t, test.want, got[logger.KeyMessage])
1285 })
1286 }
1287 }
1288
View as plain text