1
2
3
4
5
6
7 package topology
8
9 import (
10 "context"
11 "crypto/tls"
12 "errors"
13 "math/rand"
14 "net"
15 "sync"
16 "sync/atomic"
17 "testing"
18 "time"
19
20 "github.com/google/go-cmp/cmp"
21 "go.mongodb.org/mongo-driver/internal/assert"
22 "go.mongodb.org/mongo-driver/mongo/address"
23 "go.mongodb.org/mongo-driver/mongo/description"
24 "go.mongodb.org/mongo-driver/x/mongo/driver"
25 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
26 )
27
28 type testHandshaker struct {
29 getHandshakeInformation func(context.Context, address.Address, driver.Connection) (driver.HandshakeInformation, error)
30 finishHandshake func(context.Context, driver.Connection) error
31 }
32
33
34 func (th *testHandshaker) GetHandshakeInformation(ctx context.Context, addr address.Address, conn driver.Connection) (driver.HandshakeInformation, error) {
35 if th.getHandshakeInformation != nil {
36 return th.getHandshakeInformation(ctx, addr, conn)
37 }
38 return driver.HandshakeInformation{}, nil
39 }
40
41
42 func (th *testHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error {
43 if th.finishHandshake != nil {
44 return th.finishHandshake(ctx, conn)
45 }
46 return nil
47 }
48
49 var _ driver.Handshaker = &testHandshaker{}
50
51 func TestConnection(t *testing.T) {
52 t.Run("connection", func(t *testing.T) {
53 t.Run("newConnection", func(t *testing.T) {
54 t.Run("no default idle timeout", func(t *testing.T) {
55 conn := newConnection(address.Address(""))
56 wantTimeout := time.Duration(0)
57 assert.Equal(t, wantTimeout, conn.idleTimeout, "expected idle timeout %v, got %v", wantTimeout,
58 conn.idleTimeout)
59 })
60 })
61 t.Run("connect", func(t *testing.T) {
62 t.Run("dialer error", func(t *testing.T) {
63 err := errors.New("dialer error")
64 var want error = ConnectionError{Wrapped: err, init: true}
65 conn := newConnection(address.Address(""), WithDialer(func(Dialer) Dialer {
66 return DialerFunc(func(context.Context, string, string) (net.Conn, error) { return nil, err })
67 }))
68 got := conn.connect(context.Background())
69 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
70 t.Errorf("errors do not match. got %v; want %v", got, want)
71 }
72 connState := atomic.LoadInt64(&conn.state)
73 assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
74 })
75 t.Run("handshaker error", func(t *testing.T) {
76 err := errors.New("handshaker error")
77 var want error = ConnectionError{Wrapped: err, init: true}
78 conn := newConnection(address.Address(""),
79 WithHandshaker(func(Handshaker) Handshaker {
80 return &testHandshaker{
81 finishHandshake: func(context.Context, driver.Connection) error {
82 return err
83 },
84 }
85 }),
86 WithDialer(func(Dialer) Dialer {
87 return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
88 return &net.TCPConn{}, nil
89 })
90 }),
91 )
92 got := conn.connect(context.Background())
93 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
94 t.Errorf("errors do not match. got %v; want %v", got, want)
95 }
96 connState := atomic.LoadInt64(&conn.state)
97 assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
98 })
99 t.Run("context is not pinned by connect", func(t *testing.T) {
100
101
102
103
104 t.Run("connect succeeds", func(t *testing.T) {
105
106
107 conn := newConnection(address.Address(""),
108 WithDialer(func(Dialer) Dialer {
109 return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
110 return &net.TCPConn{}, nil
111 })
112 }),
113 WithHandshaker(func(Handshaker) Handshaker {
114 return &testHandshaker{}
115 }),
116 )
117
118 err := conn.connect(context.Background())
119 assert.Nil(t, err, "error establishing connection: %v", err)
120 assert.Nil(t, conn.cancelConnectContext, "cancellation function was not cleared")
121 })
122 t.Run("connect cancelled", func(t *testing.T) {
123
124
125
126
127
128 doneChan := make(chan struct{})
129 conn := newConnection(address.Address(""),
130 WithDialer(func(Dialer) Dialer {
131 return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
132 <-doneChan
133 return &net.TCPConn{}, nil
134 })
135 }),
136 WithHandshaker(func(Handshaker) Handshaker {
137 return &testHandshaker{}
138 }),
139 )
140
141
142 var wg sync.WaitGroup
143 wg.Add(1)
144 go func() {
145 defer wg.Done()
146 _ = conn.connect(context.Background())
147 }()
148
149
150 conn.closeConnectContext()
151 assert.Nil(t, conn.cancelConnectContext, "cancellation function was not cleared")
152 close(doneChan)
153 wg.Wait()
154 })
155 })
156 t.Run("tls", func(t *testing.T) {
157 t.Run("connection source is set to default if unspecified", func(t *testing.T) {
158 conn := newConnection(address.Address(""))
159 assert.NotNil(t, conn.config.tlsConnectionSource, "expected tlsConnectionSource to be set but was not")
160 })
161 t.Run("server name", func(t *testing.T) {
162 testCases := []struct {
163 name string
164 addr address.Address
165 cfg *tls.Config
166 expectedServerName string
167 }{
168 {"set to connection address if empty", "localhost:27017", &tls.Config{}, "localhost"},
169 {"left alone if non-empty", "localhost:27017", &tls.Config{ServerName: "other"}, "other"},
170 }
171 for _, tc := range testCases {
172 t.Run(tc.name, func(t *testing.T) {
173 var sentCfg *tls.Config
174 var testTLSConnectionSource tlsConnectionSourceFn = func(nc net.Conn, cfg *tls.Config) tlsConn {
175 sentCfg = cfg
176 return tls.Client(nc, cfg)
177 }
178
179 connOpts := []ConnectionOption{
180 WithDialer(func(Dialer) Dialer {
181 return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
182 return &net.TCPConn{}, nil
183 })
184 }),
185 WithHandshaker(func(Handshaker) Handshaker {
186 return &testHandshaker{}
187 }),
188 WithTLSConfig(func(*tls.Config) *tls.Config {
189 return tc.cfg
190 }),
191 withTLSConnectionSource(func(tlsConnectionSource) tlsConnectionSource {
192 return testTLSConnectionSource
193 }),
194 }
195 conn := newConnection(tc.addr, connOpts...)
196
197 _ = conn.connect(context.Background())
198 assert.NotNil(t, sentCfg, "expected TLS config to be set, but was not")
199 assert.Equal(t, tc.expectedServerName, sentCfg.ServerName, "expected ServerName %s, got %s",
200 tc.expectedServerName, sentCfg.ServerName)
201 })
202 }
203 })
204 })
205 t.Run("connectTimeout is applied correctly", func(t *testing.T) {
206 testCases := []struct {
207 name string
208 contextTimeout time.Duration
209 connectTimeout time.Duration
210 maxConnectTime time.Duration
211 }{
212
213
214
215
216 {"context timeout is lower", 1 * time.Millisecond, 100 * time.Millisecond, 50 * time.Millisecond},
217 {"connect timeout is lower", 100 * time.Millisecond, 1 * time.Millisecond, 50 * time.Millisecond},
218 }
219
220 for _, tc := range testCases {
221 t.Run("timeout applied to socket establishment: "+tc.name, func(t *testing.T) {
222
223
224
225 connOpts := []ConnectionOption{
226 WithDialer(func(Dialer) Dialer {
227 return DialerFunc(func(ctx context.Context, _, _ string) (net.Conn, error) {
228 <-ctx.Done()
229 return nil, ctx.Err()
230 })
231 }),
232 WithConnectTimeout(func(time.Duration) time.Duration {
233 return tc.connectTimeout
234 }),
235 }
236 conn := newConnection("", connOpts...)
237
238 var connectErr error
239 callback := func(ctx context.Context) {
240 connectCtx, cancel := context.WithTimeout(ctx, tc.contextTimeout)
241 defer cancel()
242
243 connectErr = conn.connect(connectCtx)
244 }
245 assert.Soon(t, callback, tc.maxConnectTime)
246
247 ce, ok := connectErr.(ConnectionError)
248 assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
249 assert.Equal(t, context.DeadlineExceeded, ce.Unwrap(), "expected wrapped error to be %v, got %v",
250 context.DeadlineExceeded, ce.Unwrap())
251 })
252 t.Run("timeout applied to TLS handshake: "+tc.name, func(t *testing.T) {
253
254
255
256
257
258
259 l, err := net.Listen("tcp", "localhost:0")
260 assert.Nil(t, err, "net.Listen() error: %q", err)
261 defer l.Close()
262
263 connOpts := []ConnectionOption{
264 WithConnectTimeout(func(time.Duration) time.Duration {
265 return tc.connectTimeout
266 }),
267 WithTLSConfig(func(*tls.Config) *tls.Config {
268 return &tls.Config{ServerName: "test"}
269 }),
270 }
271 conn := newConnection(address.Address(l.Addr().String()), connOpts...)
272
273 var connectErr error
274 callback := func(ctx context.Context) {
275 connectCtx, cancel := context.WithTimeout(ctx, tc.contextTimeout)
276 defer cancel()
277
278 connectErr = conn.connect(connectCtx)
279 }
280 assert.Soon(t, callback, tc.maxConnectTime)
281
282 ce, ok := connectErr.(ConnectionError)
283 assert.True(t, ok, "expected error %v to be of type %T", connectErr, ConnectionError{})
284
285 isTimeout := func(err error) bool {
286 if errors.Is(err, context.DeadlineExceeded) {
287 return true
288 }
289 if ne, ok := err.(net.Error); ok {
290 return ne.Timeout()
291 }
292 return false
293 }
294 assert.True(t,
295 isTimeout(ce.Unwrap()),
296 "expected wrapped error to be a timeout error, but got %q",
297 ce.Unwrap())
298 })
299 t.Run("timeout is not applied to handshaker: "+tc.name, func(t *testing.T) {
300
301
302
303 var getInfoCtx, finishCtx context.Context
304 handshaker := &testHandshaker{
305 getHandshakeInformation: func(ctx context.Context, _ address.Address, _ driver.Connection) (driver.HandshakeInformation, error) {
306 getInfoCtx = ctx
307 return driver.HandshakeInformation{}, nil
308 },
309 finishHandshake: func(ctx context.Context, _ driver.Connection) error {
310 finishCtx = ctx
311 return nil
312 },
313 }
314
315 connOpts := []ConnectionOption{
316 WithConnectTimeout(func(time.Duration) time.Duration {
317 return tc.connectTimeout
318 }),
319 WithDialer(func(Dialer) Dialer {
320 return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
321 return &net.TCPConn{}, nil
322 })
323 }),
324 WithHandshaker(func(Handshaker) Handshaker {
325 return handshaker
326 }),
327 }
328 conn := newConnection("", connOpts...)
329
330 err := conn.connect(context.Background())
331 assert.Nil(t, err, "connect error: %v", err)
332
333 assertNoContextTimeout := func(t *testing.T, ctx context.Context) {
334 t.Helper()
335 dl, ok := ctx.Deadline()
336 assert.False(t, ok, "expected context to have no deadline, but got deadline %v", dl)
337 }
338 assertNoContextTimeout(t, getInfoCtx)
339 assertNoContextTimeout(t, finishCtx)
340 })
341 }
342 })
343 })
344 t.Run("writeWireMessage", func(t *testing.T) {
345 t.Run("closed connection", func(t *testing.T) {
346 conn := &connection{id: "foobar"}
347 want := ConnectionError{ConnectionID: "foobar", message: "connection is closed"}
348 got := conn.writeWireMessage(context.Background(), []byte{})
349 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
350 t.Errorf("errors do not match. got %v; want %v", got, want)
351 }
352 })
353 t.Run("deadlines", func(t *testing.T) {
354 testCases := []struct {
355 name string
356 ctxDeadline time.Duration
357 timeout time.Duration
358 deadline time.Time
359 }{
360 {"no deadline", 0, 0, time.Now().Add(1 * time.Second)},
361 {"ctx deadline", 5 * time.Second, 0, time.Now().Add(6 * time.Second)},
362 {"timeout", 0, 10 * time.Second, time.Now().Add(11 * time.Second)},
363 {"both (ctx wins)", 15 * time.Second, 20 * time.Second, time.Now().Add(16 * time.Second)},
364 {"both (timeout wins)", 30 * time.Second, 25 * time.Second, time.Now().Add(26 * time.Second)},
365 }
366
367 for _, tc := range testCases {
368 t.Run(tc.name, func(t *testing.T) {
369 ctx := context.Background()
370 if tc.ctxDeadline > 0 {
371 var cancel context.CancelFunc
372 ctx, cancel = context.WithTimeout(ctx, tc.ctxDeadline)
373 defer cancel()
374 }
375 want := ConnectionError{
376 ConnectionID: "foobar",
377 Wrapped: errors.New("set writeDeadline error"),
378 message: "failed to set write deadline",
379 }
380 tnc := &testNetConn{deadlineerr: errors.New("set writeDeadline error")}
381 conn := &connection{id: "foobar", nc: tnc, writeTimeout: tc.timeout, state: connConnected}
382 got := conn.writeWireMessage(ctx, []byte{})
383 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
384 t.Errorf("errors do not match. got %v; want %v", got, want)
385 }
386 if !tc.deadline.After(tnc.writeDeadline) {
387 t.Errorf("write deadline not properly set. got %v; want %v", tnc.writeDeadline, tc.deadline)
388 }
389 })
390 }
391 })
392 t.Run("Write", func(t *testing.T) {
393 writeErrMsg := "unable to write wire message to network"
394
395 t.Run("error", func(t *testing.T) {
396 err := errors.New("Write error")
397 tnc := &testNetConn{writeerr: err}
398 conn := &connection{id: "foobar", nc: tnc, state: connConnected}
399 listener := newTestCancellationListener(false)
400 conn.cancellationListener = listener
401
402 want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: writeErrMsg}
403 got := conn.writeWireMessage(context.Background(), []byte{})
404 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
405 t.Errorf("errors do not match. got %v; want %v", got, want)
406 }
407 if !tnc.closed {
408 t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
409 }
410 listener.assertCalledOnce(t)
411 })
412 t.Run("success", func(t *testing.T) {
413 tnc := &testNetConn{}
414 conn := &connection{id: "foobar", nc: tnc, state: connConnected}
415 listener := newTestCancellationListener(false)
416 conn.cancellationListener = listener
417
418 want := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}
419 err := conn.writeWireMessage(context.Background(), want)
420 noerr(t, err)
421 got := tnc.buf
422 if !cmp.Equal(got, want) {
423 t.Errorf("writeWireMessage did not write the proper bytes. got %v; want %v", got, want)
424 }
425 listener.assertCalledOnce(t)
426 })
427 t.Run("cancel in-progress write", func(t *testing.T) {
428
429
430 nc := newCancellationWriteConn(&testNetConn{}, 0)
431 conn := &connection{id: "foobar", nc: nc, state: connConnected}
432 listener := newTestCancellationListener(false)
433 conn.cancellationListener = listener
434
435 ctx, cancel := context.WithCancel(context.Background())
436 var err error
437
438 var wg sync.WaitGroup
439 wg.Add(1)
440 go func() {
441 defer wg.Done()
442 err = conn.writeWireMessage(ctx, []byte("foobar"))
443 }()
444
445 <-nc.operationStartedChan
446 cancel()
447 nc.continueChan <- struct{}{}
448
449 wg.Wait()
450 want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: writeErrMsg}
451 assert.Equal(t, want, err, "expected error %v, got %v", want, err)
452 assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
453 conn.state)
454 })
455 t.Run("connection is closed if context is cancelled even if network write succeeds", func(t *testing.T) {
456
457
458
459
460 tnc := &testNetConn{}
461 conn := &connection{id: "foobar", nc: tnc, state: connConnected}
462 listener := newTestCancellationListener(true)
463 conn.cancellationListener = listener
464
465 want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: writeErrMsg}
466 err := conn.writeWireMessage(context.Background(), []byte("foobar"))
467 assert.Equal(t, want, err, "expected error %v, got %v", want, err)
468 assert.Equal(t, conn.state, connDisconnected, "expected connection state %v, got %v", connDisconnected,
469 conn.state)
470 })
471 })
472 })
473 t.Run("readWireMessage", func(t *testing.T) {
474 t.Run("closed connection", func(t *testing.T) {
475 conn := &connection{id: "foobar"}
476 want := ConnectionError{ConnectionID: "foobar", message: "connection is closed"}
477 _, got := conn.readWireMessage(context.Background())
478 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
479 t.Errorf("errors do not match. got %v; want %v", got, want)
480 }
481 })
482 t.Run("deadlines", func(t *testing.T) {
483 testCases := []struct {
484 name string
485 ctxDeadline time.Duration
486 timeout time.Duration
487 deadline time.Time
488 }{
489 {"no deadline", 0, 0, time.Now().Add(1 * time.Second)},
490 {"ctx deadline", 5 * time.Second, 0, time.Now().Add(6 * time.Second)},
491 {"timeout", 0, 10 * time.Second, time.Now().Add(11 * time.Second)},
492 {"both (ctx wins)", 15 * time.Second, 20 * time.Second, time.Now().Add(16 * time.Second)},
493 {"both (timeout wins)", 30 * time.Second, 25 * time.Second, time.Now().Add(26 * time.Second)},
494 }
495
496 for _, tc := range testCases {
497 t.Run(tc.name, func(t *testing.T) {
498 ctx := context.Background()
499 if tc.ctxDeadline > 0 {
500 var cancel context.CancelFunc
501 ctx, cancel = context.WithTimeout(ctx, tc.ctxDeadline)
502 defer cancel()
503 }
504 want := ConnectionError{
505 ConnectionID: "foobar",
506 Wrapped: errors.New("set readDeadline error"),
507 message: "failed to set read deadline",
508 }
509 tnc := &testNetConn{deadlineerr: errors.New("set readDeadline error")}
510 conn := &connection{id: "foobar", nc: tnc, readTimeout: tc.timeout, state: connConnected}
511 _, got := conn.readWireMessage(ctx)
512 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
513 t.Errorf("errors do not match. got %v; want %v", got, want)
514 }
515 if !tc.deadline.After(tnc.readDeadline) {
516 t.Errorf("read deadline not properly set. got %v; want %v", tnc.readDeadline, tc.deadline)
517 }
518 })
519 }
520 })
521 t.Run("Read", func(t *testing.T) {
522 t.Run("size read errors", func(t *testing.T) {
523 err := errors.New("Read error")
524 tnc := &testNetConn{readerr: err}
525 conn := &connection{id: "foobar", nc: tnc, state: connConnected}
526 listener := newTestCancellationListener(false)
527 conn.cancellationListener = listener
528
529 want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: "incomplete read of message header"}
530 _, got := conn.readWireMessage(context.Background())
531 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
532 t.Errorf("errors do not match. got %v; want %v", got, want)
533 }
534 if !tnc.closed {
535 t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
536 }
537 listener.assertCalledOnce(t)
538 })
539 t.Run("full message read errors", func(t *testing.T) {
540 err := errors.New("Read error")
541 tnc := &testNetConn{readerr: err, buf: []byte{0x11, 0x00, 0x00, 0x00}}
542 conn := &connection{id: "foobar", nc: tnc, state: connConnected}
543 listener := newTestCancellationListener(false)
544 conn.cancellationListener = listener
545
546 want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: "incomplete read of full message"}
547 _, got := conn.readWireMessage(context.Background())
548 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
549 t.Errorf("errors do not match. got %v; want %v", got, want)
550 }
551 if !tnc.closed {
552 t.Errorf("failed to closeConnection net.Conn after error writing bytes.")
553 }
554 listener.assertCalledOnce(t)
555 })
556 t.Run("message too large errors", func(t *testing.T) {
557 testCases := []struct {
558 name string
559 buffer []byte
560 desc description.Server
561 }{
562 {
563 "message too large errors with small max message size",
564 []byte{0x0A, 0x00, 0x00, 0x00},
565 description.Server{MaxMessageSize: 9},
566 },
567 {
568 "message too large errors with default max message size",
569 []byte{0x01, 0x6C, 0xDC, 0x02},
570 description.Server{},
571 },
572 }
573 for _, tc := range testCases {
574 t.Run(tc.name, func(t *testing.T) {
575 err := errors.New("length of read message too large")
576 tnc := &testNetConn{buf: make([]byte, len(tc.buffer))}
577 copy(tnc.buf, tc.buffer)
578 conn := &connection{id: "foobar", nc: tnc, state: connConnected, desc: tc.desc}
579 listener := newTestCancellationListener(false)
580 conn.cancellationListener = listener
581
582 want := ConnectionError{ConnectionID: "foobar", Wrapped: err, message: err.Error()}
583 _, got := conn.readWireMessage(context.Background())
584 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
585 t.Errorf("errors do not match. got %v; want %v", got, want)
586 }
587 listener.assertCalledOnce(t)
588 })
589 }
590 })
591 t.Run("success", func(t *testing.T) {
592 want := []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}
593 tnc := &testNetConn{buf: make([]byte, len(want))}
594 copy(tnc.buf, want)
595 conn := &connection{id: "foobar", nc: tnc, state: connConnected}
596 listener := newTestCancellationListener(false)
597 conn.cancellationListener = listener
598
599 got, err := conn.readWireMessage(context.Background())
600 noerr(t, err)
601 if !cmp.Equal(got, want) {
602 t.Errorf("did not read full wire message. got %v; want %v", got, want)
603 }
604 listener.assertCalledOnce(t)
605 })
606 t.Run("cancel in-progress read", func(t *testing.T) {
607
608
609
610 testCases := []struct {
611 name string
612 skip int
613 errmsg string
614 }{
615 {"cancel size read", 0, "incomplete read of message header"},
616 {"cancel full message read", 1, "incomplete read of full message"},
617 }
618 for _, tc := range testCases {
619 t.Run(tc.name, func(t *testing.T) {
620
621
622 readBuf := []byte{10, 0, 0, 0}
623 nc := newCancellationReadConn(&testNetConn{}, tc.skip, readBuf)
624
625 conn := &connection{id: "foobar", nc: nc, state: connConnected}
626 listener := newTestCancellationListener(false)
627 conn.cancellationListener = listener
628
629 ctx, cancel := context.WithCancel(context.Background())
630 var err error
631
632 var wg sync.WaitGroup
633 wg.Add(1)
634 go func() {
635 defer wg.Done()
636 _, err = conn.readWireMessage(ctx)
637 }()
638
639 <-nc.operationStartedChan
640 cancel()
641 nc.continueChan <- struct{}{}
642
643 wg.Wait()
644 want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: tc.errmsg}
645 assert.Equal(t, want, err, "expected error %v, got %v", want, err)
646 assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
647 conn.state)
648 })
649 }
650 })
651 t.Run("closes connection if context is cancelled even if the socket read succeeds", func(t *testing.T) {
652 tnc := &testNetConn{buf: []byte{0x0A, 0x00, 0x00, 0x00, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A}}
653 conn := &connection{id: "foobar", nc: tnc, state: connConnected}
654 listener := newTestCancellationListener(true)
655 conn.cancellationListener = listener
656
657 want := ConnectionError{ConnectionID: conn.id, Wrapped: context.Canceled, message: "unable to read server response"}
658 _, err := conn.readWireMessage(context.Background())
659 assert.Equal(t, want, err, "expected error %v, got %v", want, err)
660 assert.Equal(t, connDisconnected, conn.state, "expected connection state %v, got %v", connDisconnected,
661 conn.state)
662 })
663 })
664 })
665 t.Run("close", func(t *testing.T) {
666 t.Run("can close a connection that failed handshaking", func(t *testing.T) {
667 conn := newConnection(address.Address(""),
668 WithHandshaker(func(Handshaker) Handshaker {
669 return &testHandshaker{
670 finishHandshake: func(context.Context, driver.Connection) error {
671 return errors.New("handshake err")
672 },
673 }
674 }),
675 WithDialer(func(Dialer) Dialer {
676 return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
677 return &net.TCPConn{}, nil
678 })
679 }),
680 )
681
682 err := conn.connect(context.Background())
683 assert.NotNil(t, err, "expected handshake error from connect, got nil")
684 connState := atomic.LoadInt64(&conn.state)
685 assert.Equal(t, connDisconnected, connState, "expected connection state %v, got %v", connDisconnected, connState)
686
687 err = conn.close()
688 assert.Nil(t, err, "close error: %v", err)
689 })
690 })
691 t.Run("cancellation listener callback", func(t *testing.T) {
692 t.Run("closes connection", func(t *testing.T) {
693 tnc := &testNetConn{}
694 conn := &connection{state: connConnected, nc: tnc}
695
696 conn.cancellationListenerCallback()
697 assert.True(t, conn.state == connDisconnected, "expected connection state %v, got %v", connDisconnected,
698 conn.state)
699 assert.True(t, tnc.closed, "expected net.Conn to be closed but was not")
700 })
701 })
702 })
703 t.Run("Connection", func(t *testing.T) {
704 t.Run("nil connection does not panic", func(t *testing.T) {
705 conn := &Connection{}
706 defer func() {
707 if r := recover(); r != nil {
708 t.Fatalf("Methods on a Connection with a nil *connection should not panic, but panicked with %v", r)
709 }
710 }()
711
712 var want, got interface{}
713
714 want = ErrConnectionClosed
715 got = conn.WriteWireMessage(context.Background(), nil)
716 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
717 t.Errorf("errors do not match. got %v; want %v", got, want)
718 }
719 _, got = conn.ReadWireMessage(context.Background())
720 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
721 t.Errorf("errors do not match. got %v; want %v", got, want)
722 }
723
724 want = description.Server{}
725 got = conn.Description()
726 if !cmp.Equal(got, want) {
727 t.Errorf("descriptions do not match. got %v; want %v", got, want)
728 }
729
730 want = nil
731 got = conn.Close()
732 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
733 t.Errorf("errors do not match. got %v; want %v", got, want)
734 }
735
736 got = conn.Expire()
737 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
738 t.Errorf("errors do not match. got %v; want %v", got, want)
739 }
740
741 want = false
742 got = conn.Alive()
743 if !cmp.Equal(got, want) {
744 t.Errorf("Alive does not match. got %v; want %v", got, want)
745 }
746
747 want = "<closed>"
748 got = conn.ID()
749 if !cmp.Equal(got, want) {
750 t.Errorf("IDs do not match. got %v; want %v", got, want)
751 }
752
753 want = address.Address("0.0.0.0")
754 got = conn.Address()
755 if !cmp.Equal(got, want) {
756 t.Errorf("Addresses do not match. got %v; want %v", got, want)
757 }
758
759 want = address.Address("0.0.0.0")
760 got = conn.LocalAddress()
761 if !cmp.Equal(got, want) {
762 t.Errorf("LocalAddresses do not match. got %v; want %v", got, want)
763 }
764
765 want = (*int64)(nil)
766 got = conn.ServerConnectionID()
767 if !cmp.Equal(got, want) {
768 t.Errorf("ServerConnectionIDs do not match. got %v; want %v", got, want)
769 }
770 })
771
772 t.Run("pinning", func(t *testing.T) {
773 makeMultipleConnections := func(t *testing.T, numConns int) (*pool, []*Connection, func()) {
774 t.Helper()
775
776 addr := bootstrapConnections(t, numConns, func(nc net.Conn) {})
777 pool := newPool(poolConfig{
778 Address: address.Address(addr.String()),
779 })
780 err := pool.ready()
781 assert.Nil(t, err, "pool.connect() error: %v", err)
782
783 conns := make([]*Connection, 0, numConns)
784 for i := 0; i < numConns; i++ {
785 conn, err := pool.checkOut(context.Background())
786 assert.Nil(t, err, "checkOut error: %v", err)
787 conns = append(conns, &Connection{connection: conn})
788 }
789 disconnect := func() {
790 pool.close(context.Background())
791 }
792 return pool, conns, disconnect
793 }
794 makeOneConnection := func(t *testing.T) (*pool, *Connection, func()) {
795 t.Helper()
796
797 pool, conns, disconnect := makeMultipleConnections(t, 1)
798 return pool, conns[0], disconnect
799 }
800
801 assertPoolPinnedStats := func(t *testing.T, p *pool, cursorConns, txnConns uint64) {
802 t.Helper()
803
804 assert.Equal(t, cursorConns, p.pinnedCursorConnections, "expected %d connections to be pinned to cursors, got %d",
805 cursorConns, p.pinnedCursorConnections)
806 assert.Equal(t, txnConns, p.pinnedTransactionConnections, "expected %d connections to be pinned to transactions, got %d",
807 txnConns, p.pinnedTransactionConnections)
808 }
809
810 t.Run("cursors", func(t *testing.T) {
811 pool, conn, disconnect := makeOneConnection(t)
812 defer disconnect()
813
814 err := conn.PinToCursor()
815 assert.Nil(t, err, "PinToCursor error: %v", err)
816 assertPoolPinnedStats(t, pool, 1, 0)
817
818 err = conn.UnpinFromCursor()
819 assert.Nil(t, err, "UnpinFromCursor error: %v", err)
820
821 err = conn.Close()
822 assert.Nil(t, err, "Close error: %v", err)
823 assertPoolPinnedStats(t, pool, 0, 0)
824 })
825 t.Run("transactions", func(t *testing.T) {
826 pool, conn, disconnect := makeOneConnection(t)
827 defer disconnect()
828
829 err := conn.PinToTransaction()
830 assert.Nil(t, err, "PinToTransaction error: %v", err)
831 assertPoolPinnedStats(t, pool, 0, 1)
832
833 err = conn.UnpinFromTransaction()
834 assert.Nil(t, err, "UnpinFromTransaction error: %v", err)
835
836 err = conn.Close()
837 assert.Nil(t, err, "Close error: %v", err)
838 assertPoolPinnedStats(t, pool, 0, 0)
839 })
840 t.Run("pool is only updated for first reference", func(t *testing.T) {
841 pool, conn, disconnect := makeOneConnection(t)
842 defer disconnect()
843
844 err := conn.PinToTransaction()
845 assert.Nil(t, err, "PinToTransaction error: %v", err)
846 assertPoolPinnedStats(t, pool, 0, 1)
847
848 err = conn.PinToCursor()
849 assert.Nil(t, err, "PinToCursor error: %v", err)
850 assertPoolPinnedStats(t, pool, 0, 1)
851
852 err = conn.UnpinFromCursor()
853 assert.Nil(t, err, "UnpinFromCursor error: %v", err)
854 assertPoolPinnedStats(t, pool, 0, 1)
855
856 err = conn.UnpinFromTransaction()
857 assert.Nil(t, err, "UnpinFromTransaction error: %v", err)
858 assertPoolPinnedStats(t, pool, 0, 1)
859
860 err = conn.Close()
861 assert.Nil(t, err, "Close error: %v", err)
862 assertPoolPinnedStats(t, pool, 0, 0)
863 })
864 t.Run("multiple connections from a pool", func(t *testing.T) {
865 pool, conns, disconnect := makeMultipleConnections(t, 2)
866 defer disconnect()
867
868 first, second := conns[0], conns[1]
869
870 err := first.PinToTransaction()
871 assert.Nil(t, err, "PinToTransaction error: %v", err)
872 err = second.PinToCursor()
873 assert.Nil(t, err, "PinToCursor error: %v", err)
874 assertPoolPinnedStats(t, pool, 1, 1)
875
876 err = first.UnpinFromTransaction()
877 assert.Nil(t, err, "UnpinFromTransaction error: %v", err)
878 err = first.Close()
879 assert.Nil(t, err, "Close error: %v", err)
880 assertPoolPinnedStats(t, pool, 1, 0)
881
882 err = second.UnpinFromCursor()
883 assert.Nil(t, err, "UnpinFromCursor error: %v", err)
884 err = second.Close()
885 assert.Nil(t, err, "Close error: %v", err)
886 assertPoolPinnedStats(t, pool, 0, 0)
887 })
888 t.Run("close is ignored if connection is pinned", func(t *testing.T) {
889 pool, conn, disconnect := makeOneConnection(t)
890 defer disconnect()
891
892 err := conn.PinToCursor()
893 assert.Nil(t, err, "PinToCursor error: %v", err)
894
895 err = conn.Close()
896 assert.Nil(t, err, "Close error")
897 assert.NotNil(t, conn.connection, "expected connection to be pinned but it was released to the pool")
898 assertPoolPinnedStats(t, pool, 1, 0)
899 })
900 t.Run("expire forcefully returns connection to pool", func(t *testing.T) {
901 pool, conn, disconnect := makeOneConnection(t)
902 defer disconnect()
903
904 err := conn.PinToCursor()
905 assert.Nil(t, err, "PinToCursor error: %v", err)
906
907 err = conn.Expire()
908 assert.Nil(t, err, "Expire error")
909 assert.Nil(t, conn.connection, "expected connection to be released to the pool but was not")
910 assertPoolPinnedStats(t, pool, 0, 0)
911 })
912 })
913 })
914 }
915
916 func BenchmarkConnection(b *testing.B) {
917 b.Run("CompressWireMessage CompressorNoOp", func(b *testing.B) {
918 buf := make([]byte, 256)
919 _, err := rand.Read(buf)
920 if err != nil {
921 b.Log(err)
922 b.FailNow()
923 }
924 conn := Connection{connection: &connection{compressor: wiremessage.CompressorNoOp}}
925 for i := 0; i < b.N; i++ {
926 _, err := conn.CompressWireMessage(buf, nil)
927 if err != nil {
928 b.Error(err)
929 }
930 }
931 })
932 }
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951 type cancellationTestNetConn struct {
952 net.Conn
953
954 shouldSkip int
955 skipCount int
956 readBuf []byte
957 operationStartedChan chan struct{}
958 continueChan chan struct{}
959 }
960
961
962
963 func newCancellationWriteConn(nc net.Conn, skip int) *cancellationTestNetConn {
964 return &cancellationTestNetConn{
965 Conn: nc,
966 shouldSkip: skip,
967 operationStartedChan: make(chan struct{}),
968 continueChan: make(chan struct{}),
969 }
970 }
971
972
973
974
975 func newCancellationReadConn(nc net.Conn, skip int, readBuf []byte) *cancellationTestNetConn {
976 return &cancellationTestNetConn{
977 Conn: nc,
978 shouldSkip: skip,
979 readBuf: readBuf,
980 operationStartedChan: make(chan struct{}),
981 continueChan: make(chan struct{}),
982 }
983 }
984
985 func (c *cancellationTestNetConn) Read(b []byte) (int, error) {
986 if c.skipCount < c.shouldSkip {
987 c.skipCount++
988 copy(b, c.readBuf)
989 return len(c.readBuf), nil
990 }
991
992 c.operationStartedChan <- struct{}{}
993 <-c.continueChan
994 return 0, errors.New("cancelled read")
995 }
996
997 func (c *cancellationTestNetConn) Write(b []byte) (n int, err error) {
998 if c.skipCount < c.shouldSkip {
999 c.skipCount++
1000 return len(b), nil
1001 }
1002
1003 c.operationStartedChan <- struct{}{}
1004 <-c.continueChan
1005 return 0, errors.New("cancelled write")
1006 }
1007
1008 type testNetConn struct {
1009 nc net.Conn
1010 buf []byte
1011
1012 deadlineerr error
1013 writeerr error
1014 readerr error
1015 closed bool
1016
1017 deadline time.Time
1018 readDeadline time.Time
1019 writeDeadline time.Time
1020 }
1021
1022 func (tnc *testNetConn) Read(b []byte) (n int, err error) {
1023 if len(tnc.buf) > 0 {
1024 n := copy(b, tnc.buf)
1025 tnc.buf = tnc.buf[n:]
1026 return n, nil
1027 }
1028 if tnc.readerr != nil {
1029 return 0, tnc.readerr
1030 }
1031 if tnc.nc == nil {
1032 return 0, nil
1033 }
1034 return tnc.nc.Read(b)
1035 }
1036
1037 func (tnc *testNetConn) Write(b []byte) (n int, err error) {
1038 if tnc.writeerr != nil {
1039 return 0, tnc.writeerr
1040 }
1041 if tnc.nc == nil {
1042 idx := len(tnc.buf)
1043 tnc.buf = append(tnc.buf, make([]byte, len(b))...)
1044 copy(tnc.buf[idx:], b)
1045 return len(b), nil
1046 }
1047 return tnc.nc.Write(b)
1048 }
1049
1050 func (tnc *testNetConn) Close() error {
1051 tnc.closed = true
1052 if tnc.nc == nil {
1053 return nil
1054 }
1055 return tnc.nc.Close()
1056 }
1057
1058 func (tnc *testNetConn) LocalAddr() net.Addr {
1059 if tnc.nc == nil {
1060 return nil
1061 }
1062 return tnc.nc.LocalAddr()
1063 }
1064
1065 func (tnc *testNetConn) RemoteAddr() net.Addr {
1066 if tnc.nc == nil {
1067 return nil
1068 }
1069 return tnc.nc.RemoteAddr()
1070 }
1071
1072 func (tnc *testNetConn) SetDeadline(t time.Time) error {
1073 tnc.deadline = t
1074 if tnc.deadlineerr != nil {
1075 return tnc.deadlineerr
1076 }
1077 if tnc.nc == nil {
1078 return nil
1079 }
1080 return tnc.nc.SetDeadline(t)
1081 }
1082
1083 func (tnc *testNetConn) SetReadDeadline(t time.Time) error {
1084 tnc.readDeadline = t
1085 if tnc.deadlineerr != nil {
1086 return tnc.deadlineerr
1087 }
1088 if tnc.nc == nil {
1089 return nil
1090 }
1091 return tnc.nc.SetReadDeadline(t)
1092 }
1093
1094 func (tnc *testNetConn) SetWriteDeadline(t time.Time) error {
1095 tnc.writeDeadline = t
1096 if tnc.deadlineerr != nil {
1097 return tnc.deadlineerr
1098 }
1099 if tnc.nc == nil {
1100 return nil
1101 }
1102 return tnc.nc.SetWriteDeadline(t)
1103 }
1104
1105
1106
1107
1108 func bootstrapConnections(t *testing.T, num int, run func(net.Conn)) net.Addr {
1109 l, err := net.Listen("tcp", "localhost:0")
1110 if err != nil {
1111 t.Errorf("Could not set up a listener: %v", err)
1112 t.FailNow()
1113 }
1114 go func() {
1115 for i := 0; i < num; i++ {
1116 c, err := l.Accept()
1117 if err != nil {
1118 t.Errorf("Could not accept a connection: %v", err)
1119 }
1120 go run(c)
1121 }
1122 _ = l.Close()
1123 }()
1124 return l.Addr()
1125 }
1126
1127 type netconn struct {
1128 net.Conn
1129 closed chan struct{}
1130 d *dialer
1131 }
1132
1133 func (nc *netconn) Close() error {
1134 nc.closed <- struct{}{}
1135 nc.d.connclosed(nc)
1136 return nc.Conn.Close()
1137 }
1138
1139 type writeFailConn struct {
1140 net.Conn
1141 }
1142
1143 func (wfc *writeFailConn) Write([]byte) (int, error) {
1144 return 0, errors.New("Write error")
1145 }
1146
1147 func (wfc *writeFailConn) SetWriteDeadline(time.Time) error {
1148 return nil
1149 }
1150
1151 type dialer struct {
1152 Dialer
1153 opened map[*netconn]struct{}
1154 closed map[*netconn]struct{}
1155 closeCallBack func()
1156 sync.Mutex
1157 }
1158
1159 func newdialer(d Dialer) *dialer {
1160 return &dialer{Dialer: d, opened: make(map[*netconn]struct{}), closed: make(map[*netconn]struct{})}
1161 }
1162
1163 func (d *dialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
1164 d.Lock()
1165 defer d.Unlock()
1166 c, err := d.Dialer.DialContext(ctx, network, address)
1167 if err != nil {
1168 return nil, err
1169 }
1170 nc := &netconn{Conn: c, closed: make(chan struct{}, 1), d: d}
1171 d.opened[nc] = struct{}{}
1172 return nc, nil
1173 }
1174
1175 func (d *dialer) connclosed(nc *netconn) {
1176 d.Lock()
1177 defer d.Unlock()
1178 d.closed[nc] = struct{}{}
1179 if d.closeCallBack != nil {
1180 d.closeCallBack()
1181 }
1182 }
1183
1184 func (d *dialer) lenopened() int {
1185 d.Lock()
1186 defer d.Unlock()
1187 return len(d.opened)
1188 }
1189
1190 func (d *dialer) lenclosed() int {
1191 d.Lock()
1192 defer d.Unlock()
1193 return len(d.closed)
1194 }
1195
1196 type testCancellationListener struct {
1197 listener *cancellListener
1198 numListen int
1199 numStopListening int
1200 aborted bool
1201 }
1202
1203
1204
1205 func newTestCancellationListener(aborted bool) *testCancellationListener {
1206 return &testCancellationListener{
1207 listener: newCancellListener(),
1208 aborted: aborted,
1209 }
1210 }
1211
1212 func (tcl *testCancellationListener) Listen(ctx context.Context, abortFn func()) {
1213 tcl.numListen++
1214 tcl.listener.Listen(ctx, abortFn)
1215 }
1216
1217 func (tcl *testCancellationListener) StopListening() bool {
1218 tcl.numStopListening++
1219 tcl.listener.StopListening()
1220 return tcl.aborted
1221 }
1222
1223 func (tcl *testCancellationListener) assertCalledOnce(t *testing.T) {
1224 assert.Equal(t, 1, tcl.numListen, "expected Listen to be called once, got %d", tcl.numListen)
1225 assert.Equal(t, 1, tcl.numStopListening, "expected StopListening to be called once, got %d", tcl.numListen)
1226 }
1227
View as plain text