1
2
3
4
5
6
7 package integration
8
9 import (
10 "context"
11 "fmt"
12 "net"
13 "os"
14 "reflect"
15 "strings"
16 "testing"
17 "time"
18
19 "go.mongodb.org/mongo-driver/bson"
20 "go.mongodb.org/mongo-driver/bson/bsoncodec"
21 "go.mongodb.org/mongo-driver/bson/bsonrw"
22 "go.mongodb.org/mongo-driver/bson/primitive"
23 "go.mongodb.org/mongo-driver/event"
24 "go.mongodb.org/mongo-driver/internal/assert"
25 "go.mongodb.org/mongo-driver/internal/eventtest"
26 "go.mongodb.org/mongo-driver/internal/handshake"
27 "go.mongodb.org/mongo-driver/internal/integtest"
28 "go.mongodb.org/mongo-driver/internal/require"
29 "go.mongodb.org/mongo-driver/mongo"
30 "go.mongodb.org/mongo-driver/mongo/integration/mtest"
31 "go.mongodb.org/mongo-driver/mongo/options"
32 "go.mongodb.org/mongo-driver/mongo/readpref"
33 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
34 "go.mongodb.org/mongo-driver/x/mongo/driver"
35 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
36 "golang.org/x/sync/errgroup"
37 )
38
39 var noClientOpts = mtest.NewOptions().CreateClient(false)
40
41 type negateCodec struct {
42 ID int64 `bson:"_id"`
43 }
44
45 func (e *negateCodec) EncodeValue(_ bsoncodec.EncodeContext, vw bsonrw.ValueWriter, val reflect.Value) error {
46 return vw.WriteInt64(val.Int())
47 }
48
49
50 func (e *negateCodec) DecodeValue(_ bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
51 i, err := vr.ReadInt64()
52 if err != nil {
53 return err
54 }
55
56 val.SetInt(i * -1)
57 return nil
58 }
59
60 var _ options.ContextDialer = &slowConnDialer{}
61
62
63 type slowConnDialer struct {
64 dialer *net.Dialer
65 delay time.Duration
66 }
67
68 var slowConnDialerDelay = 300 * time.Millisecond
69 var reducedHeartbeatInterval = 100 * time.Millisecond
70
71 func newSlowConnDialer(delay time.Duration) *slowConnDialer {
72 return &slowConnDialer{
73 dialer: &net.Dialer{},
74 delay: delay,
75 }
76 }
77
78 func (scd *slowConnDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
79 conn, err := scd.dialer.DialContext(ctx, network, address)
80 if err != nil {
81 return nil, err
82 }
83 return &slowConn{
84 Conn: conn,
85 delay: scd.delay,
86 }, nil
87 }
88
89 var _ net.Conn = &slowConn{}
90
91
92
93 type slowConn struct {
94 net.Conn
95 delay time.Duration
96 }
97
98 func (sc *slowConn) Read(b []byte) (n int, err error) {
99 time.Sleep(sc.delay)
100 return sc.Conn.Read(b)
101 }
102
103 func TestClient(t *testing.T) {
104 mt := mtest.New(t, noClientOpts)
105
106 registryOpts := options.Client().
107 SetRegistry(bson.NewRegistryBuilder().RegisterCodec(reflect.TypeOf(int64(0)), &negateCodec{}).Build())
108 mt.RunOpts("registry passed to cursors", mtest.NewOptions().ClientOptions(registryOpts), func(mt *mtest.T) {
109 _, err := mt.Coll.InsertOne(context.Background(), negateCodec{ID: 10})
110 assert.Nil(mt, err, "InsertOne error: %v", err)
111 var got negateCodec
112 err = mt.Coll.FindOne(context.Background(), bson.D{}).Decode(&got)
113 assert.Nil(mt, err, "Find error: %v", err)
114
115 assert.Equal(mt, int64(-10), got.ID, "expected ID -10, got %v", got.ID)
116 })
117 mt.RunOpts("tls connection", mtest.NewOptions().MinServerVersion("3.0").SSL(true), func(mt *mtest.T) {
118 var result bson.Raw
119 err := mt.Coll.Database().RunCommand(context.Background(), bson.D{
120 {"serverStatus", 1},
121 }).Decode(&result)
122 assert.Nil(mt, err, "serverStatus error: %v", err)
123
124 security := result.Lookup("security")
125 assert.Equal(mt, bson.TypeEmbeddedDocument, security.Type,
126 "expected security field to be type %v, got %v", bson.TypeMaxKey, security.Type)
127 _, found := security.Document().LookupErr("SSLServerSubjectName")
128 assert.Nil(mt, found, "SSLServerSubjectName not found in result")
129 })
130 mt.RunOpts("x509", mtest.NewOptions().Auth(true).SSL(true), func(mt *mtest.T) {
131 testCases := []struct {
132 certificate string
133 password string
134 }{
135 {
136 "MONGO_GO_DRIVER_KEY_FILE",
137 "",
138 },
139 {
140 "MONGO_GO_DRIVER_PKCS8_ENCRYPTED_KEY_FILE",
141 "&sslClientCertificateKeyPassword=password",
142 },
143 {
144 "MONGO_GO_DRIVER_PKCS8_UNENCRYPTED_KEY_FILE",
145 "",
146 },
147 }
148 for _, tc := range testCases {
149 mt.Run(tc.certificate, func(mt *mtest.T) {
150 const user = "C=US,ST=New York,L=New York City,O=MDB,OU=Drivers,CN=client"
151 db := mt.Client.Database("$external")
152
153
154 _ = db.RunCommand(
155 context.Background(),
156 bson.D{{"dropUser", user}},
157 )
158 err := db.RunCommand(
159 context.Background(),
160 bson.D{
161 {"createUser", user},
162 {"roles", bson.A{
163 bson.D{{"role", "readWrite"}, {"db", "test"}},
164 }},
165 },
166 ).Err()
167 assert.Nil(mt, err, "createUser error: %v", err)
168
169 baseConnString := mtest.ClusterURI()
170
171 revisedConnString := "mongodb://"
172 split := strings.Split(baseConnString, "@")
173 assert.Equal(t, 2, len(split), "expected 2 parts after split, got %v (connstring %v)", split, baseConnString)
174 revisedConnString += split[1]
175
176 cs := fmt.Sprintf(
177 "%s&sslClientCertificateKeyFile=%s&authMechanism=MONGODB-X509&authSource=$external%s",
178 revisedConnString,
179 os.Getenv(tc.certificate),
180 tc.password,
181 )
182 authClientOpts := options.Client().ApplyURI(cs)
183 integtest.AddTestServerAPIVersion(authClientOpts)
184 authClient, err := mongo.Connect(context.Background(), authClientOpts)
185 assert.Nil(mt, err, "authClient Connect error: %v", err)
186 defer func() { _ = authClient.Disconnect(context.Background()) }()
187
188 rdr, err := authClient.Database("test").RunCommand(context.Background(), bson.D{
189 {"connectionStatus", 1},
190 }).Raw()
191 assert.Nil(mt, err, "connectionStatus error: %v", err)
192 users, err := rdr.LookupErr("authInfo", "authenticatedUsers")
193 assert.Nil(mt, err, "authenticatedUsers not found in response")
194 elems, err := users.Array().Elements()
195 assert.Nil(mt, err, "error getting users elements: %v", err)
196
197 for _, userElem := range elems {
198 rdr := userElem.Value().Document()
199 var u struct {
200 User string
201 DB string
202 }
203
204 if err := bson.Unmarshal(rdr, &u); err != nil {
205 continue
206 }
207 if u.User == user && u.DB == "$external" {
208 return
209 }
210 }
211 mt.Fatal("unable to find authenticated user")
212 })
213 }
214 })
215 mt.RunOpts("list databases", noClientOpts, func(mt *mtest.T) {
216 mt.RunOpts("filter", noClientOpts, func(mt *mtest.T) {
217 testCases := []struct {
218 name string
219 filter bson.D
220 hasTestDb bool
221 minServerVersion string
222 }{
223 {"empty", bson.D{}, true, ""},
224 {"non-empty", bson.D{{"name", "foobar"}}, false, "3.6"},
225 }
226
227 for _, tc := range testCases {
228 opts := mtest.NewOptions()
229 if tc.minServerVersion != "" {
230 opts.MinServerVersion(tc.minServerVersion)
231 }
232
233 mt.RunOpts(tc.name, opts, func(mt *mtest.T) {
234 res, err := mt.Client.ListDatabases(context.Background(), tc.filter)
235 assert.Nil(mt, err, "ListDatabases error: %v", err)
236
237 var found bool
238 for _, db := range res.Databases {
239 if db.Name == mtest.TestDb {
240 found = true
241 break
242 }
243 }
244 assert.Equal(mt, tc.hasTestDb, found, "expected to find test db: %v, found: %v", tc.hasTestDb, found)
245 })
246 }
247 })
248 mt.Run("options", func(mt *mtest.T) {
249 allOpts := options.ListDatabases().SetNameOnly(true).SetAuthorizedDatabases(true)
250 mt.ClearEvents()
251
252 _, err := mt.Client.ListDatabases(context.Background(), bson.D{}, allOpts)
253 assert.Nil(mt, err, "ListDatabases error: %v", err)
254
255 evt := mt.GetStartedEvent()
256 assert.Equal(mt, "listDatabases", evt.CommandName, "expected ")
257
258 expectedDoc := bsoncore.BuildDocumentFromElements(nil,
259 bsoncore.AppendBooleanElement(nil, "nameOnly", true),
260 bsoncore.AppendBooleanElement(nil, "authorizedDatabases", true),
261 )
262 err = compareDocs(mt, expectedDoc, evt.Command)
263 assert.Nil(mt, err, "compareDocs error: %v", err)
264 })
265 })
266 mt.RunOpts("list database names", noClientOpts, func(mt *mtest.T) {
267 mt.RunOpts("filter", noClientOpts, func(mt *mtest.T) {
268 testCases := []struct {
269 name string
270 filter bson.D
271 hasTestDb bool
272 minServerVersion string
273 }{
274 {"no filter", bson.D{}, true, ""},
275 {"filter", bson.D{{"name", "foobar"}}, false, "3.6"},
276 }
277
278 for _, tc := range testCases {
279 opts := mtest.NewOptions()
280 if tc.minServerVersion != "" {
281 opts.MinServerVersion(tc.minServerVersion)
282 }
283
284 mt.RunOpts(tc.name, opts, func(mt *mtest.T) {
285 dbs, err := mt.Client.ListDatabaseNames(context.Background(), tc.filter)
286 assert.Nil(mt, err, "ListDatabaseNames error: %v", err)
287
288 var found bool
289 for _, db := range dbs {
290 if db == mtest.TestDb {
291 found = true
292 break
293 }
294 }
295 assert.Equal(mt, tc.hasTestDb, found, "expected to find test db: %v, found: %v", tc.hasTestDb, found)
296 })
297 }
298 })
299 mt.Run("options", func(mt *mtest.T) {
300 allOpts := options.ListDatabases().SetNameOnly(true).SetAuthorizedDatabases(true)
301 mt.ClearEvents()
302
303 _, err := mt.Client.ListDatabaseNames(context.Background(), bson.D{}, allOpts)
304 assert.Nil(mt, err, "ListDatabaseNames error: %v", err)
305
306 evt := mt.GetStartedEvent()
307 assert.Equal(mt, "listDatabases", evt.CommandName, "expected ")
308
309 expectedDoc := bsoncore.BuildDocumentFromElements(nil,
310 bsoncore.AppendBooleanElement(nil, "nameOnly", true),
311 bsoncore.AppendBooleanElement(nil, "authorizedDatabases", true),
312 )
313 err = compareDocs(mt, expectedDoc, evt.Command)
314 assert.Nil(mt, err, "compareDocs error: %v", err)
315 })
316 })
317 mt.RunOpts("ping", noClientOpts, func(mt *mtest.T) {
318 mt.Run("default read preference", func(mt *mtest.T) {
319 err := mt.Client.Ping(context.Background(), nil)
320 assert.Nil(mt, err, "Ping error: %v", err)
321 })
322 mt.Run("invalid host", func(mt *mtest.T) {
323
324
325 invalidClientOpts := options.Client().
326 SetServerSelectionTimeout(100 * time.Millisecond).SetHosts([]string{"invalid:123"}).
327 SetConnectTimeout(500 * time.Millisecond).SetSocketTimeout(500 * time.Millisecond)
328 integtest.AddTestServerAPIVersion(invalidClientOpts)
329 client, err := mongo.Connect(context.Background(), invalidClientOpts)
330 assert.Nil(mt, err, "Connect error: %v", err)
331 err = client.Ping(context.Background(), readpref.Primary())
332 assert.NotNil(mt, err, "expected error for pinging invalid host, got nil")
333 _ = client.Disconnect(context.Background())
334 })
335 })
336 mt.RunOpts("disconnect", noClientOpts, func(mt *mtest.T) {
337 mt.Run("nil context", func(mt *mtest.T) {
338 err := mt.Client.Disconnect(nil)
339 assert.Nil(mt, err, "Disconnect error: %v", err)
340 })
341 })
342 mt.RunOpts("watch", noClientOpts, func(mt *mtest.T) {
343 mt.Run("disconnected", func(mt *mtest.T) {
344 c, err := mongo.NewClient(options.Client().ApplyURI(mtest.ClusterURI()))
345 assert.Nil(mt, err, "NewClient error: %v", err)
346 _, err = c.Watch(context.Background(), mongo.Pipeline{})
347 assert.Equal(mt, mongo.ErrClientDisconnected, err, "expected error %v, got %v", mongo.ErrClientDisconnected, err)
348 })
349 })
350 mt.RunOpts("end sessions", mtest.NewOptions().MinServerVersion("3.6"), func(mt *mtest.T) {
351 _, err := mt.Client.ListDatabases(context.Background(), bson.D{})
352 assert.Nil(mt, err, "ListDatabases error: %v", err)
353
354 mt.ClearEvents()
355 err = mt.Client.Disconnect(context.Background())
356 assert.Nil(mt, err, "Disconnect error: %v", err)
357
358 started := mt.GetStartedEvent()
359 assert.Equal(mt, "endSessions", started.CommandName, "expected cmd name endSessions, got %v", started.CommandName)
360 })
361 mt.RunOpts("hello lastWriteDate", mtest.NewOptions().Topologies(mtest.ReplicaSet), func(mt *mtest.T) {
362 _, err := mt.Coll.InsertOne(context.Background(), bson.D{{"x", 1}})
363 assert.Nil(mt, err, "InsertOne error: %v", err)
364 })
365 sessionOpts := mtest.NewOptions().MinServerVersion("3.6.0").CreateClient(false)
366 mt.RunOpts("causal consistency", sessionOpts, func(mt *mtest.T) {
367 testCases := []struct {
368 name string
369 opts *options.SessionOptions
370 consistent bool
371 }{
372 {"default", options.Session(), true},
373 {"true", options.Session().SetCausalConsistency(true), true},
374 {"false", options.Session().SetCausalConsistency(false), false},
375 }
376
377 for _, tc := range testCases {
378 mt.Run(tc.name, func(mt *mtest.T) {
379 sess, err := mt.Client.StartSession(tc.opts)
380 assert.Nil(mt, err, "StartSession error: %v", err)
381 defer sess.EndSession(context.Background())
382 xs := sess.(mongo.XSession)
383 consistent := xs.ClientSession().Consistent
384 assert.Equal(mt, tc.consistent, consistent, "expected consistent to be %v, got %v", tc.consistent, consistent)
385 })
386 }
387 })
388 retryOpts := mtest.NewOptions().MinServerVersion("3.6.0").ClientType(mtest.Mock)
389 mt.RunOpts("retry writes error 20 wrapped", retryOpts, func(mt *mtest.T) {
390 writeErrorCode20 := mtest.CreateWriteErrorsResponse(mtest.WriteError{
391 Message: "Transaction numbers",
392 Code: 20,
393 })
394 writeErrorCode19 := mtest.CreateWriteErrorsResponse(mtest.WriteError{
395 Message: "Transaction numbers",
396 Code: 19,
397 })
398 writeErrorCode20WrongMsg := mtest.CreateWriteErrorsResponse(mtest.WriteError{
399 Message: "Not transaction numbers",
400 Code: 20,
401 })
402 cmdErrCode20 := mtest.CreateCommandErrorResponse(mtest.CommandError{
403 Message: "Transaction numbers",
404 Code: 20,
405 })
406 cmdErrCode19 := mtest.CreateCommandErrorResponse(mtest.CommandError{
407 Message: "Transaction numbers",
408 Code: 19,
409 })
410 cmdErrCode20WrongMsg := mtest.CreateCommandErrorResponse(mtest.CommandError{
411 Message: "Not transaction numbers",
412 Code: 20,
413 })
414
415 testCases := []struct {
416 name string
417 errResponse bson.D
418 expectUnsupportedMsg bool
419 }{
420 {"write error code 20", writeErrorCode20, true},
421 {"write error code 20 wrong msg", writeErrorCode20WrongMsg, false},
422 {"write error code 19 right msg", writeErrorCode19, false},
423 {"command error code 20", cmdErrCode20, true},
424 {"command error code 20 wrong msg", cmdErrCode20WrongMsg, false},
425 {"command error code 19 right msg", cmdErrCode19, false},
426 }
427 for _, tc := range testCases {
428 mt.Run(tc.name, func(mt *mtest.T) {
429 mt.ClearMockResponses()
430 mt.AddMockResponses(tc.errResponse)
431
432 sess, err := mt.Client.StartSession()
433 assert.Nil(mt, err, "StartSession error: %v", err)
434 defer sess.EndSession(context.Background())
435
436 _, err = mt.Coll.InsertOne(context.Background(), bson.D{{"x", 1}})
437 assert.NotNil(mt, err, "expected err but got nil")
438 if tc.expectUnsupportedMsg {
439 assert.Equal(mt, driver.ErrUnsupportedStorageEngine.Error(), err.Error(),
440 "expected error %v, got %v", driver.ErrUnsupportedStorageEngine, err)
441 return
442 }
443 assert.NotEqual(mt, driver.ErrUnsupportedStorageEngine.Error(), err.Error(),
444 "got ErrUnsupportedStorageEngine but wanted different error")
445 })
446 }
447 })
448
449 testAppName := "foo"
450 appNameClientOpts := options.Client().
451 SetAppName(testAppName)
452 appNameMtOpts := mtest.NewOptions().
453 ClientType(mtest.Proxy).
454 ClientOptions(appNameClientOpts).
455 Topologies(mtest.Single)
456 mt.RunOpts("app name is always sent", appNameMtOpts, func(mt *mtest.T) {
457 err := mt.Client.Ping(context.Background(), mtest.PrimaryRp)
458 assert.Nil(mt, err, "Ping error: %v", err)
459
460 msgPairs := mt.GetProxiedMessages()
461 assert.True(mt, len(msgPairs) >= 2, "expected at least 2 events sent, got %v", len(msgPairs))
462
463
464
465 for idx, pair := range msgPairs[:2] {
466 helloCommand := handshake.LegacyHello
467
468 if os.Getenv("REQUIRE_API_VERSION") == "true" {
469 helloCommand = "hello"
470 }
471 assert.Equal(mt, pair.CommandName, helloCommand, "expected command name %s at index %d, got %s", helloCommand, idx,
472 pair.CommandName)
473
474 sent := pair.Sent
475 appNameVal, err := sent.Command.LookupErr("client", "application", "name")
476 assert.Nil(mt, err, "expected command %s at index %d to contain app name", sent.Command, idx)
477 appName := appNameVal.StringValue()
478 assert.Equal(mt, testAppName, appName, "expected app name %v at index %d, got %v", testAppName, idx,
479 appName)
480 }
481 })
482
483
484 firstServerAddr := mtest.GlobalTopology().Description().Servers[0].Addr
485 directConnectionOpts := options.Client().
486 ApplyURI(fmt.Sprintf("mongodb://%s", firstServerAddr)).
487 SetReadPreference(readpref.Primary()).
488 SetDirect(true)
489 mtOpts := mtest.NewOptions().
490 ClientOptions(directConnectionOpts).
491 CreateCollection(false).
492 MinServerVersion("3.6").
493 Topologies(mtest.ReplicaSet)
494 mt.RunOpts("direct connection made", mtOpts, func(mt *mtest.T) {
495 _, err := mt.Coll.Find(context.Background(), bson.D{})
496 assert.Nil(mt, err, "Find error: %v", err)
497
498
499 evt := mt.GetStartedEvent()
500 assert.Equal(mt, "find", evt.CommandName, "expected 'find' event, got '%s'", evt.CommandName)
501
502
503
504 modeVal, err := evt.Command.LookupErr("$readPreference", "mode")
505 assert.Nil(mt, err, "expected command %s to include $readPreference", evt.Command)
506
507 mode := modeVal.StringValue()
508 assert.Equal(mt, mode, "primaryPreferred", "expected read preference mode primaryPreferred, got %v", mode)
509 })
510
511
512 mtOpts = mtest.NewOptions().ClientOptions(options.Client().SetMinPoolSize(5))
513 mt.RunOpts("minPoolSize", mtOpts, func(mt *mtest.T) {
514 err := mt.Client.Ping(context.Background(), readpref.Primary())
515 assert.Nil(t, err, "unexpected error calling Ping: %v", err)
516 })
517
518 mt.Run("minimum RTT is monitored", func(mt *mtest.T) {
519 mt.Parallel()
520
521
522
523 mt.ResetClient(options.Client().
524 SetDialer(newSlowConnDialer(slowConnDialerDelay)).
525 SetHeartbeatInterval(reducedHeartbeatInterval))
526
527
528 topo := getTopologyFromClient(mt.Client)
529 assert.Soon(mt, func(ctx context.Context) {
530 for {
531
532 select {
533 case <-ctx.Done():
534 return
535 default:
536 }
537
538 time.Sleep(100 * time.Millisecond)
539
540
541 done := true
542 for _, desc := range topo.Description().Servers {
543 server, err := topo.FindServer(desc)
544 assert.Nil(mt, err, "FindServer error: %v", err)
545 if server.RTTMonitor().Min() <= 250*time.Millisecond {
546 done = false
547 }
548 }
549 if done {
550 return
551 }
552 }
553 }, 10*time.Second)
554 })
555
556
557
558 mt.Run("minimum RTT used to prevent sending requests", func(mt *mtest.T) {
559 mt.Parallel()
560
561
562 ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
563 defer cancel()
564 err := mt.Client.Ping(ctx, nil)
565 assert.Nil(mt, err, "Ping error: %v", err)
566
567
568
569 tpm := eventtest.NewTestPoolMonitor()
570 mt.ResetClient(options.Client().
571 SetPoolMonitor(tpm.PoolMonitor).
572 SetDialer(newSlowConnDialer(slowConnDialerDelay)).
573 SetHeartbeatInterval(reducedHeartbeatInterval))
574
575
576 topo := getTopologyFromClient(mt.Client)
577 assert.Soon(mt, func(ctx context.Context) {
578 for {
579
580 select {
581 case <-ctx.Done():
582 return
583 default:
584 }
585
586 time.Sleep(100 * time.Millisecond)
587
588
589 done := true
590 for _, desc := range topo.Description().Servers {
591 server, err := topo.FindServer(desc)
592 assert.Nil(mt, err, "FindServer error: %v", err)
593 if server.RTTMonitor().Min() <= 250*time.Millisecond {
594 done = false
595 }
596 }
597 if done {
598 return
599 }
600 }
601 }, 10*time.Second)
602
603
604
605 for i := 0; i < 10; i++ {
606 ctx, cancel = context.WithTimeout(context.Background(), 250*time.Millisecond)
607 err := mt.Client.Ping(ctx, nil)
608 cancel()
609 assert.NotNil(mt, err, "expected Ping to return an error")
610 }
611
612
613 closed := len(tpm.Events(func(e *event.PoolEvent) bool { return e.Type == event.ConnectionClosed }))
614 assert.Equal(t, 0, closed, "expected no connections to be closed")
615 })
616
617 mt.Run("RTT90 is monitored", func(mt *mtest.T) {
618 mt.Parallel()
619
620
621
622 mt.ResetClient(options.Client().
623 SetDialer(newSlowConnDialer(slowConnDialerDelay)).
624 SetHeartbeatInterval(reducedHeartbeatInterval))
625
626
627 topo := getTopologyFromClient(mt.Client)
628 assert.Soon(mt, func(ctx context.Context) {
629 for {
630
631 select {
632 case <-ctx.Done():
633 return
634 default:
635 }
636
637 time.Sleep(100 * time.Millisecond)
638
639
640 done := true
641 for _, desc := range topo.Description().Servers {
642 server, err := topo.FindServer(desc)
643 assert.Nil(mt, err, "FindServer error: %v", err)
644 if server.RTTMonitor().P90() <= 300*time.Millisecond {
645 done = false
646 }
647 }
648 if done {
649 return
650 }
651 }
652 }, 10*time.Second)
653 })
654
655
656
657 mt.Run("RTT90 used to prevent sending requests", func(mt *mtest.T) {
658 mt.Parallel()
659
660
661 ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
662 defer cancel()
663 err := mt.Client.Ping(ctx, nil)
664 assert.Nil(mt, err, "Ping error: %v", err)
665
666
667
668
669
670 tpm := eventtest.NewTestPoolMonitor()
671 mt.ResetClient(options.Client().
672 SetPoolMonitor(tpm.PoolMonitor).
673 SetDialer(newSlowConnDialer(slowConnDialerDelay)).
674 SetHeartbeatInterval(reducedHeartbeatInterval).
675 SetTimeout(0))
676
677
678 topo := getTopologyFromClient(mt.Client)
679 assert.Soon(mt, func(ctx context.Context) {
680 for {
681
682 select {
683 case <-ctx.Done():
684 return
685 default:
686 }
687
688 time.Sleep(100 * time.Millisecond)
689
690
691 done := true
692 for _, desc := range topo.Description().Servers {
693 server, err := topo.FindServer(desc)
694 assert.Nil(mt, err, "FindServer error: %v", err)
695 if server.RTTMonitor().P90() <= 275*time.Millisecond {
696 done = false
697 }
698 }
699 if done {
700 return
701 }
702 }
703 }, 10*time.Second)
704
705
706
707 for i := 0; i < 10; i++ {
708 ctx, cancel = context.WithTimeout(context.Background(), 275*time.Millisecond)
709 err := mt.Client.Ping(ctx, nil)
710 cancel()
711 assert.NotNil(mt, err, "expected Ping to return an error")
712 assert.True(mt, mongo.IsTimeout(err), "expected a timeout error, got: %v", err)
713 }
714
715
716 closed := len(tpm.Events(func(e *event.PoolEvent) bool { return e.Type == event.ConnectionClosed }))
717 assert.Equal(t, 0, closed, "expected no connections to be closed")
718 })
719
720
721
722 opMsgOpts := mtest.NewOptions().ClientType(mtest.Proxy).MinServerVersion("3.6").Auth(true).RequireAPIVersion(false)
723 mt.RunOpts("OP_MSG used for authentication on 3.6+", opMsgOpts, func(mt *mtest.T) {
724 err := mt.Client.Ping(context.Background(), mtest.PrimaryRp)
725 assert.Nil(mt, err, "Ping error: %v", err)
726
727 msgPairs := mt.GetProxiedMessages()
728 assert.True(mt, len(msgPairs) >= 3, "expected at least 3 events, got %v", len(msgPairs))
729
730
731 pair := msgPairs[0]
732 assert.Equal(mt, handshake.LegacyHello, pair.CommandName, "expected command name %s at index 0, got %s",
733 handshake.LegacyHello, pair.CommandName)
734 assert.Equal(mt, wiremessage.OpQuery, pair.Sent.OpCode,
735 "expected 'OP_QUERY' OpCode in wire message, got %q", pair.Sent.OpCode.String())
736
737
738
739 var saslContinueFound bool
740 for _, pair := range msgPairs[1:] {
741 if pair.CommandName == "saslContinue" {
742 saslContinueFound = true
743 assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode,
744 "expected 'OP_MSG' OpCode in wire message, got %s", pair.Sent.OpCode.String())
745 break
746 }
747 }
748 assert.True(mt, saslContinueFound, "did not find 'saslContinue' command in proxied messages")
749 })
750
751
752 opMsgSAPIOpts := mtest.NewOptions().ClientType(mtest.Proxy).MinServerVersion("5.0").RequireAPIVersion(true)
753 mt.RunOpts("OP_MSG used for handshakes when API version declared", opMsgSAPIOpts, func(mt *mtest.T) {
754 err := mt.Client.Ping(context.Background(), mtest.PrimaryRp)
755 assert.Nil(mt, err, "Ping error: %v", err)
756
757 msgPairs := mt.GetProxiedMessages()
758 assert.True(mt, len(msgPairs) >= 3, "expected at least 3 events, got %v", len(msgPairs))
759
760
761
762 for idx, pair := range msgPairs[:3] {
763 assert.Equal(mt, "hello", pair.CommandName, "expected command name 'hello' at index %d, got %s", idx,
764 pair.CommandName)
765
766
767 assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode,
768 "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String())
769 }
770 })
771
772 opts := mtest.NewOptions().
773
774 Topologies(mtest.Single, mtest.ReplicaSet).
775 MinServerVersion("4.2").
776
777 ClientOptions(options.Client().SetRetryReads(true).SetRetryWrites(true))
778 mt.RunOpts("operations don't retry after a context timeout", opts, func(mt *mtest.T) {
779 testCases := []struct {
780 desc string
781 operation func(context.Context, *mongo.Collection) error
782 }{
783 {
784 desc: "read op",
785 operation: func(ctx context.Context, coll *mongo.Collection) error {
786 return coll.FindOne(ctx, bson.D{}).Err()
787 },
788 },
789 {
790 desc: "write op",
791 operation: func(ctx context.Context, coll *mongo.Collection) error {
792 _, err := coll.InsertOne(ctx, bson.D{})
793 return err
794 },
795 },
796 }
797
798 for _, tc := range testCases {
799 mt.Run(tc.desc, func(mt *mtest.T) {
800 _, err := mt.Coll.InsertOne(context.Background(), bson.D{})
801 require.NoError(mt, err)
802
803 mt.SetFailPoint(mtest.FailPoint{
804 ConfigureFailPoint: "failCommand",
805 Mode: "alwaysOn",
806 Data: mtest.FailPointData{
807 FailCommands: []string{"find", "insert"},
808 BlockConnection: true,
809 BlockTimeMS: 500,
810 },
811 })
812
813 mt.ClearEvents()
814
815 for i := 0; i < 50; i++ {
816
817
818
819
820
821 ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
822 err = tc.operation(ctx, mt.Coll)
823 cancel()
824 assert.ErrorIs(mt, err, context.DeadlineExceeded)
825 assert.True(mt, mongo.IsTimeout(err), "expected mongo.IsTimeout(err) to be true")
826
827
828
829
830 evts := mt.GetAllStartedEvents()
831 require.Len(mt,
832 mt.GetAllStartedEvents(),
833 1,
834 "expected exactly 1 command started event per operation, but got %d after %d iterations",
835 len(evts),
836 i)
837 mt.ClearEvents()
838 }
839 })
840 }
841 })
842 }
843
844 func TestClient_BSONOptions(t *testing.T) {
845 t.Parallel()
846
847 mt := mtest.New(t, noClientOpts)
848
849 type jsonTagsTest struct {
850 A string
851 B string `json:"x"`
852 C string `json:"y" bson:"3"`
853 }
854
855 testCases := []struct {
856 name string
857 bsonOpts *options.BSONOptions
858 doc interface{}
859 decodeInto func() interface{}
860 want interface{}
861 wantRaw bson.Raw
862 }{
863 {
864 name: "UseJSONStructTags",
865 bsonOpts: &options.BSONOptions{
866 UseJSONStructTags: true,
867 },
868 doc: jsonTagsTest{
869 A: "apple",
870 B: "banana",
871 C: "carrot",
872 },
873 decodeInto: func() interface{} { return &jsonTagsTest{} },
874 want: &jsonTagsTest{
875 A: "apple",
876 B: "banana",
877 C: "carrot",
878 },
879 wantRaw: bson.Raw(bsoncore.NewDocumentBuilder().
880 AppendString("a", "apple").
881 AppendString("x", "banana").
882 AppendString("3", "carrot").
883 Build()),
884 },
885 {
886 name: "IntMinSize",
887 bsonOpts: &options.BSONOptions{
888 IntMinSize: true,
889 },
890 doc: bson.D{{Key: "x", Value: int64(1)}},
891 decodeInto: func() interface{} { return &bson.D{} },
892 want: &bson.D{{Key: "x", Value: int32(1)}},
893 wantRaw: bson.Raw(bsoncore.NewDocumentBuilder().
894 AppendInt32("x", 1).
895 Build()),
896 },
897 {
898 name: "DefaultDocumentM",
899 bsonOpts: &options.BSONOptions{
900 DefaultDocumentM: true,
901 },
902 doc: bson.D{{Key: "doc", Value: bson.D{{Key: "a", Value: int64(1)}}}},
903 decodeInto: func() interface{} { return &bson.D{} },
904 want: &bson.D{{Key: "doc", Value: bson.M{"a": int64(1)}}},
905 },
906 }
907
908 for _, tc := range testCases {
909 opts := mtest.NewOptions().ClientOptions(
910 options.Client().SetBSONOptions(tc.bsonOpts))
911 mt.RunOpts(tc.name, opts, func(mt *mtest.T) {
912 res, err := mt.Coll.InsertOne(context.Background(), tc.doc)
913 require.NoError(mt, err, "InsertOne error")
914
915 sr := mt.Coll.FindOne(
916 context.Background(),
917 bson.D{{Key: "_id", Value: res.InsertedID}},
918
919
920 options.FindOne().SetProjection(bson.D{{Key: "_id", Value: 0}}))
921
922 if tc.want != nil {
923 got := tc.decodeInto()
924 err := sr.Decode(got)
925 require.NoError(mt, err, "Decode error")
926
927 assert.Equal(mt, tc.want, got, "expected and actual decoded result are different")
928 }
929
930 if tc.wantRaw != nil {
931 got, err := sr.Raw()
932 require.NoError(mt, err, "Raw error")
933
934 assert.EqualBSON(mt, tc.wantRaw, got)
935 }
936 })
937 }
938
939 opts := mtest.NewOptions().ClientOptions(
940 options.Client().SetBSONOptions(&options.BSONOptions{
941 ErrorOnInlineDuplicates: true,
942 }))
943 mt.RunOpts("ErrorOnInlineDuplicates", opts, func(mt *mtest.T) {
944 type inlineDupInner struct {
945 A string
946 }
947
948 type inlineDupOuter struct {
949 A string
950 B *inlineDupInner `bson:"b,inline"`
951 }
952
953 _, err := mt.Coll.InsertOne(context.Background(), inlineDupOuter{
954 A: "outer",
955 B: &inlineDupInner{
956 A: "inner",
957 },
958 })
959 require.Error(mt, err, "expected InsertOne to return an error")
960 })
961 }
962
963 func TestClientStress(t *testing.T) {
964 mtOpts := mtest.NewOptions().CreateClient(false)
965 mt := mtest.New(t, mtOpts)
966
967
968 mt.Run("Client recovers from traffic spike", func(mt *mtest.T) {
969 oid := primitive.NewObjectID()
970 doc := bson.D{{Key: "_id", Value: oid}, {Key: "key", Value: "value"}}
971 _, err := mt.Coll.InsertOne(context.Background(), doc)
972 assert.Nil(mt, err, "InsertOne error: %v", err)
973
974
975
976 findOne := func(coll *mongo.Collection, timeout time.Duration) error {
977 ctx, cancel := context.WithTimeout(context.Background(), timeout)
978 defer cancel()
979 var res map[string]interface{}
980 return coll.FindOne(ctx, bson.D{{Key: "_id", Value: oid}}).Decode(&res)
981 }
982
983
984
985 findOneFor := func(coll *mongo.Collection, timeout time.Duration, d time.Duration) []error {
986 errs := make([]error, 0)
987 start := time.Now()
988 for time.Since(start) <= d {
989 err := findOne(coll, timeout)
990 if err != nil {
991 errs = append(errs, err)
992 }
993 time.Sleep(10 * time.Microsecond)
994 }
995 return errs
996 }
997
998
999
1000 var maxRTT time.Duration
1001 for i := 0; i < 50; i++ {
1002 start := time.Now()
1003 err := findOne(mt.Coll, 10*time.Second)
1004 assert.Nil(t, err, "FindOne error: %v", err)
1005 duration := time.Since(start)
1006 if duration > maxRTT {
1007 maxRTT = duration
1008 }
1009 }
1010 assert.True(mt, maxRTT > 0, "RTT must be greater than 0")
1011
1012
1013
1014
1015 maxPoolSizes := []uint64{1, 10, 100}
1016 for _, maxPoolSize := range maxPoolSizes {
1017 tpm := eventtest.NewTestPoolMonitor()
1018 maxPoolSizeOpt := mtest.NewOptions().ClientOptions(
1019 options.Client().
1020 SetPoolMonitor(tpm.PoolMonitor).
1021 SetMaxPoolSize(maxPoolSize))
1022 mt.RunOpts(fmt.Sprintf("maxPoolSize %d", maxPoolSize), maxPoolSizeOpt, func(mt *mtest.T) {
1023
1024
1025 defer func() {
1026 created := len(tpm.Events(func(e *event.PoolEvent) bool { return e.Type == event.ConnectionCreated }))
1027 closed := len(tpm.Events(func(e *event.PoolEvent) bool { return e.Type == event.ConnectionClosed }))
1028 poolCleared := len(tpm.Events(func(e *event.PoolEvent) bool { return e.Type == event.PoolCleared }))
1029 mt.Logf("Connections created: %d, connections closed: %d, pool clears: %d", created, closed, poolCleared)
1030 }()
1031
1032 doc := bson.D{{Key: "_id", Value: oid}, {Key: "key", Value: "value"}}
1033 _, err := mt.Coll.InsertOne(context.Background(), doc)
1034 assert.Nil(mt, err, "InsertOne error: %v", err)
1035
1036
1037
1038 timeout := maxRTT * 100
1039 minTimeout := 100 * time.Millisecond
1040 if timeout < minTimeout {
1041 timeout = minTimeout
1042 }
1043 mt.Logf("Max RTT %v; using a timeout of %v", maxRTT, timeout)
1044
1045
1046
1047 _ = findOneFor(mt.Coll, timeout, 1*time.Second)
1048
1049
1050
1051 errs := findOneFor(mt.Coll, timeout, 1*time.Second)
1052 assert.True(mt, len(errs) == 0, "expected no errors, but got %d (%v)", len(errs), errs)
1053
1054
1055
1056 g := new(errgroup.Group)
1057 for i := 0; i < 1000; i++ {
1058 g.Go(func() error {
1059 errs := findOneFor(mt.Coll, timeout, 10*time.Second)
1060 if len(errs) == 0 {
1061 return nil
1062 }
1063 return errs[len(errs)-1]
1064 })
1065 }
1066 err = g.Wait()
1067 mt.Logf("Error from extreme traffic spike (errors are expected): %v", err)
1068
1069
1070
1071 _ = findOneFor(mt.Coll, timeout, 5*time.Second)
1072
1073
1074 errs = findOneFor(mt.Coll, timeout, 1*time.Second)
1075 assert.True(mt, len(errs) == 0, "expected no errors, but got %d (%v)", len(errs), errs)
1076 })
1077 }
1078 })
1079 }
1080
View as plain text