...

Source file src/go.mongodb.org/mongo-driver/mongo/integration/unified_spec_test.go

Documentation: go.mongodb.org/mongo-driver/mongo/integration

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package integration
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"io/ioutil"
    14  	"os"
    15  	"path"
    16  	"reflect"
    17  	"sync"
    18  	"testing"
    19  	"time"
    20  	"unsafe"
    21  
    22  	"go.mongodb.org/mongo-driver/bson"
    23  	"go.mongodb.org/mongo-driver/bson/bsoncodec"
    24  	"go.mongodb.org/mongo-driver/bson/bsonrw"
    25  	"go.mongodb.org/mongo-driver/bson/bsontype"
    26  	"go.mongodb.org/mongo-driver/event"
    27  	"go.mongodb.org/mongo-driver/internal/assert"
    28  	"go.mongodb.org/mongo-driver/internal/bsonutil"
    29  	"go.mongodb.org/mongo-driver/internal/integtest"
    30  	"go.mongodb.org/mongo-driver/mongo"
    31  	"go.mongodb.org/mongo-driver/mongo/address"
    32  	"go.mongodb.org/mongo-driver/mongo/gridfs"
    33  	"go.mongodb.org/mongo-driver/mongo/integration/mtest"
    34  	"go.mongodb.org/mongo-driver/mongo/options"
    35  	"go.mongodb.org/mongo-driver/mongo/readconcern"
    36  	"go.mongodb.org/mongo-driver/mongo/readpref"
    37  	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
    38  	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
    39  )
    40  
    41  const (
    42  	gridFSFiles            = "fs.files"
    43  	gridFSChunks           = "fs.chunks"
    44  	spec1403SkipReason     = "servers less than 4.2 do not have mongocryptd; see SPEC-1403"
    45  	godriver2123SkipReason = "failpoints and timeouts together cause failures; see GODRIVER-2123"
    46  	godriver2413SkipReason = "encryptedFields argument is not supported on Collection.Drop; see GODRIVER-2413"
    47  )
    48  
    49  var (
    50  	defaultHeartbeatInterval = 50 * time.Millisecond
    51  	skippedTestDescriptions  = map[string]string{
    52  		// SPEC-1403: This test checks to see if the correct error is thrown when auto encrypting with a server < 4.2.
    53  		// Currently, the test will fail because a server < 4.2 wouldn't have mongocryptd, so Client construction
    54  		// would fail with a mongocryptd spawn error.
    55  		"operation fails with maxWireVersion < 8": spec1403SkipReason,
    56  		// GODRIVER-2123: The two tests below use a failpoint and a socket or server selection timeout.
    57  		// The timeout causes the eventual clearing of the failpoint in the test runner to fail with an
    58  		// i/o timeout.
    59  		"Ignore network timeout error on find":             godriver2123SkipReason,
    60  		"Network error on minPoolSize background creation": godriver2123SkipReason,
    61  		"CreateCollection from encryptedFields.":           godriver2413SkipReason,
    62  		"DropCollection from encryptedFields":              godriver2413SkipReason,
    63  		"DropCollection from remote encryptedFields":       godriver2413SkipReason,
    64  	}
    65  )
    66  
    67  type testFile struct {
    68  	RunOn           []mtest.RunOnBlock `bson:"runOn"`
    69  	DatabaseName    string             `bson:"database_name"`
    70  	CollectionName  string             `bson:"collection_name"`
    71  	BucketName      string             `bson:"bucket_name"`
    72  	Data            testData           `bson:"data"`
    73  	JSONSchema      bson.Raw           `bson:"json_schema"`
    74  	KeyVaultData    []bson.Raw         `bson:"key_vault_data"`
    75  	Tests           []*testCase        `bson:"tests"`
    76  	EncryptedFields bson.Raw           `bson:"encrypted_fields"`
    77  }
    78  
    79  type testData struct {
    80  	Documents  []bson.Raw
    81  	GridFSData struct {
    82  		Files  []bson.Raw `bson:"fs.files"`
    83  		Chunks []bson.Raw `bson:"fs.chunks"`
    84  	}
    85  }
    86  
    87  // custom decoder for testData type
    88  func decodeTestData(dc bsoncodec.DecodeContext, vr bsonrw.ValueReader, val reflect.Value) error {
    89  	switch vr.Type() {
    90  	case bsontype.Array:
    91  		docsVal := val.FieldByName("Documents")
    92  		decoder, err := dc.Registry.LookupDecoder(docsVal.Type())
    93  		if err != nil {
    94  			return err
    95  		}
    96  
    97  		return decoder.DecodeValue(dc, vr, docsVal)
    98  	case bsontype.EmbeddedDocument:
    99  		gridfsDataVal := val.FieldByName("GridFSData")
   100  		decoder, err := dc.Registry.LookupDecoder(gridfsDataVal.Type())
   101  		if err != nil {
   102  			return err
   103  		}
   104  
   105  		return decoder.DecodeValue(dc, vr, gridfsDataVal)
   106  	}
   107  	return nil
   108  }
   109  
   110  type testCase struct {
   111  	Description         string          `bson:"description"`
   112  	SkipReason          string          `bson:"skipReason"`
   113  	FailPoint           *bson.Raw       `bson:"failPoint"`
   114  	ClientOptions       bson.Raw        `bson:"clientOptions"`
   115  	SessionOptions      bson.Raw        `bson:"sessionOptions"`
   116  	Operations          []*operation    `bson:"operations"`
   117  	Expectations        *[]*expectation `bson:"expectations"`
   118  	UseMultipleMongoses bool            `bson:"useMultipleMongoses"`
   119  	Outcome             *outcome        `bson:"outcome"`
   120  
   121  	// set in code if the test is a GridFS test
   122  	chunkSize int32
   123  	bucket    *gridfs.Bucket
   124  
   125  	// set in code to track test context
   126  	testTopology    *topology.Topology
   127  	recordedPrimary address.Address
   128  	monitor         *unifiedRunnerEventMonitor
   129  	routinesMap     sync.Map // maps thread name to *backgroundRoutine
   130  }
   131  
   132  type operation struct {
   133  	Name              string      `bson:"name"`
   134  	Object            string      `bson:"object"`
   135  	CollectionOptions bson.Raw    `bson:"collectionOptions"`
   136  	DatabaseOptions   bson.Raw    `bson:"databaseOptions"`
   137  	Result            interface{} `bson:"result"`
   138  	Arguments         bson.Raw    `bson:"arguments"`
   139  	Error             bool        `bson:"error"`
   140  	CommandName       string      `bson:"command_name"`
   141  
   142  	// set in code after determining whether or not result represents an error
   143  	opError *operationError
   144  }
   145  
   146  type expectation struct {
   147  	CommandStartedEvent *struct {
   148  		CommandName  string                 `bson:"command_name"`
   149  		DatabaseName string                 `bson:"database_name"`
   150  		Command      bson.Raw               `bson:"command"`
   151  		Extra        map[string]interface{} `bson:",inline"`
   152  	} `bson:"command_started_event"`
   153  	CommandSucceededEvent *struct {
   154  		CommandName string                 `bson:"command_name"`
   155  		Reply       bson.Raw               `bson:"reply"`
   156  		Extra       map[string]interface{} `bson:",inline"`
   157  	} `bson:"command_succeeded_event"`
   158  	CommandFailedEvent *struct {
   159  		CommandName string                 `bson:"command_name"`
   160  		Extra       map[string]interface{} `bson:",inline"`
   161  	} `bson:"command_failed_event"`
   162  }
   163  
   164  type outcome struct {
   165  	Collection *outcomeCollection `bson:"collection"`
   166  }
   167  
   168  type outcomeCollection struct {
   169  	Name string      `bson:"name"`
   170  	Data interface{} `bson:"data"`
   171  }
   172  
   173  type operationError struct {
   174  	ErrorContains      *string  `bson:"errorContains"`
   175  	ErrorCodeName      *string  `bson:"errorCodeName"`
   176  	ErrorLabelsContain []string `bson:"errorLabelsContain"`
   177  	ErrorLabelsOmit    []string `bson:"errorLabelsOmit"`
   178  }
   179  
   180  const dataPath string = "../../testdata/"
   181  
   182  var directories = []string{
   183  	"transactions/legacy",
   184  	"convenient-transactions",
   185  	"retryable-reads",
   186  	"read-write-concern/operation",
   187  	"server-discovery-and-monitoring/integration",
   188  	"atlas-data-lake-testing",
   189  }
   190  
   191  var checkOutcomeOpts = options.Collection().SetReadPreference(readpref.Primary()).SetReadConcern(readconcern.Local())
   192  var specTestRegistry = bson.NewRegistryBuilder().
   193  	RegisterTypeMapEntry(bson.TypeEmbeddedDocument, reflect.TypeOf(bson.Raw{})).
   194  	RegisterTypeDecoder(reflect.TypeOf(testData{}), bsoncodec.ValueDecoderFunc(decodeTestData)).Build()
   195  
   196  func TestUnifiedSpecs(t *testing.T) {
   197  	for _, specDir := range directories {
   198  		t.Run(specDir, func(t *testing.T) {
   199  			for _, fileName := range jsonFilesInDir(t, path.Join(dataPath, specDir)) {
   200  				t.Run(fileName, func(t *testing.T) {
   201  					runSpecTestFile(t, specDir, fileName)
   202  				})
   203  			}
   204  		})
   205  	}
   206  }
   207  
   208  // specDir: name of directory for a spec in the data/ folder
   209  // fileName: name of test file in specDir
   210  func runSpecTestFile(t *testing.T, specDir, fileName string) {
   211  	filePath := path.Join(dataPath, specDir, fileName)
   212  	content, err := ioutil.ReadFile(filePath)
   213  	assert.Nil(t, err, "unable to read spec test file %v: %v", filePath, err)
   214  
   215  	var testFile testFile
   216  	err = bson.UnmarshalExtJSONWithRegistry(specTestRegistry, content, false, &testFile)
   217  	assert.Nil(t, err, "unable to unmarshal spec test file at %v: %v", filePath, err)
   218  
   219  	// create mtest wrapper and skip if needed
   220  	mtOpts := mtest.NewOptions().
   221  		RunOn(testFile.RunOn...).
   222  		CreateClient(false)
   223  	if specDir == "atlas-data-lake-testing" {
   224  		mtOpts.AtlasDataLake(true)
   225  	}
   226  	mt := mtest.New(t, mtOpts)
   227  
   228  	for _, test := range testFile.Tests {
   229  		runSpecTestCase(mt, test, testFile)
   230  	}
   231  }
   232  
   233  func runSpecTestCase(mt *mtest.T, test *testCase, testFile testFile) {
   234  	opts := mtest.NewOptions().DatabaseName(testFile.DatabaseName).CollectionName(testFile.CollectionName)
   235  	if mtest.ClusterTopologyKind() == mtest.Sharded && !test.UseMultipleMongoses {
   236  		// pin to a single mongos
   237  		opts = opts.ClientType(mtest.Pinned)
   238  	}
   239  
   240  	cco := options.CreateCollection()
   241  	if len(testFile.JSONSchema) > 0 {
   242  		validator := bson.D{
   243  			{"$jsonSchema", testFile.JSONSchema},
   244  		}
   245  		cco.SetValidator(validator)
   246  	}
   247  
   248  	if len(testFile.EncryptedFields) > 0 {
   249  		cco.SetEncryptedFields(testFile.EncryptedFields)
   250  	}
   251  
   252  	opts.CollectionCreateOptions(cco)
   253  
   254  	// Start the test without setting client options so the setup will be done with a default client.
   255  	mt.RunOpts(test.Description, opts, func(mt *mtest.T) {
   256  		if len(test.SkipReason) > 0 {
   257  			mt.Skip(test.SkipReason)
   258  		}
   259  		if skipReason, ok := skippedTestDescriptions[test.Description]; ok {
   260  			mt.Skipf("skipping due to known failure: %v", skipReason)
   261  		}
   262  
   263  		// work around for SERVER-39704: run a non-transactional distinct against each shard in a sharded cluster
   264  		if mtest.ClusterTopologyKind() == mtest.Sharded && test.Description == "distinct" {
   265  			err := runCommandOnAllServers(func(mongosClient *mongo.Client) error {
   266  				coll := mongosClient.Database(mt.DB.Name()).Collection(mt.Coll.Name())
   267  				_, err := coll.Distinct(context.Background(), "x", bson.D{})
   268  				return err
   269  			})
   270  			assert.Nil(mt, err, "error running distinct against all mongoses: %v", err)
   271  		}
   272  
   273  		// Defer killSessions to ensure it runs regardless of the state of the test because the client has already
   274  		// been created and the collection drop in mongotest will hang for transactions to be aborted (60 seconds)
   275  		// in error cases.
   276  		defer killSessions(mt)
   277  
   278  		// Test setup: create collections that are tracked by mtest, insert test data, and set the failpoint.
   279  		setupTest(mt, &testFile, test)
   280  		if test.FailPoint != nil {
   281  			mt.SetFailPointFromDocument(*test.FailPoint)
   282  		}
   283  
   284  		// Reset the client using the client options specified in the test.
   285  		testClientOpts := createClientOptions(mt, test.ClientOptions)
   286  
   287  		// If AutoEncryptionOptions is set and AutoEncryption isn't disabled (neither
   288  		// bypassAutoEncryption nor bypassQueryAnalysis are true), then add extra options to load
   289  		// the crypt_shared library.
   290  		if testClientOpts.AutoEncryptionOptions != nil {
   291  			bypassAutoEncryption := testClientOpts.AutoEncryptionOptions.BypassAutoEncryption != nil &&
   292  				*testClientOpts.AutoEncryptionOptions.BypassAutoEncryption
   293  			bypassQueryAnalysis := testClientOpts.AutoEncryptionOptions.BypassQueryAnalysis != nil &&
   294  				*testClientOpts.AutoEncryptionOptions.BypassQueryAnalysis
   295  			if !bypassAutoEncryption && !bypassQueryAnalysis {
   296  				if testClientOpts.AutoEncryptionOptions.ExtraOptions == nil {
   297  					testClientOpts.AutoEncryptionOptions.ExtraOptions = make(map[string]interface{})
   298  				}
   299  
   300  				for k, v := range getCryptSharedLibExtraOptions() {
   301  					testClientOpts.AutoEncryptionOptions.ExtraOptions[k] = v
   302  				}
   303  			}
   304  		}
   305  
   306  		test.monitor = newUnifiedRunnerEventMonitor()
   307  		testClientOpts.SetPoolMonitor(&event.PoolMonitor{
   308  			Event: test.monitor.handlePoolEvent,
   309  		})
   310  		testClientOpts.SetServerMonitor(test.monitor.sdamMonitor)
   311  		if testClientOpts.HeartbeatInterval == nil {
   312  			// If one isn't specified in the test, use a low heartbeat frequency so the Client will quickly recover when
   313  			// using failpoints that cause SDAM state changes.
   314  			testClientOpts.SetHeartbeatInterval(defaultHeartbeatInterval)
   315  		}
   316  		mt.ResetClient(testClientOpts)
   317  
   318  		// Record the underlying topology for the test's Client.
   319  		test.testTopology = getTopologyFromClient(mt.Client)
   320  
   321  		// Create the GridFS bucket and sessions after resetting the client so it will be created with a connected
   322  		// client.
   323  		createBucket(mt, testFile, test)
   324  		sess0, sess1 := setupSessions(mt, test)
   325  		if sess0 != nil {
   326  			defer func() {
   327  				sess0.EndSession(context.Background())
   328  				sess1.EndSession(context.Background())
   329  			}()
   330  		}
   331  
   332  		// run operations
   333  		mt.ClearEvents()
   334  		for idx, op := range test.Operations {
   335  			err := runOperation(mt, test, op, sess0, sess1)
   336  			assert.Nil(mt, err, "error running operation %q at index %d: %v", op.Name, idx, err)
   337  		}
   338  
   339  		// Needs to be done here (in spite of defer) because some tests
   340  		// require end session to be called before we check expectation
   341  		sess0.EndSession(context.Background())
   342  		sess1.EndSession(context.Background())
   343  		mt.ClearFailPoints()
   344  
   345  		checkExpectations(mt, test.Expectations, sess0.ID(), sess1.ID())
   346  
   347  		if test.Outcome != nil {
   348  			verifyTestOutcome(mt, test.Outcome.Collection)
   349  		}
   350  	})
   351  }
   352  
   353  func createBucket(mt *mtest.T, testFile testFile, testCase *testCase) {
   354  	if testFile.BucketName == "" {
   355  		return
   356  	}
   357  
   358  	bucketOpts := options.GridFSBucket()
   359  	if testFile.BucketName != "" {
   360  		bucketOpts.SetName(testFile.BucketName)
   361  	}
   362  	chunkSize := testCase.chunkSize
   363  	if chunkSize == 0 {
   364  		chunkSize = gridfs.DefaultChunkSize
   365  	}
   366  	bucketOpts.SetChunkSizeBytes(chunkSize)
   367  
   368  	var err error
   369  	testCase.bucket, err = gridfs.NewBucket(mt.DB, bucketOpts)
   370  	assert.Nil(mt, err, "NewBucket error: %v", err)
   371  }
   372  
   373  func runOperation(mt *mtest.T, testCase *testCase, op *operation, sess0, sess1 mongo.Session) error {
   374  	if op.Name == "count" {
   375  		mt.Skip("count has been deprecated")
   376  	}
   377  
   378  	var sess mongo.Session
   379  	if sessVal, err := op.Arguments.LookupErr("session"); err == nil {
   380  		sessStr := sessVal.StringValue()
   381  		switch sessStr {
   382  		case "session0":
   383  			sess = sess0
   384  		case "session1":
   385  			sess = sess1
   386  		default:
   387  			return fmt.Errorf("unrecognized session identifier: %v", sessStr)
   388  		}
   389  	}
   390  
   391  	if op.Object == "testRunner" {
   392  		return executeTestRunnerOperation(mt, testCase, op, sess)
   393  	}
   394  
   395  	if op.DatabaseOptions != nil {
   396  		mt.CloneDatabase(createDatabaseOptions(mt, op.DatabaseOptions))
   397  	}
   398  	if op.CollectionOptions != nil {
   399  		mt.CloneCollection(createCollectionOptions(mt, op.CollectionOptions))
   400  	}
   401  
   402  	// execute the command on the given object
   403  	var err error
   404  	switch op.Object {
   405  	case "session0":
   406  		err = executeSessionOperation(mt, op, sess0)
   407  	case "session1":
   408  		err = executeSessionOperation(mt, op, sess1)
   409  	case "", "collection":
   410  		// object defaults to "collection" if not specified
   411  		err = executeCollectionOperation(mt, op, sess)
   412  	case "database":
   413  		err = executeDatabaseOperation(mt, op, sess)
   414  	case "gridfsbucket":
   415  		err = executeGridFSOperation(mt, testCase.bucket, op)
   416  	case "client":
   417  		err = executeClientOperation(mt, op, sess)
   418  	default:
   419  		return fmt.Errorf("unrecognized operation object: %v", op.Object)
   420  	}
   421  
   422  	op.opError = errorFromResult(mt, op.Result)
   423  	// Some tests (e.g. crud/v2) only specify that an error should occur via the op.Error field but do not specify
   424  	// which error via the op.Result field. In this case, pass in an empty non-nil operationError so verifyError will
   425  	// make the right assertions.
   426  	if op.Error && op.Result == nil {
   427  		op.opError = &operationError{}
   428  	}
   429  	return verifyError(op.opError, err)
   430  }
   431  
   432  func executeGridFSOperation(mt *mtest.T, bucket *gridfs.Bucket, op *operation) error {
   433  	// no results for GridFS operations
   434  	assert.Nil(mt, op.Result, "unexpected result for GridFS operation")
   435  
   436  	switch op.Name {
   437  	case "download":
   438  		_, err := executeGridFSDownload(mt, bucket, op.Arguments)
   439  		return err
   440  	case "download_by_name":
   441  		_, err := executeGridFSDownloadByName(mt, bucket, op.Arguments)
   442  		return err
   443  	default:
   444  		mt.Fatalf("unrecognized gridfs operation: %v", op.Name)
   445  	}
   446  	return nil
   447  }
   448  
   449  func executeTestRunnerOperation(mt *mtest.T, testCase *testCase, op *operation, sess mongo.Session) error {
   450  	var clientSession *session.Client
   451  	if sess != nil {
   452  		xsess, ok := sess.(mongo.XSession)
   453  		if !ok {
   454  			return fmt.Errorf("expected session type %T to implement mongo.XSession", sess)
   455  		}
   456  		clientSession = xsess.ClientSession()
   457  	}
   458  
   459  	switch op.Name {
   460  	case "targetedFailPoint":
   461  		fpDoc := op.Arguments.Lookup("failPoint")
   462  
   463  		var fp mtest.FailPoint
   464  		if err := bson.Unmarshal(fpDoc.Document(), &fp); err != nil {
   465  			return fmt.Errorf("Unmarshal error: %w", err)
   466  		}
   467  
   468  		if clientSession == nil {
   469  			return errors.New("expected valid session, got nil")
   470  		}
   471  		targetHost := clientSession.PinnedServer.Addr.String()
   472  		opts := options.Client().ApplyURI(mtest.ClusterURI()).SetHosts([]string{targetHost})
   473  		integtest.AddTestServerAPIVersion(opts)
   474  		client, err := mongo.Connect(context.Background(), opts)
   475  		if err != nil {
   476  			return fmt.Errorf("Connect error for targeted client: %w", err)
   477  		}
   478  		defer func() { _ = client.Disconnect(context.Background()) }()
   479  
   480  		if err = client.Database("admin").RunCommand(context.Background(), fp).Err(); err != nil {
   481  			return fmt.Errorf("error setting targeted fail point: %w", err)
   482  		}
   483  		mt.TrackFailPoint(fp.ConfigureFailPoint)
   484  	case "configureFailPoint":
   485  		fp, err := op.Arguments.LookupErr("failPoint")
   486  		if err != nil {
   487  			return fmt.Errorf("unable to find 'failPoint' in arguments: %w", err)
   488  		}
   489  		mt.SetFailPointFromDocument(fp.Document())
   490  	case "assertSessionTransactionState":
   491  		stateVal, err := op.Arguments.LookupErr("state")
   492  		if err != nil {
   493  			return fmt.Errorf("unable to find 'state' in arguments: %w", err)
   494  		}
   495  		expectedState, ok := stateVal.StringValueOK()
   496  		if !ok {
   497  			return errors.New("expected 'state' argument to be string")
   498  		}
   499  
   500  		if clientSession == nil {
   501  			return errors.New("expected valid session, got nil")
   502  		}
   503  		actualState := clientSession.TransactionState.String()
   504  
   505  		// actualState should match expectedState, but "in progress" is the same as
   506  		// "in_progress".
   507  		stateMatch := actualState == expectedState ||
   508  			actualState == "in progress" && expectedState == "in_progress"
   509  		if !stateMatch {
   510  			return fmt.Errorf("expected transaction state %v, got %v", expectedState, actualState)
   511  		}
   512  	case "assertSessionPinned":
   513  		if clientSession == nil {
   514  			return errors.New("expected valid session, got nil")
   515  		}
   516  		if clientSession.PinnedServer == nil {
   517  			return errors.New("expected pinned server, got nil")
   518  		}
   519  	case "assertSessionUnpinned":
   520  		if clientSession == nil {
   521  			return errors.New("expected valid session, got nil")
   522  		}
   523  		// We don't use a combined helper for assertSessionPinned and assertSessionUnpinned because the unpinned
   524  		// case provides the pinned server address in the error msg for debugging.
   525  		if clientSession.PinnedServer != nil {
   526  			return fmt.Errorf("expected pinned server to be nil but got %q", clientSession.PinnedServer.Addr)
   527  		}
   528  	case "assertSameLsidOnLastTwoCommands":
   529  		first, second := lastTwoIDs(mt)
   530  		if !first.Equal(second) {
   531  			return fmt.Errorf("expected last two lsids to be equal but got %v and %v", first, second)
   532  		}
   533  	case "assertDifferentLsidOnLastTwoCommands":
   534  		first, second := lastTwoIDs(mt)
   535  		if first.Equal(second) {
   536  			return fmt.Errorf("expected last two lsids to be not equal but both were %v", first)
   537  		}
   538  	case "assertCollectionExists":
   539  		return verifyCollectionState(op, true)
   540  	case "assertCollectionNotExists":
   541  		return verifyCollectionState(op, false)
   542  	case "assertIndexExists":
   543  		return verifyIndexState(op, true)
   544  	case "assertIndexNotExists":
   545  		return verifyIndexState(op, false)
   546  	case "wait":
   547  		time.Sleep(convertValueToMilliseconds(mt, op.Arguments.Lookup("ms")))
   548  	case "waitForEvent":
   549  		waitForEvent(mt, testCase, op)
   550  	case "assertEventCount":
   551  		assertEventCount(mt, testCase, op)
   552  	case "recordPrimary":
   553  		recordPrimary(mt, testCase)
   554  	case "runAdminCommand":
   555  		executeAdminCommand(mt, op)
   556  	case "waitForPrimaryChange":
   557  		waitForPrimaryChange(mt, testCase, op)
   558  	case "startThread":
   559  		startThread(mt, testCase, op)
   560  	case "runOnThread":
   561  		runOnThread(mt, testCase, op)
   562  	case "waitForThread":
   563  		waitForThread(mt, testCase, op)
   564  	default:
   565  		return fmt.Errorf("unrecognized testRunner operation %v", op.Name)
   566  	}
   567  
   568  	return nil
   569  }
   570  
   571  func verifyIndexState(op *operation, shouldExist bool) error {
   572  	db := op.Arguments.Lookup("database").StringValue()
   573  	coll := op.Arguments.Lookup("collection").StringValue()
   574  	index := op.Arguments.Lookup("index").StringValue()
   575  
   576  	exists, err := indexExists(db, coll, index)
   577  	if err != nil {
   578  		return err
   579  	}
   580  	if exists != shouldExist {
   581  		return fmt.Errorf("index state mismatch for index %s in namespace %s.%s; should exist: %v, exists: %v",
   582  			index, db, coll, shouldExist, exists)
   583  	}
   584  	return nil
   585  }
   586  
   587  func indexExists(dbName, collName, indexName string) (bool, error) {
   588  	// Use global client because listIndexes cannot be executed inside a transaction.
   589  	iv := mtest.GlobalClient().Database(dbName).Collection(collName).Indexes()
   590  	cursor, err := iv.List(context.Background())
   591  	if err != nil {
   592  		return false, fmt.Errorf("IndexView.List error: %w", err)
   593  	}
   594  	defer cursor.Close(context.Background())
   595  
   596  	for cursor.Next(context.Background()) {
   597  		if cursor.Current.Lookup("name").StringValue() == indexName {
   598  			return true, nil
   599  		}
   600  	}
   601  	return false, cursor.Err()
   602  }
   603  
   604  func verifyCollectionState(op *operation, shouldExist bool) error {
   605  	db := op.Arguments.Lookup("database").StringValue()
   606  	coll := op.Arguments.Lookup("collection").StringValue()
   607  
   608  	exists, err := collectionExists(db, coll)
   609  	if err != nil {
   610  		return err
   611  	}
   612  	if exists != shouldExist {
   613  		return fmt.Errorf("collection state mismatch for %s.%s; should exist %v, exists: %v", db, coll, shouldExist,
   614  			exists)
   615  	}
   616  	return nil
   617  }
   618  
   619  func collectionExists(dbName, collName string) (bool, error) {
   620  	filter := bson.D{
   621  		{"name", collName},
   622  	}
   623  
   624  	// Use global client because listCollections cannot be executed inside a transaction.
   625  	collections, err := mtest.GlobalClient().Database(dbName).ListCollectionNames(context.Background(), filter)
   626  	if err != nil {
   627  		return false, fmt.Errorf("ListCollectionNames error: %w", err)
   628  	}
   629  
   630  	return len(collections) > 0, nil
   631  }
   632  
   633  func lastTwoIDs(mt *mtest.T) (bson.RawValue, bson.RawValue) {
   634  	events := mt.GetAllStartedEvents()
   635  	lastTwoEvents := events[len(events)-2:]
   636  
   637  	first := lastTwoEvents[0].Command.Lookup("lsid")
   638  	second := lastTwoEvents[1].Command.Lookup("lsid")
   639  	return first, second
   640  }
   641  
   642  func executeSessionOperation(mt *mtest.T, op *operation, sess mongo.Session) error {
   643  	switch op.Name {
   644  	case "startTransaction":
   645  		var txnOpts *options.TransactionOptions
   646  		if opts, err := op.Arguments.LookupErr("options"); err == nil {
   647  			txnOpts = createTransactionOptions(mt, opts.Document())
   648  		}
   649  		return sess.StartTransaction(txnOpts)
   650  	case "commitTransaction":
   651  		return sess.CommitTransaction(context.Background())
   652  	case "abortTransaction":
   653  		return sess.AbortTransaction(context.Background())
   654  	case "withTransaction":
   655  		return executeWithTransaction(mt, sess, op.Arguments)
   656  	default:
   657  		return fmt.Errorf("unrecognized session operation: %v", op.Name)
   658  	}
   659  }
   660  
   661  func executeCollectionOperation(mt *mtest.T, op *operation, sess mongo.Session) error {
   662  	switch op.Name {
   663  	case "countDocuments":
   664  		// no results to verify with count
   665  		res, err := executeCountDocuments(mt, sess, op.Arguments)
   666  		if op.opError == nil && err == nil {
   667  			verifyCountResult(mt, res, op.Result)
   668  		}
   669  		return err
   670  	case "distinct":
   671  		res, err := executeDistinct(mt, sess, op.Arguments)
   672  		if op.opError == nil && err == nil {
   673  			verifyDistinctResult(mt, res, op.Result)
   674  		}
   675  		return err
   676  	case "insertOne":
   677  		res, err := executeInsertOne(mt, sess, op.Arguments)
   678  		if op.opError == nil && err == nil {
   679  			verifyInsertOneResult(mt, res, op.Result)
   680  		}
   681  		return err
   682  	case "insertMany":
   683  		res, err := executeInsertMany(mt, sess, op.Arguments)
   684  		if op.opError == nil && err == nil {
   685  			verifyInsertManyResult(mt, res, op.Result)
   686  		}
   687  		return err
   688  	case "find":
   689  		cursor, err := executeFind(mt, sess, op.Arguments)
   690  		if op.opError == nil && err == nil {
   691  			verifyCursorResult(mt, cursor, op.Result)
   692  			_ = cursor.Close(context.Background())
   693  		}
   694  		return err
   695  	case "findOneAndDelete":
   696  		res := executeFindOneAndDelete(mt, sess, op.Arguments)
   697  		if op.opError == nil && res.Err() == nil {
   698  			verifySingleResult(mt, res, op.Result)
   699  		}
   700  		return res.Err()
   701  	case "findOneAndUpdate":
   702  		res := executeFindOneAndUpdate(mt, sess, op.Arguments)
   703  		if op.opError == nil && res.Err() == nil {
   704  			verifySingleResult(mt, res, op.Result)
   705  		}
   706  		return res.Err()
   707  	case "findOneAndReplace":
   708  		res := executeFindOneAndReplace(mt, sess, op.Arguments)
   709  		if op.opError == nil && res.Err() == nil {
   710  			verifySingleResult(mt, res, op.Result)
   711  		}
   712  		return res.Err()
   713  	case "deleteOne":
   714  		res, err := executeDeleteOne(mt, sess, op.Arguments)
   715  		if op.opError == nil && err == nil {
   716  			verifyDeleteResult(mt, res, op.Result)
   717  		}
   718  		return err
   719  	case "deleteMany":
   720  		res, err := executeDeleteMany(mt, sess, op.Arguments)
   721  		if op.opError == nil && err == nil {
   722  			verifyDeleteResult(mt, res, op.Result)
   723  		}
   724  		return err
   725  	case "updateOne":
   726  		res, err := executeUpdateOne(mt, sess, op.Arguments)
   727  		if op.opError == nil && err == nil {
   728  			verifyUpdateResult(mt, res, op.Result)
   729  		}
   730  		return err
   731  	case "updateMany":
   732  		res, err := executeUpdateMany(mt, sess, op.Arguments)
   733  		if op.opError == nil && err == nil {
   734  			verifyUpdateResult(mt, res, op.Result)
   735  		}
   736  		return err
   737  	case "replaceOne":
   738  		res, err := executeReplaceOne(mt, sess, op.Arguments)
   739  		if op.opError == nil && err == nil {
   740  			verifyUpdateResult(mt, res, op.Result)
   741  		}
   742  		return err
   743  	case "aggregate":
   744  		cursor, err := executeAggregate(mt, mt.Coll, sess, op.Arguments)
   745  		if op.opError == nil && err == nil {
   746  			verifyCursorResult(mt, cursor, op.Result)
   747  			_ = cursor.Close(context.Background())
   748  		}
   749  		return err
   750  	case "bulkWrite":
   751  		res, err := executeBulkWrite(mt, sess, op.Arguments)
   752  		if op.opError == nil && err == nil {
   753  			verifyBulkWriteResult(mt, res, op.Result)
   754  		}
   755  		return err
   756  	case "estimatedDocumentCount":
   757  		res, err := executeEstimatedDocumentCount(mt, sess, op.Arguments)
   758  		if op.opError == nil && err == nil {
   759  			verifyCountResult(mt, res, op.Result)
   760  		}
   761  		return err
   762  	case "findOne":
   763  		res := executeFindOne(mt, sess, op.Arguments)
   764  		if op.opError == nil && res.Err() == nil {
   765  			verifySingleResult(mt, res, op.Result)
   766  		}
   767  		return res.Err()
   768  	case "listIndexes":
   769  		cursor, err := executeListIndexes(mt, sess, op.Arguments)
   770  		if op.opError == nil && err == nil {
   771  			verifyCursorResult(mt, cursor, op.Result)
   772  			_ = cursor.Close(context.Background())
   773  		}
   774  		return err
   775  	case "watch":
   776  		stream, err := executeWatch(mt, mt.Coll, sess, op.Arguments)
   777  		if op.opError == nil && err == nil {
   778  			assert.Nil(mt, op.Result, "unexpected result for watch: %v", op.Result)
   779  			_ = stream.Close(context.Background())
   780  		}
   781  		return err
   782  	case "createIndex":
   783  		indexName, err := executeCreateIndex(mt, sess, op.Arguments)
   784  		if op.opError == nil && err == nil {
   785  			assert.Nil(mt, op.Result, "unexpected result for createIndex: %v", op.Result)
   786  			assert.True(mt, len(indexName) > 0, "expected valid index name, got empty string")
   787  			assert.True(mt, len(indexName) > 0, "created index has empty name")
   788  		}
   789  		return err
   790  	case "dropIndex":
   791  		res, err := executeDropIndex(mt, sess, op.Arguments)
   792  		if op.opError == nil && err == nil {
   793  			assert.Nil(mt, op.Result, "unexpected result for dropIndex: %v", op.Result)
   794  			assert.NotNil(mt, res, "expected result from dropIndex operation, got nil")
   795  		}
   796  		return err
   797  	case "listIndexNames", "mapReduce":
   798  		mt.Skipf("operation %v not implemented", op.Name)
   799  	default:
   800  		mt.Fatalf("unrecognized collection operation: %v", op.Name)
   801  	}
   802  	return nil
   803  }
   804  
   805  func executeDatabaseOperation(mt *mtest.T, op *operation, sess mongo.Session) error {
   806  	switch op.Name {
   807  	case "runCommand":
   808  		res := executeRunCommand(mt, sess, op.Arguments)
   809  		if op.opError == nil && res.Err() == nil {
   810  			verifySingleResult(mt, res, op.Result)
   811  		}
   812  		return res.Err()
   813  	case "aggregate":
   814  		cursor, err := executeAggregate(mt, mt.DB, sess, op.Arguments)
   815  		if op.opError == nil && err == nil {
   816  			verifyCursorResult(mt, cursor, op.Result)
   817  			_ = cursor.Close(context.Background())
   818  		}
   819  		return err
   820  	case "listCollections":
   821  		cursor, err := executeListCollections(mt, sess, op.Arguments)
   822  		if op.opError == nil && err == nil {
   823  			assert.Nil(mt, op.Result, "unexpected result for listCollections: %v", op.Result)
   824  			_ = cursor.Close(context.Background())
   825  		}
   826  		return err
   827  	case "listCollectionNames":
   828  		_, err := executeListCollectionNames(mt, sess, op.Arguments)
   829  		if op.opError == nil && err == nil {
   830  			assert.Nil(mt, op.Result, "unexpected result for listCollectionNames: %v", op.Result)
   831  		}
   832  		return err
   833  	case "watch":
   834  		stream, err := executeWatch(mt, mt.DB, sess, op.Arguments)
   835  		if op.opError == nil && err == nil {
   836  			assert.Nil(mt, op.Result, "unexpected result for watch: %v", op.Result)
   837  			_ = stream.Close(context.Background())
   838  		}
   839  		return err
   840  	case "dropCollection":
   841  		err := executeDropCollection(mt, sess, op.Arguments)
   842  		if op.opError == nil && err == nil {
   843  			assert.Nil(mt, op.Result, "unexpected result for dropCollection: %v", op.Result)
   844  		}
   845  		return err
   846  	case "createCollection":
   847  		err := executeCreateCollection(mt, sess, op.Arguments)
   848  		if op.opError == nil && err == nil {
   849  			assert.Nil(mt, op.Result, "unexpected result for createCollection: %v", op.Result)
   850  		}
   851  		return err
   852  	case "listCollectionObjects":
   853  		mt.Skipf("operation %v not implemented", op.Name)
   854  	default:
   855  		mt.Fatalf("unrecognized database operation: %v", op.Name)
   856  	}
   857  	return nil
   858  }
   859  
   860  func executeClientOperation(mt *mtest.T, op *operation, sess mongo.Session) error {
   861  	switch op.Name {
   862  	case "listDatabaseNames":
   863  		_, err := executeListDatabaseNames(mt, sess, op.Arguments)
   864  		if op.opError == nil && err == nil {
   865  			assert.Nil(mt, op.Result, "unexpected result for countDocuments: %v", op.Result)
   866  		}
   867  		return err
   868  	case "listDatabases":
   869  		res, err := executeListDatabases(mt, sess, op.Arguments)
   870  		if op.opError == nil && err == nil {
   871  			verifyListDatabasesResult(mt, res, op.Result)
   872  		}
   873  		return err
   874  	case "watch":
   875  		stream, err := executeWatch(mt, mt.Client, sess, op.Arguments)
   876  		if op.opError == nil && err == nil {
   877  			assert.Nil(mt, op.Result, "unexpected result for watch: %v", op.Result)
   878  			_ = stream.Close(context.Background())
   879  		}
   880  		return err
   881  	case "listDatabaseObjects":
   882  		mt.Skipf("operation %v not implemented", op.Name)
   883  	default:
   884  		mt.Fatalf("unrecognized client operation: %v", op.Name)
   885  	}
   886  	return nil
   887  }
   888  
   889  func setupSessions(mt *mtest.T, test *testCase) (mongo.Session, mongo.Session) {
   890  	mt.Helper()
   891  
   892  	var sess0Opts, sess1Opts *options.SessionOptions
   893  	if opts, err := test.SessionOptions.LookupErr("session0"); err == nil {
   894  		sess0Opts = createSessionOptions(mt, opts.Document())
   895  	}
   896  	if opts, err := test.SessionOptions.LookupErr("session1"); err == nil {
   897  		sess1Opts = createSessionOptions(mt, opts.Document())
   898  	}
   899  
   900  	sess0, err := mt.Client.StartSession(sess0Opts)
   901  	assert.Nil(mt, err, "error creating session0: %v", err)
   902  	sess1, err := mt.Client.StartSession(sess1Opts)
   903  	assert.Nil(mt, err, "error creating session1: %v", err)
   904  
   905  	return sess0, sess1
   906  }
   907  
   908  func insertDocuments(mt *mtest.T, coll *mongo.Collection, rawDocs []bson.Raw) {
   909  	mt.Helper()
   910  
   911  	docsToInsert := bsonutil.RawToInterfaces(rawDocs...)
   912  	if len(docsToInsert) == 0 {
   913  		return
   914  	}
   915  
   916  	_, err := coll.InsertMany(context.Background(), docsToInsert)
   917  	assert.Nil(mt, err, "InsertMany error for collection %v: %v", coll.Name(), err)
   918  }
   919  
   920  // load initial data into appropriate collections and set chunkSize for the test case if necessary
   921  func setupTest(mt *mtest.T, testFile *testFile, testCase *testCase) {
   922  	mt.Helper()
   923  
   924  	// key vault data
   925  	if len(testFile.KeyVaultData) > 0 {
   926  		// Drop the key vault collection in case it exists from a prior test run.
   927  		err := mt.Client.Database("keyvault").Collection("datakeys").Drop(context.Background())
   928  		assert.Nil(mt, err, "error dropping key vault collection")
   929  
   930  		keyVaultColl := mt.CreateCollection(mtest.Collection{
   931  			Name: "datakeys",
   932  			DB:   "keyvault",
   933  		}, false)
   934  
   935  		insertDocuments(mt, keyVaultColl, testFile.KeyVaultData)
   936  	}
   937  
   938  	// regular documents
   939  	if testFile.Data.Documents != nil {
   940  		insertDocuments(mt, mt.Coll, testFile.Data.Documents)
   941  		return
   942  	}
   943  
   944  	// GridFS data
   945  	gfsData := testFile.Data.GridFSData
   946  
   947  	if gfsData.Chunks != nil {
   948  		chunks := mt.CreateCollection(mtest.Collection{
   949  			Name: gridFSChunks,
   950  		}, false)
   951  		insertDocuments(mt, chunks, gfsData.Chunks)
   952  	}
   953  	if gfsData.Files != nil {
   954  		files := mt.CreateCollection(mtest.Collection{
   955  			Name: gridFSFiles,
   956  		}, false)
   957  		insertDocuments(mt, files, gfsData.Files)
   958  
   959  		csVal, err := gfsData.Files[0].LookupErr("chunkSize")
   960  		if err == nil {
   961  			testCase.chunkSize = csVal.Int32()
   962  		}
   963  	}
   964  }
   965  
   966  func verifyTestOutcome(mt *mtest.T, outcomeColl *outcomeCollection) {
   967  	// Outcome needs to be verified using the global client instead of the test client because certain client
   968  	// configurations will cause outcome checking to fail. For example, a client configured with auto encryption
   969  	// will decrypt results, causing comparisons to fail.
   970  
   971  	collName := mt.Coll.Name()
   972  	if outcomeColl.Name != "" {
   973  		collName = outcomeColl.Name
   974  	}
   975  	coll := mtest.GlobalClient().Database(mt.DB.Name()).Collection(collName, checkOutcomeOpts)
   976  
   977  	findOpts := options.Find().
   978  		SetSort(bson.M{"_id": 1})
   979  	cursor, err := coll.Find(context.Background(), bson.D{}, findOpts)
   980  	assert.Nil(mt, err, "Find error: %v", err)
   981  	verifyCursorResult(mt, cursor, outcomeColl.Data)
   982  }
   983  
   984  func getTopologyFromClient(client *mongo.Client) *topology.Topology {
   985  	clientElem := reflect.ValueOf(client).Elem()
   986  	deploymentField := clientElem.FieldByName("deployment")
   987  	deploymentField = reflect.NewAt(deploymentField.Type(), unsafe.Pointer(deploymentField.UnsafeAddr())).Elem()
   988  	return deploymentField.Interface().(*topology.Topology)
   989  }
   990  
   991  // getCryptSharedLibExtraOptions returns an AutoEncryption extra options map with crypt_shared
   992  // library path information if the CRYPT_SHARED_LIB_PATH environment variable is set.
   993  func getCryptSharedLibExtraOptions() map[string]interface{} {
   994  	path := os.Getenv("CRYPT_SHARED_LIB_PATH")
   995  	if path == "" {
   996  		return nil
   997  	}
   998  	return map[string]interface{}{
   999  		"cryptSharedLibRequired": true,
  1000  		"cryptSharedLibPath":     path,
  1001  	}
  1002  }
  1003  

View as plain text