...

Source file src/go.mongodb.org/mongo-driver/mongo/with_transactions_test.go

Documentation: go.mongodb.org/mongo-driver/mongo

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package mongo
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"math"
    14  	"os"
    15  	"strconv"
    16  	"strings"
    17  	"testing"
    18  	"time"
    19  
    20  	"go.mongodb.org/mongo-driver/bson"
    21  	"go.mongodb.org/mongo-driver/event"
    22  	"go.mongodb.org/mongo-driver/internal/assert"
    23  	"go.mongodb.org/mongo-driver/internal/integtest"
    24  	"go.mongodb.org/mongo-driver/mongo/description"
    25  	"go.mongodb.org/mongo-driver/mongo/options"
    26  	"go.mongodb.org/mongo-driver/mongo/readpref"
    27  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    28  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    29  	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
    30  )
    31  
    32  var (
    33  	connsCheckedOut  int
    34  	errorInterrupted int32 = 11601
    35  )
    36  
    37  func TestConvenientTransactions(t *testing.T) {
    38  	if testing.Short() {
    39  		t.Skip("skipping integration test in short mode")
    40  	}
    41  	if os.Getenv("DOCKER_RUNNING") != "" {
    42  		t.Skip("skipping test in docker environment")
    43  	}
    44  
    45  	client := setupConvenientTransactions(t)
    46  	db := client.Database("TestConvenientTransactions")
    47  	dbAdmin := client.Database("admin")
    48  
    49  	defer func() {
    50  		sessions := client.NumberSessionsInProgress()
    51  		conns := connsCheckedOut
    52  
    53  		err := dbAdmin.RunCommand(bgCtx, bson.D{
    54  			{"killAllSessions", bson.A{}},
    55  		}).Err()
    56  		if err != nil {
    57  			if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
    58  				t.Fatalf("killAllSessions error: %v", err)
    59  			}
    60  		}
    61  
    62  		_ = db.Drop(bgCtx)
    63  		_ = client.Disconnect(bgCtx)
    64  
    65  		assert.Equal(t, 0, sessions, "%v sessions checked out", sessions)
    66  		assert.Equal(t, 0, conns, "%v connections checked out", conns)
    67  	}()
    68  
    69  	t.Run("callback raises custom error", func(t *testing.T) {
    70  		coll := db.Collection(t.Name())
    71  		_, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
    72  		assert.Nil(t, err, "InsertOne error: %v", err)
    73  
    74  		sess, err := client.StartSession()
    75  		assert.Nil(t, err, "StartSession error: %v", err)
    76  		defer sess.EndSession(context.Background())
    77  
    78  		testErr := errors.New("test error")
    79  		_, err = sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) {
    80  			return nil, testErr
    81  		})
    82  		assert.Equal(t, testErr, err, "expected error %v, got %v", testErr, err)
    83  	})
    84  	t.Run("callback returns value", func(t *testing.T) {
    85  		coll := db.Collection(t.Name())
    86  		_, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
    87  		assert.Nil(t, err, "InsertOne error: %v", err)
    88  
    89  		sess, err := client.StartSession()
    90  		assert.Nil(t, err, "StartSession error: %v", err)
    91  		defer sess.EndSession(context.Background())
    92  
    93  		res, err := sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) {
    94  			return false, nil
    95  		})
    96  		assert.Nil(t, err, "WithTransaction error: %v", err)
    97  		resBool, ok := res.(bool)
    98  		assert.True(t, ok, "expected result type %T, got %T", false, res)
    99  		assert.False(t, resBool, "expected result false, got %v", resBool)
   100  	})
   101  	t.Run("retry timeout enforced", func(t *testing.T) {
   102  		withTransactionTimeout = time.Second
   103  
   104  		coll := db.Collection(t.Name())
   105  		_, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
   106  		assert.Nil(t, err, "InsertOne error: %v", err)
   107  
   108  		t.Run("transient transaction error", func(t *testing.T) {
   109  			sess, err := client.StartSession()
   110  			assert.Nil(t, err, "StartSession error: %v", err)
   111  			defer sess.EndSession(context.Background())
   112  
   113  			_, err = sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) {
   114  				return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}}
   115  			})
   116  			assert.NotNil(t, err, "expected WithTransaction error, got nil")
   117  			cmdErr, ok := err.(CommandError)
   118  			assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
   119  			assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
   120  				"expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
   121  		})
   122  		t.Run("unknown transaction commit result", func(t *testing.T) {
   123  			//set failpoint
   124  			failpoint := bson.D{{"configureFailPoint", "failCommand"},
   125  				{"mode", "alwaysOn"},
   126  				{"data", bson.D{
   127  					{"failCommands", bson.A{"commitTransaction"}},
   128  					{"closeConnection", true},
   129  				}},
   130  			}
   131  			err = dbAdmin.RunCommand(bgCtx, failpoint).Err()
   132  			assert.Nil(t, err, "error setting failpoint: %v", err)
   133  			defer func() {
   134  				err = dbAdmin.RunCommand(bgCtx, bson.D{
   135  					{"configureFailPoint", "failCommand"},
   136  					{"mode", "off"},
   137  				}).Err()
   138  				assert.Nil(t, err, "error turning off failpoint: %v", err)
   139  			}()
   140  
   141  			sess, err := client.StartSession()
   142  			assert.Nil(t, err, "StartSession error: %v", err)
   143  			defer sess.EndSession(context.Background())
   144  
   145  			_, err = sess.WithTransaction(context.Background(), func(ctx SessionContext) (interface{}, error) {
   146  				_, err := coll.InsertOne(ctx, bson.D{{"x", 1}})
   147  				return nil, err
   148  			})
   149  			assert.NotNil(t, err, "expected WithTransaction error, got nil")
   150  			cmdErr, ok := err.(CommandError)
   151  			assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
   152  			assert.True(t, cmdErr.HasErrorLabel(driver.UnknownTransactionCommitResult),
   153  				"expected error with label %v, got %v", driver.UnknownTransactionCommitResult, cmdErr)
   154  		})
   155  		t.Run("commit transient transaction error", func(t *testing.T) {
   156  			//set failpoint
   157  			failpoint := bson.D{{"configureFailPoint", "failCommand"},
   158  				{"mode", "alwaysOn"},
   159  				{"data", bson.D{
   160  					{"failCommands", bson.A{"commitTransaction"}},
   161  					{"errorCode", 251},
   162  				}},
   163  			}
   164  			err = dbAdmin.RunCommand(bgCtx, failpoint).Err()
   165  			assert.Nil(t, err, "error setting failpoint: %v", err)
   166  			defer func() {
   167  				err = dbAdmin.RunCommand(bgCtx, bson.D{
   168  					{"configureFailPoint", "failCommand"},
   169  					{"mode", "off"},
   170  				}).Err()
   171  				assert.Nil(t, err, "error turning off failpoint: %v", err)
   172  			}()
   173  
   174  			sess, err := client.StartSession()
   175  			assert.Nil(t, err, "StartSession error: %v", err)
   176  			defer sess.EndSession(context.Background())
   177  
   178  			_, err = sess.WithTransaction(context.Background(), func(ctx SessionContext) (interface{}, error) {
   179  				_, err := coll.InsertOne(ctx, bson.D{{"x", 1}})
   180  				return nil, err
   181  			})
   182  			assert.NotNil(t, err, "expected WithTransaction error, got nil")
   183  			cmdErr, ok := err.(CommandError)
   184  			assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
   185  			assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
   186  				"expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
   187  		})
   188  	})
   189  	t.Run("abortTransaction does not time out", func(t *testing.T) {
   190  		// Create a special CommandMonitor that only records information about abortTransaction events and also
   191  		// records the Context used in the CommandStartedEvent listener.
   192  		var abortStarted []*event.CommandStartedEvent
   193  		var abortSucceeded []*event.CommandSucceededEvent
   194  		var abortFailed []*event.CommandFailedEvent
   195  		var abortCtx context.Context
   196  		monitor := &event.CommandMonitor{
   197  			Started: func(ctx context.Context, evt *event.CommandStartedEvent) {
   198  				if evt.CommandName == "abortTransaction" {
   199  					abortStarted = append(abortStarted, evt)
   200  					if abortCtx == nil {
   201  						abortCtx = ctx
   202  					}
   203  				}
   204  			},
   205  			Succeeded: func(_ context.Context, evt *event.CommandSucceededEvent) {
   206  				if evt.CommandName == "abortTransaction" {
   207  					abortSucceeded = append(abortSucceeded, evt)
   208  				}
   209  			},
   210  			Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
   211  				if evt.CommandName == "abortTransaction" {
   212  					abortFailed = append(abortFailed, evt)
   213  				}
   214  			},
   215  		}
   216  
   217  		// Set up a new Client using the command monitor defined above get a handle to a collection. The collection
   218  		// needs to be explicitly created on the server because implicit collection creation is not allowed in
   219  		// transactions for server versions <= 4.2.
   220  		client := setupConvenientTransactions(t, options.Client().SetMonitor(monitor))
   221  		db := client.Database("foo")
   222  		coll := db.Collection("bar")
   223  		err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
   224  		assert.Nil(t, err, "error creating collection on server: %v\n", err)
   225  
   226  		sess, err := client.StartSession()
   227  		assert.Nil(t, err, "StartSession error: %v", err)
   228  		defer func() {
   229  			sess.EndSession(bgCtx)
   230  			_ = coll.Drop(bgCtx)
   231  			_ = client.Disconnect(bgCtx)
   232  		}()
   233  
   234  		// Create a cancellable Context with a value for ctxKey.
   235  		type ctxKey struct{}
   236  		ctx, cancel := context.WithCancel(context.WithValue(context.Background(), ctxKey{}, "foobar"))
   237  		defer cancel()
   238  
   239  		// The WithTransaction callback does an Insert to ensure that the txn has been started server-side. After the
   240  		// insert succeeds, it cancels the Context created above and returns a non-retryable error, which forces
   241  		// WithTransaction to abort the txn.
   242  		callbackErr := errors.New("error")
   243  		callback := func(sc SessionContext) (interface{}, error) {
   244  			_, err = coll.InsertOne(sc, bson.D{{"x", 1}})
   245  			if err != nil {
   246  				return nil, err
   247  			}
   248  
   249  			cancel()
   250  			return nil, callbackErr
   251  		}
   252  
   253  		_, err = sess.WithTransaction(ctx, callback)
   254  		assert.Equal(t, callbackErr, err, "expected WithTransaction error %v, got %v", callbackErr, err)
   255  
   256  		// Assert that abortTransaction was sent once and succeede.
   257  		assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted))
   258  		assert.Equal(t, 1, len(abortSucceeded), "expected 1 abortTransaction succeeded event, got %d",
   259  			len(abortSucceeded))
   260  		assert.Equal(t, 0, len(abortFailed), "expected 0 abortTransaction failed event, got %d", len(abortFailed))
   261  
   262  		// Assert that the Context propagated to the CommandStartedEvent listener for abortTransaction contained a value
   263  		// for ctxKey.
   264  		ctxValue, ok := abortCtx.Value(ctxKey{}).(string)
   265  		assert.True(t, ok, "expected context for abortTransaction to contain ctxKey")
   266  		assert.Equal(t, "foobar", ctxValue, "expected value for ctxKey to be 'world', got %s", ctxValue)
   267  	})
   268  	t.Run("commitTransaction timeout allows abortTransaction", func(t *testing.T) {
   269  		// Create a special CommandMonitor that only records information about abortTransaction events.
   270  		var abortStarted []*event.CommandStartedEvent
   271  		var abortSucceeded []*event.CommandSucceededEvent
   272  		var abortFailed []*event.CommandFailedEvent
   273  		monitor := &event.CommandMonitor{
   274  			Started: func(ctx context.Context, evt *event.CommandStartedEvent) {
   275  				if evt.CommandName == "abortTransaction" {
   276  					abortStarted = append(abortStarted, evt)
   277  				}
   278  			},
   279  			Succeeded: func(_ context.Context, evt *event.CommandSucceededEvent) {
   280  				if evt.CommandName == "abortTransaction" {
   281  					abortSucceeded = append(abortSucceeded, evt)
   282  				}
   283  			},
   284  			Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
   285  				if evt.CommandName == "abortTransaction" {
   286  					abortFailed = append(abortFailed, evt)
   287  				}
   288  			},
   289  		}
   290  
   291  		// Set up a new Client using the command monitor defined above get a handle to a collection. The collection
   292  		// needs to be explicitly created on the server because implicit collection creation is not allowed in
   293  		// transactions for server versions <= 4.2.
   294  		client := setupConvenientTransactions(t, options.Client().SetMonitor(monitor))
   295  		db := client.Database("foo")
   296  		coll := db.Collection("test")
   297  		defer func() {
   298  			_ = coll.Drop(bgCtx)
   299  		}()
   300  
   301  		err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
   302  		assert.Nil(t, err, "error creating collection on server: %v", err)
   303  
   304  		// Start session.
   305  		session, err := client.StartSession()
   306  		defer session.EndSession(bgCtx)
   307  		assert.Nil(t, err, "StartSession error: %v", err)
   308  
   309  		_ = WithSession(bgCtx, session, func(sessionContext SessionContext) error {
   310  			// Start transaction.
   311  			err = session.StartTransaction()
   312  			assert.Nil(t, err, "StartTransaction error: %v", err)
   313  
   314  			// Insert a document.
   315  			_, err := coll.InsertOne(sessionContext, bson.D{{"val", 17}})
   316  			assert.Nil(t, err, "InsertOne error: %v", err)
   317  
   318  			// Set a timeout of 0 for commitTransaction.
   319  			commitTimeoutCtx, commitCancel := context.WithTimeout(sessionContext, 0)
   320  			defer commitCancel()
   321  
   322  			// CommitTransaction results in context.DeadlineExceeded.
   323  			commitErr := session.CommitTransaction(commitTimeoutCtx)
   324  			assert.True(t, IsTimeout(commitErr),
   325  				"expected timeout error error; got %v", commitErr)
   326  
   327  			// Assert session state is not Committed.
   328  			clientSession := session.(XSession).ClientSession()
   329  			assert.False(t, clientSession.TransactionCommitted(), "expected session state to not be Committed")
   330  
   331  			// AbortTransaction without error.
   332  			abortErr := session.AbortTransaction(context.Background())
   333  			assert.Nil(t, abortErr, "AbortTransaction error: %v", abortErr)
   334  
   335  			// Assert that AbortTransaction was started once and succeeded.
   336  			assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted))
   337  			assert.Equal(t, 1, len(abortSucceeded), "expected 1 abortTransaction succeeded event, got %d",
   338  				len(abortSucceeded))
   339  			assert.Equal(t, 0, len(abortFailed), "expected 0 abortTransaction failed events, got %d", len(abortFailed))
   340  
   341  			return nil
   342  		})
   343  	})
   344  	t.Run("context error before commitTransaction does not retry and aborts", func(t *testing.T) {
   345  		withTransactionTimeout = 2 * time.Second
   346  
   347  		// Create a special CommandMonitor that only records information about abortTransaction events.
   348  		var abortStarted []*event.CommandStartedEvent
   349  		var abortSucceeded []*event.CommandSucceededEvent
   350  		var abortFailed []*event.CommandFailedEvent
   351  		monitor := &event.CommandMonitor{
   352  			Started: func(ctx context.Context, evt *event.CommandStartedEvent) {
   353  				if evt.CommandName == "abortTransaction" {
   354  					abortStarted = append(abortStarted, evt)
   355  				}
   356  			},
   357  			Succeeded: func(_ context.Context, evt *event.CommandSucceededEvent) {
   358  				if evt.CommandName == "abortTransaction" {
   359  					abortSucceeded = append(abortSucceeded, evt)
   360  				}
   361  			},
   362  			Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
   363  				if evt.CommandName == "abortTransaction" {
   364  					abortFailed = append(abortFailed, evt)
   365  				}
   366  			},
   367  		}
   368  
   369  		// Set up a new Client using the command monitor defined above get a handle to a collection. The collection
   370  		// needs to be explicitly created on the server because implicit collection creation is not allowed in
   371  		// transactions for server versions <= 4.2.
   372  		client := setupConvenientTransactions(t, options.Client().SetMonitor(monitor))
   373  		db := client.Database("foo")
   374  		coll := db.Collection("test")
   375  		// Explicitly create the collection on server because implicit collection creation is not allowed in
   376  		// transactions for server versions <= 4.2.
   377  		err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
   378  		assert.Nil(t, err, "error creating collection on server: %v", err)
   379  		defer func() {
   380  			_ = coll.Drop(bgCtx)
   381  		}()
   382  
   383  		// Start session.
   384  		sess, err := client.StartSession()
   385  		assert.Nil(t, err, "StartSession error: %v", err)
   386  		defer sess.EndSession(context.Background())
   387  
   388  		// Defer running killAllSessions to manually close open transaction.
   389  		defer func() {
   390  			err := dbAdmin.RunCommand(bgCtx, bson.D{
   391  				{"killAllSessions", bson.A{}},
   392  			}).Err()
   393  			if err != nil {
   394  				if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
   395  					t.Fatalf("killAllSessions error: %v", err)
   396  				}
   397  			}
   398  		}()
   399  
   400  		// Insert a document within a session and manually cancel context before
   401  		// "commitTransaction" can be sent.
   402  		callback := func(ctx context.Context) {
   403  			transactionCtx, cancel := context.WithCancel(ctx)
   404  
   405  			_, _ = sess.WithTransaction(transactionCtx, func(ctx SessionContext) (interface{}, error) {
   406  				_, err := coll.InsertOne(ctx, bson.M{"x": 1})
   407  				assert.Nil(t, err, "InsertOne error: %v", err)
   408  				cancel()
   409  				return nil, nil
   410  			})
   411  		}
   412  
   413  		// Assert that transaction is canceled within 500ms and not 2 seconds.
   414  		assert.Soon(t, callback, 500*time.Millisecond)
   415  
   416  		// Assert that AbortTransaction was started once and succeeded.
   417  		assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted))
   418  		assert.Equal(t, 1, len(abortSucceeded), "expected 1 abortTransaction succeeded event, got %d",
   419  			len(abortSucceeded))
   420  		assert.Equal(t, 0, len(abortFailed), "expected 0 abortTransaction failed events, got %d", len(abortFailed))
   421  	})
   422  	t.Run("wrapped transient transaction error retried", func(t *testing.T) {
   423  		sess, err := client.StartSession()
   424  		assert.Nil(t, err, "StartSession error: %v", err)
   425  		defer sess.EndSession(context.Background())
   426  
   427  		// returnError tracks whether or not the callback is being retried
   428  		returnError := true
   429  		res, err := sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) {
   430  			if returnError {
   431  				returnError = false
   432  				return nil, fmt.Errorf("%w",
   433  					CommandError{
   434  						Name:   "test Error",
   435  						Labels: []string{driver.TransientTransactionError},
   436  					},
   437  				)
   438  			}
   439  			return false, nil
   440  		})
   441  		assert.Nil(t, err, "WithTransaction error: %v", err)
   442  		resBool, ok := res.(bool)
   443  		assert.True(t, ok, "expected result type %T, got %T", false, res)
   444  		assert.False(t, resBool, "expected result false, got %v", resBool)
   445  	})
   446  	t.Run("expired context before callback does not retry", func(t *testing.T) {
   447  		withTransactionTimeout = 2 * time.Second
   448  
   449  		coll := db.Collection("test")
   450  		// Explicitly create the collection on server because implicit collection creation is not allowed in
   451  		// transactions for server versions <= 4.2.
   452  		err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
   453  		assert.Nil(t, err, "error creating collection on server: %v", err)
   454  		defer func() {
   455  			_ = coll.Drop(bgCtx)
   456  		}()
   457  
   458  		sess, err := client.StartSession()
   459  		assert.Nil(t, err, "StartSession error: %v", err)
   460  		defer sess.EndSession(context.Background())
   461  
   462  		callback := func(ctx context.Context) {
   463  			// Create transaction context with short timeout.
   464  			withTransactionContext, cancel := context.WithTimeout(ctx, time.Nanosecond)
   465  			defer cancel()
   466  
   467  			_, _ = sess.WithTransaction(withTransactionContext, func(ctx SessionContext) (interface{}, error) {
   468  				_, err := coll.InsertOne(ctx, bson.D{{}})
   469  				return nil, err
   470  			})
   471  		}
   472  
   473  		// Assert that transaction fails within 500ms and not 2 seconds.
   474  		assert.Soon(t, callback, 500*time.Millisecond)
   475  	})
   476  	t.Run("canceled context before callback does not retry", func(t *testing.T) {
   477  		withTransactionTimeout = 2 * time.Second
   478  
   479  		coll := db.Collection("test")
   480  		// Explicitly create the collection on server because implicit collection creation is not allowed in
   481  		// transactions for server versions <= 4.2.
   482  		err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
   483  		assert.Nil(t, err, "error creating collection on server: %v", err)
   484  		defer func() {
   485  			_ = coll.Drop(bgCtx)
   486  		}()
   487  
   488  		sess, err := client.StartSession()
   489  		assert.Nil(t, err, "StartSession error: %v", err)
   490  		defer sess.EndSession(context.Background())
   491  
   492  		callback := func(ctx context.Context) {
   493  			// Create transaction context and cancel it immediately.
   494  			withTransactionContext, cancel := context.WithTimeout(ctx, 2*time.Second)
   495  			cancel()
   496  
   497  			_, _ = sess.WithTransaction(withTransactionContext, func(ctx SessionContext) (interface{}, error) {
   498  				_, err := coll.InsertOne(ctx, bson.D{{}})
   499  				return nil, err
   500  			})
   501  		}
   502  
   503  		// Assert that transaction fails within 500ms and not 2 seconds.
   504  		assert.Soon(t, callback, 500*time.Millisecond)
   505  	})
   506  	t.Run("slow operation in callback retries", func(t *testing.T) {
   507  		withTransactionTimeout = 2 * time.Second
   508  
   509  		coll := db.Collection("test")
   510  		// Explicitly create the collection on server because implicit collection creation is not allowed in
   511  		// transactions for server versions <= 4.2.
   512  		err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
   513  		assert.Nil(t, err, "error creating collection on server: %v", err)
   514  		defer func() {
   515  			_ = coll.Drop(bgCtx)
   516  		}()
   517  
   518  		// Set failpoint to block insertOne once for 500ms.
   519  		failpoint := bson.D{{"configureFailPoint", "failCommand"},
   520  			{"mode", bson.D{
   521  				{"times", 1},
   522  			}},
   523  			{"data", bson.D{
   524  				{"failCommands", bson.A{"insert"}},
   525  				{"blockConnection", true},
   526  				{"blockTimeMS", 500},
   527  			}},
   528  		}
   529  		err = dbAdmin.RunCommand(bgCtx, failpoint).Err()
   530  		assert.Nil(t, err, "error setting failpoint: %v", err)
   531  		defer func() {
   532  			err = dbAdmin.RunCommand(bgCtx, bson.D{
   533  				{"configureFailPoint", "failCommand"},
   534  				{"mode", "off"},
   535  			}).Err()
   536  			assert.Nil(t, err, "error turning off failpoint: %v", err)
   537  		}()
   538  
   539  		sess, err := client.StartSession()
   540  		assert.Nil(t, err, "StartSession error: %v", err)
   541  		defer sess.EndSession(context.Background())
   542  
   543  		callback := func(ctx context.Context) {
   544  			_, err = sess.WithTransaction(ctx, func(ctx SessionContext) (interface{}, error) {
   545  				// Set a timeout of 300ms to cause a timeout on first insertOne
   546  				// and force a retry.
   547  				c, cancel := context.WithTimeout(ctx, 300*time.Millisecond)
   548  				defer cancel()
   549  
   550  				_, err := coll.InsertOne(c, bson.D{{}})
   551  				return nil, err
   552  			})
   553  			assert.Nil(t, err, "WithTransaction error: %v", err)
   554  		}
   555  
   556  		// Assert that transaction passes within 2 seconds.
   557  		assert.Soon(t, callback, 2*time.Second)
   558  	})
   559  }
   560  
   561  func setupConvenientTransactions(t *testing.T, extraClientOpts ...*options.ClientOptions) *Client {
   562  	cs := integtest.ConnString(t)
   563  	poolMonitor := &event.PoolMonitor{
   564  		Event: func(evt *event.PoolEvent) {
   565  			switch evt.Type {
   566  			case event.GetSucceeded:
   567  				connsCheckedOut++
   568  			case event.ConnectionReturned:
   569  				connsCheckedOut--
   570  			}
   571  		},
   572  	}
   573  
   574  	baseClientOpts := options.Client().
   575  		ApplyURI(cs.Original).
   576  		SetReadPreference(readpref.Primary()).
   577  		SetWriteConcern(writeconcern.New(writeconcern.WMajority())).
   578  		SetPoolMonitor(poolMonitor)
   579  	integtest.AddTestServerAPIVersion(baseClientOpts)
   580  	fullClientOpts := []*options.ClientOptions{baseClientOpts}
   581  	fullClientOpts = append(fullClientOpts, extraClientOpts...)
   582  
   583  	client, err := Connect(bgCtx, fullClientOpts...)
   584  	assert.Nil(t, err, "Connect error: %v", err)
   585  
   586  	version, err := getServerVersion(client.Database("admin"))
   587  	assert.Nil(t, err, "getServerVersion error: %v", err)
   588  	topoKind := client.deployment.(*topology.Topology).Kind()
   589  	if compareVersions(version, "4.1") < 0 || topoKind == description.Single {
   590  		t.Skip("skipping standalones and versions < 4.1")
   591  	}
   592  
   593  	if topoKind != description.Sharded {
   594  		return client
   595  	}
   596  
   597  	// For sharded clusters, disconnect the previous Client and create a new one that's pinned to a single mongos.
   598  	_ = client.Disconnect(bgCtx)
   599  	fullClientOpts = append(fullClientOpts, options.Client().SetHosts([]string{cs.Hosts[0]}))
   600  	client, err = Connect(bgCtx, fullClientOpts...)
   601  	assert.Nil(t, err, "Connect error: %v", err)
   602  	return client
   603  }
   604  
   605  func getServerVersion(db *Database) (string, error) {
   606  	serverStatus, err := db.RunCommand(
   607  		context.Background(),
   608  		bson.D{{"serverStatus", 1}},
   609  	).Raw()
   610  	if err != nil {
   611  		return "", err
   612  	}
   613  
   614  	version, err := serverStatus.LookupErr("version")
   615  	if err != nil {
   616  		return "", err
   617  	}
   618  
   619  	return version.StringValue(), nil
   620  }
   621  
   622  // compareVersions compares two version number strings (i.e. positive integers separated by
   623  // periods). Comparisons are done to the lesser precision of the two versions. For example, 3.2 is
   624  // considered equal to 3.2.11, whereas 3.2.0 is considered less than 3.2.11.
   625  //
   626  // Returns a positive int if version1 is greater than version2, a negative int if version1 is less
   627  // than version2, and 0 if version1 is equal to version2.
   628  func compareVersions(v1 string, v2 string) int {
   629  	n1 := strings.Split(v1, ".")
   630  	n2 := strings.Split(v2, ".")
   631  
   632  	for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ {
   633  		i1, err := strconv.Atoi(n1[i])
   634  		if err != nil {
   635  			return 1
   636  		}
   637  
   638  		i2, err := strconv.Atoi(n2[i])
   639  		if err != nil {
   640  			return -1
   641  		}
   642  
   643  		difference := i1 - i2
   644  		if difference != 0 {
   645  			return difference
   646  		}
   647  	}
   648  
   649  	return 0
   650  }
   651  

View as plain text