...

Source file src/go.mongodb.org/mongo-driver/mongo/options/clientoptions_test.go

Documentation: go.mongodb.org/mongo-driver/mongo/options

     1  // Copyright (C) MongoDB, Inc. 2022-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package options
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"crypto/tls"
    13  	"crypto/x509"
    14  	"encoding/pem"
    15  	"errors"
    16  	"fmt"
    17  	"io/ioutil"
    18  	"net"
    19  	"net/http"
    20  	"os"
    21  	"reflect"
    22  	"testing"
    23  	"time"
    24  
    25  	"github.com/google/go-cmp/cmp"
    26  	"github.com/google/go-cmp/cmp/cmpopts"
    27  	"go.mongodb.org/mongo-driver/bson"
    28  	"go.mongodb.org/mongo-driver/bson/bsoncodec"
    29  	"go.mongodb.org/mongo-driver/event"
    30  	"go.mongodb.org/mongo-driver/internal/assert"
    31  	"go.mongodb.org/mongo-driver/internal/httputil"
    32  	"go.mongodb.org/mongo-driver/mongo/readconcern"
    33  	"go.mongodb.org/mongo-driver/mongo/readpref"
    34  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    35  	"go.mongodb.org/mongo-driver/x/mongo/driver/connstring"
    36  )
    37  
    38  var tClientOptions = reflect.TypeOf(&ClientOptions{})
    39  
    40  func TestClientOptions(t *testing.T) {
    41  	t.Run("ApplyURI/doesn't overwrite previous errors", func(t *testing.T) {
    42  		uri := "not-mongo-db-uri://"
    43  		want := fmt.Errorf(
    44  			"error parsing uri: %w",
    45  			errors.New(`scheme must be "mongodb" or "mongodb+srv"`))
    46  		co := Client().ApplyURI(uri).ApplyURI("mongodb://localhost/")
    47  		got := co.Validate()
    48  		if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
    49  			t.Errorf("Did not received expected error. got %v; want %v", got, want)
    50  		}
    51  	})
    52  	t.Run("Validate/returns error", func(t *testing.T) {
    53  		want := errors.New("validate error")
    54  		co := &ClientOptions{err: want}
    55  		got := co.Validate()
    56  		if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
    57  			t.Errorf("Did not receive expected error. got %v; want %v", got, want)
    58  		}
    59  	})
    60  	t.Run("Set", func(t *testing.T) {
    61  		testCases := []struct {
    62  			name        string
    63  			fn          interface{} // method to be run
    64  			arg         interface{} // argument for method
    65  			field       string      // field to be set
    66  			dereference bool        // Should we compare a pointer or the field
    67  		}{
    68  			{"AppName", (*ClientOptions).SetAppName, "example-application", "AppName", true},
    69  			{"Auth", (*ClientOptions).SetAuth, Credential{Username: "foo", Password: "bar"}, "Auth", true},
    70  			{"Compressors", (*ClientOptions).SetCompressors, []string{"zstd", "snappy", "zlib"}, "Compressors", true},
    71  			{"ConnectTimeout", (*ClientOptions).SetConnectTimeout, 5 * time.Second, "ConnectTimeout", true},
    72  			{"Dialer", (*ClientOptions).SetDialer, testDialer{Num: 12345}, "Dialer", true},
    73  			{"HeartbeatInterval", (*ClientOptions).SetHeartbeatInterval, 5 * time.Second, "HeartbeatInterval", true},
    74  			{"Hosts", (*ClientOptions).SetHosts, []string{"localhost:27017", "localhost:27018", "localhost:27019"}, "Hosts", true},
    75  			{"LocalThreshold", (*ClientOptions).SetLocalThreshold, 5 * time.Second, "LocalThreshold", true},
    76  			{"MaxConnIdleTime", (*ClientOptions).SetMaxConnIdleTime, 5 * time.Second, "MaxConnIdleTime", true},
    77  			{"MaxPoolSize", (*ClientOptions).SetMaxPoolSize, uint64(250), "MaxPoolSize", true},
    78  			{"MinPoolSize", (*ClientOptions).SetMinPoolSize, uint64(10), "MinPoolSize", true},
    79  			{"MaxConnecting", (*ClientOptions).SetMaxConnecting, uint64(10), "MaxConnecting", true},
    80  			{"PoolMonitor", (*ClientOptions).SetPoolMonitor, &event.PoolMonitor{}, "PoolMonitor", false},
    81  			{"Monitor", (*ClientOptions).SetMonitor, &event.CommandMonitor{}, "Monitor", false},
    82  			{"ReadConcern", (*ClientOptions).SetReadConcern, readconcern.Majority(), "ReadConcern", false},
    83  			{"ReadPreference", (*ClientOptions).SetReadPreference, readpref.SecondaryPreferred(), "ReadPreference", false},
    84  			{"Registry", (*ClientOptions).SetRegistry, bson.NewRegistryBuilder().Build(), "Registry", false},
    85  			{"ReplicaSet", (*ClientOptions).SetReplicaSet, "example-replicaset", "ReplicaSet", true},
    86  			{"RetryWrites", (*ClientOptions).SetRetryWrites, true, "RetryWrites", true},
    87  			{"ServerSelectionTimeout", (*ClientOptions).SetServerSelectionTimeout, 5 * time.Second, "ServerSelectionTimeout", true},
    88  			{"Direct", (*ClientOptions).SetDirect, true, "Direct", true},
    89  			{"SocketTimeout", (*ClientOptions).SetSocketTimeout, 5 * time.Second, "SocketTimeout", true},
    90  			{"TLSConfig", (*ClientOptions).SetTLSConfig, &tls.Config{}, "TLSConfig", false},
    91  			{"WriteConcern", (*ClientOptions).SetWriteConcern, writeconcern.New(writeconcern.WMajority()), "WriteConcern", false},
    92  			{"ZlibLevel", (*ClientOptions).SetZlibLevel, 6, "ZlibLevel", true},
    93  			{"DisableOCSPEndpointCheck", (*ClientOptions).SetDisableOCSPEndpointCheck, true, "DisableOCSPEndpointCheck", true},
    94  			{"LoadBalanced", (*ClientOptions).SetLoadBalanced, true, "LoadBalanced", true},
    95  		}
    96  
    97  		opt1, opt2, optResult := Client(), Client(), Client()
    98  		for idx, tc := range testCases {
    99  			t.Run(tc.name, func(t *testing.T) {
   100  				fn := reflect.ValueOf(tc.fn)
   101  				if fn.Kind() != reflect.Func {
   102  					t.Fatal("fn argument must be a function")
   103  				}
   104  				if fn.Type().NumIn() < 2 || fn.Type().In(0) != tClientOptions {
   105  					t.Fatal("fn argument must have a *ClientOptions as the first argument and one other argument")
   106  				}
   107  				if _, exists := tClientOptions.Elem().FieldByName(tc.field); !exists {
   108  					t.Fatalf("field (%s) does not exist in ClientOptions", tc.field)
   109  				}
   110  				args := make([]reflect.Value, 2)
   111  				client := reflect.New(tClientOptions.Elem())
   112  				args[0] = client
   113  				want := reflect.ValueOf(tc.arg)
   114  				args[1] = want
   115  
   116  				if !want.IsValid() || !want.CanInterface() {
   117  					t.Fatal("arg property of test case must be valid")
   118  				}
   119  
   120  				_ = fn.Call(args)
   121  
   122  				// To avoid duplication we're piggybacking on the Set* tests to make the
   123  				// MergeClientOptions test simpler and more thorough.
   124  				// To do this we set the odd numbered test cases to the first opt, the even and
   125  				// divisible by three test cases to the second, and the result of merging the two to
   126  				// the result option. This gives us coverage of options set by the first option, by
   127  				// the second, and by both.
   128  				if idx%2 != 0 {
   129  					args[0] = reflect.ValueOf(opt1)
   130  					_ = fn.Call(args)
   131  				}
   132  				if idx%2 == 0 || idx%3 == 0 {
   133  					args[0] = reflect.ValueOf(opt2)
   134  					_ = fn.Call(args)
   135  				}
   136  				args[0] = reflect.ValueOf(optResult)
   137  				_ = fn.Call(args)
   138  
   139  				got := client.Elem().FieldByName(tc.field)
   140  				if !got.IsValid() || !got.CanInterface() {
   141  					t.Fatal("cannot create concrete instance from retrieved field")
   142  				}
   143  
   144  				if got.Kind() == reflect.Ptr && tc.dereference {
   145  					got = got.Elem()
   146  				}
   147  
   148  				if !cmp.Equal(
   149  					got.Interface(), want.Interface(),
   150  					cmp.AllowUnexported(readconcern.ReadConcern{}, writeconcern.WriteConcern{}, readpref.ReadPref{}),
   151  					cmp.Comparer(func(r1, r2 *bsoncodec.Registry) bool { return r1 == r2 }),
   152  					cmp.Comparer(func(cfg1, cfg2 *tls.Config) bool { return cfg1 == cfg2 }),
   153  					cmp.Comparer(func(fp1, fp2 *event.PoolMonitor) bool { return fp1 == fp2 }),
   154  				) {
   155  					t.Errorf("Field not set properly. got %v; want %v", got.Interface(), want.Interface())
   156  				}
   157  			})
   158  		}
   159  		t.Run("MergeClientOptions/all set", func(t *testing.T) {
   160  			want := optResult
   161  			got := MergeClientOptions(nil, opt1, opt2)
   162  			if diff := cmp.Diff(
   163  				got, want,
   164  				cmp.AllowUnexported(readconcern.ReadConcern{}, writeconcern.WriteConcern{}, readpref.ReadPref{}),
   165  				cmp.Comparer(func(r1, r2 *bsoncodec.Registry) bool { return r1 == r2 }),
   166  				cmp.Comparer(func(cfg1, cfg2 *tls.Config) bool { return cfg1 == cfg2 }),
   167  				cmp.Comparer(func(fp1, fp2 *event.PoolMonitor) bool { return fp1 == fp2 }),
   168  				cmp.AllowUnexported(ClientOptions{}),
   169  				cmpopts.IgnoreFields(http.Client{}, "Transport"),
   170  			); diff != "" {
   171  				t.Errorf("diff:\n%s", diff)
   172  				t.Errorf("Merged client options do not match. got %v; want %v", got, want)
   173  			}
   174  		})
   175  
   176  		// go-cmp dont support error comparisons (https://github.com/google/go-cmp/issues/24)
   177  		// Use specifique test for this
   178  		t.Run("MergeClientOptions/err", func(t *testing.T) {
   179  			opt1, opt2 := Client(), Client()
   180  			opt1.err = errors.New("Test error")
   181  
   182  			got := MergeClientOptions(nil, opt1, opt2)
   183  			if got.err.Error() != "Test error" {
   184  				t.Errorf("Merged client options do not match. got %v; want %v", got.err.Error(), opt1.err.Error())
   185  			}
   186  		})
   187  	})
   188  	t.Run("ApplyURI", func(t *testing.T) {
   189  		baseClient := func() *ClientOptions {
   190  			return Client().SetHosts([]string{"localhost"})
   191  		}
   192  		testCases := []struct {
   193  			name   string
   194  			uri    string
   195  			result *ClientOptions
   196  		}{
   197  			{
   198  				"ParseError",
   199  				"not-mongo-db-uri://",
   200  				&ClientOptions{
   201  					err: fmt.Errorf(
   202  						"error parsing uri: %w",
   203  						errors.New(`scheme must be "mongodb" or "mongodb+srv"`)),
   204  					HTTPClient: httputil.DefaultHTTPClient,
   205  				},
   206  			},
   207  			{
   208  				"ReadPreference Invalid Mode",
   209  				"mongodb://localhost/?maxStaleness=200",
   210  				&ClientOptions{
   211  					err:        fmt.Errorf("unknown read preference %v", ""),
   212  					Hosts:      []string{"localhost"},
   213  					HTTPClient: httputil.DefaultHTTPClient,
   214  				},
   215  			},
   216  			{
   217  				"ReadPreference Primary With Options",
   218  				"mongodb://localhost/?readPreference=Primary&maxStaleness=200",
   219  				&ClientOptions{
   220  					err:        errors.New("can not specify tags, max staleness, or hedge with mode primary"),
   221  					Hosts:      []string{"localhost"},
   222  					HTTPClient: httputil.DefaultHTTPClient,
   223  				},
   224  			},
   225  			{
   226  				"TLS addCertFromFile error",
   227  				"mongodb://localhost/?ssl=true&sslCertificateAuthorityFile=testdata/doesntexist",
   228  				&ClientOptions{
   229  					err:        &os.PathError{Op: "open", Path: "testdata/doesntexist"},
   230  					Hosts:      []string{"localhost"},
   231  					HTTPClient: httputil.DefaultHTTPClient,
   232  				},
   233  			},
   234  			{
   235  				"TLS ClientCertificateKey",
   236  				"mongodb://localhost/?ssl=true&sslClientCertificateKeyFile=testdata/doesntexist",
   237  				&ClientOptions{
   238  					err:        &os.PathError{Op: "open", Path: "testdata/doesntexist"},
   239  					Hosts:      []string{"localhost"},
   240  					HTTPClient: httputil.DefaultHTTPClient,
   241  				},
   242  			},
   243  			{
   244  				"AppName",
   245  				"mongodb://localhost/?appName=awesome-example-application",
   246  				baseClient().SetAppName("awesome-example-application"),
   247  			},
   248  			{
   249  				"AuthMechanism",
   250  				"mongodb://localhost/?authMechanism=mongodb-x509",
   251  				baseClient().SetAuth(Credential{AuthSource: "$external", AuthMechanism: "mongodb-x509"}),
   252  			},
   253  			{
   254  				"AuthMechanismProperties",
   255  				"mongodb://foo@localhost/?authMechanism=gssapi&authMechanismProperties=SERVICE_NAME:mongodb-fake",
   256  				baseClient().SetAuth(Credential{
   257  					AuthSource:              "$external",
   258  					AuthMechanism:           "gssapi",
   259  					AuthMechanismProperties: map[string]string{"SERVICE_NAME": "mongodb-fake"},
   260  					Username:                "foo",
   261  				}),
   262  			},
   263  			{
   264  				"AuthSource",
   265  				"mongodb://foo@localhost/?authSource=random-database-example",
   266  				baseClient().SetAuth(Credential{AuthSource: "random-database-example", Username: "foo"}),
   267  			},
   268  			{
   269  				"Username",
   270  				"mongodb://foo@localhost/",
   271  				baseClient().SetAuth(Credential{AuthSource: "admin", Username: "foo"}),
   272  			},
   273  			{
   274  				"Unescaped slash in username",
   275  				"mongodb:///:pwd@localhost",
   276  				&ClientOptions{
   277  					err: fmt.Errorf(
   278  						"error parsing uri: %w",
   279  						errors.New("unescaped slash in username")),
   280  					HTTPClient: httputil.DefaultHTTPClient,
   281  				},
   282  			},
   283  			{
   284  				"Password",
   285  				"mongodb://foo:bar@localhost/",
   286  				baseClient().SetAuth(Credential{
   287  					AuthSource: "admin", Username: "foo",
   288  					Password: "bar", PasswordSet: true,
   289  				}),
   290  			},
   291  			{
   292  				"Single character username and password",
   293  				"mongodb://f:b@localhost/",
   294  				baseClient().SetAuth(Credential{
   295  					AuthSource: "admin", Username: "f",
   296  					Password: "b", PasswordSet: true,
   297  				}),
   298  			},
   299  			{
   300  				"Connect",
   301  				"mongodb://localhost/?connect=direct",
   302  				baseClient().SetDirect(true),
   303  			},
   304  			{
   305  				"ConnectTimeout",
   306  				"mongodb://localhost/?connectTimeoutms=5000",
   307  				baseClient().SetConnectTimeout(5 * time.Second),
   308  			},
   309  			{
   310  				"Compressors",
   311  				"mongodb://localhost/?compressors=zlib,snappy",
   312  				baseClient().SetCompressors([]string{"zlib", "snappy"}).SetZlibLevel(6),
   313  			},
   314  			{
   315  				"DatabaseNoAuth",
   316  				"mongodb://localhost/example-database",
   317  				baseClient(),
   318  			},
   319  			{
   320  				"DatabaseAsDefault",
   321  				"mongodb://foo@localhost/example-database",
   322  				baseClient().SetAuth(Credential{AuthSource: "example-database", Username: "foo"}),
   323  			},
   324  			{
   325  				"HeartbeatInterval",
   326  				"mongodb://localhost/?heartbeatIntervalms=12000",
   327  				baseClient().SetHeartbeatInterval(12 * time.Second),
   328  			},
   329  			{
   330  				"Hosts",
   331  				"mongodb://localhost:27017,localhost:27018,localhost:27019/",
   332  				baseClient().SetHosts([]string{"localhost:27017", "localhost:27018", "localhost:27019"}),
   333  			},
   334  			{
   335  				"LocalThreshold",
   336  				"mongodb://localhost/?localThresholdMS=200",
   337  				baseClient().SetLocalThreshold(200 * time.Millisecond),
   338  			},
   339  			{
   340  				"MaxConnIdleTime",
   341  				"mongodb://localhost/?maxIdleTimeMS=300000",
   342  				baseClient().SetMaxConnIdleTime(5 * time.Minute),
   343  			},
   344  			{
   345  				"MaxPoolSize",
   346  				"mongodb://localhost/?maxPoolSize=256",
   347  				baseClient().SetMaxPoolSize(256),
   348  			},
   349  			{
   350  				"MinPoolSize",
   351  				"mongodb://localhost/?minPoolSize=256",
   352  				baseClient().SetMinPoolSize(256),
   353  			},
   354  			{
   355  				"MaxConnecting",
   356  				"mongodb://localhost/?maxConnecting=10",
   357  				baseClient().SetMaxConnecting(10),
   358  			},
   359  			{
   360  				"ReadConcern",
   361  				"mongodb://localhost/?readConcernLevel=linearizable",
   362  				baseClient().SetReadConcern(readconcern.Linearizable()),
   363  			},
   364  			{
   365  				"ReadPreference",
   366  				"mongodb://localhost/?readPreference=secondaryPreferred",
   367  				baseClient().SetReadPreference(readpref.SecondaryPreferred()),
   368  			},
   369  			{
   370  				"ReadPreferenceTagSets",
   371  				"mongodb://localhost/?readPreference=secondaryPreferred&readPreferenceTags=foo:bar",
   372  				baseClient().SetReadPreference(readpref.SecondaryPreferred(readpref.WithTags("foo", "bar"))),
   373  			},
   374  			{
   375  				"MaxStaleness",
   376  				"mongodb://localhost/?readPreference=secondaryPreferred&maxStaleness=250",
   377  				baseClient().SetReadPreference(readpref.SecondaryPreferred(readpref.WithMaxStaleness(250 * time.Second))),
   378  			},
   379  			{
   380  				"RetryWrites",
   381  				"mongodb://localhost/?retryWrites=true",
   382  				baseClient().SetRetryWrites(true),
   383  			},
   384  			{
   385  				"ReplicaSet",
   386  				"mongodb://localhost/?replicaSet=rs01",
   387  				baseClient().SetReplicaSet("rs01"),
   388  			},
   389  			{
   390  				"ServerSelectionTimeout",
   391  				"mongodb://localhost/?serverSelectionTimeoutMS=45000",
   392  				baseClient().SetServerSelectionTimeout(45 * time.Second),
   393  			},
   394  			{
   395  				"SocketTimeout",
   396  				"mongodb://localhost/?socketTimeoutMS=15000",
   397  				baseClient().SetSocketTimeout(15 * time.Second),
   398  			},
   399  			{
   400  				"TLS CACertificate",
   401  				"mongodb://localhost/?ssl=true&sslCertificateAuthorityFile=testdata/ca.pem",
   402  				baseClient().SetTLSConfig(&tls.Config{
   403  					RootCAs: createCertPool(t, "testdata/ca.pem"),
   404  				}),
   405  			},
   406  			{
   407  				"TLS Insecure",
   408  				"mongodb://localhost/?ssl=true&sslInsecure=true",
   409  				baseClient().SetTLSConfig(&tls.Config{InsecureSkipVerify: true}),
   410  			},
   411  			{
   412  				"TLS ClientCertificateKey",
   413  				"mongodb://localhost/?ssl=true&sslClientCertificateKeyFile=testdata/nopass/certificate.pem",
   414  				baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
   415  			},
   416  			{
   417  				"TLS ClientCertificateKey with password",
   418  				"mongodb://localhost/?ssl=true&sslClientCertificateKeyFile=testdata/certificate.pem&sslClientCertificateKeyPassword=passphrase",
   419  				baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
   420  			},
   421  			{
   422  				"TLS Username",
   423  				"mongodb://localhost/?ssl=true&authMechanism=mongodb-x509&sslClientCertificateKeyFile=testdata/nopass/certificate.pem",
   424  				baseClient().SetAuth(Credential{
   425  					AuthMechanism: "mongodb-x509", AuthSource: "$external",
   426  					Username: `C=US,ST=New York,L=New York City, Inc,O=MongoDB\,OU=WWW`,
   427  				}),
   428  			},
   429  			{
   430  				"WriteConcern J",
   431  				"mongodb://localhost/?journal=true",
   432  				baseClient().SetWriteConcern(writeconcern.New(writeconcern.J(true))),
   433  			},
   434  			{
   435  				"WriteConcern WString",
   436  				"mongodb://localhost/?w=majority",
   437  				baseClient().SetWriteConcern(writeconcern.New(writeconcern.WMajority())),
   438  			},
   439  			{
   440  				"WriteConcern W",
   441  				"mongodb://localhost/?w=3",
   442  				baseClient().SetWriteConcern(writeconcern.New(writeconcern.W(3))),
   443  			},
   444  			{
   445  				"WriteConcern WTimeout",
   446  				"mongodb://localhost/?wTimeoutMS=45000",
   447  				baseClient().SetWriteConcern(writeconcern.New(writeconcern.WTimeout(45 * time.Second))),
   448  			},
   449  			{
   450  				"ZLibLevel",
   451  				"mongodb://localhost/?zlibCompressionLevel=4",
   452  				baseClient().SetZlibLevel(4),
   453  			},
   454  			{
   455  				"TLS tlsCertificateFile and tlsPrivateKeyFile",
   456  				"mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem&tlsPrivateKeyFile=testdata/nopass/key.pem",
   457  				baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
   458  			},
   459  			{
   460  				"TLS only tlsCertificateFile",
   461  				"mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem",
   462  				&ClientOptions{
   463  					err: fmt.Errorf(
   464  						"error validating uri: %w",
   465  						errors.New("the tlsPrivateKeyFile URI option must be provided if the tlsCertificateFile option is specified")),
   466  					HTTPClient: httputil.DefaultHTTPClient,
   467  				},
   468  			},
   469  			{
   470  				"TLS only tlsPrivateKeyFile",
   471  				"mongodb://localhost/?tlsPrivateKeyFile=testdata/nopass/key.pem",
   472  				&ClientOptions{
   473  					err: fmt.Errorf(
   474  						"error validating uri: %w",
   475  						errors.New("the tlsCertificateFile URI option must be provided if the tlsPrivateKeyFile option is specified")),
   476  					HTTPClient: httputil.DefaultHTTPClient,
   477  				},
   478  			},
   479  			{
   480  				"TLS tlsCertificateFile and tlsPrivateKeyFile and tlsCertificateKeyFile",
   481  				"mongodb://localhost/?tlsCertificateFile=testdata/nopass/cert.pem&tlsPrivateKeyFile=testdata/nopass/key.pem&tlsCertificateKeyFile=testdata/nopass/certificate.pem",
   482  				&ClientOptions{
   483  					err: fmt.Errorf(
   484  						"error validating uri: %w",
   485  						errors.New("the sslClientCertificateKeyFile/tlsCertificateKeyFile URI option cannot be provided "+
   486  							"along with tlsCertificateFile or tlsPrivateKeyFile")),
   487  					HTTPClient: httputil.DefaultHTTPClient,
   488  				},
   489  			},
   490  			{
   491  				"disable OCSP endpoint check",
   492  				"mongodb://localhost/?tlsDisableOCSPEndpointCheck=true",
   493  				baseClient().SetDisableOCSPEndpointCheck(true),
   494  			},
   495  			{
   496  				"directConnection",
   497  				"mongodb://localhost/?directConnection=true",
   498  				baseClient().SetDirect(true),
   499  			},
   500  			{
   501  				"TLS CA file with multiple certificiates",
   502  				"mongodb://localhost/?tlsCAFile=testdata/ca-with-intermediates.pem",
   503  				baseClient().SetTLSConfig(&tls.Config{
   504  					RootCAs: createCertPool(t, "testdata/ca-with-intermediates-first.pem",
   505  						"testdata/ca-with-intermediates-second.pem", "testdata/ca-with-intermediates-third.pem"),
   506  				}),
   507  			},
   508  			{
   509  				"TLS empty CA file",
   510  				"mongodb://localhost/?tlsCAFile=testdata/empty-ca.pem",
   511  				&ClientOptions{
   512  					Hosts:      []string{"localhost"},
   513  					HTTPClient: httputil.DefaultHTTPClient,
   514  					err:        errors.New("the specified CA file does not contain any valid certificates"),
   515  				},
   516  			},
   517  			{
   518  				"TLS CA file with no certificates",
   519  				"mongodb://localhost/?tlsCAFile=testdata/ca-key.pem",
   520  				&ClientOptions{
   521  					Hosts:      []string{"localhost"},
   522  					HTTPClient: httputil.DefaultHTTPClient,
   523  					err:        errors.New("the specified CA file does not contain any valid certificates"),
   524  				},
   525  			},
   526  			{
   527  				"TLS malformed CA file",
   528  				"mongodb://localhost/?tlsCAFile=testdata/malformed-ca.pem",
   529  				&ClientOptions{
   530  					Hosts:      []string{"localhost"},
   531  					HTTPClient: httputil.DefaultHTTPClient,
   532  					err:        errors.New("the specified CA file does not contain any valid certificates"),
   533  				},
   534  			},
   535  			{
   536  				"loadBalanced=true",
   537  				"mongodb://localhost/?loadBalanced=true",
   538  				baseClient().SetLoadBalanced(true),
   539  			},
   540  			{
   541  				"loadBalanced=false",
   542  				"mongodb://localhost/?loadBalanced=false",
   543  				baseClient().SetLoadBalanced(false),
   544  			},
   545  			{
   546  				"srvServiceName",
   547  				"mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname",
   548  				baseClient().SetSRVServiceName("customname").
   549  					SetHosts([]string{"localhost.test.build.10gen.cc:27017", "localhost.test.build.10gen.cc:27018"}),
   550  			},
   551  			{
   552  				"srvMaxHosts",
   553  				"mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=2",
   554  				baseClient().SetSRVMaxHosts(2).
   555  					SetHosts([]string{"localhost.test.build.10gen.cc:27017", "localhost.test.build.10gen.cc:27018"}),
   556  			},
   557  			{
   558  				"GODRIVER-2263 regression test",
   559  				"mongodb://localhost/?tlsCertificateKeyFile=testdata/one-pk-multiple-certs.pem",
   560  				baseClient().SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
   561  			},
   562  			{
   563  				"GODRIVER-2650 X509 certificate",
   564  				"mongodb://localhost/?ssl=true&authMechanism=mongodb-x509&sslClientCertificateKeyFile=testdata/one-pk-multiple-certs.pem",
   565  				baseClient().SetAuth(Credential{
   566  					AuthMechanism: "mongodb-x509", AuthSource: "$external",
   567  					// Subject name in the first certificate is used as the username for X509 auth.
   568  					Username: `C=US,ST=New York,L=New York City,O=MongoDB,OU=Drivers,CN=localhost`,
   569  				}).SetTLSConfig(&tls.Config{Certificates: make([]tls.Certificate, 1)}),
   570  			},
   571  		}
   572  
   573  		for _, tc := range testCases {
   574  			t.Run(tc.name, func(t *testing.T) {
   575  				result := Client().ApplyURI(tc.uri)
   576  
   577  				// Manually add the URI and ConnString to the test expectations to avoid adding them in each test
   578  				// definition. The ConnString should only be recorded if there was no error while parsing.
   579  				cs, err := connstring.ParseAndValidate(tc.uri)
   580  				if err == nil {
   581  					tc.result.cs = cs
   582  				}
   583  
   584  				// We have to sort string slices in comparison, as Hosts resolved from SRV URIs do not have a set order.
   585  				stringLess := func(a, b string) bool { return a < b }
   586  				if diff := cmp.Diff(
   587  					tc.result, result,
   588  					cmp.AllowUnexported(ClientOptions{}, readconcern.ReadConcern{}, writeconcern.WriteConcern{}, readpref.ReadPref{}),
   589  					cmp.Comparer(func(r1, r2 *bsoncodec.Registry) bool { return r1 == r2 }),
   590  					cmp.Comparer(compareTLSConfig),
   591  					cmp.Comparer(compareErrors),
   592  					cmpopts.SortSlices(stringLess),
   593  					cmpopts.IgnoreFields(connstring.ConnString{}, "SSLClientCertificateKeyPassword"),
   594  					cmpopts.IgnoreFields(http.Client{}, "Transport"),
   595  				); diff != "" {
   596  					t.Errorf("URI did not apply correctly: (-want +got)\n%s", diff)
   597  				}
   598  			})
   599  		}
   600  	})
   601  	t.Run("direct connection validation", func(t *testing.T) {
   602  		t.Run("multiple hosts", func(t *testing.T) {
   603  			expectedErr := errors.New("a direct connection cannot be made if multiple hosts are specified")
   604  
   605  			testCases := []struct {
   606  				name string
   607  				opts *ClientOptions
   608  			}{
   609  				{"hosts in URI", Client().ApplyURI("mongodb://localhost,localhost2")},
   610  				{"hosts in options", Client().SetHosts([]string{"localhost", "localhost2"})},
   611  			}
   612  			for _, tc := range testCases {
   613  				t.Run(tc.name, func(t *testing.T) {
   614  					err := tc.opts.SetDirect(true).Validate()
   615  					assert.NotNil(t, err, "expected error, got nil")
   616  					assert.Equal(t, expectedErr.Error(), err.Error(), "expected error %v, got %v", expectedErr, err)
   617  				})
   618  			}
   619  		})
   620  		t.Run("srv", func(t *testing.T) {
   621  			expectedErr := errors.New("a direct connection cannot be made if an SRV URI is used")
   622  			// Use a non-SRV URI and manually set the scheme because using an SRV URI would force an SRV lookup.
   623  			opts := Client().ApplyURI("mongodb://localhost:27017")
   624  			opts.cs.Scheme = connstring.SchemeMongoDBSRV
   625  
   626  			err := opts.SetDirect(true).Validate()
   627  			assert.NotNil(t, err, "expected error, got nil")
   628  			assert.Equal(t, expectedErr.Error(), err.Error(), "expected error %v, got %v", expectedErr, err)
   629  		})
   630  	})
   631  	t.Run("loadBalanced validation", func(t *testing.T) {
   632  		testCases := []struct {
   633  			name string
   634  			opts *ClientOptions
   635  			err  error
   636  		}{
   637  			{"multiple hosts in URI", Client().ApplyURI("mongodb://foo,bar"), connstring.ErrLoadBalancedWithMultipleHosts},
   638  			{"multiple hosts in options", Client().SetHosts([]string{"foo", "bar"}), connstring.ErrLoadBalancedWithMultipleHosts},
   639  			{"replica set name", Client().SetReplicaSet("foo"), connstring.ErrLoadBalancedWithReplicaSet},
   640  			{"directConnection=true", Client().SetDirect(true), connstring.ErrLoadBalancedWithDirectConnection},
   641  		}
   642  		for _, tc := range testCases {
   643  			t.Run(tc.name, func(t *testing.T) {
   644  				// The loadBalanced option should not be validated if it is unset or false.
   645  				err := tc.opts.Validate()
   646  				assert.Nil(t, err, "Validate error when loadBalanced is unset: %v", err)
   647  
   648  				tc.opts.SetLoadBalanced(false)
   649  				err = tc.opts.Validate()
   650  				assert.Nil(t, err, "Validate error when loadBalanced=false: %v", err)
   651  
   652  				tc.opts.SetLoadBalanced(true)
   653  				err = tc.opts.Validate()
   654  				assert.Equal(t, tc.err, err, "expected error %v when loadBalanced=true, got %v", tc.err, err)
   655  			})
   656  		}
   657  	})
   658  	t.Run("minPoolSize validation", func(t *testing.T) {
   659  		testCases := []struct {
   660  			name string
   661  			opts *ClientOptions
   662  			err  error
   663  		}{
   664  			{
   665  				"minPoolSize < maxPoolSize",
   666  				Client().SetMinPoolSize(128).SetMaxPoolSize(256),
   667  				nil,
   668  			},
   669  			{
   670  				"minPoolSize == maxPoolSize",
   671  				Client().SetMinPoolSize(128).SetMaxPoolSize(128),
   672  				nil,
   673  			},
   674  			{
   675  				"minPoolSize > maxPoolSize",
   676  				Client().SetMinPoolSize(64).SetMaxPoolSize(32),
   677  				errors.New("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=64 maxPoolSize=32"),
   678  			},
   679  			{
   680  				"maxPoolSize == 0",
   681  				Client().SetMinPoolSize(128).SetMaxPoolSize(0),
   682  				nil,
   683  			},
   684  		}
   685  		for _, tc := range testCases {
   686  			t.Run(tc.name, func(t *testing.T) {
   687  				err := tc.opts.Validate()
   688  				assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
   689  			})
   690  		}
   691  	})
   692  	t.Run("srvMaxHosts validation", func(t *testing.T) {
   693  		testCases := []struct {
   694  			name string
   695  			opts *ClientOptions
   696  			err  error
   697  		}{
   698  			{"replica set name", Client().SetReplicaSet("foo"), connstring.ErrSRVMaxHostsWithReplicaSet},
   699  			{"loadBalanced=true", Client().SetLoadBalanced(true), connstring.ErrSRVMaxHostsWithLoadBalanced},
   700  			{"loadBalanced=false", Client().SetLoadBalanced(false), nil},
   701  		}
   702  		for _, tc := range testCases {
   703  			t.Run(tc.name, func(t *testing.T) {
   704  				err := tc.opts.Validate()
   705  				assert.Nil(t, err, "Validate error when srvMxaHosts is unset: %v", err)
   706  
   707  				tc.opts.SetSRVMaxHosts(0)
   708  				err = tc.opts.Validate()
   709  				assert.Nil(t, err, "Validate error when srvMaxHosts is 0: %v", err)
   710  
   711  				tc.opts.SetSRVMaxHosts(2)
   712  				err = tc.opts.Validate()
   713  				assert.Equal(t, tc.err, err, "expected error %v when srvMaxHosts > 0, got %v", tc.err, err)
   714  			})
   715  		}
   716  	})
   717  	t.Run("srvMaxHosts validation", func(t *testing.T) {
   718  		t.Parallel()
   719  
   720  		testCases := []struct {
   721  			name string
   722  			opts *ClientOptions
   723  			err  error
   724  		}{
   725  			{
   726  				name: "valid ServerAPI",
   727  				opts: Client().SetServerAPIOptions(ServerAPI(ServerAPIVersion1)),
   728  				err:  nil,
   729  			},
   730  			{
   731  				name: "invalid ServerAPI",
   732  				opts: Client().SetServerAPIOptions(ServerAPI("nope")),
   733  				err:  errors.New(`api version "nope" not supported; this driver version only supports API version "1"`),
   734  			},
   735  			{
   736  				name: "invalid ServerAPI with other invalid options",
   737  				opts: Client().SetServerAPIOptions(ServerAPI("nope")).SetSRVMaxHosts(1).SetReplicaSet("foo"),
   738  				err:  errors.New(`api version "nope" not supported; this driver version only supports API version "1"`),
   739  			},
   740  		}
   741  		for _, tc := range testCases {
   742  			tc := tc // Capture range variable.
   743  
   744  			t.Run(tc.name, func(t *testing.T) {
   745  				t.Parallel()
   746  
   747  				err := tc.opts.Validate()
   748  				assert.Equal(t, tc.err, err, "want error %v, got error %v", tc.err, err)
   749  			})
   750  		}
   751  	})
   752  	t.Run("server monitoring mode validation", func(t *testing.T) {
   753  		t.Parallel()
   754  
   755  		testCases := []struct {
   756  			name string
   757  			opts *ClientOptions
   758  			err  error
   759  		}{
   760  			{
   761  				name: "undefined",
   762  				opts: Client(),
   763  				err:  nil,
   764  			},
   765  			{
   766  				name: "auto",
   767  				opts: Client().SetServerMonitoringMode(ServerMonitoringModeAuto),
   768  				err:  nil,
   769  			},
   770  			{
   771  				name: "poll",
   772  				opts: Client().SetServerMonitoringMode(ServerMonitoringModePoll),
   773  				err:  nil,
   774  			},
   775  			{
   776  				name: "stream",
   777  				opts: Client().SetServerMonitoringMode(ServerMonitoringModeStream),
   778  				err:  nil,
   779  			},
   780  			{
   781  				name: "invalid",
   782  				opts: Client().SetServerMonitoringMode("invalid"),
   783  				err:  errors.New("invalid server monitoring mode: \"invalid\""),
   784  			},
   785  		}
   786  
   787  		for _, tc := range testCases {
   788  			tc := tc // Capture the range variable
   789  
   790  			t.Run(tc.name, func(t *testing.T) {
   791  				t.Parallel()
   792  
   793  				err := tc.opts.Validate()
   794  				assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
   795  			})
   796  		}
   797  	})
   798  }
   799  
   800  func createCertPool(t *testing.T, paths ...string) *x509.CertPool {
   801  	t.Helper()
   802  
   803  	pool := x509.NewCertPool()
   804  	for _, path := range paths {
   805  		pool.AddCert(loadCert(t, path))
   806  	}
   807  	return pool
   808  }
   809  
   810  func loadCert(t *testing.T, file string) *x509.Certificate {
   811  	t.Helper()
   812  
   813  	data := readFile(t, file)
   814  	block, _ := pem.Decode(data)
   815  	cert, err := x509.ParseCertificate(block.Bytes)
   816  	assert.Nil(t, err, "ParseCertificate error for %s: %v", file, err)
   817  	return cert
   818  }
   819  
   820  func readFile(t *testing.T, path string) []byte {
   821  	data, err := ioutil.ReadFile(path)
   822  	assert.Nil(t, err, "ReadFile error for %s: %v", path, err)
   823  	return data
   824  }
   825  
   826  type testDialer struct {
   827  	Num int
   828  }
   829  
   830  func (testDialer) DialContext(context.Context, string, string) (net.Conn, error) {
   831  	return nil, nil
   832  }
   833  
   834  func compareTLSConfig(cfg1, cfg2 *tls.Config) bool {
   835  	if cfg1 == nil && cfg2 == nil {
   836  		return true
   837  	}
   838  
   839  	if cfg1 == nil || cfg2 == nil {
   840  		return true
   841  	}
   842  
   843  	if (cfg1.RootCAs == nil && cfg1.RootCAs != nil) || (cfg1.RootCAs != nil && cfg1.RootCAs == nil) {
   844  		return false
   845  	}
   846  
   847  	if cfg1.RootCAs != nil {
   848  		cfg1Subjects := cfg1.RootCAs.Subjects()
   849  		cfg2Subjects := cfg2.RootCAs.Subjects()
   850  		if len(cfg1Subjects) != len(cfg2Subjects) {
   851  			return false
   852  		}
   853  
   854  		for idx, firstSubject := range cfg1Subjects {
   855  			if !bytes.Equal(firstSubject, cfg2Subjects[idx]) {
   856  				return false
   857  			}
   858  		}
   859  	}
   860  
   861  	if len(cfg1.Certificates) != len(cfg2.Certificates) {
   862  		return false
   863  	}
   864  
   865  	if cfg1.InsecureSkipVerify != cfg2.InsecureSkipVerify {
   866  		return false
   867  	}
   868  
   869  	return true
   870  }
   871  
   872  func compareErrors(err1, err2 error) bool {
   873  	if err1 == nil && err2 == nil {
   874  		return true
   875  	}
   876  
   877  	if err1 == nil || err2 == nil {
   878  		return false
   879  	}
   880  
   881  	var ospe1, ospe2 *os.PathError
   882  	if errors.As(err1, &ospe1) && errors.As(err2, &ospe2) {
   883  		return ospe1.Op == ospe2.Op && ospe1.Path == ospe2.Path
   884  	}
   885  
   886  	if err1.Error() != err2.Error() {
   887  		return false
   888  	}
   889  
   890  	return true
   891  }
   892  

View as plain text