...

Source file src/go.mongodb.org/mongo-driver/mongo/integration/mtest/mongotest.go

Documentation: go.mongodb.org/mongo-driver/mongo/integration/mtest

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package mtest
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"strings"
    14  	"sync"
    15  	"sync/atomic"
    16  	"testing"
    17  	"time"
    18  
    19  	"go.mongodb.org/mongo-driver/bson"
    20  	"go.mongodb.org/mongo-driver/event"
    21  	"go.mongodb.org/mongo-driver/internal/assert"
    22  	"go.mongodb.org/mongo-driver/internal/csfle"
    23  	"go.mongodb.org/mongo-driver/mongo"
    24  	"go.mongodb.org/mongo-driver/mongo/options"
    25  	"go.mongodb.org/mongo-driver/mongo/readconcern"
    26  	"go.mongodb.org/mongo-driver/mongo/readpref"
    27  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    28  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    29  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    30  )
    31  
    32  var (
    33  	// MajorityWc is the majority write concern.
    34  	MajorityWc = writeconcern.New(writeconcern.WMajority())
    35  	// PrimaryRp is the primary read preference.
    36  	PrimaryRp = readpref.Primary()
    37  	// SecondaryRp is the secondary read preference.
    38  	SecondaryRp = readpref.Secondary()
    39  	// LocalRc is the local read concern
    40  	LocalRc = readconcern.Local()
    41  	// MajorityRc is the majority read concern
    42  	MajorityRc = readconcern.Majority()
    43  )
    44  
    45  const (
    46  	namespaceExistsErrCode int32 = 48
    47  )
    48  
    49  // FailPoint is a representation of a server fail point.
    50  // See https://github.com/mongodb/specifications/tree/HEAD/source/transactions/tests#server-fail-point
    51  // for more information regarding fail points.
    52  type FailPoint struct {
    53  	ConfigureFailPoint string `bson:"configureFailPoint"`
    54  	// Mode should be a string, FailPointMode, or map[string]interface{}
    55  	Mode interface{}   `bson:"mode"`
    56  	Data FailPointData `bson:"data"`
    57  }
    58  
    59  // FailPointMode is a representation of the Failpoint.Mode field.
    60  type FailPointMode struct {
    61  	Times int32 `bson:"times"`
    62  	Skip  int32 `bson:"skip"`
    63  }
    64  
    65  // FailPointData is a representation of the FailPoint.Data field.
    66  type FailPointData struct {
    67  	FailCommands                  []string               `bson:"failCommands,omitempty"`
    68  	CloseConnection               bool                   `bson:"closeConnection,omitempty"`
    69  	ErrorCode                     int32                  `bson:"errorCode,omitempty"`
    70  	FailBeforeCommitExceptionCode int32                  `bson:"failBeforeCommitExceptionCode,omitempty"`
    71  	ErrorLabels                   *[]string              `bson:"errorLabels,omitempty"`
    72  	WriteConcernError             *WriteConcernErrorData `bson:"writeConcernError,omitempty"`
    73  	BlockConnection               bool                   `bson:"blockConnection,omitempty"`
    74  	BlockTimeMS                   int32                  `bson:"blockTimeMS,omitempty"`
    75  	AppName                       string                 `bson:"appName,omitempty"`
    76  }
    77  
    78  // WriteConcernErrorData is a representation of the FailPoint.Data.WriteConcern field.
    79  type WriteConcernErrorData struct {
    80  	Code        int32     `bson:"code"`
    81  	Name        string    `bson:"codeName"`
    82  	Errmsg      string    `bson:"errmsg"`
    83  	ErrorLabels *[]string `bson:"errorLabels,omitempty"`
    84  	ErrInfo     bson.Raw  `bson:"errInfo,omitempty"`
    85  }
    86  
    87  // T is a wrapper around testing.T.
    88  type T struct {
    89  	// connsCheckedOut is the net number of connections checked out during test execution.
    90  	// It must be accessed using the atomic package and should be at the beginning of the struct.
    91  	// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
    92  	// - suggested layout: https://go101.org/article/memory-layout.html
    93  	connsCheckedOut int64
    94  
    95  	*testing.T
    96  
    97  	// members for only this T instance
    98  	createClient      *bool
    99  	createCollection  *bool
   100  	runOn             []RunOnBlock
   101  	mockDeployment    *mockDeployment // nil if the test is not being run against a mock
   102  	mockResponses     []bson.D
   103  	createdColls      []*Collection // collections created in this test
   104  	proxyDialer       *proxyDialer
   105  	dbName, collName  string
   106  	failPointNames    []string
   107  	minServerVersion  string
   108  	maxServerVersion  string
   109  	validTopologies   []TopologyKind
   110  	auth              *bool
   111  	enterprise        *bool
   112  	dataLake          *bool
   113  	ssl               *bool
   114  	collCreateOpts    *options.CreateCollectionOptions
   115  	requireAPIVersion *bool
   116  
   117  	// options copied to sub-tests
   118  	clientType  ClientType
   119  	clientOpts  *options.ClientOptions
   120  	collOpts    *options.CollectionOptions
   121  	shareClient *bool
   122  
   123  	baseOpts *Options // used to create subtests
   124  
   125  	// command monitoring channels
   126  	monitorLock sync.Mutex
   127  	started     []*event.CommandStartedEvent
   128  	succeeded   []*event.CommandSucceededEvent
   129  	failed      []*event.CommandFailedEvent
   130  
   131  	Client *mongo.Client
   132  	DB     *mongo.Database
   133  	Coll   *mongo.Collection
   134  }
   135  
   136  func newT(wrapped *testing.T, opts ...*Options) *T {
   137  	t := &T{
   138  		T: wrapped,
   139  	}
   140  	for _, opt := range opts {
   141  		for _, optFn := range opt.optFuncs {
   142  			optFn(t)
   143  		}
   144  	}
   145  
   146  	if err := t.verifyConstraints(); err != nil {
   147  		t.Skipf("skipping due to environmental constraints: %v", err)
   148  	}
   149  
   150  	if t.collName == "" {
   151  		t.collName = t.Name()
   152  	}
   153  	if t.dbName == "" {
   154  		t.dbName = TestDb
   155  	}
   156  	t.collName = sanitizeCollectionName(t.dbName, t.collName)
   157  
   158  	// create a set of base options for sub-tests
   159  	t.baseOpts = NewOptions().ClientOptions(t.clientOpts).CollectionOptions(t.collOpts).ClientType(t.clientType)
   160  	if t.shareClient != nil {
   161  		t.baseOpts.ShareClient(*t.shareClient)
   162  	}
   163  
   164  	return t
   165  }
   166  
   167  // New creates a new T instance with the given options. If the current environment does not satisfy constraints
   168  // specified in the options, the test will be skipped automatically.
   169  func New(wrapped *testing.T, opts ...*Options) *T {
   170  	// All tests that use mtest.New() are expected to be integration tests, so skip them when the
   171  	// -short flag is included in the "go test" command.
   172  	if testing.Short() {
   173  		wrapped.Skip("skipping mtest integration test in short mode")
   174  	}
   175  
   176  	t := newT(wrapped, opts...)
   177  
   178  	// only create a client if it needs to be shared in sub-tests
   179  	// otherwise, a new client will be created for each subtest
   180  	if t.shareClient != nil && *t.shareClient {
   181  		t.createTestClient()
   182  	}
   183  
   184  	wrapped.Cleanup(t.cleanup)
   185  
   186  	return t
   187  }
   188  
   189  // cleanup cleans up any resources associated with a T. It is intended to be
   190  // called by [testing.T.Cleanup].
   191  func (t *T) cleanup() {
   192  	if t.Client == nil {
   193  		return
   194  	}
   195  
   196  	// only clear collections and fail points if the test is not running against a mock
   197  	if t.clientType != Mock {
   198  		t.ClearCollections()
   199  		t.ClearFailPoints()
   200  	}
   201  
   202  	// always disconnect the client regardless of clientType because Client.Disconnect will work against
   203  	// all deployments
   204  	_ = t.Client.Disconnect(context.Background())
   205  }
   206  
   207  // Run creates a new T instance for a sub-test and runs the given callback. It also creates a new collection using the
   208  // given name which is available to the callback through the T.Coll variable and is dropped after the callback
   209  // returns.
   210  func (t *T) Run(name string, callback func(mt *T)) {
   211  	t.RunOpts(name, NewOptions(), callback)
   212  }
   213  
   214  // RunOpts creates a new T instance for a sub-test with the given options. If the current environment does not satisfy
   215  // constraints specified in the options, the new sub-test will be skipped automatically. If the test is not skipped,
   216  // the callback will be run with the new T instance. RunOpts creates a new collection with the given name which is
   217  // available to the callback through the T.Coll variable and is dropped after the callback returns.
   218  func (t *T) RunOpts(name string, opts *Options, callback func(mt *T)) {
   219  	t.T.Run(name, func(wrapped *testing.T) {
   220  		sub := newT(wrapped, t.baseOpts, opts)
   221  
   222  		// add any mock responses for this test
   223  		if sub.clientType == Mock && len(sub.mockResponses) > 0 {
   224  			sub.AddMockResponses(sub.mockResponses...)
   225  		}
   226  
   227  		// for shareClient, inherit the client from the parent
   228  		if sub.shareClient != nil && *sub.shareClient && sub.clientType == t.clientType {
   229  			sub.Client = t.Client
   230  		}
   231  		// only create a client if not already set
   232  		if sub.Client == nil {
   233  			if sub.createClient == nil || *sub.createClient {
   234  				sub.createTestClient()
   235  			}
   236  		}
   237  		// create a collection for this test
   238  		if sub.Client != nil {
   239  			sub.createTestCollection()
   240  		}
   241  
   242  		// defer dropping all collections if the test is using a client
   243  		defer func() {
   244  			if sub.Client == nil {
   245  				return
   246  			}
   247  
   248  			// store number of sessions and connections checked out here but assert that they're equal to 0 after
   249  			// cleaning up test resources to make sure resources are always cleared
   250  			sessions := sub.Client.NumberSessionsInProgress()
   251  			conns := sub.NumberConnectionsCheckedOut()
   252  
   253  			if sub.clientType != Mock {
   254  				sub.ClearFailPoints()
   255  				sub.ClearCollections()
   256  			}
   257  			// only disconnect client if it's not being shared
   258  			if sub.shareClient == nil || !*sub.shareClient {
   259  				_ = sub.Client.Disconnect(context.Background())
   260  			}
   261  			assert.Equal(sub, 0, sessions, "%v sessions checked out", sessions)
   262  			assert.Equal(sub, 0, conns, "%v connections checked out", conns)
   263  		}()
   264  
   265  		// clear any events that may have happened during setup and run the test
   266  		sub.ClearEvents()
   267  		callback(sub)
   268  	})
   269  }
   270  
   271  // AddMockResponses adds responses to be returned by the mock deployment. This should only be used if T is being run
   272  // against a mock deployment.
   273  func (t *T) AddMockResponses(responses ...bson.D) {
   274  	t.mockDeployment.addResponses(responses...)
   275  }
   276  
   277  // ClearMockResponses clears all responses in the mock deployment.
   278  func (t *T) ClearMockResponses() {
   279  	t.mockDeployment.clearResponses()
   280  }
   281  
   282  // GetStartedEvent returns the most recent CommandStartedEvent, or nil if one is not present.
   283  // This can only be called once per event.
   284  func (t *T) GetStartedEvent() *event.CommandStartedEvent {
   285  	// TODO(GODRIVER-2075): GetStartedEvent documents that it returns the most recent event, but actually returns the first
   286  	// TODO event. Update either the documentation or implementation.
   287  	if len(t.started) == 0 {
   288  		return nil
   289  	}
   290  	e := t.started[0]
   291  	t.started = t.started[1:]
   292  	return e
   293  }
   294  
   295  // GetSucceededEvent returns the most recent CommandSucceededEvent, or nil if one is not present.
   296  // This can only be called once per event.
   297  func (t *T) GetSucceededEvent() *event.CommandSucceededEvent {
   298  	// TODO(GODRIVER-2075): GetSucceededEvent documents that it returns the most recent event, but actually returns the
   299  	// TODO first event. Update either the documentation or implementation.
   300  	if len(t.succeeded) == 0 {
   301  		return nil
   302  	}
   303  	e := t.succeeded[0]
   304  	t.succeeded = t.succeeded[1:]
   305  	return e
   306  }
   307  
   308  // GetFailedEvent returns the most recent CommandFailedEvent, or nil if one is not present.
   309  // This can only be called once per event.
   310  func (t *T) GetFailedEvent() *event.CommandFailedEvent {
   311  	// TODO(GODRIVER-2075): GetFailedEvent documents that it returns the most recent event, but actually  returns the first
   312  	// TODO event. Update either the documentation or implementation.
   313  	if len(t.failed) == 0 {
   314  		return nil
   315  	}
   316  	e := t.failed[0]
   317  	t.failed = t.failed[1:]
   318  	return e
   319  }
   320  
   321  // GetAllStartedEvents returns a slice of all CommandStartedEvent instances for this test. This can be called multiple
   322  // times.
   323  func (t *T) GetAllStartedEvents() []*event.CommandStartedEvent {
   324  	return t.started
   325  }
   326  
   327  // GetAllSucceededEvents returns a slice of all CommandSucceededEvent instances for this test. This can be called multiple
   328  // times.
   329  func (t *T) GetAllSucceededEvents() []*event.CommandSucceededEvent {
   330  	return t.succeeded
   331  }
   332  
   333  // GetAllFailedEvents returns a slice of all CommandFailedEvent instances for this test. This can be called multiple
   334  // times.
   335  func (t *T) GetAllFailedEvents() []*event.CommandFailedEvent {
   336  	return t.failed
   337  }
   338  
   339  // FilterStartedEvents filters the existing CommandStartedEvent instances for this test using the provided filter
   340  // callback. An event will be retained if the filter returns true. The list of filtered events will be used to overwrite
   341  // the list of events for this test and will therefore change the output of t.GetAllStartedEvents().
   342  func (t *T) FilterStartedEvents(filter func(*event.CommandStartedEvent) bool) {
   343  	var newEvents []*event.CommandStartedEvent
   344  	for _, evt := range t.started {
   345  		if filter(evt) {
   346  			newEvents = append(newEvents, evt)
   347  		}
   348  	}
   349  	t.started = newEvents
   350  }
   351  
   352  // FilterSucceededEvents filters the existing CommandSucceededEvent instances for this test using the provided filter
   353  // callback. An event will be retained if the filter returns true. The list of filtered events will be used to overwrite
   354  // the list of events for this test and will therefore change the output of t.GetAllSucceededEvents().
   355  func (t *T) FilterSucceededEvents(filter func(*event.CommandSucceededEvent) bool) {
   356  	var newEvents []*event.CommandSucceededEvent
   357  	for _, evt := range t.succeeded {
   358  		if filter(evt) {
   359  			newEvents = append(newEvents, evt)
   360  		}
   361  	}
   362  	t.succeeded = newEvents
   363  }
   364  
   365  // FilterFailedEvents filters the existing CommandFailedEVent instances for this test using the provided filter
   366  // callback. An event will be retained if the filter returns true. The list of filtered events will be used to overwrite
   367  // the list of events for this test and will therefore change the output of t.GetAllFailedEvents().
   368  func (t *T) FilterFailedEvents(filter func(*event.CommandFailedEvent) bool) {
   369  	var newEvents []*event.CommandFailedEvent
   370  	for _, evt := range t.failed {
   371  		if filter(evt) {
   372  			newEvents = append(newEvents, evt)
   373  		}
   374  	}
   375  	t.failed = newEvents
   376  }
   377  
   378  // GetProxiedMessages returns the messages proxied to the server by the test. If the client type is not Proxy, this
   379  // returns nil.
   380  func (t *T) GetProxiedMessages() []*ProxyMessage {
   381  	if t.proxyDialer == nil {
   382  		return nil
   383  	}
   384  	return t.proxyDialer.Messages()
   385  }
   386  
   387  // NumberConnectionsCheckedOut returns the number of connections checked out from the test Client.
   388  func (t *T) NumberConnectionsCheckedOut() int {
   389  	return int(atomic.LoadInt64(&t.connsCheckedOut))
   390  }
   391  
   392  // ClearEvents clears the existing command monitoring events.
   393  func (t *T) ClearEvents() {
   394  	t.started = t.started[:0]
   395  	t.succeeded = t.succeeded[:0]
   396  	t.failed = t.failed[:0]
   397  }
   398  
   399  // ResetClient resets the existing client with the given options. If opts is nil, the existing options will be used.
   400  // If t.Coll is not-nil, it will be reset to use the new client. Should only be called if the existing client is
   401  // not nil. This will Disconnect the existing client but will not drop existing collections. To do so, ClearCollections
   402  // must be called before calling ResetClient.
   403  func (t *T) ResetClient(opts *options.ClientOptions) {
   404  	if opts != nil {
   405  		t.clientOpts = opts
   406  	}
   407  
   408  	_ = t.Client.Disconnect(context.Background())
   409  	t.createTestClient()
   410  	t.DB = t.Client.Database(t.dbName)
   411  	t.Coll = t.DB.Collection(t.collName, t.collOpts)
   412  
   413  	for _, coll := range t.createdColls {
   414  		// If the collection was created using a different Client, it doesn't need to be reset.
   415  		if coll.hasDifferentClient {
   416  			continue
   417  		}
   418  
   419  		// If the namespace is the same as t.Coll, we can use t.Coll.
   420  		if coll.created.Name() == t.collName && coll.created.Database().Name() == t.dbName {
   421  			coll.created = t.Coll
   422  			continue
   423  		}
   424  
   425  		// Otherwise, reset the collection to use the new Client.
   426  		coll.created = t.Client.Database(coll.DB).Collection(coll.Name, coll.Opts)
   427  	}
   428  }
   429  
   430  // Collection is used to configure a new collection created during a test.
   431  type Collection struct {
   432  	Name               string
   433  	DB                 string        // defaults to mt.DB.Name() if not specified
   434  	Client             *mongo.Client // defaults to mt.Client if not specified
   435  	Opts               *options.CollectionOptions
   436  	CreateOpts         *options.CreateCollectionOptions
   437  	ViewOn             string
   438  	ViewPipeline       interface{}
   439  	hasDifferentClient bool
   440  	created            *mongo.Collection // the actual collection that was created
   441  }
   442  
   443  // CreateCollection creates a new collection with the given configuration. The collection will be dropped after the test
   444  // finishes running. If createOnServer is true, the function ensures that the collection has been created server-side
   445  // by running the create command. The create command will appear in command monitoring channels.
   446  func (t *T) CreateCollection(coll Collection, createOnServer bool) *mongo.Collection {
   447  	if coll.DB == "" {
   448  		coll.DB = t.DB.Name()
   449  	}
   450  	if coll.Client == nil {
   451  		coll.Client = t.Client
   452  	}
   453  	coll.hasDifferentClient = coll.Client != t.Client
   454  
   455  	db := coll.Client.Database(coll.DB)
   456  
   457  	if coll.CreateOpts != nil && coll.CreateOpts.EncryptedFields != nil {
   458  		// An encrypted collection consists of a data collection and three state collections.
   459  		// Aborted test runs may leave these collections.
   460  		// Drop all four collections to avoid a quiet failure to create all collections.
   461  		DropEncryptedCollection(t, db.Collection(coll.Name), coll.CreateOpts.EncryptedFields)
   462  	}
   463  
   464  	if createOnServer && t.clientType != Mock {
   465  		var err error
   466  		if coll.ViewOn != "" {
   467  			err = db.CreateView(context.Background(), coll.Name, coll.ViewOn, coll.ViewPipeline)
   468  		} else {
   469  			err = db.CreateCollection(context.Background(), coll.Name, coll.CreateOpts)
   470  		}
   471  
   472  		// ignore ErrUnacknowledgedWrite. Client may be configured with unacknowledged write concern.
   473  		if err != nil && !errors.Is(err, driver.ErrUnacknowledgedWrite) {
   474  			// ignore NamespaceExists errors for idempotency
   475  
   476  			var cmdErr mongo.CommandError
   477  			if !errors.As(err, &cmdErr) || cmdErr.Code != namespaceExistsErrCode {
   478  				t.Fatalf("error creating collection or view: %v on server: %v", coll.Name, err)
   479  			}
   480  		}
   481  	}
   482  
   483  	coll.created = db.Collection(coll.Name, coll.Opts)
   484  	t.createdColls = append(t.createdColls, &coll)
   485  	return coll.created
   486  }
   487  
   488  // DropEncryptedCollection drops a collection with EncryptedFields.
   489  // The EncryptedFields option is not supported in Collection.Drop(). See GODRIVER-2413.
   490  func DropEncryptedCollection(t *T, coll *mongo.Collection, encryptedFields interface{}) {
   491  	t.Helper()
   492  
   493  	var efBSON bsoncore.Document
   494  	efBSON, err := bson.Marshal(encryptedFields)
   495  	assert.Nil(t, err, "error in Marshal: %v", err)
   496  
   497  	// Drop the two encryption-related, associated collections: `escCollection` and `ecocCollection`.
   498  	// Drop ESCCollection.
   499  	escCollection, err := csfle.GetEncryptedStateCollectionName(efBSON, coll.Name(), csfle.EncryptedStateCollection)
   500  	assert.Nil(t, err, "error in getEncryptedStateCollectionName: %v", err)
   501  	err = coll.Database().Collection(escCollection).Drop(context.Background())
   502  	assert.Nil(t, err, "error in Drop: %v", err)
   503  
   504  	// Drop ECOCCollection.
   505  	ecocCollection, err := csfle.GetEncryptedStateCollectionName(efBSON, coll.Name(), csfle.EncryptedCompactionCollection)
   506  	assert.Nil(t, err, "error in getEncryptedStateCollectionName: %v", err)
   507  	err = coll.Database().Collection(ecocCollection).Drop(context.Background())
   508  	assert.Nil(t, err, "error in Drop: %v", err)
   509  
   510  	// Drop the data collection.
   511  	err = coll.Drop(context.Background())
   512  	assert.Nil(t, err, "error in Drop: %v", err)
   513  }
   514  
   515  // ClearCollections drops all collections previously created by this test.
   516  func (t *T) ClearCollections() {
   517  	// Collections should not be dropped when testing against Atlas Data Lake because the data is pre-inserted.
   518  	if !testContext.dataLake {
   519  		for _, coll := range t.createdColls {
   520  			if coll.CreateOpts != nil && coll.CreateOpts.EncryptedFields != nil {
   521  				DropEncryptedCollection(t, coll.created, coll.CreateOpts.EncryptedFields)
   522  			}
   523  
   524  			err := coll.created.Drop(context.Background())
   525  			if errors.Is(err, mongo.ErrUnacknowledgedWrite) || errors.Is(err, driver.ErrUnacknowledgedWrite) {
   526  				// It's possible that a collection could have an unacknowledged write concern, which
   527  				// could prevent it from being dropped for sharded clusters. We can resolve this by
   528  				// re-instantiating the collection with a majority write concern before dropping.
   529  				collname := coll.created.Name()
   530  				wcm := writeconcern.New(writeconcern.WMajority(), writeconcern.WTimeout(1*time.Second))
   531  				wccoll := t.DB.Collection(collname, options.Collection().SetWriteConcern(wcm))
   532  				_ = wccoll.Drop(context.Background())
   533  
   534  			}
   535  		}
   536  	}
   537  	t.createdColls = t.createdColls[:0]
   538  }
   539  
   540  // SetFailPoint sets a fail point for the client associated with T. Commands to create the failpoint will appear
   541  // in command monitoring channels. The fail point will automatically be disabled after this test has run.
   542  func (t *T) SetFailPoint(fp FailPoint) {
   543  	// ensure mode fields are int32
   544  	if modeMap, ok := fp.Mode.(map[string]interface{}); ok {
   545  		var key string
   546  		var err error
   547  
   548  		if times, ok := modeMap["times"]; ok {
   549  			key = "times"
   550  			modeMap["times"], err = t.interfaceToInt32(times)
   551  		}
   552  		if skip, ok := modeMap["skip"]; ok {
   553  			key = "skip"
   554  			modeMap["skip"], err = t.interfaceToInt32(skip)
   555  		}
   556  
   557  		if err != nil {
   558  			t.Fatalf("error converting %s to int32: %v", key, err)
   559  		}
   560  	}
   561  
   562  	if err := SetFailPoint(fp, t.Client); err != nil {
   563  		t.Fatal(err)
   564  	}
   565  	t.failPointNames = append(t.failPointNames, fp.ConfigureFailPoint)
   566  }
   567  
   568  // SetFailPointFromDocument sets the fail point represented by the given document for the client associated with T. This
   569  // method assumes that the given document is in the form {configureFailPoint: <failPointName>, ...}. Commands to create
   570  // the failpoint will appear in command monitoring channels. The fail point will be automatically disabled after this
   571  // test has run.
   572  func (t *T) SetFailPointFromDocument(fp bson.Raw) {
   573  	if err := SetRawFailPoint(fp, t.Client); err != nil {
   574  		t.Fatal(err)
   575  	}
   576  
   577  	name := fp.Index(0).Value().StringValue()
   578  	t.failPointNames = append(t.failPointNames, name)
   579  }
   580  
   581  // TrackFailPoint adds the given fail point to the list of fail points to be disabled when the current test finishes.
   582  // This function does not create a fail point on the server.
   583  func (t *T) TrackFailPoint(fpName string) {
   584  	t.failPointNames = append(t.failPointNames, fpName)
   585  }
   586  
   587  // ClearFailPoints disables all previously set failpoints for this test.
   588  func (t *T) ClearFailPoints() {
   589  	db := t.Client.Database("admin")
   590  	for _, fp := range t.failPointNames {
   591  		cmd := bson.D{
   592  			{"configureFailPoint", fp},
   593  			{"mode", "off"},
   594  		}
   595  		err := db.RunCommand(context.Background(), cmd).Err()
   596  		if err != nil {
   597  			t.Fatalf("error clearing fail point %s: %v", fp, err)
   598  		}
   599  	}
   600  	t.failPointNames = t.failPointNames[:0]
   601  }
   602  
   603  // CloneDatabase modifies the default database for this test to match the given options.
   604  func (t *T) CloneDatabase(opts *options.DatabaseOptions) {
   605  	t.DB = t.Client.Database(t.dbName, opts)
   606  }
   607  
   608  // CloneCollection modifies the default collection for this test to match the given options.
   609  func (t *T) CloneCollection(opts *options.CollectionOptions) {
   610  	var err error
   611  	t.Coll, err = t.Coll.Clone(opts)
   612  	assert.Nil(t, err, "error cloning collection: %v", err)
   613  }
   614  
   615  func sanitizeCollectionName(db string, coll string) string {
   616  	// Collections can't have "$" in their names, so we substitute it with "%".
   617  	coll = strings.Replace(coll, "$", "%", -1)
   618  
   619  	// Namespaces can only have 120 bytes max.
   620  	if len(db+"."+coll) >= 120 {
   621  		// coll len must be <= remaining
   622  		remaining := 120 - (len(db) + 1) // +1 for "."
   623  		coll = coll[len(coll)-remaining:]
   624  	}
   625  	return coll
   626  }
   627  
   628  func (t *T) createTestClient() {
   629  	clientOpts := t.clientOpts
   630  	if clientOpts == nil {
   631  		// default opts
   632  		clientOpts = options.Client().SetWriteConcern(MajorityWc).SetReadPreference(PrimaryRp)
   633  	}
   634  	// set ServerAPIOptions to latest version if required
   635  	if clientOpts.Deployment == nil && t.clientType != Mock && clientOpts.ServerAPIOptions == nil && testContext.requireAPIVersion {
   636  		clientOpts.SetServerAPIOptions(options.ServerAPI(driver.TestServerAPIVersion))
   637  	}
   638  
   639  	// Setup command monitor
   640  	var customMonitor = clientOpts.Monitor
   641  	clientOpts.SetMonitor(&event.CommandMonitor{
   642  		Started: func(_ context.Context, cse *event.CommandStartedEvent) {
   643  			if customMonitor != nil && customMonitor.Started != nil {
   644  				customMonitor.Started(context.Background(), cse)
   645  			}
   646  			t.monitorLock.Lock()
   647  			defer t.monitorLock.Unlock()
   648  			t.started = append(t.started, cse)
   649  		},
   650  		Succeeded: func(_ context.Context, cse *event.CommandSucceededEvent) {
   651  			if customMonitor != nil && customMonitor.Succeeded != nil {
   652  				customMonitor.Succeeded(context.Background(), cse)
   653  			}
   654  			t.monitorLock.Lock()
   655  			defer t.monitorLock.Unlock()
   656  			t.succeeded = append(t.succeeded, cse)
   657  		},
   658  		Failed: func(_ context.Context, cfe *event.CommandFailedEvent) {
   659  			if customMonitor != nil && customMonitor.Failed != nil {
   660  				customMonitor.Failed(context.Background(), cfe)
   661  			}
   662  			t.monitorLock.Lock()
   663  			defer t.monitorLock.Unlock()
   664  			t.failed = append(t.failed, cfe)
   665  		},
   666  	})
   667  	// only specify connection pool monitor if no deployment is given
   668  	if clientOpts.Deployment == nil {
   669  		previousPoolMonitor := clientOpts.PoolMonitor
   670  
   671  		clientOpts.SetPoolMonitor(&event.PoolMonitor{
   672  			Event: func(evt *event.PoolEvent) {
   673  				if previousPoolMonitor != nil {
   674  					previousPoolMonitor.Event(evt)
   675  				}
   676  
   677  				switch evt.Type {
   678  				case event.GetSucceeded:
   679  					atomic.AddInt64(&t.connsCheckedOut, 1)
   680  				case event.ConnectionReturned:
   681  					atomic.AddInt64(&t.connsCheckedOut, -1)
   682  				}
   683  			},
   684  		})
   685  	}
   686  
   687  	var err error
   688  	switch t.clientType {
   689  	case Pinned:
   690  		// pin to first mongos
   691  		pinnedHostList := []string{testContext.connString.Hosts[0]}
   692  		uriOpts := options.Client().ApplyURI(testContext.connString.Original).SetHosts(pinnedHostList)
   693  		t.Client, err = mongo.NewClient(uriOpts, clientOpts)
   694  	case Mock:
   695  		// clear pool monitor to avoid configuration error
   696  		clientOpts.PoolMonitor = nil
   697  		t.mockDeployment = newMockDeployment()
   698  		clientOpts.Deployment = t.mockDeployment
   699  		t.Client, err = mongo.NewClient(clientOpts)
   700  	case Proxy:
   701  		t.proxyDialer = newProxyDialer()
   702  		clientOpts.SetDialer(t.proxyDialer)
   703  
   704  		// After setting the Dialer, fall-through to the Default case to apply the correct URI
   705  		fallthrough
   706  	case Default:
   707  		// Use a different set of options to specify the URI because clientOpts may already have a URI or host seedlist
   708  		// specified.
   709  		var uriOpts *options.ClientOptions
   710  		if clientOpts.Deployment == nil {
   711  			// Only specify URI if the deployment is not set to avoid setting topology/server options along with the
   712  			// deployment.
   713  			uriOpts = options.Client().ApplyURI(testContext.connString.Original)
   714  		}
   715  
   716  		// Pass in uriOpts first so clientOpts wins if there are any conflicting settings.
   717  		t.Client, err = mongo.NewClient(uriOpts, clientOpts)
   718  	}
   719  	if err != nil {
   720  		t.Fatalf("error creating client: %v", err)
   721  	}
   722  	if err := t.Client.Connect(context.Background()); err != nil {
   723  		t.Fatalf("error connecting client: %v", err)
   724  	}
   725  }
   726  
   727  func (t *T) createTestCollection() {
   728  	t.DB = t.Client.Database(t.dbName)
   729  	t.createdColls = t.createdColls[:0]
   730  
   731  	// Collections should not be explicitly created when testing against Atlas Data Lake because they already exist in
   732  	// the server with pre-seeded data.
   733  	createOnServer := (t.createCollection == nil || *t.createCollection) && !testContext.dataLake
   734  	t.Coll = t.CreateCollection(Collection{
   735  		Name:       t.collName,
   736  		CreateOpts: t.collCreateOpts,
   737  		Opts:       t.collOpts,
   738  	}, createOnServer)
   739  }
   740  
   741  // verifyVersionConstraints returns an error if the cluster's server version is not in the range [min, max]. Server
   742  // versions will only be checked if they are non-empty.
   743  func verifyVersionConstraints(min, max string) error {
   744  	if min != "" && CompareServerVersions(testContext.serverVersion, min) < 0 {
   745  		return fmt.Errorf("server version %q is lower than min required version %q", testContext.serverVersion, min)
   746  	}
   747  	if max != "" && CompareServerVersions(testContext.serverVersion, max) > 0 {
   748  		return fmt.Errorf("server version %q is higher than max version %q", testContext.serverVersion, max)
   749  	}
   750  	return nil
   751  }
   752  
   753  // verifyTopologyConstraints returns an error if the cluster's topology kind does not match one of the provided
   754  // kinds. If the topologies slice is empty, nil is returned without any additional checks.
   755  func verifyTopologyConstraints(topologies []TopologyKind) error {
   756  	if len(topologies) == 0 {
   757  		return nil
   758  	}
   759  
   760  	for _, topo := range topologies {
   761  		// For ShardedReplicaSet, we won't get an exact match because testContext.topoKind will be Sharded so we do an
   762  		// additional comparison with the testContext.shardedReplicaSet field.
   763  		if topo == testContext.topoKind || (topo == ShardedReplicaSet && testContext.shardedReplicaSet) {
   764  			return nil
   765  		}
   766  	}
   767  	return fmt.Errorf("topology kind %q does not match any of the required kinds %q", testContext.topoKind, topologies)
   768  }
   769  
   770  func verifyServerParametersConstraints(serverParameters map[string]bson.RawValue) error {
   771  	for param, expected := range serverParameters {
   772  		actual, err := testContext.serverParameters.LookupErr(param)
   773  		if err != nil {
   774  			return fmt.Errorf("server does not support parameter %q", param)
   775  		}
   776  		if !expected.Equal(actual) {
   777  			return fmt.Errorf("mismatched values for server parameter %q; expected %s, got %s", param, expected, actual)
   778  		}
   779  	}
   780  	return nil
   781  }
   782  
   783  func verifyAuthConstraint(expected *bool) error {
   784  	if expected != nil && *expected != testContext.authEnabled {
   785  		return fmt.Errorf("test requires auth value: %v, cluster auth value: %v", *expected, testContext.authEnabled)
   786  	}
   787  	return nil
   788  }
   789  
   790  func verifyServerlessConstraint(expected string) error {
   791  	switch expected {
   792  	case "require":
   793  		if !testContext.serverless {
   794  			return fmt.Errorf("test requires serverless")
   795  		}
   796  	case "forbid":
   797  		if testContext.serverless {
   798  			return fmt.Errorf("test forbids serverless")
   799  		}
   800  	case "allow", "":
   801  	default:
   802  		return fmt.Errorf("invalid value for serverless: %s", expected)
   803  	}
   804  	return nil
   805  }
   806  
   807  // verifyRunOnBlockConstraint returns an error if the current environment does not match the provided RunOnBlock.
   808  func verifyRunOnBlockConstraint(rob RunOnBlock) error {
   809  	if err := verifyVersionConstraints(rob.MinServerVersion, rob.MaxServerVersion); err != nil {
   810  		return err
   811  	}
   812  	if err := verifyTopologyConstraints(rob.Topology); err != nil {
   813  		return err
   814  	}
   815  
   816  	// Tests in the unified test format have runOn.auth to indicate whether the
   817  	// test should be run against an auth-enabled configuration. SDAM integration
   818  	// spec tests have runOn.authEnabled to indicate the same thing. Use whichever
   819  	// is set for verifyAuthConstraint().
   820  	auth := rob.Auth
   821  	if rob.AuthEnabled != nil {
   822  		if auth != nil {
   823  			return fmt.Errorf("runOnBlock cannot specify both auth and authEnabled")
   824  		}
   825  		auth = rob.AuthEnabled
   826  	}
   827  	if err := verifyAuthConstraint(auth); err != nil {
   828  		return err
   829  	}
   830  
   831  	if err := verifyServerlessConstraint(rob.Serverless); err != nil {
   832  		return err
   833  	}
   834  	if err := verifyServerParametersConstraints(rob.ServerParameters); err != nil {
   835  		return err
   836  	}
   837  
   838  	if rob.CSFLE != nil {
   839  		if *rob.CSFLE && !IsCSFLEEnabled() {
   840  			return fmt.Errorf("runOnBlock requires CSFLE to be enabled. Build with the cse tag to enable")
   841  		} else if !*rob.CSFLE && IsCSFLEEnabled() {
   842  			return fmt.Errorf("runOnBlock requires CSFLE to be disabled. Build without the cse tag to disable")
   843  		}
   844  		if *rob.CSFLE {
   845  			if err := verifyVersionConstraints("4.2", ""); err != nil {
   846  				return err
   847  			}
   848  		}
   849  	}
   850  	return nil
   851  }
   852  
   853  // verifyConstraints returns an error if the current environment does not match the constraints specified for the test.
   854  func (t *T) verifyConstraints() error {
   855  	// Check constraints not specified as runOn blocks
   856  	if err := verifyVersionConstraints(t.minServerVersion, t.maxServerVersion); err != nil {
   857  		return err
   858  	}
   859  	if err := verifyTopologyConstraints(t.validTopologies); err != nil {
   860  		return err
   861  	}
   862  	if err := verifyAuthConstraint(t.auth); err != nil {
   863  		return err
   864  	}
   865  	if t.ssl != nil && *t.ssl != testContext.sslEnabled {
   866  		return fmt.Errorf("test requires ssl value: %v, cluster ssl value: %v", *t.ssl, testContext.sslEnabled)
   867  	}
   868  	if t.enterprise != nil && *t.enterprise != testContext.enterpriseServer {
   869  		return fmt.Errorf("test requires enterprise value: %v, cluster enterprise value: %v", *t.enterprise,
   870  			testContext.enterpriseServer)
   871  	}
   872  	if t.dataLake != nil && *t.dataLake != testContext.dataLake {
   873  		return fmt.Errorf("test requires cluster to be data lake: %v, cluster is data lake: %v", *t.dataLake,
   874  			testContext.dataLake)
   875  	}
   876  	if t.requireAPIVersion != nil && *t.requireAPIVersion != testContext.requireAPIVersion {
   877  		return fmt.Errorf("test requires RequireAPIVersion value: %v, local RequireAPIVersion value: %v", *t.requireAPIVersion,
   878  			testContext.requireAPIVersion)
   879  	}
   880  
   881  	// Check runOn blocks. The test can be executed if there are no blocks or at least block matches the current test
   882  	// setup.
   883  	if len(t.runOn) == 0 {
   884  		return nil
   885  	}
   886  
   887  	// Stop once we find a RunOnBlock that matches the current environment. Record all errors as we go because if we
   888  	// don't find any matching blocks, we want to report the comparison errors for each block.
   889  	runOnErrors := make([]error, 0, len(t.runOn))
   890  	for _, runOn := range t.runOn {
   891  		err := verifyRunOnBlockConstraint(runOn)
   892  		if err == nil {
   893  			return nil
   894  		}
   895  
   896  		runOnErrors = append(runOnErrors, err)
   897  	}
   898  	return fmt.Errorf("no matching RunOnBlock; comparison errors: %v", runOnErrors)
   899  }
   900  
   901  func (t *T) interfaceToInt32(i interface{}) (int32, error) {
   902  	switch conv := i.(type) {
   903  	case int:
   904  		return int32(conv), nil
   905  	case int32:
   906  		return conv, nil
   907  	case int64:
   908  		return int32(conv), nil
   909  	case float64:
   910  		return int32(conv), nil
   911  	}
   912  
   913  	return 0, fmt.Errorf("type %T cannot be converted to int32", i)
   914  }
   915  

View as plain text