1
2
3
4
5
6
7
8
9
10 package topology
11
12 import (
13 "context"
14 "crypto/tls"
15 "crypto/x509"
16 "errors"
17 "io/ioutil"
18 "net"
19 "os"
20 "runtime"
21 "sync"
22 "sync/atomic"
23 "testing"
24 "time"
25
26 "github.com/google/go-cmp/cmp"
27 "go.mongodb.org/mongo-driver/bson/primitive"
28 "go.mongodb.org/mongo-driver/event"
29 "go.mongodb.org/mongo-driver/internal/assert"
30 "go.mongodb.org/mongo-driver/internal/eventtest"
31 "go.mongodb.org/mongo-driver/internal/require"
32 "go.mongodb.org/mongo-driver/mongo/address"
33 "go.mongodb.org/mongo-driver/mongo/description"
34 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
35 "go.mongodb.org/mongo-driver/x/mongo/driver"
36 "go.mongodb.org/mongo-driver/x/mongo/driver/auth"
37 "go.mongodb.org/mongo-driver/x/mongo/driver/drivertest"
38 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
39 )
40
41 type channelNetConnDialer struct{}
42
43 func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (net.Conn, error) {
44 cnc := &drivertest.ChannelNetConn{
45 Written: make(chan []byte, 1),
46 ReadResp: make(chan []byte, 2),
47 ReadErr: make(chan error, 1),
48 }
49 if err := cnc.AddResponse(makeHelloReply()); err != nil {
50 return nil, err
51 }
52
53 return cnc, nil
54 }
55
56 type errorQueue struct {
57 errors []error
58 mutex sync.Mutex
59 }
60
61 func (eq *errorQueue) head() error {
62 eq.mutex.Lock()
63 defer eq.mutex.Unlock()
64 if len(eq.errors) > 0 {
65 return eq.errors[0]
66 }
67 return nil
68 }
69
70 func (eq *errorQueue) dequeue() bool {
71 eq.mutex.Lock()
72 defer eq.mutex.Unlock()
73 if len(eq.errors) > 0 {
74 eq.errors = eq.errors[1:]
75 return true
76 }
77 return false
78 }
79
80 type timeoutConn struct {
81 net.Conn
82 errors *errorQueue
83 }
84
85 func (c *timeoutConn) Read(b []byte) (int, error) {
86 n, err := 0, c.errors.head()
87 if err == nil {
88 n, err = c.Conn.Read(b)
89 }
90 return n, err
91 }
92
93 func (c *timeoutConn) Write(b []byte) (int, error) {
94 n, err := 0, c.errors.head()
95 if err == nil {
96 n, err = c.Conn.Write(b)
97 }
98 return n, err
99 }
100
101 type timeoutDialer struct {
102 Dialer
103 errors *errorQueue
104 }
105
106 func (d *timeoutDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
107 c, e := d.Dialer.DialContext(ctx, network, address)
108
109 if caFile := os.Getenv("MONGO_GO_DRIVER_CA_FILE"); len(caFile) > 0 {
110 pem, err := ioutil.ReadFile(caFile)
111 if err != nil {
112 return nil, err
113 }
114
115 ca := x509.NewCertPool()
116 if !ca.AppendCertsFromPEM(pem) {
117 return nil, errors.New("unable to load CA file")
118 }
119
120 config := &tls.Config{
121 InsecureSkipVerify: true,
122 RootCAs: ca,
123 }
124 c = tls.Client(c, config)
125 }
126 return &timeoutConn{c, d.errors}, e
127 }
128
129
130 func TestServerHeartbeatTimeout(t *testing.T) {
131 if os.Getenv("DOCKER_RUNNING") != "" {
132 t.Skip("Skipping this test in docker.")
133 }
134
135 networkTimeoutError := &net.DNSError{
136 IsTimeout: true,
137 }
138
139 testCases := []struct {
140 desc string
141 ioErrors []error
142 expectInterruptions int
143 }{
144 {
145 desc: "one single timeout should not clear the pool",
146 ioErrors: []error{nil, networkTimeoutError, nil, networkTimeoutError, nil},
147 expectInterruptions: 0,
148 },
149 {
150 desc: "continuous timeouts should clear the pool with interruption",
151 ioErrors: []error{nil, networkTimeoutError, networkTimeoutError, nil},
152 expectInterruptions: 1,
153 },
154 }
155 for _, tc := range testCases {
156 tc := tc
157 t.Run(tc.desc, func(t *testing.T) {
158 t.Parallel()
159
160 var wg sync.WaitGroup
161 wg.Add(1)
162
163 errors := &errorQueue{errors: tc.ioErrors}
164 tpm := eventtest.NewTestPoolMonitor()
165 server := NewServer(
166 address.Address("localhost:27017"),
167 primitive.NewObjectID(),
168 WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor {
169 return tpm.PoolMonitor
170 }),
171 WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption {
172 return append(opts,
173 WithDialer(func(d Dialer) Dialer {
174 var dialer net.Dialer
175 return &timeoutDialer{&dialer, errors}
176 }))
177 }),
178 WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor {
179 return &event.ServerMonitor{
180 ServerHeartbeatSucceeded: func(e *event.ServerHeartbeatSucceededEvent) {
181 if !errors.dequeue() {
182 wg.Done()
183 }
184 },
185 ServerHeartbeatFailed: func(e *event.ServerHeartbeatFailedEvent) {
186 if !errors.dequeue() {
187 wg.Done()
188 }
189 },
190 }
191 }),
192 WithHeartbeatInterval(func(time.Duration) time.Duration {
193 return 200 * time.Millisecond
194 }),
195 )
196 require.NoError(t, server.Connect(nil))
197 wg.Wait()
198 interruptions := tpm.Interruptions()
199 assert.Equal(t, tc.expectInterruptions, interruptions, "expected %d interruption but got %d", tc.expectInterruptions, interruptions)
200 })
201 }
202 }
203
204
205
206 func TestServerConnectionTimeout(t *testing.T) {
207 testCases := []struct {
208 desc string
209 dialer func(Dialer) Dialer
210 handshaker func(Handshaker) Handshaker
211 operationTimeout time.Duration
212 connectTimeout time.Duration
213 expectErr bool
214 expectPoolCleared bool
215 }{
216 {
217 desc: "successful connection should not clear the pool",
218 expectErr: false,
219 expectPoolCleared: false,
220 },
221 {
222 desc: "timeout error during dialing should clear the pool",
223 dialer: func(Dialer) Dialer {
224 var d net.Dialer
225 return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
226
227
228 <-ctx.Done()
229 return d.DialContext(ctx, network, addr)
230 })
231 },
232 operationTimeout: 1 * time.Minute,
233 connectTimeout: 100 * time.Millisecond,
234 expectErr: true,
235 expectPoolCleared: true,
236 },
237 {
238 desc: "timeout error during dialing with no operation timeout should clear the pool",
239 dialer: func(Dialer) Dialer {
240 var d net.Dialer
241 return DialerFunc(func(ctx context.Context, network, addr string) (net.Conn, error) {
242
243
244 <-ctx.Done()
245 return d.DialContext(ctx, network, addr)
246 })
247 },
248 operationTimeout: 0,
249 connectTimeout: 100 * time.Millisecond,
250 expectErr: true,
251 expectPoolCleared: true,
252 },
253 {
254 desc: "dial errors unrelated to context timeouts should clear the pool",
255 dialer: func(Dialer) Dialer {
256 var d net.Dialer
257 return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
258
259 return d.DialContext(ctx, "tcp", "300.0.0.0:nope")
260 })
261 },
262 expectErr: true,
263 expectPoolCleared: true,
264 },
265 {
266 desc: "operation context timeout with unrelated dial errors should clear the pool",
267 dialer: func(Dialer) Dialer {
268 var d net.Dialer
269 return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
270
271 c, err := d.DialContext(ctx, "tcp", "300.0.0.0:nope")
272
273
274 <-ctx.Done()
275 return c, err
276 })
277 },
278 operationTimeout: 1 * time.Millisecond,
279 connectTimeout: 100 * time.Millisecond,
280 expectErr: true,
281 expectPoolCleared: true,
282 },
283 }
284
285 for _, tc := range testCases {
286 tc := tc
287 t.Run(tc.desc, func(t *testing.T) {
288 t.Parallel()
289
290
291
292 l, err := net.Listen("tcp", "127.0.0.1:0")
293 require.NoError(t, err)
294 defer func() {
295 _ = l.Close()
296 }()
297
298 tpm := eventtest.NewTestPoolMonitor()
299 server := NewServer(
300 address.Address(l.Addr().String()),
301 primitive.NewObjectID(),
302 WithConnectionPoolMonitor(func(*event.PoolMonitor) *event.PoolMonitor {
303 return tpm.PoolMonitor
304 }),
305
306
307 WithConnectionOptions(func(opts ...ConnectionOption) []ConnectionOption {
308 if tc.connectTimeout > 0 {
309 opts = append(opts, WithConnectTimeout(func(time.Duration) time.Duration { return tc.connectTimeout }))
310 }
311 if tc.dialer != nil {
312 opts = append(opts, WithDialer(tc.dialer))
313 }
314 if tc.handshaker != nil {
315 opts = append(opts, WithHandshaker(tc.handshaker))
316 }
317 return opts
318 }),
319
320
321 withMonitoringDisabled(func(bool) bool { return true }),
322 )
323 require.NoError(t, server.Connect(nil))
324
325
326 ctx := context.Background()
327 if tc.operationTimeout > 0 {
328 var cancel context.CancelFunc
329 ctx, cancel = context.WithTimeout(ctx, tc.operationTimeout)
330 defer cancel()
331 }
332 _, err = server.Connection(ctx)
333 if tc.expectErr {
334 assert.NotNil(t, err, "expected an error but got nil")
335 } else {
336 assert.Nil(t, err, "expected no error but got %s", err)
337 }
338
339
340
341 if tc.expectPoolCleared {
342 assert.Eventually(t,
343 tpm.IsPoolCleared,
344 10*time.Second,
345 100*time.Millisecond,
346 "expected pool to be cleared within 10s but was not cleared")
347 }
348
349
350
351 _ = server.Disconnect(context.Background())
352
353
354
355 if !tc.expectPoolCleared {
356 assert.False(t, tpm.IsPoolCleared(), "expected pool to not be cleared but was cleared")
357 }
358 })
359 }
360 }
361
362 func TestServer(t *testing.T) {
363 var serverTestTable = []struct {
364 name string
365 connectionError bool
366 networkError bool
367 hasDesc bool
368 }{
369 {"auth_error", true, false, false},
370 {"no_error", false, false, false},
371 {"network_error_no_desc", false, true, false},
372 {"network_error_desc", false, true, true},
373 }
374
375 authErr := ConnectionError{Wrapped: &auth.Error{}, init: true}
376 netErr := ConnectionError{Wrapped: &net.AddrError{}, init: true}
377 for _, tt := range serverTestTable {
378 t.Run(tt.name, func(t *testing.T) {
379 var returnConnectionError bool
380 s := NewServer(
381 address.Address("localhost"),
382 primitive.NewObjectID(),
383 WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption {
384 return append(connOpts,
385 WithHandshaker(func(Handshaker) Handshaker {
386 return &testHandshaker{
387 finishHandshake: func(context.Context, driver.Connection) error {
388 var err error
389 if tt.connectionError && returnConnectionError {
390 err = authErr.Wrapped
391 }
392 return err
393 },
394 }
395 }),
396 WithDialer(func(Dialer) Dialer {
397 return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
398 var err error
399 if tt.networkError && returnConnectionError {
400 err = netErr.Wrapped
401 }
402 return &net.TCPConn{}, err
403 })
404 }),
405 )
406 }),
407 )
408
409 var desc *description.Server
410 descript := s.Description()
411 if tt.hasDesc {
412 desc = &descript
413 require.Nil(t, desc.LastError)
414 }
415 err := s.pool.ready()
416 require.NoError(t, err, "pool.ready() error")
417 defer s.pool.close(context.Background())
418
419 s.state = serverConnected
420
421
422
423
424
425 _, err = s.Connection(context.Background())
426 assert.Nil(t, err, "error getting initial connection: %v", err)
427 returnConnectionError = true
428 _, err = s.Connection(context.Background())
429
430 switch {
431 case tt.connectionError && !cmp.Equal(err, authErr, cmp.Comparer(compareErrors)):
432 t.Errorf("Expected connection error. got %v; want %v", err, authErr)
433 case tt.networkError && !cmp.Equal(err, netErr, cmp.Comparer(compareErrors)):
434 t.Errorf("Expected network error. got %v; want %v", err, netErr)
435 case !tt.connectionError && !tt.networkError && err != nil:
436 t.Errorf("Expected error to be nil. got %v; want %v", err, "<nil>")
437 }
438
439 if tt.hasDesc {
440 require.Equal(t, s.Description().Kind, (description.ServerKind)(description.Unknown))
441 require.NotNil(t, s.Description().LastError)
442 }
443
444 generation, _ := s.pool.generation.getGeneration(nil)
445 if (tt.connectionError || tt.networkError) && generation != 1 {
446 t.Errorf("Expected pool to be drained once on connection or network error. got %d; want %d", generation, 1)
447 }
448 })
449 }
450
451 t.Run("multiple connection initialization errors are processed correctly", func(t *testing.T) {
452 assertGenerationStats := func(t *testing.T, server *Server, serviceID primitive.ObjectID, wantGeneration, wantNumConns uint64) {
453 t.Helper()
454
455 getGeneration := func(serviceIDPtr *primitive.ObjectID) uint64 {
456 generation, _ := server.pool.generation.getGeneration(serviceIDPtr)
457 return generation
458 }
459
460
461
462
463 assert.Eventuallyf(t,
464 func() bool {
465 generation, _ := server.pool.generation.getGeneration(&serviceID)
466 numConns := server.pool.generation.getNumConns(&serviceID)
467 return generation == wantGeneration && numConns == wantNumConns
468 },
469 100*time.Millisecond,
470 10*time.Millisecond,
471 "expected generation number %v, got %v; expected connection count %v, got %v",
472 wantGeneration,
473 getGeneration(&serviceID),
474 wantNumConns,
475 server.pool.generation.getNumConns(&serviceID))
476 }
477
478 testCases := []struct {
479 name string
480 loadBalanced bool
481 dialErr error
482 getInfoErr error
483 finishHandshakeErr error
484 finalGeneration uint64
485 numNewConns uint64
486 }{
487
488 {"dial errors are ignored for load balancers", true, netErr.Wrapped, nil, nil, 0, 1},
489 {"initial handshake errors are ignored for load balancers", true, nil, netErr.Wrapped, nil, 0, 1},
490
491
492
493 {"post-handshake errors are not ignored for load balancers", true, nil, nil, netErr.Wrapped, 5, 1},
494
495
496
497
498 {"dial errors are not ignored for non-lb clusters", false, netErr.Wrapped, nil, nil, 1, 1},
499 {"initial handshake errors are not ignored for non-lb clusters", false, nil, netErr.Wrapped, nil, 1, 1},
500 {"post-handshake errors are not ignored for non-lb clusters", false, nil, nil, netErr.Wrapped, 1, 1},
501 }
502 for _, tc := range testCases {
503 tc := tc
504
505 t.Run(tc.name, func(t *testing.T) {
506 var returnConnectionError bool
507 var serviceID primitive.ObjectID
508 if tc.loadBalanced {
509 serviceID = primitive.NewObjectID()
510 }
511
512 handshaker := &testHandshaker{
513 getHandshakeInformation: func(_ context.Context, addr address.Address, _ driver.Connection) (driver.HandshakeInformation, error) {
514 if tc.getInfoErr != nil && returnConnectionError {
515 return driver.HandshakeInformation{}, tc.getInfoErr
516 }
517
518 desc := description.NewDefaultServer(addr)
519 if tc.loadBalanced {
520 desc.ServiceID = &serviceID
521 }
522 return driver.HandshakeInformation{Description: desc}, nil
523 },
524 finishHandshake: func(context.Context, driver.Connection) error {
525 if tc.finishHandshakeErr != nil && returnConnectionError {
526 return tc.finishHandshakeErr
527 }
528 return nil
529 },
530 }
531 connOpts := []ConnectionOption{
532 WithDialer(func(Dialer) Dialer {
533 return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
534 var err error
535 if returnConnectionError && tc.dialErr != nil {
536 err = tc.dialErr
537 }
538 return &net.TCPConn{}, err
539 })
540 }),
541 WithHandshaker(func(Handshaker) Handshaker {
542 return handshaker
543 }),
544 WithConnectionLoadBalanced(func(bool) bool {
545 return tc.loadBalanced
546 }),
547 }
548 serverOpts := []ServerOption{
549 WithServerLoadBalanced(func(bool) bool {
550 return tc.loadBalanced
551 }),
552 WithConnectionOptions(func(...ConnectionOption) []ConnectionOption {
553 return connOpts
554 }),
555
556
557 withMonitoringDisabled(func(bool) bool {
558 return true
559 }),
560
561
562
563
564
565
566 WithMaxConnecting(func(uint64) uint64 { return 1 }),
567 }
568
569 server, err := ConnectServer(address.Address("localhost:27017"), nil, primitive.NewObjectID(), serverOpts...)
570 assert.Nil(t, err, "ConnectServer error: %v", err)
571 defer func() {
572 _ = server.Disconnect(context.Background())
573 }()
574
575 _, err = server.Connection(context.Background())
576 assert.Nil(t, err, "Connection error: %v", err)
577 assertGenerationStats(t, server, serviceID, 0, 1)
578
579 returnConnectionError = true
580 for i := 0; i < 5; i++ {
581 _, err = server.Connection(context.Background())
582 switch {
583 case tc.dialErr != nil || tc.getInfoErr != nil || tc.finishHandshakeErr != nil:
584 assert.NotNil(t, err, "expected Connection error at iteration %d, got nil", i)
585 default:
586 assert.Nil(t, err, "Connection error at iteration %d: %v", i, err)
587 }
588 }
589 assertGenerationStats(t, server, serviceID, tc.finalGeneration, tc.numNewConns)
590 })
591 }
592 })
593
594 t.Run("Cannot starve connection request", func(t *testing.T) {
595 cleanup := make(chan struct{})
596 addr := bootstrapConnections(t, 3, func(nc net.Conn) {
597 <-cleanup
598 _ = nc.Close()
599 })
600 d := newdialer(&net.Dialer{})
601 s := NewServer(address.Address(addr.String()),
602 primitive.NewObjectID(),
603 WithConnectionOptions(func(option ...ConnectionOption) []ConnectionOption {
604 return []ConnectionOption{WithDialer(func(_ Dialer) Dialer { return d })}
605 }),
606 WithMaxConnections(func(u uint64) uint64 {
607 return 1
608 }))
609 s.state = serverConnected
610 err := s.pool.ready()
611 noerr(t, err)
612 defer s.pool.close(context.Background())
613
614 conn, err := s.Connection(context.Background())
615 noerr(t, err)
616 if d.lenopened() != 1 {
617 t.Errorf("Should have opened 1 connections, but didn't. got %d; want %d", d.lenopened(), 1)
618 }
619
620 var wg sync.WaitGroup
621
622 wg.Add(1)
623 ch := make(chan struct{})
624 go func() {
625 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
626 defer cancel()
627 ch <- struct{}{}
628 _, err := s.Connection(ctx)
629 if err != nil {
630 t.Errorf("Should not be able to starve connection request, but got error: %v", err)
631 }
632 wg.Done()
633 }()
634 <-ch
635 runtime.Gosched()
636 err = conn.Close()
637 noerr(t, err)
638 wg.Wait()
639 close(cleanup)
640 })
641
642 t.Run("update topology", func(t *testing.T) {
643 var updated atomic.Value
644 updated.Store(false)
645
646 updateCallback := func(desc description.Server) description.Server {
647 updated.Store(true)
648 return desc
649 }
650 s, err := ConnectServer(address.Address("localhost"), updateCallback, primitive.NewObjectID())
651 require.NoError(t, err)
652 s.updateDescription(description.Server{Addr: s.address})
653 require.True(t, updated.Load().(bool))
654 })
655 t.Run("heartbeat", func(t *testing.T) {
656
657 dialer := &channelNetConnDialer{}
658 dialerOpt := WithDialer(func(Dialer) Dialer {
659 return dialer
660 })
661 serverOpt := WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption {
662 return append(connOpts, dialerOpt)
663 })
664
665 s := NewServer(address.Address("localhost:27017"), primitive.NewObjectID(), serverOpt)
666
667
668 _, err := s.check()
669 assert.Nil(t, err, "check error: %v", err)
670 assert.NotNil(t, s.conn, "no connection dialed in check")
671
672 channelConn := s.conn.nc.(*drivertest.ChannelNetConn)
673 wm := channelConn.GetWrittenMessage()
674 if wm == nil {
675 t.Fatal("no wire message written for handshake")
676 }
677 if !includesClientMetadata(t, wm) {
678 t.Fatal("client metadata expected in handshake but not found")
679 }
680
681
682 if err = channelConn.AddResponse(makeHelloReply()); err != nil {
683 t.Fatalf("error adding response: %v", err)
684 }
685 _, err = s.check()
686 assert.Nil(t, err, "check error: %v", err)
687
688 wm = channelConn.GetWrittenMessage()
689 if wm == nil {
690 t.Fatal("no wire message written for heartbeat")
691 }
692 if includesClientMetadata(t, wm) {
693 t.Fatal("client metadata not expected in heartbeat but found")
694 }
695 })
696 t.Run("heartbeat monitoring", func(t *testing.T) {
697 var publishedEvents []interface{}
698
699 serverHeartbeatStarted := func(e *event.ServerHeartbeatStartedEvent) {
700 publishedEvents = append(publishedEvents, *e)
701 }
702
703 serverHeartbeatSucceeded := func(e *event.ServerHeartbeatSucceededEvent) {
704 publishedEvents = append(publishedEvents, *e)
705 }
706
707 serverHeartbeatFailed := func(e *event.ServerHeartbeatFailedEvent) {
708 publishedEvents = append(publishedEvents, *e)
709 }
710
711 sdam := &event.ServerMonitor{
712 ServerHeartbeatStarted: serverHeartbeatStarted,
713 ServerHeartbeatSucceeded: serverHeartbeatSucceeded,
714 ServerHeartbeatFailed: serverHeartbeatFailed,
715 }
716
717 dialer := &channelNetConnDialer{}
718 dialerOpt := WithDialer(func(Dialer) Dialer {
719 return dialer
720 })
721 serverOpts := []ServerOption{
722 WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption {
723 return append(connOpts, dialerOpt)
724 }),
725 withMonitoringDisabled(func(bool) bool { return true }),
726 WithServerMonitor(func(*event.ServerMonitor) *event.ServerMonitor { return sdam }),
727 }
728
729 s := NewServer(address.Address("localhost:27017"), primitive.NewObjectID(), serverOpts...)
730
731
732 _, err := s.check()
733 assert.Nil(t, err, "check error: %v", err)
734
735 channelConn := s.conn.nc.(*drivertest.ChannelNetConn)
736 _ = channelConn.GetWrittenMessage()
737
738 t.Run("success", func(t *testing.T) {
739 publishedEvents = nil
740
741 if err = channelConn.AddResponse(makeHelloReply()); err != nil {
742 t.Fatalf("error adding response: %v", err)
743 }
744 _, err = s.check()
745 _ = channelConn.GetWrittenMessage()
746 assert.Nil(t, err, "check error: %v", err)
747
748 assert.Equal(t, len(publishedEvents), 2, "expected %v events, got %v", 2, len(publishedEvents))
749
750 started, ok := publishedEvents[0].(event.ServerHeartbeatStartedEvent)
751 assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatStartedEvent{}, publishedEvents[0])
752 assert.Equal(t, started.ConnectionID, s.conn.ID(), "expected connectionID to match")
753 assert.False(t, started.Awaited, "expected awaited to be false")
754
755 succeeded, ok := publishedEvents[1].(event.ServerHeartbeatSucceededEvent)
756 assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatSucceededEvent{}, publishedEvents[1])
757 assert.Equal(t, succeeded.ConnectionID, s.conn.ID(), "expected connectionID to match")
758 assert.Equal(t, succeeded.Reply.Addr, s.address, "expected address %v, got %v", s.address, succeeded.Reply.Addr)
759 assert.False(t, succeeded.Awaited, "expected awaited to be false")
760 })
761 t.Run("failure", func(t *testing.T) {
762 publishedEvents = nil
763
764 readErr := errors.New("error")
765 channelConn.ReadErr <- readErr
766 _, err = s.check()
767 _ = channelConn.GetWrittenMessage()
768 assert.Nil(t, err, "check error: %v", err)
769
770 assert.Equal(t, len(publishedEvents), 2, "expected %v events, got %v", 2, len(publishedEvents))
771
772 started, ok := publishedEvents[0].(event.ServerHeartbeatStartedEvent)
773 assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatStartedEvent{}, publishedEvents[0])
774 assert.Equal(t, started.ConnectionID, s.conn.ID(), "expected connectionID to match")
775 assert.False(t, started.Awaited, "expected awaited to be false")
776
777 failed, ok := publishedEvents[1].(event.ServerHeartbeatFailedEvent)
778 assert.True(t, ok, "expected type %T, got %T", event.ServerHeartbeatFailedEvent{}, publishedEvents[1])
779 assert.Equal(t, failed.ConnectionID, s.conn.ID(), "expected connectionID to match")
780 assert.False(t, failed.Awaited, "expected awaited to be false")
781 assert.True(t, errors.Is(failed.Failure, readErr), "expected Failure to be %v, got: %v", readErr, failed.Failure)
782 })
783 })
784 t.Run("WithServerAppName", func(t *testing.T) {
785 name := "test"
786
787 s := NewServer(address.Address("localhost"),
788 primitive.NewObjectID(),
789 WithServerAppName(func(string) string { return name }))
790 require.Equal(t, name, s.cfg.appname, "expected appname to be: %v, got: %v", name, s.cfg.appname)
791 })
792 t.Run("createConnection overwrites WithSocketTimeout", func(t *testing.T) {
793 socketTimeout := 40 * time.Second
794
795 s := NewServer(
796 address.Address("localhost"),
797 primitive.NewObjectID(),
798 WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption {
799 return append(
800 connOpts,
801 WithReadTimeout(func(time.Duration) time.Duration { return socketTimeout }),
802 WithWriteTimeout(func(time.Duration) time.Duration { return socketTimeout }),
803 )
804 }),
805 )
806
807 conn := s.createConnection()
808 assert.Equal(t, s.cfg.heartbeatTimeout, 10*time.Second, "expected heartbeatTimeout to be: %v, got: %v", 10*time.Second, s.cfg.heartbeatTimeout)
809 assert.Equal(t, s.cfg.heartbeatTimeout, conn.readTimeout, "expected readTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.readTimeout)
810 assert.Equal(t, s.cfg.heartbeatTimeout, conn.writeTimeout, "expected writeTimeout to be: %v, got: %v", s.cfg.heartbeatTimeout, conn.writeTimeout)
811 })
812 t.Run("heartbeat contexts are not leaked", func(t *testing.T) {
813
814
815 server, err := ConnectServer(
816 address.Address("invalid"),
817 nil,
818 primitive.NewObjectID(),
819 withMonitoringDisabled(func(bool) bool {
820 return true
821 }),
822 )
823 assert.Nil(t, err, "ConnectServer error: %v", err)
824
825
826
827 desc, err := server.check()
828 assert.Nil(t, err, "check error: %v", err)
829 assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil")
830 assert.NotNil(t, server.heartbeatCtx, "expected heartbeatCtx to be non-nil, got nil")
831 assert.Nil(t, server.heartbeatCtx.Err(), "expected heartbeatCtx error to be nil, got %v", server.heartbeatCtx.Err())
832
833
834 oldCancelFn := server.heartbeatCtxCancel
835 var previousCtxCancelled bool
836 server.heartbeatCtxCancel = func() {
837 previousCtxCancelled = true
838 oldCancelFn()
839 }
840
841
842
843 desc, err = server.check()
844 assert.Nil(t, err, "check error: %v", err)
845 assert.NotNil(t, desc.LastError, "expected server description to contain an error, got nil")
846 assert.True(t, previousCtxCancelled, "expected check to cancel previous context but did not")
847 })
848 }
849
850 func TestServer_ProcessError(t *testing.T) {
851 t.Parallel()
852
853 processID := primitive.NewObjectID()
854 newProcessID := primitive.NewObjectID()
855
856 testCases := []struct {
857 name string
858
859 startDescription description.Server
860
861 inputErr error
862 inputConn driver.Connection
863
864 want driver.ProcessErrorResult
865 wantGeneration uint64
866 wantDescription description.Server
867 }{
868
869 {
870 name: "nil error",
871 startDescription: description.Server{
872 Kind: description.RSPrimary,
873 },
874 inputErr: nil,
875 want: driver.NoChange,
876 wantGeneration: 0,
877 wantDescription: description.Server{
878 Kind: description.RSPrimary,
879 },
880 },
881
882 {
883 name: "stale connection",
884 startDescription: description.Server{
885 Kind: description.RSPrimary,
886 },
887 inputErr: errors.New("foo"),
888 inputConn: newProcessErrorTestConn(
889 &description.VersionRange{
890 Max: 17,
891 },
892 true),
893 want: driver.NoChange,
894 wantGeneration: 0,
895 wantDescription: description.Server{
896 Kind: description.RSPrimary,
897 },
898 },
899
900
901 {
902 name: "non state change error",
903 startDescription: description.Server{
904 Kind: description.RSPrimary,
905 },
906 inputErr: driver.Error{
907 Code: 1,
908 },
909 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
910 want: driver.NoChange,
911 wantGeneration: 0,
912 wantDescription: description.Server{
913 Kind: description.RSPrimary,
914 },
915 },
916
917 {
918 name: "stale not writable primary error",
919 startDescription: newServerDescription(description.RSPrimary, processID, 1, nil),
920 inputErr: driver.Error{
921 Code: 10107,
922 TopologyVersion: &description.TopologyVersion{
923 ProcessID: processID,
924 Counter: 0,
925 },
926 },
927 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
928 want: driver.NoChange,
929 wantGeneration: 0,
930 wantDescription: newServerDescription(description.RSPrimary, processID, 1, nil),
931 },
932
933
934 {
935 name: "new not writable primary error",
936 startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
937 inputErr: driver.Error{
938 Code: 10107,
939 TopologyVersion: &description.TopologyVersion{
940 ProcessID: processID,
941 Counter: 1,
942 },
943 },
944 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
945 want: driver.ServerMarkedUnknown,
946 wantGeneration: 0,
947 wantDescription: newServerDescription(description.Unknown, processID, 1, driver.Error{
948 Code: 10107,
949 TopologyVersion: &description.TopologyVersion{
950 ProcessID: processID,
951 Counter: 1,
952 },
953 }),
954 },
955
956
957 {
958 name: "new process ID not writable primary error",
959 startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
960 inputErr: driver.Error{
961 Code: 10107,
962 TopologyVersion: &description.TopologyVersion{
963 ProcessID: newProcessID,
964 Counter: 0,
965 },
966 },
967 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
968 want: driver.ServerMarkedUnknown,
969 wantGeneration: 0,
970 wantDescription: newServerDescription(description.Unknown, newProcessID, 0, driver.Error{
971 Code: 10107,
972 TopologyVersion: &description.TopologyVersion{
973 ProcessID: newProcessID,
974 Counter: 0,
975 },
976 }),
977 },
978
979
980
981 {
982 name: "newer connection topology version",
983 startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
984 inputErr: driver.Error{
985 Code: 10107,
986 TopologyVersion: &description.TopologyVersion{
987 ProcessID: processID,
988 Counter: 1,
989 },
990 },
991 inputConn: &processErrorTestConn{
992 description: description.Server{
993 WireVersion: &description.VersionRange{Max: 17},
994 TopologyVersion: &description.TopologyVersion{
995 ProcessID: processID,
996 Counter: 1,
997 },
998 },
999 stale: false,
1000 },
1001 want: driver.NoChange,
1002 wantGeneration: 0,
1003 wantDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
1004 },
1005
1006
1007 {
1008 name: "new shutdown error",
1009 startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
1010 inputErr: driver.Error{
1011 Code: 11600,
1012 TopologyVersion: &description.TopologyVersion{
1013 ProcessID: processID,
1014 Counter: 1,
1015 },
1016 },
1017 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
1018 want: driver.ConnectionPoolCleared,
1019 wantGeneration: 1,
1020 wantDescription: newServerDescription(description.Unknown, processID, 1, driver.Error{
1021 Code: 11600,
1022 TopologyVersion: &description.TopologyVersion{
1023 ProcessID: processID,
1024 Counter: 1,
1025 },
1026 }),
1027 },
1028
1029 {
1030 name: "stale not writable primary write concern error",
1031 startDescription: newServerDescription(description.RSPrimary, processID, 1, nil),
1032 inputErr: driver.WriteCommandError{
1033 WriteConcernError: &driver.WriteConcernError{
1034 Code: 10107,
1035 TopologyVersion: &description.TopologyVersion{
1036 ProcessID: processID,
1037 Counter: 0,
1038 },
1039 },
1040 },
1041 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
1042 want: driver.NoChange,
1043 wantGeneration: 0,
1044 wantDescription: newServerDescription(description.RSPrimary, processID, 1, nil),
1045 },
1046
1047
1048 {
1049 name: "new not writable primary write concern error",
1050 startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
1051 inputErr: driver.WriteCommandError{
1052 WriteConcernError: &driver.WriteConcernError{
1053 Code: 10107,
1054 TopologyVersion: &description.TopologyVersion{
1055 ProcessID: processID,
1056 Counter: 1,
1057 },
1058 },
1059 },
1060 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
1061 want: driver.ServerMarkedUnknown,
1062 wantGeneration: 0,
1063 wantDescription: newServerDescription(description.Unknown, processID, 1, driver.WriteCommandError{
1064 WriteConcernError: &driver.WriteConcernError{
1065 Code: 10107,
1066 TopologyVersion: &description.TopologyVersion{
1067 ProcessID: processID,
1068 Counter: 1,
1069 },
1070 },
1071 }),
1072 },
1073
1074
1075 {
1076 name: "new shutdown write concern error",
1077 startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
1078 inputErr: driver.WriteCommandError{
1079 WriteConcernError: &driver.WriteConcernError{
1080 Code: 11600,
1081 TopologyVersion: &description.TopologyVersion{
1082 ProcessID: processID,
1083 Counter: 1,
1084 },
1085 },
1086 },
1087 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
1088 want: driver.ConnectionPoolCleared,
1089 wantGeneration: 1,
1090 wantDescription: newServerDescription(description.Unknown, processID, 1, driver.WriteCommandError{
1091 WriteConcernError: &driver.WriteConcernError{
1092 Code: 11600,
1093 TopologyVersion: &description.TopologyVersion{
1094 ProcessID: processID,
1095 Counter: 1,
1096 },
1097 },
1098 }),
1099 },
1100
1101
1102
1103 {
1104 name: "older than 4.2 write concern error",
1105 startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
1106 inputErr: driver.WriteCommandError{
1107 WriteConcernError: &driver.WriteConcernError{
1108 Code: 10107,
1109 TopologyVersion: &description.TopologyVersion{
1110 ProcessID: processID,
1111 Counter: 1,
1112 },
1113 },
1114 },
1115 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 7}, false),
1116 want: driver.ConnectionPoolCleared,
1117 wantGeneration: 1,
1118 wantDescription: newServerDescription(description.Unknown, processID, 1, driver.WriteCommandError{
1119 WriteConcernError: &driver.WriteConcernError{
1120 Code: 10107,
1121 TopologyVersion: &description.TopologyVersion{
1122 ProcessID: processID,
1123 Counter: 1,
1124 },
1125 },
1126 }),
1127 },
1128
1129 {
1130 name: "network timeout error",
1131 startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
1132 inputErr: driver.Error{
1133 Labels: []string{driver.NetworkError},
1134 Wrapped: ConnectionError{
1135
1136 Wrapped: &net.DNSError{
1137 IsTimeout: true,
1138 },
1139 },
1140 },
1141 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
1142 want: driver.NoChange,
1143 wantGeneration: 0,
1144 wantDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
1145 },
1146
1147 {
1148 name: "context canceled error",
1149 startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
1150 inputErr: driver.Error{
1151 Labels: []string{driver.NetworkError},
1152 Wrapped: ConnectionError{
1153 Wrapped: context.Canceled,
1154 },
1155 },
1156 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
1157 want: driver.NoChange,
1158 wantGeneration: 0,
1159 wantDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
1160 },
1161
1162
1163 {
1164 name: "non-timeout network error",
1165 startDescription: newServerDescription(description.RSPrimary, processID, 0, nil),
1166 inputErr: driver.Error{
1167 Labels: []string{driver.NetworkError},
1168 Wrapped: ConnectionError{
1169
1170 Wrapped: &net.AddrError{},
1171 },
1172 },
1173 inputConn: newProcessErrorTestConn(&description.VersionRange{Max: 17}, false),
1174 want: driver.ConnectionPoolCleared,
1175 wantGeneration: 1,
1176 wantDescription: description.Server{
1177 Kind: description.Unknown,
1178 LastError: driver.Error{
1179 Labels: []string{driver.NetworkError},
1180 Wrapped: ConnectionError{
1181 Wrapped: &net.AddrError{},
1182 },
1183 },
1184 },
1185 },
1186 }
1187
1188 for _, tc := range testCases {
1189 tc := tc
1190
1191 t.Run(tc.name, func(t *testing.T) {
1192 t.Parallel()
1193
1194 server := NewServer(address.Address(""), primitive.NewObjectID())
1195 server.state = serverConnected
1196 err := server.pool.ready()
1197 require.Nil(t, err, "pool.ready() error: %v", err)
1198
1199 server.desc.Store(tc.startDescription)
1200
1201 got := server.ProcessError(tc.inputErr, tc.inputConn)
1202 assert.Equal(t, tc.want, got, "expected and actual ProcessError result are different")
1203
1204 desc := server.Description()
1205 assert.Equal(t,
1206 tc.wantDescription,
1207 desc,
1208 "expected and actual server descriptions are different")
1209
1210 generation, _ := server.pool.generation.getGeneration(nil)
1211 assert.Equal(t,
1212 tc.wantGeneration,
1213 generation,
1214 "expected and actual pool generation are different")
1215 })
1216 }
1217 }
1218
1219
1220
1221 func includesClientMetadata(t *testing.T, wm []byte) bool {
1222 t.Helper()
1223
1224 var ok bool
1225 _, _, _, _, wm, ok = wiremessage.ReadHeader(wm)
1226 if !ok {
1227 t.Fatal("could not read header")
1228 }
1229 _, wm, ok = wiremessage.ReadQueryFlags(wm)
1230 if !ok {
1231 t.Fatal("could not read flags")
1232 }
1233 _, wm, ok = wiremessage.ReadQueryFullCollectionName(wm)
1234 if !ok {
1235 t.Fatal("could not read fullCollectionName")
1236 }
1237 _, wm, ok = wiremessage.ReadQueryNumberToSkip(wm)
1238 if !ok {
1239 t.Fatal("could not read numberToSkip")
1240 }
1241 _, wm, ok = wiremessage.ReadQueryNumberToReturn(wm)
1242 if !ok {
1243 t.Fatal("could not read numberToReturn")
1244 }
1245 var query bsoncore.Document
1246 query, wm, ok = wiremessage.ReadQueryQuery(wm)
1247 if !ok {
1248 t.Fatal("could not read query")
1249 }
1250
1251 if _, err := query.LookupErr("client"); err == nil {
1252 return true
1253 }
1254 if _, err := query.LookupErr("$query", "client"); err == nil {
1255 return true
1256 }
1257
1258 return false
1259 }
1260
1261
1262
1263
1264 type processErrorTestConn struct {
1265
1266
1267 driver.Connection
1268 description description.Server
1269 stale bool
1270 }
1271
1272 func newProcessErrorTestConn(wireVersion *description.VersionRange, stale bool) *processErrorTestConn {
1273 return &processErrorTestConn{
1274 description: description.Server{
1275 WireVersion: wireVersion,
1276 },
1277 stale: stale,
1278 }
1279 }
1280
1281 func (p *processErrorTestConn) Stale() bool {
1282 return p.stale
1283 }
1284
1285 func (p *processErrorTestConn) Description() description.Server {
1286 return p.description
1287 }
1288
1289
1290
1291 func newServerDescription(
1292 kind description.ServerKind,
1293 processID primitive.ObjectID,
1294 counter int64,
1295 lastError error,
1296 ) description.Server {
1297 return description.Server{
1298 Kind: kind,
1299 TopologyVersion: &description.TopologyVersion{
1300 ProcessID: processID,
1301 Counter: counter,
1302 },
1303 LastError: lastError,
1304 }
1305 }
1306
View as plain text