...

Source file src/go.mongodb.org/mongo-driver/mongo/integration/sessions_test.go

Documentation: go.mongodb.org/mongo-driver/mongo/integration

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package integration
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"errors"
    13  	"fmt"
    14  	"reflect"
    15  	"sync"
    16  	"testing"
    17  	"time"
    18  
    19  	"go.mongodb.org/mongo-driver/bson"
    20  	"go.mongodb.org/mongo-driver/internal/assert"
    21  	"go.mongodb.org/mongo-driver/internal/require"
    22  	"go.mongodb.org/mongo-driver/mongo"
    23  	"go.mongodb.org/mongo-driver/mongo/integration/mtest"
    24  	"go.mongodb.org/mongo-driver/mongo/options"
    25  	"go.mongodb.org/mongo-driver/mongo/readpref"
    26  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    27  	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
    28  	"golang.org/x/sync/errgroup"
    29  )
    30  
    31  func TestSessionPool(t *testing.T) {
    32  	mt := mtest.New(t, mtest.NewOptions().MinServerVersion("3.6").CreateClient(false))
    33  
    34  	mt.Run("last use time updated", func(mt *mtest.T) {
    35  		sess, err := mt.Client.StartSession()
    36  		assert.Nil(mt, err, "StartSession error: %v", err)
    37  		defer sess.EndSession(context.Background())
    38  		initialLastUsedTime := getSessionLastUsedTime(mt, sess)
    39  
    40  		err = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
    41  			return mt.Client.Ping(sc, readpref.Primary())
    42  		})
    43  		assert.Nil(mt, err, "WithSession error: %v", err)
    44  
    45  		newLastUsedTime := getSessionLastUsedTime(mt, sess)
    46  		assert.True(mt, newLastUsedTime.After(initialLastUsedTime),
    47  			"last used time %s is not after the initial last used time %s", newLastUsedTime, initialLastUsedTime)
    48  	})
    49  }
    50  
    51  func TestSessions(t *testing.T) {
    52  	mtOpts := mtest.NewOptions().MinServerVersion("3.6").Topologies(mtest.ReplicaSet, mtest.Sharded).
    53  		CreateClient(false)
    54  	mt := mtest.New(t, mtOpts)
    55  
    56  	mt.Run("imperative API", func(mt *mtest.T) {
    57  		mt.Run("round trip Session object", func(mt *mtest.T) {
    58  			// Roundtrip a Session object through NewSessionContext/ContextFromSession and assert that it is correctly
    59  			// stored/retrieved.
    60  
    61  			sess, err := mt.Client.StartSession()
    62  			assert.Nil(mt, err, "StartSession error: %v", err)
    63  			defer sess.EndSession(context.Background())
    64  
    65  			ctx := mongo.NewSessionContext(context.Background(), sess)
    66  			assert.Equal(mt, sess.ID(), ctx.ID(), "expected Session ID %v, got %v", sess.ID(), ctx.ID())
    67  
    68  			gotSess := mongo.SessionFromContext(ctx)
    69  			assert.NotNil(mt, gotSess, "expected SessionFromContext to return non-nil value, got nil")
    70  			assert.Equal(mt, sess.ID(), gotSess.ID(), "expected Session ID %v, got %v", sess.ID(), gotSess.ID())
    71  		})
    72  
    73  		txnOpts := mtest.NewOptions().RunOn(
    74  			mtest.RunOnBlock{Topology: []mtest.TopologyKind{mtest.ReplicaSet}, MinServerVersion: "4.0"},
    75  			mtest.RunOnBlock{Topology: []mtest.TopologyKind{mtest.Sharded}, MinServerVersion: "4.2"},
    76  		)
    77  		mt.RunOpts("run transaction", txnOpts, func(mt *mtest.T) {
    78  			// Test that the imperative sessions API can be used to run a transaction.
    79  
    80  			createSessionContext := func(mt *mtest.T) mongo.SessionContext {
    81  				sess, err := mt.Client.StartSession()
    82  				assert.Nil(mt, err, "StartSession error: %v", err)
    83  
    84  				return mongo.NewSessionContext(context.Background(), sess)
    85  			}
    86  
    87  			ctx := createSessionContext(mt)
    88  			sess := mongo.SessionFromContext(ctx)
    89  			assert.NotNil(mt, sess, "expected SessionFromContext to return non-nil value, got nil")
    90  			defer sess.EndSession(context.Background())
    91  
    92  			err := sess.StartTransaction()
    93  			assert.Nil(mt, err, "StartTransaction error: %v", err)
    94  
    95  			numDocs := 2
    96  			for i := 0; i < numDocs; i++ {
    97  				_, err = mt.Coll.InsertOne(ctx, bson.D{{"x", 1}})
    98  				assert.Nil(mt, err, "InsertOne error at index %d: %v", i, err)
    99  			}
   100  
   101  			// Assert that the collection count is 0 before committing and numDocs after. This tests that the InsertOne
   102  			// calls were actually executed in the transaction because the pre-commit count does not include them.
   103  			assertCollectionCount(mt, 0)
   104  			err = sess.CommitTransaction(ctx)
   105  			assert.Nil(mt, err, "CommitTransaction error: %v", err)
   106  			assertCollectionCount(mt, int64(numDocs))
   107  		})
   108  	})
   109  
   110  	unackWcOpts := options.Collection().SetWriteConcern(writeconcern.New(writeconcern.W(0)))
   111  	mt.RunOpts("unacknowledged write", mtest.NewOptions().CollectionOptions(unackWcOpts), func(mt *mtest.T) {
   112  		// unacknowledged write during a session should result in an error
   113  		sess, err := mt.Client.StartSession()
   114  		assert.Nil(mt, err, "StartSession error: %v", err)
   115  		defer sess.EndSession(context.Background())
   116  
   117  		err = mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
   118  			_, err := mt.Coll.InsertOne(sc, bson.D{{"x", 1}})
   119  			return err
   120  		})
   121  
   122  		assert.Equal(mt, err, mongo.ErrUnacknowledgedWrite,
   123  			"expected ErrUnacknowledgedWrite on unacknowledged write in session, got %v", err)
   124  	})
   125  
   126  	// Regression test for GODRIVER-2533. Note that this test assumes the race
   127  	// detector is enabled (GODRIVER-2072).
   128  	mt.Run("NumberSessionsInProgress data race", func(mt *mtest.T) {
   129  		// Use two goroutines to execute a few simultaneous runs of NumberSessionsInProgress
   130  		// and a basic collection operation (CountDocuments).
   131  		var wg sync.WaitGroup
   132  		wg.Add(2)
   133  
   134  		go func() {
   135  			defer wg.Done()
   136  
   137  			for i := 0; i < 100; i++ {
   138  				time.Sleep(100 * time.Microsecond)
   139  				_ = mt.Client.NumberSessionsInProgress()
   140  			}
   141  		}()
   142  		go func() {
   143  			defer wg.Done()
   144  
   145  			for i := 0; i < 100; i++ {
   146  				time.Sleep(100 * time.Microsecond)
   147  				_, err := mt.Coll.CountDocuments(context.Background(), bson.D{})
   148  				assert.Nil(mt, err, "CountDocument error: %v", err)
   149  			}
   150  		}()
   151  
   152  		wg.Wait()
   153  	})
   154  }
   155  
   156  func TestSessionsProse(t *testing.T) {
   157  	mtOpts := mtest.
   158  		NewOptions().
   159  		MinServerVersion("3.6").
   160  		Topologies(mtest.ReplicaSet, mtest.Sharded).
   161  		CreateClient(false)
   162  
   163  	mt := mtest.New(t, mtOpts)
   164  
   165  	hosts := options.Client().ApplyURI(mtest.ClusterURI()).Hosts
   166  
   167  	mt.Run("1 setting both snapshot and causalConsistency to true is not allowed", func(mt *mtest.T) {
   168  		// causalConsistency and snapshot are mutually exclusive
   169  		sessOpts := options.Session().SetCausalConsistency(true).SetSnapshot(true)
   170  		_, err := mt.Client.StartSession(sessOpts)
   171  		assert.NotNil(mt, err, "expected StartSession error, got nil")
   172  		expectedErr := errors.New("causal consistency and snapshot cannot both be set for a session")
   173  		assert.Equal(mt, expectedErr, err, "expected error %v, got %v", expectedErr, err)
   174  	})
   175  
   176  	mt.Run("2 pool is LIFO", func(mt *mtest.T) {
   177  		aSess, err := mt.Client.StartSession()
   178  		assert.Nil(mt, err, "StartSession error: %v", err)
   179  		bSess, err := mt.Client.StartSession()
   180  		assert.Nil(mt, err, "StartSession error: %v", err)
   181  
   182  		// end the sessions to return them to the pool
   183  		aSess.EndSession(context.Background())
   184  		bSess.EndSession(context.Background())
   185  
   186  		firstSess, err := mt.Client.StartSession()
   187  		assert.Nil(mt, err, "StartSession error: %v", err)
   188  		defer firstSess.EndSession(context.Background())
   189  		want := bSess.ID()
   190  		got := firstSess.ID()
   191  		assert.True(mt, sessionIDsEqual(mt, want, got), "expected session ID %v, got %v", want, got)
   192  
   193  		secondSess, err := mt.Client.StartSession()
   194  		assert.Nil(mt, err, "StartSession error: %v", err)
   195  		defer secondSess.EndSession(context.Background())
   196  		want = aSess.ID()
   197  		got = secondSess.ID()
   198  		assert.True(mt, sessionIDsEqual(mt, want, got), "expected session ID %v, got %v", want, got)
   199  	})
   200  
   201  	// Pin to a single mongos so heartbeats/handshakes to other mongoses
   202  	// won't cause errors.
   203  	clusterTimeOpts := mtest.NewOptions().
   204  		ClientOptions(options.Client().SetHeartbeatInterval(50 * time.Second)).
   205  		ClientType(mtest.Pinned).
   206  		CreateClient(false)
   207  
   208  	mt.RunOpts("3 clusterTime in commands", clusterTimeOpts, func(mt *mtest.T) {
   209  		serverStatus := sessionFunction{"server status", "database", "RunCommand", []interface{}{bson.D{{"serverStatus", 1}}}}
   210  		insert := sessionFunction{"insert one", "collection", "InsertOne", []interface{}{bson.D{{"x", 1}}}}
   211  		agg := sessionFunction{"aggregate", "collection", "Aggregate", []interface{}{mongo.Pipeline{}}}
   212  		find := sessionFunction{"find", "collection", "Find", []interface{}{bson.D{}}}
   213  
   214  		sessionFunctions := []sessionFunction{serverStatus, insert, agg, find}
   215  		for _, sf := range sessionFunctions {
   216  			mt.Run(sf.name, func(mt *mtest.T) {
   217  				err := sf.execute(mt, nil)
   218  				assert.Nil(mt, err, "%v error: %v", sf.name, err)
   219  
   220  				// assert $clusterTime was sent to server
   221  				started := mt.GetStartedEvent()
   222  				assert.NotNil(mt, started, "expected started event, got nil")
   223  				_, err = started.Command.LookupErr("$clusterTime")
   224  				assert.Nil(mt, err, "$clusterTime not sent")
   225  
   226  				// record response cluster time
   227  				succeeded := mt.GetSucceededEvent()
   228  				assert.NotNil(mt, succeeded, "expected succeeded event, got nil")
   229  				replyClusterTimeVal, err := succeeded.Reply.LookupErr("$clusterTime")
   230  				assert.Nil(mt, err, "$clusterTime not found in response")
   231  
   232  				// call function again
   233  				err = sf.execute(mt, nil)
   234  				assert.Nil(mt, err, "%v error: %v", sf.name, err)
   235  
   236  				// find cluster time sent to server and assert it is the same as the one in the previous response
   237  				sentClusterTimeVal, err := mt.GetStartedEvent().Command.LookupErr("$clusterTime")
   238  				assert.Nil(mt, err, "$clusterTime not sent")
   239  				replyClusterTimeDoc := replyClusterTimeVal.Document()
   240  				sentClusterTimeDoc := sentClusterTimeVal.Document()
   241  				assert.Equal(mt, replyClusterTimeDoc, sentClusterTimeDoc,
   242  					"expected cluster time %v, got %v", replyClusterTimeDoc, sentClusterTimeDoc)
   243  			})
   244  		}
   245  	})
   246  
   247  	mt.RunOpts("4 explicit and implicit session arguments", noClientOpts, func(mt *mtest.T) {
   248  		// lsid is included in commands with explicit and implicit sessions
   249  
   250  		sessionFunctions := createFunctionsSlice()
   251  		for _, sf := range sessionFunctions {
   252  			mt.Run(sf.name, func(mt *mtest.T) {
   253  				// explicit session
   254  				sess, err := mt.Client.StartSession()
   255  				assert.Nil(mt, err, "StartSession error: %v", err)
   256  				defer sess.EndSession(context.Background())
   257  				mt.ClearEvents()
   258  
   259  				_ = sf.execute(mt, sess) // don't check error because we only care about lsid
   260  				_, wantID := sess.ID().Lookup("id").Binary()
   261  				gotID := extractSentSessionID(mt)
   262  				assert.True(mt, bytes.Equal(wantID, gotID), "expected session ID %v, got %v", wantID, gotID)
   263  
   264  				// implicit session
   265  				_ = sf.execute(mt, nil)
   266  				gotID = extractSentSessionID(mt)
   267  				assert.NotNil(mt, gotID, "expected lsid, got nil")
   268  			})
   269  		}
   270  	})
   271  
   272  	mt.Run("5 session argument is for the right client", func(mt *mtest.T) {
   273  		// a session can only be used in commands associated with the client that created it
   274  
   275  		sessionFunctions := createFunctionsSlice()
   276  		sess, err := mt.Client.StartSession()
   277  		assert.Nil(mt, err, "StartSession error: %v", err)
   278  		defer sess.EndSession(context.Background())
   279  
   280  		for _, sf := range sessionFunctions {
   281  			mt.Run(sf.name, func(mt *mtest.T) {
   282  				err = sf.execute(mt, sess)
   283  				assert.Equal(mt, mongo.ErrWrongClient, err, "expected error %v, got %v", mongo.ErrWrongClient, err)
   284  			})
   285  		}
   286  	})
   287  
   288  	const proseTest6 = "6 no further operations can be performed using a session after endSession has been called"
   289  	mt.RunOpts(proseTest6, noClientOpts, func(mt *mtest.T) {
   290  		// an ended session cannot be used in commands
   291  
   292  		sessionFunctions := createFunctionsSlice()
   293  		for _, sf := range sessionFunctions {
   294  			mt.Run(sf.name, func(mt *mtest.T) {
   295  				sess, err := mt.Client.StartSession()
   296  				assert.Nil(mt, err, "StartSession error: %v", err)
   297  				sess.EndSession(context.Background())
   298  
   299  				err = sf.execute(mt, sess)
   300  				assert.Equal(mt, session.ErrSessionEnded, err, "expected error %v, got %v", session.ErrSessionEnded, err)
   301  			})
   302  		}
   303  	})
   304  
   305  	mt.Run("7 authenticating as multiple users suppresses implicit sessions", func(mt *mtest.T) {
   306  		mt.Skip("Go Driver does not allow simultaneous authentication with multiple users.")
   307  	})
   308  
   309  	mt.Run("8 client side cursor that exhausts the results on the initial query immediately returns the implicit session to the pool",
   310  		func(mt *mtest.T) {
   311  			// implicit sessions are returned to the server session pool
   312  
   313  			doc := bson.D{{"x", 1}}
   314  			_, err := mt.Coll.InsertOne(context.Background(), doc)
   315  			assert.Nil(mt, err, "InsertOne error: %v", err)
   316  			_, err = mt.Coll.InsertOne(context.Background(), doc)
   317  			assert.Nil(mt, err, "InsertOne error: %v", err)
   318  
   319  			// create a cursor that will hold onto an implicit session and record the sent session ID
   320  			mt.ClearEvents()
   321  			cursor, err := mt.Coll.Find(context.Background(), bson.D{})
   322  			assert.Nil(mt, err, "Find error: %v", err)
   323  			findID := extractSentSessionID(mt)
   324  			assert.True(mt, cursor.Next(context.Background()), "expected Next true, got false")
   325  
   326  			// execute another operation and verify the find session ID was reused
   327  			_, err = mt.Coll.DeleteOne(context.Background(), bson.D{})
   328  			assert.Nil(mt, err, "DeleteOne error: %v", err)
   329  			deleteID := extractSentSessionID(mt)
   330  			assert.Equal(mt, findID, deleteID, "expected session ID %v, got %v", findID, deleteID)
   331  		})
   332  
   333  	mt.Run("9 client side cursor that exhausts the results after a getMore immediately returns the implicit session to the pool",
   334  		func(mt *mtest.T) {
   335  			// Client-side cursor that exhausts the results after a getMore immediately returns the implicit session to the pool.
   336  
   337  			var docs []interface{}
   338  			for i := 0; i < 5; i++ {
   339  				docs = append(docs, bson.D{{"x", i}})
   340  			}
   341  
   342  			_, err := mt.Coll.InsertMany(context.Background(), docs)
   343  			assert.Nil(mt, err, "InsertMany error: %v", err)
   344  
   345  			// run a find that will hold onto the implicit session and record the session ID
   346  			mt.ClearEvents()
   347  			cursor, err := mt.Coll.Find(context.Background(), bson.D{}, options.Find().SetBatchSize(3))
   348  			assert.Nil(mt, err, "Find error: %v", err)
   349  			findID := extractSentSessionID(mt)
   350  
   351  			// iterate past 4 documents, forcing a getMore. session should be returned to pool after getMore
   352  			for i := 0; i < 4; i++ {
   353  				assert.True(mt, cursor.Next(context.Background()), "Next returned false on iteration %v", i)
   354  			}
   355  
   356  			// execute another operation and verify the find session ID was reused
   357  			_, err = mt.Coll.DeleteOne(context.Background(), bson.D{})
   358  			assert.Nil(mt, err, "DeleteOne error: %v", err)
   359  			deleteID := extractSentSessionID(mt)
   360  			assert.Equal(mt, findID, deleteID, "expected session ID %v, got %v", findID, deleteID)
   361  		})
   362  
   363  	mt.Run("10 no remaining sessions are checked out after each functional test", func(mt *mtest.T) {
   364  		mt.Skip("This is tested individually in each functional test.")
   365  	})
   366  
   367  	mt.Run("11 for every combination of topology and readPreference, ensure that find and getMore both send the same session id", func(mt *mtest.T) {
   368  		var docs []interface{}
   369  		for i := 0; i < 3; i++ {
   370  			docs = append(docs, bson.D{{"x", i}})
   371  		}
   372  		_, err := mt.Coll.InsertMany(context.Background(), docs)
   373  		assert.Nil(mt, err, "InsertMany error: %v", err)
   374  
   375  		// run a find that will hold onto an implicit session and record the session ID
   376  		mt.ClearEvents()
   377  		cursor, err := mt.Coll.Find(context.Background(), bson.D{}, options.Find().SetBatchSize(2))
   378  		assert.Nil(mt, err, "Find error: %v", err)
   379  		findID := extractSentSessionID(mt)
   380  		assert.NotNil(mt, findID, "expected session ID for find, got nil")
   381  
   382  		// iterate over all documents and record the session ID of the getMore
   383  		for i := 0; i < 3; i++ {
   384  			assert.True(mt, cursor.Next(context.Background()), "Next returned false on iteration %v", i)
   385  		}
   386  		getMoreID := extractSentSessionID(mt)
   387  		assert.Equal(mt, findID, getMoreID, "expected session ID %v, got %v", findID, getMoreID)
   388  	})
   389  
   390  	sessallocopts := mtest.NewOptions().ClientOptions(options.Client().SetMaxPoolSize(1).SetRetryWrites(true).
   391  		SetHosts(hosts[:1]))
   392  	mt.RunOpts("14 implicit session allocation", sessallocopts, func(mt *mtest.T) {
   393  		// TODO(GODRIVER-2844): Fix and unskip this test case.
   394  		mt.Skip("Test fails frequently, skipping. See GODRIVER-2844")
   395  
   396  		ops := map[string]func(ctx context.Context) error{
   397  			"insert": func(ctx context.Context) error {
   398  				_, err := mt.Coll.InsertOne(ctx, bson.D{})
   399  				return err
   400  			},
   401  			"delete": func(ctx context.Context) error {
   402  				_, err := mt.Coll.DeleteOne(ctx, bson.D{})
   403  				return err
   404  			},
   405  			"update": func(ctx context.Context) error {
   406  				_, err := mt.Coll.UpdateOne(ctx, bson.D{}, bson.D{{"$set", bson.D{{"a", 1}}}})
   407  				return err
   408  			},
   409  			"bulkWrite": func(ctx context.Context) error {
   410  				model := mongo.NewUpdateOneModel().
   411  					SetFilter(bson.D{}).
   412  					SetUpdate(bson.D{{"$set", bson.D{{"a", 1}}}})
   413  				_, err := mt.Coll.BulkWrite(ctx, []mongo.WriteModel{model})
   414  				return err
   415  			},
   416  			"findOneAndDelete": func(ctx context.Context) error {
   417  				result := mt.Coll.FindOneAndDelete(ctx, bson.D{})
   418  				if err := result.Err(); err != nil && !errors.Is(err, mongo.ErrNoDocuments) {
   419  					return err
   420  				}
   421  				return nil
   422  			},
   423  			"findOneAndUpdate": func(ctx context.Context) error {
   424  				result := mt.Coll.FindOneAndUpdate(ctx, bson.D{},
   425  					bson.D{{"$set", bson.D{{"a", 1}}}})
   426  
   427  				if err := result.Err(); err != nil && !errors.Is(err, mongo.ErrNoDocuments) {
   428  					return err
   429  				}
   430  				return nil
   431  			},
   432  			"findOneAndReplace": func(ctx context.Context) error {
   433  				result := mt.Coll.FindOneAndReplace(ctx, bson.D{}, bson.D{{"a", 1}})
   434  				if err := result.Err(); err != nil && !errors.Is(err, mongo.ErrNoDocuments) {
   435  					return err
   436  				}
   437  				return nil
   438  			},
   439  			"find": func(ctx context.Context) error {
   440  				cursor, err := mt.Coll.Find(ctx, bson.D{})
   441  				if err != nil {
   442  					return err
   443  				}
   444  				return cursor.All(ctx, &bson.A{})
   445  			},
   446  		}
   447  
   448  		// maintainedOneSession asserts that exactly one session is used for all operations at least once
   449  		// across the retries of this test.
   450  		var maintainedOneSession bool
   451  
   452  		// minimumSessionCount asserts the least amount of sessions used over all the retries of the
   453  		// operations. For example, if we retry 5 times we could result in session use { 1, 2, 1, 1, 6 }. In
   454  		// this case, minimumSessionCount should be 1.
   455  		var minimumSessionCount int
   456  
   457  		// limitedSessionUse asserts that the number of allocated sessions is strictly less than the number of
   458  		// concurrent operations in every retry of this test. In this instance it would be less than (but NOT
   459  		// equal to the number of operations).
   460  		limitedSessionUse := true
   461  
   462  		retrycount := 5
   463  		for i := 1; i <= retrycount; i++ {
   464  			errs, ctx := errgroup.WithContext(context.Background())
   465  
   466  			// Execute the ops list concurrently.
   467  			for cmd, op := range ops {
   468  				op := op
   469  				cmd := cmd
   470  				errs.Go(func() error {
   471  					if err := op(ctx); err != nil {
   472  						return fmt.Errorf("error running %s operation: %w", cmd, err)
   473  					}
   474  					return nil
   475  				})
   476  			}
   477  			err := errs.Wait()
   478  			assert.Nil(mt, err, "expected no error, got: %v", err)
   479  
   480  			// Get all started events and collect them by the session ID.
   481  			set := make(map[string]bool)
   482  			for _, event := range mt.GetAllStartedEvents() {
   483  				lsid := event.Command.Lookup("lsid")
   484  				set[lsid.String()] = true
   485  			}
   486  
   487  			setSize := len(set)
   488  			if setSize == 1 {
   489  				maintainedOneSession = true
   490  			} else if setSize < minimumSessionCount || minimumSessionCount == 0 {
   491  				// record the minimum number of sessions we used over all retries.
   492  				minimumSessionCount = setSize
   493  			}
   494  
   495  			if setSize >= len(ops) {
   496  				limitedSessionUse = false
   497  			}
   498  		}
   499  
   500  		oneSessMsg := "expected one session across all %v operations for at least 1/%v retries, got: %v"
   501  		assert.True(mt, maintainedOneSession, oneSessMsg, len(ops), retrycount, minimumSessionCount)
   502  
   503  		limitedSessMsg := "expected session count to be less than the number of operations: %v"
   504  		assert.True(mt, limitedSessionUse, limitedSessMsg, len(ops))
   505  
   506  	})
   507  }
   508  
   509  type sessionFunction struct {
   510  	name   string
   511  	target string
   512  	fnName string
   513  	params []interface{} // should not include context
   514  }
   515  
   516  func (sf sessionFunction) execute(mt *mtest.T, sess mongo.Session) error {
   517  	var target reflect.Value
   518  	switch sf.target {
   519  	case "client":
   520  		target = reflect.ValueOf(mt.Client)
   521  	case "database":
   522  		// use a different database for drops because any executed after the drop will get "database not found"
   523  		// errors on sharded clusters
   524  		if sf.name != "drop database" {
   525  			target = reflect.ValueOf(mt.DB)
   526  			break
   527  		}
   528  		target = reflect.ValueOf(mt.Client.Database("sessionsTestsDropDatabase"))
   529  	case "collection":
   530  		target = reflect.ValueOf(mt.Coll)
   531  	case "indexView":
   532  		target = reflect.ValueOf(mt.Coll.Indexes())
   533  	default:
   534  		mt.Fatalf("unrecognized target: %v", sf.target)
   535  	}
   536  
   537  	fn := target.MethodByName(sf.fnName)
   538  	paramsValues := interfaceSliceToValueSlice(sf.params)
   539  
   540  	if sess != nil {
   541  		return mongo.WithSession(context.Background(), sess, func(sc mongo.SessionContext) error {
   542  			valueArgs := []reflect.Value{reflect.ValueOf(sc)}
   543  			valueArgs = append(valueArgs, paramsValues...)
   544  			returnValues := fn.Call(valueArgs)
   545  			return extractReturnError(returnValues)
   546  		})
   547  	}
   548  	valueArgs := []reflect.Value{reflect.ValueOf(context.Background())}
   549  	valueArgs = append(valueArgs, paramsValues...)
   550  	returnValues := fn.Call(valueArgs)
   551  	return extractReturnError(returnValues)
   552  }
   553  
   554  func createFunctionsSlice() []sessionFunction {
   555  	insertManyDocs := []interface{}{bson.D{{"x", 1}}}
   556  	fooIndex := mongo.IndexModel{
   557  		Keys:    bson.D{{"foo", -1}},
   558  		Options: options.Index().SetName("fooIndex"),
   559  	}
   560  	manyIndexes := []mongo.IndexModel{fooIndex}
   561  	updateDoc := bson.D{{"$inc", bson.D{{"x", 1}}}}
   562  
   563  	return []sessionFunction{
   564  		{"list databases", "client", "ListDatabases", []interface{}{bson.D{}}},
   565  		{"insert one", "collection", "InsertOne", []interface{}{bson.D{{"x", 1}}}},
   566  		{"insert many", "collection", "InsertMany", []interface{}{insertManyDocs}},
   567  		{"delete one", "collection", "DeleteOne", []interface{}{bson.D{}}},
   568  		{"delete many", "collection", "DeleteMany", []interface{}{bson.D{}}},
   569  		{"update one", "collection", "UpdateOne", []interface{}{bson.D{}, updateDoc}},
   570  		{"update many", "collection", "UpdateMany", []interface{}{bson.D{}, updateDoc}},
   571  		{"replace one", "collection", "ReplaceOne", []interface{}{bson.D{}, bson.D{}}},
   572  		{"aggregate", "collection", "Aggregate", []interface{}{mongo.Pipeline{}}},
   573  		{"estimated document count", "collection", "EstimatedDocumentCount", nil},
   574  		{"distinct", "collection", "Distinct", []interface{}{"field", bson.D{}}},
   575  		{"find", "collection", "Find", []interface{}{bson.D{}}},
   576  		{"find one and delete", "collection", "FindOneAndDelete", []interface{}{bson.D{}}},
   577  		{"find one and replace", "collection", "FindOneAndReplace", []interface{}{bson.D{}, bson.D{}}},
   578  		{"find one and update", "collection", "FindOneAndUpdate", []interface{}{bson.D{}, updateDoc}},
   579  		{"drop collection", "collection", "Drop", nil},
   580  		{"list collections", "database", "ListCollections", []interface{}{bson.D{}}},
   581  		{"drop database", "database", "Drop", nil},
   582  		{"create one index", "indexView", "CreateOne", []interface{}{fooIndex}},
   583  		{"create many indexes", "indexView", "CreateMany", []interface{}{manyIndexes}},
   584  		{"drop one index", "indexView", "DropOne", []interface{}{"barIndex"}},
   585  		{"drop all indexes", "indexView", "DropAll", nil},
   586  		{"list indexes", "indexView", "List", nil},
   587  	}
   588  }
   589  
   590  func assertCollectionCount(mt *mtest.T, expectedCount int64) {
   591  	mt.Helper()
   592  
   593  	count, err := mt.Coll.CountDocuments(context.Background(), bson.D{})
   594  	require.NoError(mt, err, "CountDocuments error")
   595  	assert.Equal(mt, expectedCount, count, "expected CountDocuments result %v, got %v", expectedCount, count)
   596  }
   597  
   598  func sessionIDsEqual(mt *mtest.T, id1, id2 bson.Raw) bool {
   599  	first, err := id1.LookupErr("id")
   600  	assert.Nil(mt, err, "id not found in document %v", id1)
   601  	second, err := id2.LookupErr("id")
   602  	assert.Nil(mt, err, "id not found in document %v", id2)
   603  
   604  	_, firstUUID := first.Binary()
   605  	_, secondUUID := second.Binary()
   606  	return bytes.Equal(firstUUID, secondUUID)
   607  }
   608  
   609  func interfaceSliceToValueSlice(args []interface{}) []reflect.Value {
   610  	vals := make([]reflect.Value, 0, len(args))
   611  	for _, arg := range args {
   612  		vals = append(vals, reflect.ValueOf(arg))
   613  	}
   614  	return vals
   615  }
   616  
   617  func extractReturnError(returnValues []reflect.Value) error {
   618  	errVal := returnValues[len(returnValues)-1]
   619  	switch converted := errVal.Interface().(type) {
   620  	case error:
   621  		return converted
   622  	case *mongo.SingleResult:
   623  		return converted.Err()
   624  	default:
   625  		return nil
   626  	}
   627  }
   628  
   629  func extractSentSessionID(mt *mtest.T) []byte {
   630  	event := mt.GetStartedEvent()
   631  	if event == nil {
   632  		return nil
   633  	}
   634  	lsid, err := event.Command.LookupErr("lsid")
   635  	if err != nil {
   636  		return nil
   637  	}
   638  
   639  	_, data := lsid.Document().Lookup("id").Binary()
   640  	return data
   641  }
   642  
   643  func getSessionLastUsedTime(mt *mtest.T, sess mongo.Session) time.Time {
   644  	xsess, ok := sess.(mongo.XSession)
   645  	assert.True(mt, ok, "expected session to implement mongo.XSession, but got %T", sess)
   646  	return xsess.ClientSession().LastUsed
   647  }
   648  

View as plain text