1
2
3
4
5
6
7 package mongo
8
9 import (
10 "context"
11 "errors"
12 "math"
13 "os"
14 "testing"
15 "time"
16
17 "go.mongodb.org/mongo-driver/bson"
18 "go.mongodb.org/mongo-driver/event"
19 "go.mongodb.org/mongo-driver/internal/assert"
20 "go.mongodb.org/mongo-driver/internal/integtest"
21 "go.mongodb.org/mongo-driver/mongo/options"
22 "go.mongodb.org/mongo-driver/mongo/readconcern"
23 "go.mongodb.org/mongo-driver/mongo/readpref"
24 "go.mongodb.org/mongo-driver/mongo/writeconcern"
25 "go.mongodb.org/mongo-driver/tag"
26 "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
27 "go.mongodb.org/mongo-driver/x/mongo/driver/session"
28 "go.mongodb.org/mongo-driver/x/mongo/driver/topology"
29 )
30
31 var bgCtx = context.Background()
32
33 func setupClient(opts ...*options.ClientOptions) *Client {
34 if len(opts) == 0 {
35 clientOpts := options.Client().ApplyURI("mongodb://localhost:27017")
36 integtest.AddTestServerAPIVersion(clientOpts)
37 opts = append(opts, clientOpts)
38 }
39 client, _ := NewClient(opts...)
40 return client
41 }
42
43 func TestClient(t *testing.T) {
44 t.Run("new client", func(t *testing.T) {
45 client := setupClient()
46 assert.NotNil(t, client.deployment, "expected valid deployment, got nil")
47 })
48 t.Run("database", func(t *testing.T) {
49 dbName := "foo"
50 client := setupClient()
51 db := client.Database(dbName)
52 assert.Equal(t, dbName, db.Name(), "expected db name %v, got %v", dbName, db.Name())
53 assert.Equal(t, client, db.Client(), "expected client %v, got %v", client, db.Client())
54 })
55 t.Run("replace topology error", func(t *testing.T) {
56 client := setupClient()
57
58 _, err := client.StartSession()
59 assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
60
61 _, err = client.ListDatabases(bgCtx, bson.D{})
62 assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
63
64 err = client.Ping(bgCtx, nil)
65 assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
66
67 err = client.Disconnect(bgCtx)
68 assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
69
70 _, err = client.Watch(bgCtx, []bson.D{})
71 assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
72 })
73 t.Run("nil document error", func(t *testing.T) {
74
75 client := setupClient()
76 client.sessionPool = &session.Pool{}
77
78 _, err := client.Watch(bgCtx, nil)
79 watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid")
80 assert.Equal(t, watchErr, err, "expected error %v, got %v", watchErr, err)
81
82 _, err = client.ListDatabases(bgCtx, nil)
83 assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err)
84
85 _, err = client.ListDatabaseNames(bgCtx, nil)
86 assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err)
87 })
88 t.Run("read preference", func(t *testing.T) {
89 t.Run("absent", func(t *testing.T) {
90 client := setupClient()
91 gotMode := client.readPreference.Mode()
92 wantMode := readpref.PrimaryMode
93 assert.Equal(t, gotMode, wantMode, "expected mode %v, got %v", wantMode, gotMode)
94 _, flag := client.readPreference.MaxStaleness()
95 assert.False(t, flag, "expected max staleness to not be set but was")
96 })
97 t.Run("specified", func(t *testing.T) {
98 tags := []tag.Set{
99 {
100 tag.Tag{
101 Name: "one",
102 Value: "1",
103 },
104 },
105 {
106 tag.Tag{
107 Name: "two",
108 Value: "2",
109 },
110 },
111 }
112 cs := "mongodb://localhost:27017/"
113 cs += "?readpreference=secondary&readPreferenceTags=one:1&readPreferenceTags=two:2&maxStaleness=5"
114
115 client := setupClient(options.Client().ApplyURI(cs))
116 gotMode := client.readPreference.Mode()
117 assert.Equal(t, gotMode, readpref.SecondaryMode, "expected mode %v, got %v", readpref.SecondaryMode, gotMode)
118 gotTags := client.readPreference.TagSets()
119 assert.Equal(t, gotTags, tags, "expected tags %v, got %v", tags, gotTags)
120 gotStaleness, flag := client.readPreference.MaxStaleness()
121 assert.True(t, flag, "expected max staleness to be set but was not")
122 wantStaleness := time.Duration(5) * time.Second
123 assert.Equal(t, gotStaleness, wantStaleness, "expected staleness %v, got %v", wantStaleness, gotStaleness)
124 })
125 })
126 t.Run("localThreshold", func(t *testing.T) {
127 testCases := []struct {
128 name string
129 opts *options.ClientOptions
130 expectedThreshold time.Duration
131 }{
132 {"default", options.Client(), defaultLocalThreshold},
133 {"custom", options.Client().SetLocalThreshold(10 * time.Second), 10 * time.Second},
134 }
135 for _, tc := range testCases {
136 t.Run(tc.name, func(t *testing.T) {
137 client := setupClient(tc.opts)
138 assert.Equal(t, tc.expectedThreshold, client.localThreshold,
139 "expected localThreshold %v, got %v", tc.expectedThreshold, client.localThreshold)
140 })
141 }
142 })
143 t.Run("read concern", func(t *testing.T) {
144 rc := readconcern.Majority()
145 client := setupClient(options.Client().SetReadConcern(rc))
146 assert.Equal(t, rc, client.readConcern, "expected read concern %v, got %v", rc, client.readConcern)
147 })
148 t.Run("min pool size from Set*PoolSize()", func(t *testing.T) {
149 testCases := []struct {
150 name string
151 opts *options.ClientOptions
152 err error
153 }{
154 {
155 name: "minPoolSize < default maxPoolSize",
156 opts: options.Client().SetMinPoolSize(64),
157 err: nil,
158 },
159 {
160 name: "minPoolSize > default maxPoolSize",
161 opts: options.Client().SetMinPoolSize(128),
162 err: errors.New("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=128 maxPoolSize=100"),
163 },
164 {
165 name: "minPoolSize < maxPoolSize",
166 opts: options.Client().SetMinPoolSize(128).SetMaxPoolSize(256),
167 err: nil,
168 },
169 {
170 name: "minPoolSize == maxPoolSize",
171 opts: options.Client().SetMinPoolSize(128).SetMaxPoolSize(128),
172 err: nil,
173 },
174 {
175 name: "minPoolSize > maxPoolSize",
176 opts: options.Client().SetMinPoolSize(64).SetMaxPoolSize(32),
177 err: errors.New("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=64 maxPoolSize=32"),
178 },
179 {
180 name: "maxPoolSize == 0",
181 opts: options.Client().SetMinPoolSize(128).SetMaxPoolSize(0),
182 err: nil,
183 },
184 }
185 for _, tc := range testCases {
186 t.Run(tc.name, func(t *testing.T) {
187 _, err := NewClient(tc.opts)
188 assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
189 })
190 }
191 })
192 t.Run("min pool size from ApplyURI()", func(t *testing.T) {
193 testCases := []struct {
194 name string
195 opts *options.ClientOptions
196 err error
197 }{
198 {
199 name: "minPoolSize < default maxPoolSize",
200 opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=64"),
201 err: nil,
202 },
203 {
204 name: "minPoolSize > default maxPoolSize",
205 opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=128"),
206 err: errors.New("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=128 maxPoolSize=100"),
207 },
208 {
209 name: "minPoolSize < maxPoolSize",
210 opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=128&maxPoolSize=256"),
211 err: nil,
212 },
213 {
214 name: "minPoolSize == maxPoolSize",
215 opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=128&maxPoolSize=128"),
216 err: nil,
217 },
218 {
219 name: "minPoolSize > maxPoolSize",
220 opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=64&maxPoolSize=32"),
221 err: errors.New("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=64 maxPoolSize=32"),
222 },
223 {
224 name: "maxPoolSize == 0",
225 opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=128&maxPoolSize=0"),
226 err: nil,
227 },
228 }
229 for _, tc := range testCases {
230 t.Run(tc.name, func(t *testing.T) {
231 _, err := NewClient(tc.opts)
232 assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
233 })
234 }
235 })
236 t.Run("retry writes", func(t *testing.T) {
237 retryWritesURI := "mongodb://localhost:27017/?retryWrites=false"
238 retryWritesErrorURI := "mongodb://localhost:27017/?retryWrites=foobar"
239
240 testCases := []struct {
241 name string
242 opts *options.ClientOptions
243 expectErr bool
244 expectedRetry bool
245 }{
246 {"default", options.Client(), false, true},
247 {"custom options", options.Client().SetRetryWrites(false), false, false},
248 {"custom URI", options.Client().ApplyURI(retryWritesURI), false, false},
249 {"custom URI error", options.Client().ApplyURI(retryWritesErrorURI), true, false},
250 }
251 for _, tc := range testCases {
252 t.Run(tc.name, func(t *testing.T) {
253 client, err := NewClient(tc.opts)
254 if tc.expectErr {
255 assert.NotNil(t, err, "expected error, got nil")
256 return
257 }
258 assert.Nil(t, err, "configuration error: %v", err)
259 assert.Equal(t, tc.expectedRetry, client.retryWrites, "expected retryWrites %v, got %v",
260 tc.expectedRetry, client.retryWrites)
261 })
262 }
263 })
264 t.Run("retry reads", func(t *testing.T) {
265 retryReadsURI := "mongodb://localhost:27017/?retryReads=false"
266 retryReadsErrorURI := "mongodb://localhost:27017/?retryReads=foobar"
267
268 testCases := []struct {
269 name string
270 opts *options.ClientOptions
271 expectErr bool
272 expectedRetry bool
273 }{
274 {"default", options.Client(), false, true},
275 {"custom options", options.Client().SetRetryReads(false), false, false},
276 {"custom URI", options.Client().ApplyURI(retryReadsURI), false, false},
277 {"custom URI error", options.Client().ApplyURI(retryReadsErrorURI), true, false},
278 }
279 for _, tc := range testCases {
280 t.Run(tc.name, func(t *testing.T) {
281 client, err := NewClient(tc.opts)
282 if tc.expectErr {
283 assert.NotNil(t, err, "expected error, got nil")
284 return
285 }
286 assert.Nil(t, err, "configuration error: %v", err)
287 assert.Equal(t, tc.expectedRetry, client.retryReads, "expected retryReads %v, got %v",
288 tc.expectedRetry, client.retryReads)
289 })
290 }
291 })
292 t.Run("write concern", func(t *testing.T) {
293 wc := writeconcern.New(writeconcern.WMajority())
294 client := setupClient(options.Client().SetWriteConcern(wc))
295 assert.Equal(t, wc, client.writeConcern, "mismatch; expected write concern %v, got %v", wc, client.writeConcern)
296 })
297 t.Run("server monitor", func(t *testing.T) {
298 monitor := &event.ServerMonitor{}
299 client := setupClient(options.Client().SetServerMonitor(monitor))
300 assert.Equal(t, monitor, client.serverMonitor, "expected sdam monitor %v, got %v", monitor, client.serverMonitor)
301 })
302 t.Run("GetURI", func(t *testing.T) {
303 t.Run("ApplyURI not called", func(t *testing.T) {
304 opts := options.Client().SetHosts([]string{"localhost:27017"})
305 uri := opts.GetURI()
306 assert.Equal(t, "", uri, "expected GetURI to return empty string, got %v", uri)
307 })
308 t.Run("ApplyURI called with empty string", func(t *testing.T) {
309 opts := options.Client().ApplyURI("")
310 uri := opts.GetURI()
311 assert.Equal(t, "", uri, "expected GetURI to return empty string, got %v", uri)
312 })
313 t.Run("ApplyURI called with non-empty string", func(t *testing.T) {
314 uri := "mongodb://localhost:27017/foobar"
315 opts := options.Client().ApplyURI(uri)
316 got := opts.GetURI()
317 assert.Equal(t, uri, got, "expected GetURI to return %v, got %v", uri, got)
318 })
319 })
320 t.Run("endSessions", func(t *testing.T) {
321 cs := integtest.ConnString(t)
322 originalBatchSize := endSessionsBatchSize
323 endSessionsBatchSize = 2
324 defer func() {
325 endSessionsBatchSize = originalBatchSize
326 }()
327
328 testCases := []struct {
329 name string
330 numSessions int
331 eventBatchSizes []int
332 }{
333 {"number of sessions divides evenly", endSessionsBatchSize * 2, []int{endSessionsBatchSize, endSessionsBatchSize}},
334 {"number of sessions does not divide evenly", endSessionsBatchSize + 1, []int{endSessionsBatchSize, 1}},
335 }
336 for _, tc := range testCases {
337 if testing.Short() {
338 t.Skip("skipping integration test in short mode")
339 }
340 if os.Getenv("DOCKER_RUNNING") != "" {
341 t.Skip("skipping test in docker environment")
342 }
343
344 t.Run(tc.name, func(t *testing.T) {
345
346 var started []*event.CommandStartedEvent
347 var failureReasons []string
348 cmdMonitor := &event.CommandMonitor{
349 Started: func(_ context.Context, evt *event.CommandStartedEvent) {
350 if evt.CommandName == "endSessions" {
351 started = append(started, evt)
352 }
353 },
354 Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
355 if evt.CommandName == "endSessions" {
356 failureReasons = append(failureReasons, evt.Failure)
357 }
358 },
359 }
360 clientOpts := options.Client().ApplyURI(cs.Original).SetReadPreference(readpref.Primary()).
361 SetWriteConcern(writeconcern.New(writeconcern.WMajority())).SetMonitor(cmdMonitor)
362 integtest.AddTestServerAPIVersion(clientOpts)
363 client, err := Connect(bgCtx, clientOpts)
364 assert.Nil(t, err, "Connect error: %v", err)
365 defer func() {
366 _ = client.Disconnect(bgCtx)
367 }()
368
369 serverVersion, err := getServerVersion(client.Database("admin"))
370 assert.Nil(t, err, "getServerVersion error: %v", err)
371 if compareVersions(serverVersion, "3.6.0") < 1 {
372 t.Skip("skipping server version < 3.6")
373 }
374
375 coll := client.Database("foo").Collection("bar")
376 defer func() {
377 _ = coll.Drop(bgCtx)
378 }()
379
380
381 _, err = coll.CountDocuments(bgCtx, bson.D{})
382 assert.Nil(t, err, "CountDocuments error: %v", err)
383 var sessions []Session
384 for i := 0; i < tc.numSessions; i++ {
385 sess, err := client.StartSession()
386 assert.Nil(t, err, "StartSession error at index %d: %v", i, err)
387 sessions = append(sessions, sess)
388 }
389 for _, sess := range sessions {
390 sess.EndSession(bgCtx)
391 }
392
393 client.endSessions(bgCtx)
394 divisionResult := float64(tc.numSessions) / float64(endSessionsBatchSize)
395 numEventsExpected := int(math.Ceil(divisionResult))
396 assert.Equal(t, len(started), numEventsExpected, "expected %d started events, got %d", numEventsExpected,
397 len(started))
398 assert.Equal(t, len(failureReasons), 0, "endSessions errors: %v", failureReasons)
399
400 for i := 0; i < numEventsExpected; i++ {
401 sentArray := started[i].Command.Lookup("endSessions").Array()
402 values, _ := sentArray.Values()
403 expectedNumValues := tc.eventBatchSizes[i]
404 assert.Equal(t, len(values), expectedNumValues,
405 "batch size mismatch at index %d; expected %d sessions in batch, got %d", i, expectedNumValues,
406 len(values))
407 }
408 })
409 }
410 })
411 t.Run("serverAPI version", func(t *testing.T) {
412 getServerAPIOptions := func() *options.ServerAPIOptions {
413 return options.ServerAPI(options.ServerAPIVersion1).
414 SetStrict(false).SetDeprecationErrors(false)
415 }
416
417 t.Run("success with all options", func(t *testing.T) {
418 serverAPIOptions := getServerAPIOptions()
419 client, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions))
420 assert.Nil(t, err, "unexpected error from NewClient: %v", err)
421 convertedAPIOptions := topology.ConvertToDriverAPIOptions(serverAPIOptions)
422 assert.Equal(t, convertedAPIOptions, client.serverAPI,
423 "mismatch in serverAPI; expected %v, got %v", convertedAPIOptions, client.serverAPI)
424 })
425 t.Run("failure with unsupported version", func(t *testing.T) {
426 serverAPIOptions := options.ServerAPI("badVersion")
427 _, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions))
428 assert.NotNil(t, err, "expected error from NewClient, got nil")
429 errmsg := `api version "badVersion" not supported; this driver version only supports API version "1"`
430 assert.Equal(t, errmsg, err.Error(), "expected error %v, got %v", errmsg, err.Error())
431 })
432 t.Run("cannot modify options after client creation", func(t *testing.T) {
433 serverAPIOptions := getServerAPIOptions()
434 client, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions))
435 assert.Nil(t, err, "unexpected error from NewClient: %v", err)
436
437 expectedServerAPIOptions := getServerAPIOptions()
438
439 serverAPIOptions.SetStrict(true).SetDeprecationErrors(true)
440 convertedAPIOptions := topology.ConvertToDriverAPIOptions(expectedServerAPIOptions)
441 assert.Equal(t, convertedAPIOptions, client.serverAPI,
442 "unexpected modification to serverAPI; expected %v, got %v", convertedAPIOptions, client.serverAPI)
443 })
444 })
445 t.Run("mongocryptd or crypt_shared", func(t *testing.T) {
446 cryptSharedLibPath := os.Getenv("CRYPT_SHARED_LIB_PATH")
447 if cryptSharedLibPath == "" {
448 t.Skip("CRYPT_SHARED_LIB_PATH not set, skipping")
449 }
450 if len(mongocrypt.Version()) == 0 {
451 t.Skip("Not built with cse flag")
452 }
453
454 testCases := []struct {
455 description string
456 useCryptSharedLib bool
457 }{
458 {
459 description: "when crypt_shared is loaded, should not attempt to spawn mongocryptd",
460 useCryptSharedLib: true,
461 },
462 {
463 description: "when crypt_shared is not loaded, should attempt to spawn mongocryptd",
464 useCryptSharedLib: false,
465 },
466 }
467 for _, tc := range testCases {
468 t.Run(tc.description, func(t *testing.T) {
469 extraOptions := map[string]interface{}{
470
471
472 "mongocryptdPath": "/does/not/exist",
473 }
474
475
476
477
478
479 if tc.useCryptSharedLib {
480 extraOptions["cryptSharedLibRequired"] = true
481 extraOptions["cryptSharedLibPath"] = cryptSharedLibPath
482 } else {
483 extraOptions["__cryptSharedLibDisabledForTestOnly"] = true
484 }
485
486 _, err := NewClient(options.Client().
487 SetAutoEncryptionOptions(options.AutoEncryption().
488 SetKmsProviders(map[string]map[string]interface{}{
489 "local": {"key": make([]byte, 96)},
490 }).
491 SetExtraOptions(extraOptions)))
492
493
494
495
496 if tc.useCryptSharedLib {
497 assert.Nil(t, err, "Connect() error: %v", err)
498 } else {
499 assert.NotNil(t, err, "expected Connect() error, but got nil")
500 }
501 })
502 }
503 })
504 }
505
View as plain text