...

Source file src/go.mongodb.org/mongo-driver/mongo/client_test.go

Documentation: go.mongodb.org/mongo-driver/mongo

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package mongo
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"math"
    13  	"os"
    14  	"testing"
    15  	"time"
    16  
    17  	"go.mongodb.org/mongo-driver/bson"
    18  	"go.mongodb.org/mongo-driver/event"
    19  	"go.mongodb.org/mongo-driver/internal/assert"
    20  	"go.mongodb.org/mongo-driver/internal/integtest"
    21  	"go.mongodb.org/mongo-driver/mongo/options"
    22  	"go.mongodb.org/mongo-driver/mongo/readconcern"
    23  	"go.mongodb.org/mongo-driver/mongo/readpref"
    24  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    25  	"go.mongodb.org/mongo-driver/tag"
    26  	"go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
    27  	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
    28  	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
    29  )
    30  
    31  var bgCtx = context.Background()
    32  
    33  func setupClient(opts ...*options.ClientOptions) *Client {
    34  	if len(opts) == 0 {
    35  		clientOpts := options.Client().ApplyURI("mongodb://localhost:27017")
    36  		integtest.AddTestServerAPIVersion(clientOpts)
    37  		opts = append(opts, clientOpts)
    38  	}
    39  	client, _ := NewClient(opts...)
    40  	return client
    41  }
    42  
    43  func TestClient(t *testing.T) {
    44  	t.Run("new client", func(t *testing.T) {
    45  		client := setupClient()
    46  		assert.NotNil(t, client.deployment, "expected valid deployment, got nil")
    47  	})
    48  	t.Run("database", func(t *testing.T) {
    49  		dbName := "foo"
    50  		client := setupClient()
    51  		db := client.Database(dbName)
    52  		assert.Equal(t, dbName, db.Name(), "expected db name %v, got %v", dbName, db.Name())
    53  		assert.Equal(t, client, db.Client(), "expected client %v, got %v", client, db.Client())
    54  	})
    55  	t.Run("replace topology error", func(t *testing.T) {
    56  		client := setupClient()
    57  
    58  		_, err := client.StartSession()
    59  		assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
    60  
    61  		_, err = client.ListDatabases(bgCtx, bson.D{})
    62  		assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
    63  
    64  		err = client.Ping(bgCtx, nil)
    65  		assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
    66  
    67  		err = client.Disconnect(bgCtx)
    68  		assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
    69  
    70  		_, err = client.Watch(bgCtx, []bson.D{})
    71  		assert.Equal(t, ErrClientDisconnected, err, "expected error %v, got %v", ErrClientDisconnected, err)
    72  	})
    73  	t.Run("nil document error", func(t *testing.T) {
    74  		// manually set session pool to non-nil because Watch will return ErrClientDisconnected
    75  		client := setupClient()
    76  		client.sessionPool = &session.Pool{}
    77  
    78  		_, err := client.Watch(bgCtx, nil)
    79  		watchErr := errors.New("can only marshal slices and arrays into aggregation pipelines, but got invalid")
    80  		assert.Equal(t, watchErr, err, "expected error %v, got %v", watchErr, err)
    81  
    82  		_, err = client.ListDatabases(bgCtx, nil)
    83  		assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err)
    84  
    85  		_, err = client.ListDatabaseNames(bgCtx, nil)
    86  		assert.Equal(t, ErrNilDocument, err, "expected error %v, got %v", ErrNilDocument, err)
    87  	})
    88  	t.Run("read preference", func(t *testing.T) {
    89  		t.Run("absent", func(t *testing.T) {
    90  			client := setupClient()
    91  			gotMode := client.readPreference.Mode()
    92  			wantMode := readpref.PrimaryMode
    93  			assert.Equal(t, gotMode, wantMode, "expected mode %v, got %v", wantMode, gotMode)
    94  			_, flag := client.readPreference.MaxStaleness()
    95  			assert.False(t, flag, "expected max staleness to not be set but was")
    96  		})
    97  		t.Run("specified", func(t *testing.T) {
    98  			tags := []tag.Set{
    99  				{
   100  					tag.Tag{
   101  						Name:  "one",
   102  						Value: "1",
   103  					},
   104  				},
   105  				{
   106  					tag.Tag{
   107  						Name:  "two",
   108  						Value: "2",
   109  					},
   110  				},
   111  			}
   112  			cs := "mongodb://localhost:27017/"
   113  			cs += "?readpreference=secondary&readPreferenceTags=one:1&readPreferenceTags=two:2&maxStaleness=5"
   114  
   115  			client := setupClient(options.Client().ApplyURI(cs))
   116  			gotMode := client.readPreference.Mode()
   117  			assert.Equal(t, gotMode, readpref.SecondaryMode, "expected mode %v, got %v", readpref.SecondaryMode, gotMode)
   118  			gotTags := client.readPreference.TagSets()
   119  			assert.Equal(t, gotTags, tags, "expected tags %v, got %v", tags, gotTags)
   120  			gotStaleness, flag := client.readPreference.MaxStaleness()
   121  			assert.True(t, flag, "expected max staleness to be set but was not")
   122  			wantStaleness := time.Duration(5) * time.Second
   123  			assert.Equal(t, gotStaleness, wantStaleness, "expected staleness %v, got %v", wantStaleness, gotStaleness)
   124  		})
   125  	})
   126  	t.Run("localThreshold", func(t *testing.T) {
   127  		testCases := []struct {
   128  			name              string
   129  			opts              *options.ClientOptions
   130  			expectedThreshold time.Duration
   131  		}{
   132  			{"default", options.Client(), defaultLocalThreshold},
   133  			{"custom", options.Client().SetLocalThreshold(10 * time.Second), 10 * time.Second},
   134  		}
   135  		for _, tc := range testCases {
   136  			t.Run(tc.name, func(t *testing.T) {
   137  				client := setupClient(tc.opts)
   138  				assert.Equal(t, tc.expectedThreshold, client.localThreshold,
   139  					"expected localThreshold %v, got %v", tc.expectedThreshold, client.localThreshold)
   140  			})
   141  		}
   142  	})
   143  	t.Run("read concern", func(t *testing.T) {
   144  		rc := readconcern.Majority()
   145  		client := setupClient(options.Client().SetReadConcern(rc))
   146  		assert.Equal(t, rc, client.readConcern, "expected read concern %v, got %v", rc, client.readConcern)
   147  	})
   148  	t.Run("min pool size from Set*PoolSize()", func(t *testing.T) {
   149  		testCases := []struct {
   150  			name string
   151  			opts *options.ClientOptions
   152  			err  error
   153  		}{
   154  			{
   155  				name: "minPoolSize < default maxPoolSize",
   156  				opts: options.Client().SetMinPoolSize(64),
   157  				err:  nil,
   158  			},
   159  			{
   160  				name: "minPoolSize > default maxPoolSize",
   161  				opts: options.Client().SetMinPoolSize(128),
   162  				err:  errors.New("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=128 maxPoolSize=100"),
   163  			},
   164  			{
   165  				name: "minPoolSize < maxPoolSize",
   166  				opts: options.Client().SetMinPoolSize(128).SetMaxPoolSize(256),
   167  				err:  nil,
   168  			},
   169  			{
   170  				name: "minPoolSize == maxPoolSize",
   171  				opts: options.Client().SetMinPoolSize(128).SetMaxPoolSize(128),
   172  				err:  nil,
   173  			},
   174  			{
   175  				name: "minPoolSize > maxPoolSize",
   176  				opts: options.Client().SetMinPoolSize(64).SetMaxPoolSize(32),
   177  				err:  errors.New("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=64 maxPoolSize=32"),
   178  			},
   179  			{
   180  				name: "maxPoolSize == 0",
   181  				opts: options.Client().SetMinPoolSize(128).SetMaxPoolSize(0),
   182  				err:  nil,
   183  			},
   184  		}
   185  		for _, tc := range testCases {
   186  			t.Run(tc.name, func(t *testing.T) {
   187  				_, err := NewClient(tc.opts)
   188  				assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
   189  			})
   190  		}
   191  	})
   192  	t.Run("min pool size from ApplyURI()", func(t *testing.T) {
   193  		testCases := []struct {
   194  			name string
   195  			opts *options.ClientOptions
   196  			err  error
   197  		}{
   198  			{
   199  				name: "minPoolSize < default maxPoolSize",
   200  				opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=64"),
   201  				err:  nil,
   202  			},
   203  			{
   204  				name: "minPoolSize > default maxPoolSize",
   205  				opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=128"),
   206  				err:  errors.New("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=128 maxPoolSize=100"),
   207  			},
   208  			{
   209  				name: "minPoolSize < maxPoolSize",
   210  				opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=128&maxPoolSize=256"),
   211  				err:  nil,
   212  			},
   213  			{
   214  				name: "minPoolSize == maxPoolSize",
   215  				opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=128&maxPoolSize=128"),
   216  				err:  nil,
   217  			},
   218  			{
   219  				name: "minPoolSize > maxPoolSize",
   220  				opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=64&maxPoolSize=32"),
   221  				err:  errors.New("minPoolSize must be less than or equal to maxPoolSize, got minPoolSize=64 maxPoolSize=32"),
   222  			},
   223  			{
   224  				name: "maxPoolSize == 0",
   225  				opts: options.Client().ApplyURI("mongodb://localhost:27017/?minPoolSize=128&maxPoolSize=0"),
   226  				err:  nil,
   227  			},
   228  		}
   229  		for _, tc := range testCases {
   230  			t.Run(tc.name, func(t *testing.T) {
   231  				_, err := NewClient(tc.opts)
   232  				assert.Equal(t, tc.err, err, "expected error %v, got %v", tc.err, err)
   233  			})
   234  		}
   235  	})
   236  	t.Run("retry writes", func(t *testing.T) {
   237  		retryWritesURI := "mongodb://localhost:27017/?retryWrites=false"
   238  		retryWritesErrorURI := "mongodb://localhost:27017/?retryWrites=foobar"
   239  
   240  		testCases := []struct {
   241  			name          string
   242  			opts          *options.ClientOptions
   243  			expectErr     bool
   244  			expectedRetry bool
   245  		}{
   246  			{"default", options.Client(), false, true},
   247  			{"custom options", options.Client().SetRetryWrites(false), false, false},
   248  			{"custom URI", options.Client().ApplyURI(retryWritesURI), false, false},
   249  			{"custom URI error", options.Client().ApplyURI(retryWritesErrorURI), true, false},
   250  		}
   251  		for _, tc := range testCases {
   252  			t.Run(tc.name, func(t *testing.T) {
   253  				client, err := NewClient(tc.opts)
   254  				if tc.expectErr {
   255  					assert.NotNil(t, err, "expected error, got nil")
   256  					return
   257  				}
   258  				assert.Nil(t, err, "configuration error: %v", err)
   259  				assert.Equal(t, tc.expectedRetry, client.retryWrites, "expected retryWrites %v, got %v",
   260  					tc.expectedRetry, client.retryWrites)
   261  			})
   262  		}
   263  	})
   264  	t.Run("retry reads", func(t *testing.T) {
   265  		retryReadsURI := "mongodb://localhost:27017/?retryReads=false"
   266  		retryReadsErrorURI := "mongodb://localhost:27017/?retryReads=foobar"
   267  
   268  		testCases := []struct {
   269  			name          string
   270  			opts          *options.ClientOptions
   271  			expectErr     bool
   272  			expectedRetry bool
   273  		}{
   274  			{"default", options.Client(), false, true},
   275  			{"custom options", options.Client().SetRetryReads(false), false, false},
   276  			{"custom URI", options.Client().ApplyURI(retryReadsURI), false, false},
   277  			{"custom URI error", options.Client().ApplyURI(retryReadsErrorURI), true, false},
   278  		}
   279  		for _, tc := range testCases {
   280  			t.Run(tc.name, func(t *testing.T) {
   281  				client, err := NewClient(tc.opts)
   282  				if tc.expectErr {
   283  					assert.NotNil(t, err, "expected error, got nil")
   284  					return
   285  				}
   286  				assert.Nil(t, err, "configuration error: %v", err)
   287  				assert.Equal(t, tc.expectedRetry, client.retryReads, "expected retryReads %v, got %v",
   288  					tc.expectedRetry, client.retryReads)
   289  			})
   290  		}
   291  	})
   292  	t.Run("write concern", func(t *testing.T) {
   293  		wc := writeconcern.New(writeconcern.WMajority())
   294  		client := setupClient(options.Client().SetWriteConcern(wc))
   295  		assert.Equal(t, wc, client.writeConcern, "mismatch; expected write concern %v, got %v", wc, client.writeConcern)
   296  	})
   297  	t.Run("server monitor", func(t *testing.T) {
   298  		monitor := &event.ServerMonitor{}
   299  		client := setupClient(options.Client().SetServerMonitor(monitor))
   300  		assert.Equal(t, monitor, client.serverMonitor, "expected sdam monitor %v, got %v", monitor, client.serverMonitor)
   301  	})
   302  	t.Run("GetURI", func(t *testing.T) {
   303  		t.Run("ApplyURI not called", func(t *testing.T) {
   304  			opts := options.Client().SetHosts([]string{"localhost:27017"})
   305  			uri := opts.GetURI()
   306  			assert.Equal(t, "", uri, "expected GetURI to return empty string, got %v", uri)
   307  		})
   308  		t.Run("ApplyURI called with empty string", func(t *testing.T) {
   309  			opts := options.Client().ApplyURI("")
   310  			uri := opts.GetURI()
   311  			assert.Equal(t, "", uri, "expected GetURI to return empty string, got %v", uri)
   312  		})
   313  		t.Run("ApplyURI called with non-empty string", func(t *testing.T) {
   314  			uri := "mongodb://localhost:27017/foobar"
   315  			opts := options.Client().ApplyURI(uri)
   316  			got := opts.GetURI()
   317  			assert.Equal(t, uri, got, "expected GetURI to return %v, got %v", uri, got)
   318  		})
   319  	})
   320  	t.Run("endSessions", func(t *testing.T) {
   321  		cs := integtest.ConnString(t)
   322  		originalBatchSize := endSessionsBatchSize
   323  		endSessionsBatchSize = 2
   324  		defer func() {
   325  			endSessionsBatchSize = originalBatchSize
   326  		}()
   327  
   328  		testCases := []struct {
   329  			name            string
   330  			numSessions     int
   331  			eventBatchSizes []int
   332  		}{
   333  			{"number of sessions divides evenly", endSessionsBatchSize * 2, []int{endSessionsBatchSize, endSessionsBatchSize}},
   334  			{"number of sessions does not divide evenly", endSessionsBatchSize + 1, []int{endSessionsBatchSize, 1}},
   335  		}
   336  		for _, tc := range testCases {
   337  			if testing.Short() {
   338  				t.Skip("skipping integration test in short mode")
   339  			}
   340  			if os.Getenv("DOCKER_RUNNING") != "" {
   341  				t.Skip("skipping test in docker environment")
   342  			}
   343  
   344  			t.Run(tc.name, func(t *testing.T) {
   345  				// Setup a client and skip the test based on server version.
   346  				var started []*event.CommandStartedEvent
   347  				var failureReasons []string
   348  				cmdMonitor := &event.CommandMonitor{
   349  					Started: func(_ context.Context, evt *event.CommandStartedEvent) {
   350  						if evt.CommandName == "endSessions" {
   351  							started = append(started, evt)
   352  						}
   353  					},
   354  					Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
   355  						if evt.CommandName == "endSessions" {
   356  							failureReasons = append(failureReasons, evt.Failure)
   357  						}
   358  					},
   359  				}
   360  				clientOpts := options.Client().ApplyURI(cs.Original).SetReadPreference(readpref.Primary()).
   361  					SetWriteConcern(writeconcern.New(writeconcern.WMajority())).SetMonitor(cmdMonitor)
   362  				integtest.AddTestServerAPIVersion(clientOpts)
   363  				client, err := Connect(bgCtx, clientOpts)
   364  				assert.Nil(t, err, "Connect error: %v", err)
   365  				defer func() {
   366  					_ = client.Disconnect(bgCtx)
   367  				}()
   368  
   369  				serverVersion, err := getServerVersion(client.Database("admin"))
   370  				assert.Nil(t, err, "getServerVersion error: %v", err)
   371  				if compareVersions(serverVersion, "3.6.0") < 1 {
   372  					t.Skip("skipping server version < 3.6")
   373  				}
   374  
   375  				coll := client.Database("foo").Collection("bar")
   376  				defer func() {
   377  					_ = coll.Drop(bgCtx)
   378  				}()
   379  
   380  				// Do an application operation and create the number of sessions specified by the test.
   381  				_, err = coll.CountDocuments(bgCtx, bson.D{})
   382  				assert.Nil(t, err, "CountDocuments error: %v", err)
   383  				var sessions []Session
   384  				for i := 0; i < tc.numSessions; i++ {
   385  					sess, err := client.StartSession()
   386  					assert.Nil(t, err, "StartSession error at index %d: %v", i, err)
   387  					sessions = append(sessions, sess)
   388  				}
   389  				for _, sess := range sessions {
   390  					sess.EndSession(bgCtx)
   391  				}
   392  
   393  				client.endSessions(bgCtx)
   394  				divisionResult := float64(tc.numSessions) / float64(endSessionsBatchSize)
   395  				numEventsExpected := int(math.Ceil(divisionResult))
   396  				assert.Equal(t, len(started), numEventsExpected, "expected %d started events, got %d", numEventsExpected,
   397  					len(started))
   398  				assert.Equal(t, len(failureReasons), 0, "endSessions errors: %v", failureReasons)
   399  
   400  				for i := 0; i < numEventsExpected; i++ {
   401  					sentArray := started[i].Command.Lookup("endSessions").Array()
   402  					values, _ := sentArray.Values()
   403  					expectedNumValues := tc.eventBatchSizes[i]
   404  					assert.Equal(t, len(values), expectedNumValues,
   405  						"batch size mismatch at index %d; expected %d sessions in batch, got %d", i, expectedNumValues,
   406  						len(values))
   407  				}
   408  			})
   409  		}
   410  	})
   411  	t.Run("serverAPI version", func(t *testing.T) {
   412  		getServerAPIOptions := func() *options.ServerAPIOptions {
   413  			return options.ServerAPI(options.ServerAPIVersion1).
   414  				SetStrict(false).SetDeprecationErrors(false)
   415  		}
   416  
   417  		t.Run("success with all options", func(t *testing.T) {
   418  			serverAPIOptions := getServerAPIOptions()
   419  			client, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions))
   420  			assert.Nil(t, err, "unexpected error from NewClient: %v", err)
   421  			convertedAPIOptions := topology.ConvertToDriverAPIOptions(serverAPIOptions)
   422  			assert.Equal(t, convertedAPIOptions, client.serverAPI,
   423  				"mismatch in serverAPI; expected %v, got %v", convertedAPIOptions, client.serverAPI)
   424  		})
   425  		t.Run("failure with unsupported version", func(t *testing.T) {
   426  			serverAPIOptions := options.ServerAPI("badVersion")
   427  			_, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions))
   428  			assert.NotNil(t, err, "expected error from NewClient, got nil")
   429  			errmsg := `api version "badVersion" not supported; this driver version only supports API version "1"`
   430  			assert.Equal(t, errmsg, err.Error(), "expected error %v, got %v", errmsg, err.Error())
   431  		})
   432  		t.Run("cannot modify options after client creation", func(t *testing.T) {
   433  			serverAPIOptions := getServerAPIOptions()
   434  			client, err := NewClient(options.Client().SetServerAPIOptions(serverAPIOptions))
   435  			assert.Nil(t, err, "unexpected error from NewClient: %v", err)
   436  
   437  			expectedServerAPIOptions := getServerAPIOptions()
   438  			// modify passed-in options
   439  			serverAPIOptions.SetStrict(true).SetDeprecationErrors(true)
   440  			convertedAPIOptions := topology.ConvertToDriverAPIOptions(expectedServerAPIOptions)
   441  			assert.Equal(t, convertedAPIOptions, client.serverAPI,
   442  				"unexpected modification to serverAPI; expected %v, got %v", convertedAPIOptions, client.serverAPI)
   443  		})
   444  	})
   445  	t.Run("mongocryptd or crypt_shared", func(t *testing.T) {
   446  		cryptSharedLibPath := os.Getenv("CRYPT_SHARED_LIB_PATH")
   447  		if cryptSharedLibPath == "" {
   448  			t.Skip("CRYPT_SHARED_LIB_PATH not set, skipping")
   449  		}
   450  		if len(mongocrypt.Version()) == 0 {
   451  			t.Skip("Not built with cse flag")
   452  		}
   453  
   454  		testCases := []struct {
   455  			description       string
   456  			useCryptSharedLib bool
   457  		}{
   458  			{
   459  				description:       "when crypt_shared is loaded, should not attempt to spawn mongocryptd",
   460  				useCryptSharedLib: true,
   461  			},
   462  			{
   463  				description:       "when crypt_shared is not loaded, should attempt to spawn mongocryptd",
   464  				useCryptSharedLib: false,
   465  			},
   466  		}
   467  		for _, tc := range testCases {
   468  			t.Run(tc.description, func(t *testing.T) {
   469  				extraOptions := map[string]interface{}{
   470  					// Set a mongocryptd path that does not exist. If Connect() attempts to start
   471  					// mongocryptd, it will cause an error.
   472  					"mongocryptdPath": "/does/not/exist",
   473  				}
   474  
   475  				// If we're using the crypt_shared library, set the "cryptSharedLibRequired" option
   476  				// to true and the "cryptSharedLibPath" option to the crypt_shared library path from
   477  				// the CRYPT_SHARED_LIB_PATH environment variable. If we're not using the
   478  				// crypt_shared library, explicitly disable loading the crypt_shared library.
   479  				if tc.useCryptSharedLib {
   480  					extraOptions["cryptSharedLibRequired"] = true
   481  					extraOptions["cryptSharedLibPath"] = cryptSharedLibPath
   482  				} else {
   483  					extraOptions["__cryptSharedLibDisabledForTestOnly"] = true
   484  				}
   485  
   486  				_, err := NewClient(options.Client().
   487  					SetAutoEncryptionOptions(options.AutoEncryption().
   488  						SetKmsProviders(map[string]map[string]interface{}{
   489  							"local": {"key": make([]byte, 96)},
   490  						}).
   491  						SetExtraOptions(extraOptions)))
   492  
   493  				// If we're using the crypt_shared library, expect that Connect() doesn't attempt to spawn
   494  				// mongocryptd and no error is returned. If we're not using the crypt_shared library,
   495  				// expect that Connect() tries to spawn mongocryptd and returns an error.
   496  				if tc.useCryptSharedLib {
   497  					assert.Nil(t, err, "Connect() error: %v", err)
   498  				} else {
   499  					assert.NotNil(t, err, "expected Connect() error, but got nil")
   500  				}
   501  			})
   502  		}
   503  	})
   504  }
   505  

View as plain text