1
2
3
4
5
6
7 package driver
8
9 import (
10 "bytes"
11 "context"
12 "errors"
13 "math"
14 "testing"
15 "time"
16
17 "github.com/google/go-cmp/cmp"
18 "go.mongodb.org/mongo-driver/bson/bsontype"
19 "go.mongodb.org/mongo-driver/bson/primitive"
20 "go.mongodb.org/mongo-driver/internal/assert"
21 "go.mongodb.org/mongo-driver/internal/csot"
22 "go.mongodb.org/mongo-driver/internal/handshake"
23 "go.mongodb.org/mongo-driver/internal/require"
24 "go.mongodb.org/mongo-driver/internal/uuid"
25 "go.mongodb.org/mongo-driver/mongo/address"
26 "go.mongodb.org/mongo-driver/mongo/description"
27 "go.mongodb.org/mongo-driver/mongo/readconcern"
28 "go.mongodb.org/mongo-driver/mongo/readpref"
29 "go.mongodb.org/mongo-driver/mongo/writeconcern"
30 "go.mongodb.org/mongo-driver/tag"
31 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
32 "go.mongodb.org/mongo-driver/x/mongo/driver/session"
33 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
34 )
35
36 func noerr(t *testing.T, err error) {
37 t.Helper()
38 if err != nil {
39 t.Errorf("Unexpected error: %v", err)
40 t.FailNow()
41 }
42 }
43
44 func compareErrors(err1, err2 error) bool {
45 if err1 == nil && err2 == nil {
46 return true
47 }
48
49 if err1 == nil || err2 == nil {
50 return false
51 }
52
53 if err1.Error() != err2.Error() {
54 return false
55 }
56
57 return true
58 }
59
60 func TestOperation(t *testing.T) {
61 int64ToPtr := func(i64 int64) *int64 { return &i64 }
62
63 t.Run("selectServer", func(t *testing.T) {
64 t.Run("returns validation error", func(t *testing.T) {
65 op := &Operation{}
66 _, err := op.selectServer(context.Background(), 1, nil)
67 if err == nil {
68 t.Error("Expected a validation error from selectServer, but got <nil>")
69 }
70 })
71 t.Run("uses specified server selector", func(t *testing.T) {
72 want := new(mockServerSelector)
73 d := new(mockDeployment)
74 op := &Operation{
75 CommandFn: func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil },
76 Deployment: d,
77 Database: "testing",
78 Selector: want,
79 }
80 _, err := op.selectServer(context.Background(), 1, nil)
81 noerr(t, err)
82
83
84 oss, ok := d.params.selector.(*opServerSelector)
85 require.True(t, ok)
86
87 if !cmp.Equal(oss.selector, want) {
88 t.Errorf("Did not get expected server selector. got %v; want %v", oss.selector, want)
89 }
90 })
91 t.Run("uses a default server selector", func(t *testing.T) {
92 d := new(mockDeployment)
93 op := &Operation{
94 CommandFn: func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil },
95 Deployment: d,
96 Database: "testing",
97 }
98 _, err := op.selectServer(context.Background(), 1, nil)
99 noerr(t, err)
100 if d.params.selector == nil {
101 t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed <nil>.")
102 }
103 })
104 })
105 t.Run("Validate", func(t *testing.T) {
106 cmdFn := func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil }
107 d := new(mockDeployment)
108 testCases := []struct {
109 name string
110 op *Operation
111 err error
112 }{
113 {"CommandFn", &Operation{}, InvalidOperationError{MissingField: "CommandFn"}},
114 {"Deployment", &Operation{CommandFn: cmdFn}, InvalidOperationError{MissingField: "Deployment"}},
115 {"Database", &Operation{CommandFn: cmdFn, Deployment: d}, errDatabaseNameEmpty},
116 {"<nil>", &Operation{CommandFn: cmdFn, Deployment: d, Database: "test"}, nil},
117 }
118
119 for _, tc := range testCases {
120 t.Run(tc.name, func(t *testing.T) {
121 if tc.op == nil {
122 t.Fatal("op cannot be <nil>")
123 }
124 want := tc.err
125 got := tc.op.Validate()
126 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
127 t.Errorf("Did not validate properly. got %v; want %v", got, want)
128 }
129 })
130 }
131 })
132 t.Run("retryableWrite", func(t *testing.T) {
133 sessPool := session.NewPool(nil)
134 id, err := uuid.New()
135 noerr(t, err)
136
137 sess, err := session.NewClientSession(sessPool, id)
138 noerr(t, err)
139
140 sessStartingTransaction, err := session.NewClientSession(sessPool, id)
141 noerr(t, err)
142 err = sessStartingTransaction.StartTransaction(nil)
143 noerr(t, err)
144
145 sessInProgressTransaction, err := session.NewClientSession(sessPool, id)
146 noerr(t, err)
147 err = sessInProgressTransaction.StartTransaction(nil)
148 noerr(t, err)
149 err = sessInProgressTransaction.ApplyCommand(description.Server{})
150 noerr(t, err)
151
152 wcAck := writeconcern.New(writeconcern.WMajority())
153 wcUnack := writeconcern.New(writeconcern.W(0))
154
155 descRetryable := description.Server{
156 WireVersion: &description.VersionRange{Min: 6, Max: 21},
157 SessionTimeoutMinutes: 1,
158 SessionTimeoutMinutesPtr: int64ToPtr(1),
159 }
160
161 descNotRetryableWireVersion := description.Server{
162 WireVersion: &description.VersionRange{Min: 6, Max: 21},
163 SessionTimeoutMinutes: 1,
164 SessionTimeoutMinutesPtr: int64ToPtr(1),
165 }
166
167 descNotRetryableStandalone := description.Server{
168 WireVersion: &description.VersionRange{Min: 6, Max: 21},
169 SessionTimeoutMinutes: 1,
170 SessionTimeoutMinutesPtr: int64ToPtr(1),
171 Kind: description.Standalone,
172 }
173
174 testCases := []struct {
175 name string
176 op Operation
177 desc description.Server
178 want Type
179 }{
180 {"deployment doesn't support", Operation{}, description.Server{}, Type(0)},
181 {"wire version too low", Operation{Client: sess, WriteConcern: wcAck}, descNotRetryableWireVersion, Type(0)},
182 {"standalone not supported", Operation{Client: sess, WriteConcern: wcAck}, descNotRetryableStandalone, Type(0)},
183 {
184 "transaction in progress",
185 Operation{Client: sessInProgressTransaction, WriteConcern: wcAck},
186 descRetryable, Type(0),
187 },
188 {
189 "transaction starting",
190 Operation{Client: sessStartingTransaction, WriteConcern: wcAck},
191 descRetryable, Type(0),
192 },
193 {"unacknowledged write concern", Operation{Client: sess, WriteConcern: wcUnack}, descRetryable, Type(0)},
194 {
195 "acknowledged write concern",
196 Operation{Client: sess, WriteConcern: wcAck, Type: Write},
197 descRetryable, Write,
198 },
199 }
200
201 for _, tc := range testCases {
202 t.Run(tc.name, func(t *testing.T) {
203 got := tc.op.retryable(tc.desc)
204 if got != (tc.want != Type(0)) {
205 t.Errorf("Did not receive expected Type. got %v; want %v", got, tc.want)
206 }
207 })
208 }
209 })
210 t.Run("addReadConcern", func(t *testing.T) {
211 majorityRc := bsoncore.AppendDocumentElement(nil, "readConcern", bsoncore.BuildDocument(nil,
212 bsoncore.AppendStringElement(nil, "level", "majority"),
213 ))
214
215 testCases := []struct {
216 name string
217 rc *readconcern.ReadConcern
218 want bsoncore.Document
219 }{
220 {"nil", nil, nil},
221 {"empty", readconcern.New(), nil},
222 {"non-empty", readconcern.Majority(), majorityRc},
223 }
224
225 for _, tc := range testCases {
226 got, err := Operation{ReadConcern: tc.rc}.addReadConcern(nil, description.SelectedServer{})
227 noerr(t, err)
228 if !bytes.Equal(got, tc.want) {
229 t.Errorf("ReadConcern elements do not match. got %v; want %v", got, tc.want)
230 }
231 }
232 })
233 t.Run("addWriteConcern", func(t *testing.T) {
234 want := bsoncore.AppendDocumentElement(nil, "writeConcern", bsoncore.BuildDocumentFromElements(
235 nil, bsoncore.AppendStringElement(nil, "w", "majority"),
236 ))
237 got, err := Operation{WriteConcern: writeconcern.New(writeconcern.WMajority())}.addWriteConcern(nil, description.SelectedServer{})
238 noerr(t, err)
239 if !bytes.Equal(got, want) {
240 t.Errorf("WriteConcern elements do not match. got %v; want %v", got, want)
241 }
242 })
243 t.Run("addSession", func(t *testing.T) { t.Skip("These tests should be covered by spec tests.") })
244 t.Run("addClusterTime", func(t *testing.T) {
245 t.Run("adds max cluster time", func(t *testing.T) {
246 want := bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocumentFromElements(nil,
247 bsoncore.AppendTimestampElement(nil, "clusterTime", 1234, 5678),
248 ))
249 newer := bsoncore.BuildDocumentFromElements(nil, want)
250 older := bsoncore.BuildDocumentFromElements(nil,
251 bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocumentFromElements(nil,
252 bsoncore.AppendTimestampElement(nil, "clusterTime", 1234, 5670),
253 )),
254 )
255
256 clusterClock := new(session.ClusterClock)
257 clusterClock.AdvanceClusterTime(newer)
258 sessPool := session.NewPool(nil)
259 id, err := uuid.New()
260 noerr(t, err)
261
262 sess, err := session.NewClientSession(sessPool, id)
263 noerr(t, err)
264 err = sess.AdvanceClusterTime(older)
265 noerr(t, err)
266
267 got := Operation{Client: sess, Clock: clusterClock}.addClusterTime(nil, description.SelectedServer{
268 Server: description.Server{WireVersion: &description.VersionRange{Min: 6, Max: 21}},
269 })
270 if !bytes.Equal(got, want) {
271 t.Errorf("ClusterTimes do not match. got %v; want %v", got, want)
272 }
273 })
274 })
275 t.Run("calculateMaxTimeMS", func(t *testing.T) {
276 timeout := 5 * time.Second
277 maxTime := 2 * time.Second
278 negMaxTime := -2 * time.Second
279 shortRTT := 50 * time.Millisecond
280 longRTT := 10 * time.Second
281 timeoutCtx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
282 defer cancel()
283
284 testCases := []struct {
285 name string
286 op Operation
287 ctx context.Context
288 rtt90 time.Duration
289 want uint64
290 err error
291 }{
292 {
293 name: "uses context deadline and rtt90 with timeout",
294 op: Operation{MaxTime: &maxTime},
295 ctx: timeoutCtx,
296 rtt90: shortRTT,
297 want: 5000,
298 err: nil,
299 },
300 {
301 name: "uses MaxTime without timeout",
302 op: Operation{MaxTime: &maxTime},
303 ctx: context.Background(),
304 rtt90: longRTT,
305 want: 2000,
306 err: nil,
307 },
308 {
309 name: "errors when remaining timeout is less than rtt90",
310 op: Operation{MaxTime: &maxTime},
311 ctx: timeoutCtx,
312 rtt90: timeout,
313 want: 0,
314 err: ErrDeadlineWouldBeExceeded,
315 },
316 {
317 name: "errors when MaxTime is negative",
318 op: Operation{MaxTime: &negMaxTime},
319 ctx: context.Background(),
320 rtt90: longRTT,
321 want: 0,
322 err: ErrNegativeMaxTime,
323 },
324 }
325 for _, tc := range testCases {
326
327 tc := tc
328 t.Run(tc.name, func(t *testing.T) {
329 t.Parallel()
330
331 got, err := tc.op.calculateMaxTimeMS(tc.ctx, mockRTTMonitor{p90: tc.rtt90})
332
333
334
335
336 if got > tc.want {
337 t.Errorf("maxTimeMS value higher than expected. got %v; wanted at most %v", got, tc.want)
338 }
339 if !errors.Is(err, tc.err) {
340 t.Errorf("error values do not match. got %v; want %v", err, tc.err)
341 }
342 })
343 }
344 })
345 t.Run("updateClusterTimes", func(t *testing.T) {
346 clustertime := bsoncore.BuildDocumentFromElements(nil,
347 bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocumentFromElements(nil,
348 bsoncore.AppendTimestampElement(nil, "clusterTime", 1234, 5678),
349 )),
350 )
351
352 clusterClock := new(session.ClusterClock)
353 sessPool := session.NewPool(nil)
354 id, err := uuid.New()
355 noerr(t, err)
356
357 sess, err := session.NewClientSession(sessPool, id)
358 noerr(t, err)
359 Operation{Client: sess, Clock: clusterClock}.updateClusterTimes(clustertime)
360
361 got := sess.ClusterTime
362 if !bytes.Equal(got, clustertime) {
363 t.Errorf("ClusterTimes do not match. got %v; want %v", got, clustertime)
364 }
365 got = clusterClock.GetClusterTime()
366 if !bytes.Equal(got, clustertime) {
367 t.Errorf("ClusterTimes do not match. got %v; want %v", got, clustertime)
368 }
369
370 Operation{}.updateClusterTimes(bsoncore.BuildDocumentFromElements(nil))
371 })
372 t.Run("updateOperationTime", func(t *testing.T) {
373 want := primitive.Timestamp{T: 1234, I: 4567}
374
375 sessPool := session.NewPool(nil)
376 id, err := uuid.New()
377 noerr(t, err)
378
379 sess, err := session.NewClientSession(sessPool, id)
380 noerr(t, err)
381 if sess.OperationTime != nil {
382 t.Fatal("OperationTime should not be set on new session.")
383 }
384 response := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendTimestampElement(nil, "operationTime", want.T, want.I))
385 Operation{Client: sess}.updateOperationTime(response)
386 got := sess.OperationTime
387 if got.T != want.T || got.I != want.I {
388 t.Errorf("OperationTimes do not match. got %v; want %v", got, want)
389 }
390
391 response = bsoncore.BuildDocumentFromElements(nil)
392 Operation{Client: sess}.updateOperationTime(response)
393 got = sess.OperationTime
394 if got.T != want.T || got.I != want.I {
395 t.Errorf("OperationTimes do not match. got %v; want %v", got, want)
396 }
397
398 Operation{}.updateOperationTime(response)
399 })
400 t.Run("createReadPref", func(t *testing.T) {
401 rpWithTags := bsoncore.BuildDocumentFromElements(nil,
402 bsoncore.AppendStringElement(nil, "mode", "secondaryPreferred"),
403 bsoncore.BuildArrayElement(nil, "tags",
404 bsoncore.Value{Type: bsontype.EmbeddedDocument,
405 Data: bsoncore.BuildDocumentFromElements(nil,
406 bsoncore.AppendStringElement(nil, "disk", "ssd"),
407 bsoncore.AppendStringElement(nil, "use", "reporting"),
408 ),
409 },
410 ),
411 )
412 rpWithMaxStaleness := bsoncore.BuildDocumentFromElements(nil,
413 bsoncore.AppendStringElement(nil, "mode", "secondaryPreferred"),
414 bsoncore.AppendInt32Element(nil, "maxStalenessSeconds", 25),
415 )
416
417 rpWithHedge := bsoncore.BuildDocumentFromElements(nil,
418 bsoncore.AppendStringElement(nil, "mode", "secondaryPreferred"),
419 bsoncore.AppendDocumentElement(nil, "hedge", bsoncore.BuildDocumentFromElements(nil,
420 bsoncore.AppendBooleanElement(nil, "enabled", true),
421 )),
422 )
423 rpWithAllOptions := bsoncore.BuildDocumentFromElements(nil,
424 bsoncore.AppendStringElement(nil, "mode", "secondaryPreferred"),
425 bsoncore.BuildArrayElement(nil, "tags",
426 bsoncore.Value{Type: bsontype.EmbeddedDocument,
427 Data: bsoncore.BuildDocumentFromElements(nil,
428 bsoncore.AppendStringElement(nil, "disk", "ssd"),
429 bsoncore.AppendStringElement(nil, "use", "reporting"),
430 ),
431 },
432 ),
433 bsoncore.AppendInt32Element(nil, "maxStalenessSeconds", 25),
434 bsoncore.AppendDocumentElement(nil, "hedge", bsoncore.BuildDocumentFromElements(nil,
435 bsoncore.AppendBooleanElement(nil, "enabled", false),
436 )),
437 )
438
439 rpPrimaryPreferred := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "primaryPreferred"))
440 rpSecondaryPreferred := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "secondaryPreferred"))
441 rpSecondary := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "secondary"))
442 rpNearest := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "nearest"))
443
444 testCases := []struct {
445 name string
446 rp *readpref.ReadPref
447 serverKind description.ServerKind
448 topoKind description.TopologyKind
449 opQuery bool
450 want bsoncore.Document
451 }{
452 {"nil/single/mongos", nil, description.Mongos, description.Single, false, nil},
453 {"nil/single/secondary", nil, description.RSSecondary, description.Single, false, rpPrimaryPreferred},
454 {"primary/mongos", readpref.Primary(), description.Mongos, description.Sharded, false, nil},
455 {"primary/single", readpref.Primary(), description.RSPrimary, description.Single, false, rpPrimaryPreferred},
456 {"primary/primary", readpref.Primary(), description.RSPrimary, description.ReplicaSet, false, nil},
457 {"primaryPreferred", readpref.PrimaryPreferred(), description.RSSecondary, description.ReplicaSet, false, rpPrimaryPreferred},
458 {"secondaryPreferred/mongos/opquery", readpref.SecondaryPreferred(), description.Mongos, description.Sharded, true, nil},
459 {"secondaryPreferred", readpref.SecondaryPreferred(), description.RSSecondary, description.ReplicaSet, false, rpSecondaryPreferred},
460 {"secondary", readpref.Secondary(), description.RSSecondary, description.ReplicaSet, false, rpSecondary},
461 {"nearest", readpref.Nearest(), description.RSSecondary, description.ReplicaSet, false, rpNearest},
462 {
463 "secondaryPreferred/withTags",
464 readpref.SecondaryPreferred(readpref.WithTags("disk", "ssd", "use", "reporting")),
465 description.RSSecondary, description.ReplicaSet, false, rpWithTags,
466 },
467
468
469
470 {
471 "secondaryPreferred/withTags/emptyTagSet",
472 readpref.SecondaryPreferred(readpref.WithTagSets(
473 tag.Set{{Name: "disk", Value: "ssd"}},
474 tag.Set{})),
475 description.RSSecondary,
476 description.ReplicaSet,
477 false,
478 bsoncore.NewDocumentBuilder().
479 AppendString("mode", "secondaryPreferred").
480 AppendArray("tags", bsoncore.NewArrayBuilder().
481 AppendDocument(bsoncore.NewDocumentBuilder().AppendString("disk", "ssd").Build()).
482 AppendDocument(bsoncore.NewDocumentBuilder().Build()).
483 Build()).
484 Build(),
485 },
486 {
487 "secondaryPreferred/withMaxStaleness",
488 readpref.SecondaryPreferred(readpref.WithMaxStaleness(25 * time.Second)),
489 description.RSSecondary, description.ReplicaSet, false, rpWithMaxStaleness,
490 },
491 {
492
493 "secondaryPreferred with hedge to mongos using OP_QUERY",
494 readpref.SecondaryPreferred(readpref.WithHedgeEnabled(true)),
495 description.Mongos,
496 description.Sharded,
497 true,
498 rpWithHedge,
499 },
500 {
501 "secondaryPreferred with all options",
502 readpref.SecondaryPreferred(
503 readpref.WithTags("disk", "ssd", "use", "reporting"),
504 readpref.WithMaxStaleness(25*time.Second),
505 readpref.WithHedgeEnabled(false),
506 ),
507 description.RSSecondary,
508 description.ReplicaSet,
509 false,
510 rpWithAllOptions,
511 },
512 }
513
514 for _, tc := range testCases {
515 tc := tc
516 t.Run(tc.name, func(t *testing.T) {
517 desc := description.SelectedServer{Kind: tc.topoKind, Server: description.Server{Kind: tc.serverKind}}
518 got, err := Operation{ReadPreference: tc.rp}.createReadPref(desc, tc.opQuery)
519 if err != nil {
520 t.Fatalf("error creating read pref: %v", err)
521 }
522 if !bytes.Equal(got, tc.want) {
523 t.Errorf("Returned documents do not match. got %v; want %v", got, tc.want)
524 }
525 })
526 }
527 })
528 t.Run("secondaryOK", func(t *testing.T) {
529 t.Run("description.SelectedServer", func(t *testing.T) {
530 want := wiremessage.SecondaryOK
531 desc := description.SelectedServer{
532 Kind: description.Single,
533 Server: description.Server{Kind: description.RSSecondary},
534 }
535 got := Operation{}.secondaryOK(desc)
536 if got != want {
537 t.Errorf("Did not receive expected query flags. got %v; want %v", got, want)
538 }
539 })
540 t.Run("readPreference", func(t *testing.T) {
541 want := wiremessage.SecondaryOK
542 got := Operation{ReadPreference: readpref.Secondary()}.secondaryOK(description.SelectedServer{})
543 if got != want {
544 t.Errorf("Did not receive expected query flags. got %v; want %v", got, want)
545 }
546 })
547 t.Run("not secondaryOK", func(t *testing.T) {
548 var want wiremessage.QueryFlag
549 got := Operation{}.secondaryOK(description.SelectedServer{})
550 if got != want {
551 t.Errorf("Did not receive expected query flags. got %v; want %v", got, want)
552 }
553 })
554 })
555 t.Run("ExecuteExhaust", func(t *testing.T) {
556 t.Run("errors if connection is not streaming", func(t *testing.T) {
557 conn := &mockConnection{
558 rStreaming: false,
559 }
560 err := Operation{}.ExecuteExhaust(context.TODO(), conn)
561 assert.NotNil(t, err, "expected error, got nil")
562 })
563 })
564 t.Run("exhaustAllowed and moreToCome", func(t *testing.T) {
565
566
567
568
569 serverResponseDoc := bsoncore.BuildDocumentFromElements(nil,
570 bsoncore.AppendInt32Element(nil, "ok", 1),
571 )
572 nonStreamingResponse := createExhaustServerResponse(serverResponseDoc, false)
573
574
575 conn := &mockConnection{
576 rDesc: description.Server{
577 WireVersion: &description.VersionRange{
578 Max: 6,
579 },
580 },
581 rReadWM: nonStreamingResponse,
582 rCanStream: false,
583 }
584 op := Operation{
585 CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
586 return bsoncore.AppendInt32Element(dst, handshake.LegacyHello, 1), nil
587 },
588 Database: "admin",
589 Deployment: SingleConnectionDeployment{conn},
590 }
591 err := op.Execute(context.TODO())
592 assert.Nil(t, err, "Execute error: %v", err)
593
594
595
596 assertExhaustAllowedSet(t, conn.pWriteWM, false)
597 assert.False(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be false")
598
599
600 streamingResponse := createExhaustServerResponse(serverResponseDoc, true)
601 conn.rReadWM = streamingResponse
602 conn.rCanStream = true
603 err = op.Execute(context.TODO())
604 assert.Nil(t, err, "Execute error: %v", err)
605 assertExhaustAllowedSet(t, conn.pWriteWM, true)
606 assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true")
607
608
609
610 conn.rReadWM = streamingResponse
611 err = op.ExecuteExhaust(context.TODO(), conn)
612 assert.Nil(t, err, "ExecuteExhaust error: %v", err)
613 assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true")
614 })
615 t.Run("context deadline exceeded not marked as TransientTransactionError", func(t *testing.T) {
616 conn := new(mockConnection)
617
618 ctx, cancel := context.WithDeadline(context.Background(), time.Unix(893934480, 0))
619 defer cancel()
620
621 op := Operation{
622 Database: "foobar",
623 Deployment: SingleConnectionDeployment{C: conn},
624 CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
625 dst = bsoncore.AppendInt32Element(dst, "ping", 1)
626 return dst, nil
627 },
628 }
629
630 err := op.Execute(ctx)
631 assert.NotNil(t, err, "expected an error from Execute(), got nil")
632
633
634 assert.Equal(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded error, got %v", err)
635 })
636 t.Run("canceled context not marked as TransientTransactionError", func(t *testing.T) {
637 conn := new(mockConnection)
638
639 ctx, cancel := context.WithCancel(context.Background())
640 cancel()
641
642 op := Operation{
643 Database: "foobar",
644 Deployment: SingleConnectionDeployment{C: conn},
645 CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
646 dst = bsoncore.AppendInt32Element(dst, "ping", 1)
647 return dst, nil
648 },
649 }
650
651 err := op.Execute(ctx)
652 assert.NotNil(t, err, "expected an error from Execute(), got nil")
653
654
655 assert.Equal(t, err, context.Canceled, "expected context.Canceled error, got %v", err)
656 })
657 t.Run("ErrDeadlineWouldBeExceeded wraps context.DeadlineExceeded", func(t *testing.T) {
658
659
660 d := new(mockDeployment)
661 d.returns.server = mockServer{
662 conn: new(mockConnection),
663 rttMonitor: mockRTTMonitor{p90: 1 * time.Minute},
664 }
665
666
667 var dur time.Duration
668 op := Operation{
669 Database: "foobar",
670 Deployment: d,
671 CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
672 return dst, nil
673 },
674 Timeout: &dur,
675 }
676
677
678
679 ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
680 defer cancel()
681 err := op.Execute(ctx)
682
683 assert.ErrorIs(t, err, ErrDeadlineWouldBeExceeded)
684 assert.ErrorIs(t, err, context.DeadlineExceeded)
685 })
686 }
687
688 func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte {
689 const psuedoRequestID = 1
690 idx, wm := wiremessage.AppendHeaderStart(nil, 0, psuedoRequestID, wiremessage.OpMsg)
691 var flags wiremessage.MsgFlag
692 if moreToCome {
693 flags = wiremessage.MoreToCome
694 }
695 wm = wiremessage.AppendMsgFlags(wm, flags)
696 wm = wiremessage.AppendMsgSectionType(wm, wiremessage.SingleDocument)
697 wm = bsoncore.AppendDocument(wm, response)
698 return bsoncore.UpdateLength(wm, idx, int32(len(wm)))
699 }
700
701 func assertExhaustAllowedSet(t *testing.T, wm []byte, expected bool) {
702 t.Helper()
703 _, _, _, _, wm, ok := wiremessage.ReadHeader(wm)
704 if !ok {
705 t.Fatal("could not read wm header")
706 }
707 flags, wm, ok := wiremessage.ReadMsgFlags(wm)
708 if !ok {
709 t.Fatal("could not read wm flags")
710 }
711
712 actual := flags&wiremessage.ExhaustAllowed > 0
713 assert.Equal(t, expected, actual, "expected exhaustAllowed set %v, got %v", expected, actual)
714 }
715
716 type mockDeployment struct {
717 params struct {
718 selector description.ServerSelector
719 }
720 returns struct {
721 server Server
722 err error
723 retry bool
724 kind description.TopologyKind
725 }
726 }
727
728 func (m *mockDeployment) SelectServer(_ context.Context, desc description.ServerSelector) (Server, error) {
729 m.params.selector = desc
730 return m.returns.server, m.returns.err
731 }
732
733 func (m *mockDeployment) Kind() description.TopologyKind { return m.returns.kind }
734
735 type mockServerSelector struct{}
736
737 func (m *mockServerSelector) SelectServer(description.Topology, []description.Server) ([]description.Server, error) {
738 panic("not implemented")
739 }
740
741 func (m *mockServerSelector) String() string {
742 panic("not implemented")
743 }
744
745 type mockServer struct {
746 conn Connection
747 err error
748 rttMonitor RTTMonitor
749 }
750
751 func (ms mockServer) Connection(context.Context) (Connection, error) { return ms.conn, ms.err }
752 func (ms mockServer) RTTMonitor() RTTMonitor { return ms.rttMonitor }
753
754 type mockRTTMonitor struct {
755 ewma time.Duration
756 min time.Duration
757 p90 time.Duration
758 stats string
759 }
760
761 func (mrm mockRTTMonitor) EWMA() time.Duration { return mrm.ewma }
762 func (mrm mockRTTMonitor) Min() time.Duration { return mrm.min }
763 func (mrm mockRTTMonitor) P90() time.Duration { return mrm.p90 }
764 func (mrm mockRTTMonitor) Stats() string { return mrm.stats }
765
766 type mockConnection struct {
767
768 pWriteWM []byte
769
770
771 rWriteErr error
772 rReadWM []byte
773 rReadErr error
774 rDesc description.Server
775 rCloseErr error
776 rID string
777 rServerConnID *int64
778 rAddr address.Address
779 rCanStream bool
780 rStreaming bool
781 }
782
783 func (m *mockConnection) Description() description.Server { return m.rDesc }
784 func (m *mockConnection) Close() error { return m.rCloseErr }
785 func (m *mockConnection) ID() string { return m.rID }
786 func (m *mockConnection) ServerConnectionID() *int64 { return m.rServerConnID }
787 func (m *mockConnection) Address() address.Address { return m.rAddr }
788 func (m *mockConnection) SupportsStreaming() bool { return m.rCanStream }
789 func (m *mockConnection) CurrentlyStreaming() bool { return m.rStreaming }
790 func (m *mockConnection) SetStreaming(streaming bool) { m.rStreaming = streaming }
791 func (m *mockConnection) Stale() bool { return false }
792
793
794 func (m *mockConnection) DriverConnectionID() uint64 { return 0 }
795
796 func (m *mockConnection) WriteWireMessage(_ context.Context, wm []byte) error {
797 m.pWriteWM = wm
798 return m.rWriteErr
799 }
800
801 func (m *mockConnection) ReadWireMessage(_ context.Context) ([]byte, error) {
802 return m.rReadWM, m.rReadErr
803 }
804
805 type retryableError struct {
806 error
807 }
808
809 func (retryableError) Retryable() bool { return true }
810
811 var _ RetryablePoolError = retryableError{}
812
813
814
815 type mockRetryServer struct {
816 numCallsToConnection int
817 }
818
819
820
821 func (ms *mockRetryServer) Connection(ctx context.Context) (Connection, error) {
822 ms.numCallsToConnection++
823
824 if ctx.Err() != nil {
825 return nil, ctx.Err()
826 }
827
828 time.Sleep(1 * time.Millisecond)
829 return nil, retryableError{error: errors.New("test error")}
830 }
831
832 func (ms *mockRetryServer) RTTMonitor() RTTMonitor {
833 return &csot.ZeroRTTMonitor{}
834 }
835
836 func TestRetry(t *testing.T) {
837 t.Run("retries multiple times with RetryContext", func(t *testing.T) {
838 d := new(mockDeployment)
839 ms := new(mockRetryServer)
840 d.returns.server = ms
841
842 ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
843 defer cancel()
844
845 retry := RetryContext
846 err := Operation{
847 CommandFn: func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil },
848 Deployment: d,
849 Database: "testing",
850 RetryMode: &retry,
851 Type: Read,
852 }.Execute(ctx)
853 assert.NotNil(t, err, "expected an error from Execute()")
854
855
856
857
858 assert.True(t,
859 ms.numCallsToConnection >= 3,
860 "expected Connection() to be called at least 3 times")
861
862 deadline, _ := ctx.Deadline()
863 assert.True(t,
864 time.Now().After(deadline),
865 "expected operation to complete only after the context deadline is exceeded")
866 })
867 }
868
869 func TestConvertI64PtrToI32Ptr(t *testing.T) {
870 t.Parallel()
871
872 newI64 := func(i64 int64) *int64 { return &i64 }
873 newI32 := func(i32 int32) *int32 { return &i32 }
874
875 tests := []struct {
876 name string
877 i64 *int64
878 want *int32
879 }{
880 {
881 name: "empty",
882 want: nil,
883 },
884 {
885 name: "in bounds",
886 i64: newI64(1),
887 want: newI32(1),
888 },
889 {
890 name: "out of bounds negative",
891 i64: newI64(math.MinInt32 - 1),
892 },
893 {
894 name: "out of bounds positive",
895 i64: newI64(math.MaxInt32 + 1),
896 },
897 {
898 name: "exact min int32",
899 i64: newI64(math.MinInt32),
900 want: newI32(math.MinInt32),
901 },
902 {
903 name: "exact max int32",
904 i64: newI64(math.MaxInt32),
905 want: newI32(math.MaxInt32),
906 },
907 }
908
909 for _, test := range tests {
910 test := test
911
912 t.Run(test.name, func(t *testing.T) {
913 t.Parallel()
914
915 got := convertInt64PtrToInt32Ptr(test.i64)
916 assert.Equal(t, test.want, got)
917 })
918 }
919 }
920
921 func TestDecodeOpReply(t *testing.T) {
922 t.Parallel()
923
924
925 t.Run("malformatted wiremessage with length of 0", func(t *testing.T) {
926 t.Parallel()
927
928 var wm []byte
929 wm = wiremessage.AppendReplyFlags(wm, 0)
930 wm = wiremessage.AppendReplyCursorID(wm, int64(0))
931 wm = wiremessage.AppendReplyStartingFrom(wm, 0)
932 wm = wiremessage.AppendReplyNumberReturned(wm, 0)
933 idx, wm := bsoncore.ReserveLength(wm)
934 wm = bsoncore.UpdateLength(wm, idx, 0)
935 reply := Operation{}.decodeOpReply(wm)
936 assert.Equal(t, []bsoncore.Document(nil), reply.documents)
937 })
938 }
939
940 func TestFilterDeprioritizedServers(t *testing.T) {
941 t.Parallel()
942
943 tests := []struct {
944 name string
945 deprioritized []description.Server
946 candidates []description.Server
947 want []description.Server
948 }{
949 {
950 name: "empty",
951 candidates: []description.Server{},
952 want: []description.Server{},
953 },
954 {
955 name: "nil candidates",
956 candidates: nil,
957 want: []description.Server{},
958 },
959 {
960 name: "nil deprioritized server list",
961 candidates: []description.Server{
962 {
963 Addr: address.Address("mongodb://localhost:27017"),
964 },
965 },
966 want: []description.Server{
967 {
968 Addr: address.Address("mongodb://localhost:27017"),
969 },
970 },
971 },
972 {
973 name: "deprioritize single server candidate list",
974 candidates: []description.Server{
975 {
976 Addr: address.Address("mongodb://localhost:27017"),
977 },
978 },
979 deprioritized: []description.Server{
980 {
981 Addr: address.Address("mongodb://localhost:27017"),
982 },
983 },
984 want: []description.Server{
985
986
987 {
988 Addr: address.Address("mongodb://localhost:27017"),
989 },
990 },
991 },
992 {
993 name: "depriotirize one server in multi server candidate list",
994 candidates: []description.Server{
995 {
996 Addr: address.Address("mongodb://localhost:27017"),
997 },
998 {
999 Addr: address.Address("mongodb://localhost:27018"),
1000 },
1001 {
1002 Addr: address.Address("mongodb://localhost:27019"),
1003 },
1004 },
1005 deprioritized: []description.Server{
1006 {
1007 Addr: address.Address("mongodb://localhost:27017"),
1008 },
1009 },
1010 want: []description.Server{
1011 {
1012 Addr: address.Address("mongodb://localhost:27018"),
1013 },
1014 {
1015 Addr: address.Address("mongodb://localhost:27019"),
1016 },
1017 },
1018 },
1019 {
1020 name: "depriotirize multiple servers in multi server candidate list",
1021 deprioritized: []description.Server{
1022 {
1023 Addr: address.Address("mongodb://localhost:27017"),
1024 },
1025 {
1026 Addr: address.Address("mongodb://localhost:27018"),
1027 },
1028 },
1029 candidates: []description.Server{
1030 {
1031 Addr: address.Address("mongodb://localhost:27017"),
1032 },
1033 {
1034 Addr: address.Address("mongodb://localhost:27018"),
1035 },
1036 {
1037 Addr: address.Address("mongodb://localhost:27019"),
1038 },
1039 },
1040 want: []description.Server{
1041 {
1042 Addr: address.Address("mongodb://localhost:27019"),
1043 },
1044 },
1045 },
1046 }
1047
1048 for _, tc := range tests {
1049 tc := tc
1050
1051 t.Run(tc.name, func(t *testing.T) {
1052 t.Parallel()
1053
1054 got := filterDeprioritizedServers(tc.candidates, tc.deprioritized)
1055 assert.ElementsMatch(t, got, tc.want)
1056 })
1057 }
1058 }
1059
View as plain text