1
2
3
4
5
6
7 package options
8
9 import (
10 "bytes"
11 "context"
12 "crypto/tls"
13 "crypto/x509"
14 "encoding/pem"
15 "errors"
16 "fmt"
17 "io/ioutil"
18 "net"
19 "net/http"
20 "os"
21 "reflect"
22 "testing"
23 "time"
24
25 "github.com/google/go-cmp/cmp"
26 "github.com/google/go-cmp/cmp/cmpopts"
27 "go.mongodb.org/mongo-driver/bson"
28 "go.mongodb.org/mongo-driver/bson/bsoncodec"
29 "go.mongodb.org/mongo-driver/event"
30 "go.mongodb.org/mongo-driver/internal/assert"
31 "go.mongodb.org/mongo-driver/internal/httputil"
32 "go.mongodb.org/mongo-driver/mongo/readconcern"
33 "go.mongodb.org/mongo-driver/mongo/readpref"
34 "go.mongodb.org/mongo-driver/mongo/writeconcern"
35 "go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
36 )
37
38 var tClientOptions = reflect.TypeOf(&ClientOptions{})
39
40 func TestClientOptions(t *testing.T) {
41 t.Run("ApplyURI/doesn't overwrite previous errors", func(t *testing.T) {
42 uri := "not-mongo-db-uri://"
43 want := fmt.Errorf(
44 "error parsing uri: %w",
45 errors.New(`scheme must be "mongodb" or "mongodb+srv"`))
46 co := Client().ApplyURI(uri).ApplyURI("mongodb://localhost/")
47 got := co.Validate()
48 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
49 t.Errorf("Did not received expected error. got %v; want %v", got, want)
50 }
51 })
52 t.Run("Validate/returns error", func(t *testing.T) {
53 want := errors.New("validate error")
54 co := &ClientOptions{err: want}
55 got := co.Validate()
56 if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
57 t.Errorf("Did not receive expected error. got %v; want %v", got, want)
58 }
59 })
60 t.Run("Set", func(t *testing.T) {
61 testCases := []struct {
62 name string
63 fn interface{}
64 arg interface{}
65 field string
66 dereference bool
67 }{
68 {"AppName", (*ClientOptions).SetAppName, "example-application", "AppName", true},
69 {"Auth", (*ClientOptions).SetAuth, Credential{Username: "foo", Password: "bar"}, "Auth", true},
70 {"Compressors", (*ClientOptions).SetCompressors, []string{"zstd", "snappy", "zlib"}, "Compressors", true},
71 {"ConnectTimeout", (*ClientOptions).SetConnectTimeout, 5 * time.Second, "ConnectTimeout", true},
72 {"Dialer", (*ClientOptions).SetDialer, testDialer{Num: 12345}, "Dialer", true},
73 {"HeartbeatInterval", (*ClientOptions).SetHeartbeatInterval, 5 * time.Second, "HeartbeatInterval", true},
74 {"Hosts", (*ClientOptions).SetHosts, []string{"localhost:27017", "localhost:27018", "localhost:27019"}, "Hosts", true},
75 {"LocalThreshold", (*ClientOptions).SetLocalThreshold, 5 * time.Second, "LocalThreshold", true},
76 {"MaxConnIdleTime", (*ClientOptions).SetMaxConnIdleTime, 5 * time.Second, "MaxConnIdleTime", true},
77 {"MaxPoolSize", (*ClientOptions).SetMaxPoolSize, uint64(250), "MaxPoolSize", true},
78 {"MinPoolSize", (*ClientOptions).SetMinPoolSize, uint64(10), "MinPoolSize", true},
79 {"MaxConnecting", (*ClientOptions).SetMaxConnecting, uint64(10), "MaxConnecting", true},
80 {"PoolMonitor", (*ClientOptions).SetPoolMonitor, &event.PoolMonitor{}, "PoolMonitor", false},
81 {"Monitor", (*ClientOptions).SetMonitor, &event.CommandMonitor{}, "Monitor", false},
82 {"ReadConcern", (*ClientOptions).SetReadConcern, readconcern.Majority(), "ReadConcern", false},
83 {"ReadPreference", (*ClientOptions).SetReadPreference, readpref.SecondaryPreferred(), "ReadPreference", false},
84 {"Registry", (*ClientOptions).SetRegistry, bson.NewRegistryBuilder().Build(), "Registry", false},
85 {"ReplicaSet", (*ClientOptions).SetReplicaSet, "example-replicaset", "ReplicaSet", true},
86 {"RetryWrites", (*ClientOptions).SetRetryWrites, true, "RetryWrites", true},
87 {"ServerSelectionTimeout", (*ClientOptions).SetServerSelectionTimeout, 5 * time.Second, "ServerSelectionTimeout", true},
88 {"Direct", (*ClientOptions).SetDirect, true, "Direct", true},
89 {"SocketTimeout", (*ClientOptions).SetSocketTimeout, 5 * time.Second, "SocketTimeout", true},
90 {"TLSConfig", (*ClientOptions).SetTLSConfig, &tls.Config{}, "TLSConfig", false},
91 {"WriteConcern", (*ClientOptions).SetWriteConcern, writeconcern.New(writeconcern.WMajority()), "WriteConcern", false},
92 {"ZlibLevel", (*ClientOptions).SetZlibLevel, 6, "ZlibLevel", true},
93 {"DisableOCSPEndpointCheck", (*ClientOptions).SetDisableOCSPEndpointCheck, true, "DisableOCSPEndpointCheck", true},
94 {"LoadBalanced", (*ClientOptions).SetLoadBalanced, true, "LoadBalanced", true},
95 }
96
97 opt1, opt2, optResult := Client(), Client(), Client()
98 for idx, tc := range testCases {
99 t.Run(tc.name, func(t *testing.T) {
100 fn := reflect.ValueOf(tc.fn)
101 if fn.Kind() != reflect.Func {
102 t.Fatal("fn argument must be a function")
103 }
104 if fn.Type().NumIn() < 2 || fn.Type().In(0) != tClientOptions {
105 t.Fatal("fn argument must have a *ClientOptions as the first argument and one other argument")
106 }
107 if _, exists := tClientOptions.Elem().FieldByName(tc.field); !exists {
108 t.Fatalf("field (%s) does not exist in ClientOptions", tc.field)
109 }
110 args := make([]reflect.Value, 2)
111 client := reflect.New(tClientOptions.Elem())
112 args[0] = client
113 want := reflect.ValueOf(tc.arg)
114 args[1] = want
115
116 if !want.IsValid() || !want.CanInterface() {
117 t.Fatal("arg property of test case must be valid")
118 }
119
120 _ = fn.Call(args)
121
122
123
124
125
126
127
128 if idx%2 != 0 {
129 args[0] = reflect.ValueOf(opt1)
130 _ = fn.Call(args)
131 }
132 if idx%2 == 0 || idx%3 == 0 {
133 args[0] = reflect.ValueOf(opt2)
134 _ = fn.Call(args)
135 }
136 args[0] = reflect.ValueOf(optResult)
137 _ = fn.Call(args)
138
139 got := client.Elem().FieldByName(tc.field)
140 if !got.IsValid() || !got.CanInterface() {
141 t.Fatal("cannot create concrete instance from retrieved field")
142 }
143
144 if got.Kind() == reflect.Ptr && tc.dereference {
145 got = got.Elem()
146 }
147
148 if !cmp.Equal(
149 got.Interface(), want.Interface(),
150 cmp.AllowUnexported(readconcern.ReadConcern{}, writeconcern.WriteConcern{}, readpref.ReadPref{}),
151 cmp.Comparer(func(r1, r2 *bsoncodec.Registry) bool { return r1 == r2 }),
152 cmp.Comparer(func(cfg1, cfg2 *tls.Config) bool { return cfg1 == cfg2 }),
153 cmp.Comparer(func(fp1, fp2 *event.PoolMonitor) bool { return fp1 == fp2 }),
154 ) {
155 t.Errorf("Field not set properly. got %v; want %v", got.Interface(), want.Interface())
156 }
157 })
158 }
159 t.Run("MergeClientOptions/all set", func(t *testing.T) {
160 want := optResult
161 got := MergeClientOptions(nil, opt1, opt2)
162 if diff := cmp.Diff(
163 got, want,
164 cmp.AllowUnexported(readconcern.ReadConcern{}, writeconcern.WriteConcern{}, readpref.ReadPref{}),
165 cmp.Comparer(func(r1, r2 *bsoncodec.Registry) bool { return r1 == r2 }),
166 cmp.Comparer(func(cfg1, cfg2 *tls.Config) bool { return cfg1 == cfg2 }),
167 cmp.Comparer(func(fp1, fp2 *event.PoolMonitor) bool { return fp1 == fp2 }),
168 cmp.AllowUnexported(ClientOptions{}),
169 cmpopts.IgnoreFields(http.Client{}, "Transport"),
170 ); diff != "" {
171 t.Errorf("diff:\n%s", diff)
172 t.Errorf("Merged client options do not match. got %v; want %v", got, want)
173 }
174 })
175
176
177
178 t.Run("MergeClientOptions/err", func(t *testing.T) {
179 opt1, opt2 := Client(), Client()
180 opt1.err = errors.New("Test error")
181
182 got := MergeClientOptions(nil, opt1, opt2)
183 if got.err.Error() != "Test error" {
184 t.Errorf("Merged client options do not match. got %v; want %v", got.err.Error(), opt1.err.Error())
185 }
186 })
187 })
188 t.Run("ApplyURI", func(t *testing.T) {
189 baseClient := func() *ClientOptions {
190 return Client().SetHosts([]string{"localhost"})
191 }
192 testCases := []struct {
193 name string
194 uri string
195 result *ClientOptions
196 }{
197 {
198 "ParseError",
199 "not-mongo-db-uri://",
200 &ClientOptions{
201 err: fmt.Errorf(
202 "error parsing uri: %w",
203 errors.New(`scheme must be "mongodb" or "mongodb+srv"`)),
204 HTTPClient: httputil.DefaultHTTPClient,
205 },
206 },
207 {
208 "ReadPreference Invalid Mode",
209 "mongodb://localhost/?maxStaleness=200",
210 &ClientOptions{
211 err: fmt.Errorf("unknown read preference %v", ""),
212 Hosts: []string{"localhost"},
213 HTTPClient: httputil.DefaultHTTPClient,
214 },
215 },
216 {
217 "ReadPreference Primary With Options",
218 "mongodb://localhost/?readPreference=Primary&maxStaleness=200",
219 &ClientOptions{
220 err: errors.New("can not specify tags, max staleness, or hedge with mode primary"),
221 Hosts: []string{"localhost"},
222 HTTPClient: httputil.DefaultHTTPClient,
223 },
224 },
225 {
226 "TLS addCertFromFile error",
227 "mongodb://localhost/?ssl=true&sslCertificateAuthorityFile=testdata/doesntexist",
228 &ClientOptions{
229 err: &os.PathError{Op: "open", Path: "testdata/doesntexist"},
230 Hosts: []string{"localhost"},
231 HTTPClient: httputil.DefaultHTTPClient,
232 },
233 },
234 {
235 "TLS ClientCertificateKey",
236 "mongodb://localhost/?ssl=true&sslClientCertificateKeyFile=testdata/doesntexist",
237 &ClientOptions{
238 err: &os.PathError{Op: "open", Path: "testdata/doesntexist"},
239 Hosts: []string{"localhost"},
240 HTTPClient: httputil.DefaultHTTPClient,
241 },
242 },
243 {
244 "AppName",
245 "mongodb://localhost/?appName=awesome-example-application",
246 baseClient().SetAppName("awesome-example-application"),
247 },
248 {
249 "AuthMechanism",
250 "mongodb://localhost/?authMechanism=mongodb-x509",
251 baseClient().SetAuth(Credential{AuthSource: "$external", AuthMechanism: "mongodb-x509"}),
252 },
253 {
254 "AuthMechanismProperties",
255 "mongodb://foo@localhost/?authMechanism=gssapi&authMechanismProperties=SERVICE_NAME:mongodb-fake",
256 baseClient().SetAuth(Credential{
257 AuthSource: "$external",
258 AuthMechanism: "gssapi",
259 AuthMechanismProperties: map[string]string{"SERVICE_NAME": "mongodb-fake"},
260 Username: "foo",
261 }),
262 },
263 {
264 "AuthSource",
265 "mongodb://foo@localhost/?authSource=random-database-example",
266 baseClient().SetAuth(Credential{AuthSource: "random-database-example", Username: "foo"}),
267 },
268 {
269 "Username",
270 "mongodb://foo@localhost/",
271 baseClient().SetAuth(Credential{AuthSource: "admin", Username: "foo"}),
272 },
273 {
274 "Unescaped slash in username",
275 "mongodb:///:pwd@localhost",
276 &ClientOptions{
277 err: fmt.Errorf(
278 "error parsing uri: %w",
279 errors.New("unescaped slash in username")),
280 HTTPClient: httputil.DefaultHTTPClient,
281 },
282 },
283 {
284 "Password",
285 "mongodb://foo:bar@localhost/",
286 baseClient().SetAuth(Credential{
287 AuthSource: "admin", Username: "foo",
288 Password: "bar", PasswordSet: true,
289 }),
290 },
291 {
292 "Single character username and password",
293 "mongodb://f:b@localhost/",
294 baseClient().SetAuth(Credential{
295 AuthSource: "admin", Username: "f",
296 Password: "b", PasswordSet: true,
297 }),
298 },
299 {
300 "Connect",
301 "mongodb://localhost/?connect=direct",
302 baseClient().SetDirect(true),
303 },
304 {
305 "ConnectTimeout",
306 "mongodb://localhost/?connectTimeoutms=5000",
307 baseClient().SetConnectTimeout(5 * time.Second),
308 },
309 {
310 "Compressors",
311 "mongodb://localhost/?compressors=zlib,snappy",
312 baseClient().SetCompressors([]string{"zlib", "snappy"}).SetZlibLevel(6),
313 },
314 {
315 "DatabaseNoAuth",
316 "mongodb://localhost/example-database",
317 baseClient(),
318 },
319 {
320 "DatabaseAsDefault",
321 "mongodb://foo@localhost/example-database",
322 baseClient().SetAuth(Credential{AuthSource: "example-database", Username: "foo"}),
323 },
324 {
325 "HeartbeatInterval",
326 "mongodb://localhost/?heartbeatIntervalms=12000",
327 baseClient().SetHeartbeatInterval(12 * time.Second),
328 },
329 {
330 "Hosts",
331 "mongodb://localhost:27017,localhost:27018,localhost:27019/",
332 baseClient().SetHosts([]string{"localhost:27017", "localhost:27018", "localhost:27019"}),
333 },
334 {
335 "LocalThreshold",
336 "mongodb://localhost/?localThresholdMS=200",
337 baseClient().SetLocalThreshold(200 * time.Millisecond),
338 },
339 {
340 "MaxConnIdleTime",
341 "mongodb://localhost/?maxIdleTimeMS=300000",
342 baseClient().SetMaxConnIdleTime(5 * time.Minute),
343 },
344 {
345 "MaxPoolSize",
346 "mongodb://localhost/?maxPoolSize=256",
347 baseClient().SetMaxPoolSize(256),
348 },
349 {
350 "MinPoolSize",
351 "mongodb://localhost/?minPoolSize=256",
352 baseClient().SetMinPoolSize(256),
353 },
354 {
355 "MaxConnecting",
356 "mongodb://localhost/?maxConnecting=10",
357 baseClient().SetMaxConnecting(10),
358 },
359 {
360 "ReadConcern",
361 "mongodb://localhost/?readConcernLevel=linearizable",
362 baseClient().SetReadConcern(readconcern.Linearizable()),
363 },
364 {
365 "ReadPreference",
366 "mongodb://localhost/?readPreference=secondaryPreferred",
367 baseClient().SetReadPreference(readpref.SecondaryPreferred()),
368 },
369 {
370 "ReadPreferenceTagSets",
371 "mongodb://localhost/?readPreference=secondaryPreferred&readPreferenceTags=foo:bar",
372 baseClient().SetReadPreference(readpref.SecondaryPreferred(readpref.WithTags("foo", "bar"))),
373 },
374 {
375 "MaxStaleness",
376 "mongodb://localhost/?readPreference=secondaryPreferred&maxStaleness=250",
377 baseClient().SetReadPreference(readpref.SecondaryPreferred(readpref.WithMaxStaleness(250 * time.Second))),
378 },
379 {
380 "RetryWrites",
381 "mongodb://localhost/?retryWrites=true",
382 baseClient().SetRetryWrites(true),
383 },
384 {
385 "ReplicaSet",
386 "mongodb://localhost/?replicaSet=rs01",
387 baseClient().SetReplicaSet("rs01"),
388 },
389 {
390 "ServerSelectionTimeout",
391 "mongodb://localhost/?serverSelectionTimeoutMS=45000",
392 baseClient().SetServerSelectionTimeout(45 * time.Second),
393 },
394 {
395 "SocketTimeout",
396 "mongodb://localhost/?socketTimeoutMS=15000",
397 baseClient().SetSocketTimeout(15 * time.Second),
398 },
399 {
400 "TLS CACertificate",
401 "mongodb://localhost/?ssl=true&sslCertificateAuthorityFile=testdata/ca.pem",
402 baseClient().SetTLSConfig(&tls.Config{
403 RootCAs: createCertPool(t, "testdata/ca.pem"),
404 }),
405 },
406 {
407 "TLS Insecure",
408 "mongodb://localhost/?ssl=true&sslInsecure=true",
409 baseClient().SetTLSConfig(&tls.Config{InsecureSkipVerify: true}),
410 },
411 {
412 "TLS ClientCertificateKey",
413 "mongodb://localhost/?ssl=true&sslClientCertificateKeyFile=testdata/nopass/certificate.pem",
414 baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
415 },
416 {
417 "TLS ClientCertificateKey with password",
418 "mongodb://localhost/?ssl=true&sslClientCertificateKeyFile=testdata/certificate.pem&sslClientCertificateKeyPassword=passphrase",
419 baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
420 },
421 {
422 "TLS Username",
423 "mongodb://localhost/?ssl=true&authMechanism=mongodb-x509&sslClientCertificateKeyFile=testdata/nopass/certificate.pem",
424 baseClient().SetAuth(Credential{
425 AuthMechanism: "mongodb-x509", AuthSource: "$external",
426 Username: `C=US,ST=New York,L=New York City, Inc,O=MongoDB\,OU=WWW`,
427 }),
428 },
429 {
430 "WriteConcern J",
431 "mongodb://localhost/?journal=true",
432 baseClient().SetWriteConcern(writeconcern.New(writeconcern.J(true))),
433 },
434 {
435 "WriteConcern WString",
436 "mongodb://localhost/?w=majority",
437 baseClient().SetWriteConcern(writeconcern.New(writeconcern.WMajority())),
438 },
439 {
440 "WriteConcern W",
441 "mongodb://localhost/?w=3",
442 baseClient().SetWriteConcern(writeconcern.New(writeconcern.W(3))),
443 },
444 {
445 "WriteConcern WTimeout",
446 "mongodb://localhost/?wTimeoutMS=45000",
447 baseClient().SetWriteConcern(writeconcern.New(writeconcern.WTimeout(45 * time.Second))),
448 },
449 {
450 "ZLibLevel",
451 "mongodb://localhost/?zlibCompressionLevel=4",
452 baseClient().SetZlibLevel(4),
453 },
454 {
455 "TLS tlsCertificateFile and tlsPrivateKeyFile",
456 "mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem&tlsPrivateKeyFile=testdata/nopass/key.pem",
457 baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
458 },
459 {
460 "TLS only tlsCertificateFile",
461 "mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem",
462 &ClientOptions{
463 err: fmt.Errorf(
464 "error validating uri: %w",
465 errors.New("the tlsPrivateKeyFile URI option must be provided if the tlsCertificateFile option is specified")),
466 HTTPClient: httputil.DefaultHTTPClient,
467 },
468 },
469 {
470 "TLS only tlsPrivateKeyFile",
471 "mongodb://localhost/?tlsPrivateKeyFile=testdata/nopass/key.pem",
472 &ClientOptions{
473 err: fmt.Errorf(
474 "error validating uri: %w",
475 errors.New("the tlsCertificateFile URI option must be provided if the tlsPrivateKeyFile option is specified")),
476 HTTPClient: httputil.DefaultHTTPClient,
477 },
478 },
479 {
480 "TLS tlsCertificateFile and tlsPrivateKeyFile and tlsCertificateKeyFile",
481 "mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem&tlsPrivateKeyFile=testdata/nopass/key.pem&tlsCertificateKeyFile=testdata/nopass/certificate.pem",
482 &ClientOptions{
483 err: fmt.Errorf(
484 "error validating uri: %w",
485 errors.New("the sslClientCertificateKeyFile/tlsCertificateKeyFile URI option cannot be provided "+
486 "along with tlsCertificateFile or tlsPrivateKeyFile")),
487 HTTPClient: httputil.DefaultHTTPClient,
488 },
489 },
490 {
491 "disable OCSP endpoint check",
492 "mongodb://localhost/?tlsDisableOCSPEndpointCheck=true",
493 baseClient().SetDisableOCSPEndpointCheck(true),
494 },
495 {
496 "directConnection",
497 "mongodb://localhost/?directConnection=true",
498 baseClient().SetDirect(true),
499 },
500 {
501 "TLS CA file with multiple certificiates",
502 "mongodb://localhost/?tlsCAFile=testdata/ca-with-intermediates.pem",
503 baseClient().SetTLSConfig(&tls.Config{
504 RootCAs: createCertPool(t, "testdata/ca-with-intermediates-first.pem",
505 "testdata/ca-with-intermediates-second.pem", "testdata/ca-with-intermediates-third.pem"),
506 }),
507 },
508 {
509 "TLS empty CA file",
510 "mongodb://localhost/?tlsCAFile=testdata/empty-ca.pem",
511 &ClientOptions{
512 Hosts: []string{"localhost"},
513 HTTPClient: httputil.DefaultHTTPClient,
514 err: errors.New("the specified CA file does not contain any valid certificates"),
515 },
516 },
517 {
518 "TLS CA file with no certificates",
519 "mongodb://localhost/?tlsCAFile=testdata/ca-key.pem",
520 &ClientOptions{
521 Hosts: []string{"localhost"},
522 HTTPClient: httputil.DefaultHTTPClient,
523 err: errors.New("the specified CA file does not contain any valid certificates"),
524 },
525 },
526 {
527 "TLS malformed CA file",
528 "mongodb://localhost/?tlsCAFile=testdata/malformed-ca.pem",
529 &ClientOptions{
530 Hosts: []string{"localhost"},
531 HTTPClient: httputil.DefaultHTTPClient,
532 err: errors.New("the specified CA file does not contain any valid certificates"),
533 },
534 },
535 {
536 "loadBalanced=true",
537 "mongodb://localhost/?loadBalanced=true",
538 baseClient().SetLoadBalanced(true),
539 },
540 {
541 "loadBalanced=false",
542 "mongodb://localhost/?loadBalanced=false",
543 baseClient().SetLoadBalanced(false),
544 },
545 {
546 "srvServiceName",
547 "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname",
548 baseClient().SetSRVServiceName("customname").
549 SetHosts([]string{"localhost.test.build.10gen.cc:27017", "localhost.test.build.10gen.cc:27018"}),
550 },
551 {
552 "srvMaxHosts",
553 "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=2",
554 baseClient().SetSRVMaxHosts(2).
555 SetHosts([]string{"localhost.test.build.10gen.cc:27017", "localhost.test.build.10gen.cc:27018"}),
556 },
557 {
558 "GODRIVER-2263 regression test",
559 "mongodb://localhost/?tlsCertificateKeyFile=testdata/one-pk-multiple-certs.pem",
560 baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
561 },
562 {
563 "GODRIVER-2650 X509 certificate",
564 "mongodb://localhost/?ssl=true&authMechanism=mongodb-x509&sslClientCertificateKeyFile=testdata/one-pk-multiple-certs.pem",
565 baseClient().SetAuth(Credential{
566 AuthMechanism: "mongodb-x509", AuthSource: "$external",
567
568 Username: `C=US,ST=New York,L=New York City,O=MongoDB,OU=Drivers,CN=localhost`,
569 }).SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
570 },
571 }
572
573 for _, tc := range testCases {
574 t.Run(tc.name, func(t *testing.T) {
575 result := Client().ApplyURI(tc.uri)
576
577
578
579 cs, err := connstring.ParseAndValidate(tc.uri)
580 if err == nil {
581 tc.result.cs = cs
582 }
583
584
585 stringLess := func(a, b string) bool { return a < b }
586 if diff := cmp.Diff(
587 tc.result, result,
588 cmp.AllowUnexported(ClientOptions{}, readconcern.ReadConcern{}, writeconcern.WriteConcern{}, readpref.ReadPref{}),
589 cmp.Comparer(func(r1, r2 *bsoncodec.Registry) bool { return r1 == r2 }),
590 cmp.Comparer(compareTLSConfig),
591 cmp.Comparer(compareErrors),
592 cmpopts.SortSlices(stringLess),
593 cmpopts.IgnoreFields(connstring.ConnString{}, "SSLClientCertificateKeyPassword"),
594 cmpopts.IgnoreFields(http.Client{}, "Transport"),
595 ); diff != "" {
596 t.Errorf("URI did not apply correctly: (-want +got)\n%s", diff)
597 }
598 })
599 }
600 })
601 t.Run("direct connection validation", func(t *testing.T) {
602 t.Run("multiple hosts", func(t *testing.T) {
603 expectedErr := errors.New("a direct connection cannot be made if multiple hosts are specified")
604
605 testCases := []struct {
606 name string
607 opts *ClientOptions
608 }{
609 {"hosts in URI", Client().ApplyURI("mongodb://localhost,localhost2")},
610 {"hosts in options", Client().SetHosts([]string{"localhost", "localhost2"})},
611 }
612 for _, tc := range testCases {
613 t.Run(tc.name, func(t *testing.T) {
614 err := tc.opts.SetDirect(true).Validate()
615 assert.NotNil(t, err, "expected error, got nil")
616 assert.Equal(t, expectedErr.Error(), err.Error(), "expected error %v, got %v", expectedErr, err)
617 })
618 }
619 })
620 t.Run("srv", func(t *testing.T) {
621 expectedErr := errors.New("a direct connection cannot be made if an SRV URI is used")
622
623 opts := Client().ApplyURI("mongodb://localhost:27017")
624 opts.cs.Scheme = connstring.SchemeMongoDBSRV
625
626 err := opts.SetDirect(true).Validate()
627 assert.NotNil(t, err, "expected error, got nil")
628 assert.Equal(t, expectedErr.Error(), err.Error(), "expected error %v, got %v", expectedErr, err)
629 })
630 })
631 t.Run("loadBalanced validation", func(t *testing.T) {
632 testCases := []struct {
633 name string
634 opts *ClientOptions
635 err error
636 }{
637 {"multiple hosts in URI", Client().ApplyURI("mongodb://foo,bar"), connstring.ErrLoadBalancedWithMultipleHosts},
638 {"multiple hosts in options", Client().SetHosts([]string{"foo", "bar"}), connstring.ErrLoadBalancedWithMultipleHosts},
639 {"replica set name", Client().SetReplicaSet("foo"), connstring.ErrLoadBalancedWithReplicaSet},
640 {"directConnection=true", Client().SetDirect(true), connstring.ErrLoadBalancedWithDirectConnection},
641 }
642 for _, tc := range testCases {
643 t.Run(tc.name, func(t *testing.T) {
644
645 err := tc.opts.Validate()
646 assert.Nil(t, err, "Validate error when loadBalanced is unset: %v", err)
647
648 tc.opts.SetLoadBalanced(false)
649 err = tc.opts.Validate()
650 assert.Nil(t, err, "Validate error when loadBalanced=false: %v", err)
651
652 tc.opts.SetLoadBalanced(true)
653 err = tc.opts.Validate()
654 assert.Equal(t, tc.err, err, "expected error %v when loadBalanced=true, got %v", tc.err, err)
655 })
656 }
657 })
658 t.Run("minPoolSize validation", func(t *testing.T) {
659 testCases := []struct {
660 name string
661 opts *ClientOptions
662 err error
663 }{
664 {
665 "minPoolSize < maxPoolSize",
666 Client().SetMinPoolSize(128).SetMaxPoolSize(256),
667 nil,
668 },
669 {
670 "minPoolSize == maxPoolSize",
671 Client().SetMinPoolSize(128).SetMaxPoolSize(128),
672 nil,
673 },
674 {
675 "minPoolSize > maxPoolSize",
676 Client().SetMinPoolSize(64).SetMaxPoolSize(32),
677 errors.New("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=64 maxPoolSize=32"),
678 },
679 {
680 "maxPoolSize == 0",
681 Client().SetMinPoolSize(128).SetMaxPoolSize(0),
682 nil,
683 },
684 }
685 for _, tc := range testCases {
686 t.Run(tc.name, func(t *testing.T) {
687 err := tc.opts.Validate()
688 assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
689 })
690 }
691 })
692 t.Run("srvMaxHosts validation", func(t *testing.T) {
693 testCases := []struct {
694 name string
695 opts *ClientOptions
696 err error
697 }{
698 {"replica set name", Client().SetReplicaSet("foo"), connstring.ErrSRVMaxHostsWithReplicaSet},
699 {"loadBalanced=true", Client().SetLoadBalanced(true), connstring.ErrSRVMaxHostsWithLoadBalanced},
700 {"loadBalanced=false", Client().SetLoadBalanced(false), nil},
701 }
702 for _, tc := range testCases {
703 t.Run(tc.name, func(t *testing.T) {
704 err := tc.opts.Validate()
705 assert.Nil(t, err, "Validate error when srvMxaHosts is unset: %v", err)
706
707 tc.opts.SetSRVMaxHosts(0)
708 err = tc.opts.Validate()
709 assert.Nil(t, err, "Validate error when srvMaxHosts is 0: %v", err)
710
711 tc.opts.SetSRVMaxHosts(2)
712 err = tc.opts.Validate()
713 assert.Equal(t, tc.err, err, "expected error %v when srvMaxHosts > 0, got %v", tc.err, err)
714 })
715 }
716 })
717 t.Run("srvMaxHosts validation", func(t *testing.T) {
718 t.Parallel()
719
720 testCases := []struct {
721 name string
722 opts *ClientOptions
723 err error
724 }{
725 {
726 name: "valid ServerAPI",
727 opts: Client().SetServerAPIOptions(ServerAPI(ServerAPIVersion1)),
728 err: nil,
729 },
730 {
731 name: "invalid ServerAPI",
732 opts: Client().SetServerAPIOptions(ServerAPI("nope")),
733 err: errors.New(`api version "nope" not supported; this driver version only supports API version "1"`),
734 },
735 {
736 name: "invalid ServerAPI with other invalid options",
737 opts: Client().SetServerAPIOptions(ServerAPI("nope")).SetSRVMaxHosts(1).SetReplicaSet("foo"),
738 err: errors.New(`api version "nope" not supported; this driver version only supports API version "1"`),
739 },
740 }
741 for _, tc := range testCases {
742 tc := tc
743
744 t.Run(tc.name, func(t *testing.T) {
745 t.Parallel()
746
747 err := tc.opts.Validate()
748 assert.Equal(t, tc.err, err, "want error %v, got error %v", tc.err, err)
749 })
750 }
751 })
752 t.Run("server monitoring mode validation", func(t *testing.T) {
753 t.Parallel()
754
755 testCases := []struct {
756 name string
757 opts *ClientOptions
758 err error
759 }{
760 {
761 name: "undefined",
762 opts: Client(),
763 err: nil,
764 },
765 {
766 name: "auto",
767 opts: Client().SetServerMonitoringMode(ServerMonitoringModeAuto),
768 err: nil,
769 },
770 {
771 name: "poll",
772 opts: Client().SetServerMonitoringMode(ServerMonitoringModePoll),
773 err: nil,
774 },
775 {
776 name: "stream",
777 opts: Client().SetServerMonitoringMode(ServerMonitoringModeStream),
778 err: nil,
779 },
780 {
781 name: "invalid",
782 opts: Client().SetServerMonitoringMode("invalid"),
783 err: errors.New("invalid server monitoring mode: \"invalid\""),
784 },
785 }
786
787 for _, tc := range testCases {
788 tc := tc
789
790 t.Run(tc.name, func(t *testing.T) {
791 t.Parallel()
792
793 err := tc.opts.Validate()
794 assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
795 })
796 }
797 })
798 }
799
800 func createCertPool(t *testing.T, paths ...string) *x509.CertPool {
801 t.Helper()
802
803 pool := x509.NewCertPool()
804 for _, path := range paths {
805 pool.AddCert(loadCert(t, path))
806 }
807 return pool
808 }
809
810 func loadCert(t *testing.T, file string) *x509.Certificate {
811 t.Helper()
812
813 data := readFile(t, file)
814 block, _ := pem.Decode(data)
815 cert, err := x509.ParseCertificate(block.Bytes)
816 assert.Nil(t, err, "ParseCertificate error for %s: %v", file, err)
817 return cert
818 }
819
820 func readFile(t *testing.T, path string) []byte {
821 data, err := ioutil.ReadFile(path)
822 assert.Nil(t, err, "ReadFile error for %s: %v", path, err)
823 return data
824 }
825
826 type testDialer struct {
827 Num int
828 }
829
830 func (testDialer) DialContext(context.Context, string, string) (net.Conn, error) {
831 return nil, nil
832 }
833
834 func compareTLSConfig(cfg1, cfg2 *tls.Config) bool {
835 if cfg1 == nil && cfg2 == nil {
836 return true
837 }
838
839 if cfg1 == nil || cfg2 == nil {
840 return true
841 }
842
843 if (cfg1.RootCAs == nil && cfg1.RootCAs != nil) || (cfg1.RootCAs != nil && cfg1.RootCAs == nil) {
844 return false
845 }
846
847 if cfg1.RootCAs != nil {
848 cfg1Subjects := cfg1.RootCAs.Subjects()
849 cfg2Subjects := cfg2.RootCAs.Subjects()
850 if len(cfg1Subjects) != len(cfg2Subjects) {
851 return false
852 }
853
854 for idx, firstSubject := range cfg1Subjects {
855 if !bytes.Equal(firstSubject, cfg2Subjects[idx]) {
856 return false
857 }
858 }
859 }
860
861 if len(cfg1.Certificates) != len(cfg2.Certificates) {
862 return false
863 }
864
865 if cfg1.InsecureSkipVerify != cfg2.InsecureSkipVerify {
866 return false
867 }
868
869 return true
870 }
871
872 func compareErrors(err1, err2 error) bool {
873 if err1 == nil && err2 == nil {
874 return true
875 }
876
877 if err1 == nil || err2 == nil {
878 return false
879 }
880
881 var ospe1, ospe2 *os.PathError
882 if errors.As(err1, &ospe1) && errors.As(err2, &ospe2) {
883 return ospe1.Op == ospe2.Op && ospe1.Path == ospe2.Path
884 }
885
886 if err1.Error() != err2.Error() {
887 return false
888 }
889
890 return true
891 }
892
View as plain text