...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/operation_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver

     1  // Copyright (C) MongoDB, Inc. 2022-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package driver
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"errors"
    13  	"math"
    14  	"testing"
    15  	"time"
    16  
    17  	"github.com/google/go-cmp/cmp"
    18  	"go.mongodb.org/mongo-driver/bson/bsontype"
    19  	"go.mongodb.org/mongo-driver/bson/primitive"
    20  	"go.mongodb.org/mongo-driver/internal/assert"
    21  	"go.mongodb.org/mongo-driver/internal/csot"
    22  	"go.mongodb.org/mongo-driver/internal/handshake"
    23  	"go.mongodb.org/mongo-driver/internal/require"
    24  	"go.mongodb.org/mongo-driver/internal/uuid"
    25  	"go.mongodb.org/mongo-driver/mongo/address"
    26  	"go.mongodb.org/mongo-driver/mongo/description"
    27  	"go.mongodb.org/mongo-driver/mongo/readconcern"
    28  	"go.mongodb.org/mongo-driver/mongo/readpref"
    29  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    30  	"go.mongodb.org/mongo-driver/tag"
    31  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    32  	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
    33  	"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
    34  )
    35  
    36  func noerr(t *testing.T, err error) {
    37  	t.Helper()
    38  	if err != nil {
    39  		t.Errorf("Unexpected error: %v", err)
    40  		t.FailNow()
    41  	}
    42  }
    43  
    44  func compareErrors(err1, err2 error) bool {
    45  	if err1 == nil && err2 == nil {
    46  		return true
    47  	}
    48  
    49  	if err1 == nil || err2 == nil {
    50  		return false
    51  	}
    52  
    53  	if err1.Error() != err2.Error() {
    54  		return false
    55  	}
    56  
    57  	return true
    58  }
    59  
    60  func TestOperation(t *testing.T) {
    61  	int64ToPtr := func(i64 int64) *int64 { return &i64 }
    62  
    63  	t.Run("selectServer", func(t *testing.T) {
    64  		t.Run("returns validation error", func(t *testing.T) {
    65  			op := &Operation{}
    66  			_, err := op.selectServer(context.Background(), 1, nil)
    67  			if err == nil {
    68  				t.Error("Expected a validation error from selectServer, but got <nil>")
    69  			}
    70  		})
    71  		t.Run("uses specified server selector", func(t *testing.T) {
    72  			want := new(mockServerSelector)
    73  			d := new(mockDeployment)
    74  			op := &Operation{
    75  				CommandFn:  func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil },
    76  				Deployment: d,
    77  				Database:   "testing",
    78  				Selector:   want,
    79  			}
    80  			_, err := op.selectServer(context.Background(), 1, nil)
    81  			noerr(t, err)
    82  
    83  			// Assert the the selector is an operation selector wrapper.
    84  			oss, ok := d.params.selector.(*opServerSelector)
    85  			require.True(t, ok)
    86  
    87  			if !cmp.Equal(oss.selector, want) {
    88  				t.Errorf("Did not get expected server selector. got %v; want %v", oss.selector, want)
    89  			}
    90  		})
    91  		t.Run("uses a default server selector", func(t *testing.T) {
    92  			d := new(mockDeployment)
    93  			op := &Operation{
    94  				CommandFn:  func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil },
    95  				Deployment: d,
    96  				Database:   "testing",
    97  			}
    98  			_, err := op.selectServer(context.Background(), 1, nil)
    99  			noerr(t, err)
   100  			if d.params.selector == nil {
   101  				t.Error("The selectServer method should use a default selector when not specified on Operation, but it passed <nil>.")
   102  			}
   103  		})
   104  	})
   105  	t.Run("Validate", func(t *testing.T) {
   106  		cmdFn := func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil }
   107  		d := new(mockDeployment)
   108  		testCases := []struct {
   109  			name string
   110  			op   *Operation
   111  			err  error
   112  		}{
   113  			{"CommandFn", &Operation{}, InvalidOperationError{MissingField: "CommandFn"}},
   114  			{"Deployment", &Operation{CommandFn: cmdFn}, InvalidOperationError{MissingField: "Deployment"}},
   115  			{"Database", &Operation{CommandFn: cmdFn, Deployment: d}, errDatabaseNameEmpty},
   116  			{"<nil>", &Operation{CommandFn: cmdFn, Deployment: d, Database: "test"}, nil},
   117  		}
   118  
   119  		for _, tc := range testCases {
   120  			t.Run(tc.name, func(t *testing.T) {
   121  				if tc.op == nil {
   122  					t.Fatal("op cannot be <nil>")
   123  				}
   124  				want := tc.err
   125  				got := tc.op.Validate()
   126  				if !cmp.Equal(got, want, cmp.Comparer(compareErrors)) {
   127  					t.Errorf("Did not validate properly. got %v; want %v", got, want)
   128  				}
   129  			})
   130  		}
   131  	})
   132  	t.Run("retryableWrite", func(t *testing.T) {
   133  		sessPool := session.NewPool(nil)
   134  		id, err := uuid.New()
   135  		noerr(t, err)
   136  
   137  		sess, err := session.NewClientSession(sessPool, id)
   138  		noerr(t, err)
   139  
   140  		sessStartingTransaction, err := session.NewClientSession(sessPool, id)
   141  		noerr(t, err)
   142  		err = sessStartingTransaction.StartTransaction(nil)
   143  		noerr(t, err)
   144  
   145  		sessInProgressTransaction, err := session.NewClientSession(sessPool, id)
   146  		noerr(t, err)
   147  		err = sessInProgressTransaction.StartTransaction(nil)
   148  		noerr(t, err)
   149  		err = sessInProgressTransaction.ApplyCommand(description.Server{})
   150  		noerr(t, err)
   151  
   152  		wcAck := writeconcern.New(writeconcern.WMajority())
   153  		wcUnack := writeconcern.New(writeconcern.W(0))
   154  
   155  		descRetryable := description.Server{
   156  			WireVersion:              &description.VersionRange{Min: 6, Max: 21},
   157  			SessionTimeoutMinutes:    1,
   158  			SessionTimeoutMinutesPtr: int64ToPtr(1),
   159  		}
   160  
   161  		descNotRetryableWireVersion := description.Server{
   162  			WireVersion:              &description.VersionRange{Min: 6, Max: 21},
   163  			SessionTimeoutMinutes:    1,
   164  			SessionTimeoutMinutesPtr: int64ToPtr(1),
   165  		}
   166  
   167  		descNotRetryableStandalone := description.Server{
   168  			WireVersion:              &description.VersionRange{Min: 6, Max: 21},
   169  			SessionTimeoutMinutes:    1,
   170  			SessionTimeoutMinutesPtr: int64ToPtr(1),
   171  			Kind:                     description.Standalone,
   172  		}
   173  
   174  		testCases := []struct {
   175  			name string
   176  			op   Operation
   177  			desc description.Server
   178  			want Type
   179  		}{
   180  			{"deployment doesn't support", Operation{}, description.Server{}, Type(0)},
   181  			{"wire version too low", Operation{Client: sess, WriteConcern: wcAck}, descNotRetryableWireVersion, Type(0)},
   182  			{"standalone not supported", Operation{Client: sess, WriteConcern: wcAck}, descNotRetryableStandalone, Type(0)},
   183  			{
   184  				"transaction in progress",
   185  				Operation{Client: sessInProgressTransaction, WriteConcern: wcAck},
   186  				descRetryable, Type(0),
   187  			},
   188  			{
   189  				"transaction starting",
   190  				Operation{Client: sessStartingTransaction, WriteConcern: wcAck},
   191  				descRetryable, Type(0),
   192  			},
   193  			{"unacknowledged write concern", Operation{Client: sess, WriteConcern: wcUnack}, descRetryable, Type(0)},
   194  			{
   195  				"acknowledged write concern",
   196  				Operation{Client: sess, WriteConcern: wcAck, Type: Write},
   197  				descRetryable, Write,
   198  			},
   199  		}
   200  
   201  		for _, tc := range testCases {
   202  			t.Run(tc.name, func(t *testing.T) {
   203  				got := tc.op.retryable(tc.desc)
   204  				if got != (tc.want != Type(0)) {
   205  					t.Errorf("Did not receive expected Type. got %v; want %v", got, tc.want)
   206  				}
   207  			})
   208  		}
   209  	})
   210  	t.Run("addReadConcern", func(t *testing.T) {
   211  		majorityRc := bsoncore.AppendDocumentElement(nil, "readConcern", bsoncore.BuildDocument(nil,
   212  			bsoncore.AppendStringElement(nil, "level", "majority"),
   213  		))
   214  
   215  		testCases := []struct {
   216  			name string
   217  			rc   *readconcern.ReadConcern
   218  			want bsoncore.Document
   219  		}{
   220  			{"nil", nil, nil},
   221  			{"empty", readconcern.New(), nil},
   222  			{"non-empty", readconcern.Majority(), majorityRc},
   223  		}
   224  
   225  		for _, tc := range testCases {
   226  			got, err := Operation{ReadConcern: tc.rc}.addReadConcern(nil, description.SelectedServer{})
   227  			noerr(t, err)
   228  			if !bytes.Equal(got, tc.want) {
   229  				t.Errorf("ReadConcern elements do not match. got %v; want %v", got, tc.want)
   230  			}
   231  		}
   232  	})
   233  	t.Run("addWriteConcern", func(t *testing.T) {
   234  		want := bsoncore.AppendDocumentElement(nil, "writeConcern", bsoncore.BuildDocumentFromElements(
   235  			nil, bsoncore.AppendStringElement(nil, "w", "majority"),
   236  		))
   237  		got, err := Operation{WriteConcern: writeconcern.New(writeconcern.WMajority())}.addWriteConcern(nil, description.SelectedServer{})
   238  		noerr(t, err)
   239  		if !bytes.Equal(got, want) {
   240  			t.Errorf("WriteConcern elements do not match. got %v; want %v", got, want)
   241  		}
   242  	})
   243  	t.Run("addSession", func(t *testing.T) { t.Skip("These tests should be covered by spec tests.") })
   244  	t.Run("addClusterTime", func(t *testing.T) {
   245  		t.Run("adds max cluster time", func(t *testing.T) {
   246  			want := bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocumentFromElements(nil,
   247  				bsoncore.AppendTimestampElement(nil, "clusterTime", 1234, 5678),
   248  			))
   249  			newer := bsoncore.BuildDocumentFromElements(nil, want)
   250  			older := bsoncore.BuildDocumentFromElements(nil,
   251  				bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocumentFromElements(nil,
   252  					bsoncore.AppendTimestampElement(nil, "clusterTime", 1234, 5670),
   253  				)),
   254  			)
   255  
   256  			clusterClock := new(session.ClusterClock)
   257  			clusterClock.AdvanceClusterTime(newer)
   258  			sessPool := session.NewPool(nil)
   259  			id, err := uuid.New()
   260  			noerr(t, err)
   261  
   262  			sess, err := session.NewClientSession(sessPool, id)
   263  			noerr(t, err)
   264  			err = sess.AdvanceClusterTime(older)
   265  			noerr(t, err)
   266  
   267  			got := Operation{Client: sess, Clock: clusterClock}.addClusterTime(nil, description.SelectedServer{
   268  				Server: description.Server{WireVersion: &description.VersionRange{Min: 6, Max: 21}},
   269  			})
   270  			if !bytes.Equal(got, want) {
   271  				t.Errorf("ClusterTimes do not match. got %v; want %v", got, want)
   272  			}
   273  		})
   274  	})
   275  	t.Run("calculateMaxTimeMS", func(t *testing.T) {
   276  		timeout := 5 * time.Second
   277  		maxTime := 2 * time.Second
   278  		negMaxTime := -2 * time.Second
   279  		shortRTT := 50 * time.Millisecond
   280  		longRTT := 10 * time.Second
   281  		timeoutCtx, cancel := csot.MakeTimeoutContext(context.Background(), timeout)
   282  		defer cancel()
   283  
   284  		testCases := []struct {
   285  			name  string
   286  			op    Operation
   287  			ctx   context.Context
   288  			rtt90 time.Duration
   289  			want  uint64
   290  			err   error
   291  		}{
   292  			{
   293  				name:  "uses context deadline and rtt90 with timeout",
   294  				op:    Operation{MaxTime: &maxTime},
   295  				ctx:   timeoutCtx,
   296  				rtt90: shortRTT,
   297  				want:  5000,
   298  				err:   nil,
   299  			},
   300  			{
   301  				name:  "uses MaxTime without timeout",
   302  				op:    Operation{MaxTime: &maxTime},
   303  				ctx:   context.Background(),
   304  				rtt90: longRTT,
   305  				want:  2000,
   306  				err:   nil,
   307  			},
   308  			{
   309  				name:  "errors when remaining timeout is less than rtt90",
   310  				op:    Operation{MaxTime: &maxTime},
   311  				ctx:   timeoutCtx,
   312  				rtt90: timeout,
   313  				want:  0,
   314  				err:   ErrDeadlineWouldBeExceeded,
   315  			},
   316  			{
   317  				name:  "errors when MaxTime is negative",
   318  				op:    Operation{MaxTime: &negMaxTime},
   319  				ctx:   context.Background(),
   320  				rtt90: longRTT,
   321  				want:  0,
   322  				err:   ErrNegativeMaxTime,
   323  			},
   324  		}
   325  		for _, tc := range testCases {
   326  			// Capture test-case for parallel sub-test.
   327  			tc := tc
   328  			t.Run(tc.name, func(t *testing.T) {
   329  				t.Parallel()
   330  
   331  				got, err := tc.op.calculateMaxTimeMS(tc.ctx, mockRTTMonitor{p90: tc.rtt90})
   332  
   333  				// Assert that the calculated maxTimeMS is less than or equal to the expected value. A few
   334  				// milliseconds will have elapsed toward the context deadline, and (remainingTimeout
   335  				// - rtt90) will be slightly smaller than the expected value.
   336  				if got > tc.want {
   337  					t.Errorf("maxTimeMS value higher than expected. got %v; wanted at most %v", got, tc.want)
   338  				}
   339  				if !errors.Is(err, tc.err) {
   340  					t.Errorf("error values do not match. got %v; want %v", err, tc.err)
   341  				}
   342  			})
   343  		}
   344  	})
   345  	t.Run("updateClusterTimes", func(t *testing.T) {
   346  		clustertime := bsoncore.BuildDocumentFromElements(nil,
   347  			bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocumentFromElements(nil,
   348  				bsoncore.AppendTimestampElement(nil, "clusterTime", 1234, 5678),
   349  			)),
   350  		)
   351  
   352  		clusterClock := new(session.ClusterClock)
   353  		sessPool := session.NewPool(nil)
   354  		id, err := uuid.New()
   355  		noerr(t, err)
   356  
   357  		sess, err := session.NewClientSession(sessPool, id)
   358  		noerr(t, err)
   359  		Operation{Client: sess, Clock: clusterClock}.updateClusterTimes(clustertime)
   360  
   361  		got := sess.ClusterTime
   362  		if !bytes.Equal(got, clustertime) {
   363  			t.Errorf("ClusterTimes do not match. got %v; want %v", got, clustertime)
   364  		}
   365  		got = clusterClock.GetClusterTime()
   366  		if !bytes.Equal(got, clustertime) {
   367  			t.Errorf("ClusterTimes do not match. got %v; want %v", got, clustertime)
   368  		}
   369  
   370  		Operation{}.updateClusterTimes(bsoncore.BuildDocumentFromElements(nil)) // should do nothing
   371  	})
   372  	t.Run("updateOperationTime", func(t *testing.T) {
   373  		want := primitive.Timestamp{T: 1234, I: 4567}
   374  
   375  		sessPool := session.NewPool(nil)
   376  		id, err := uuid.New()
   377  		noerr(t, err)
   378  
   379  		sess, err := session.NewClientSession(sessPool, id)
   380  		noerr(t, err)
   381  		if sess.OperationTime != nil {
   382  			t.Fatal("OperationTime should not be set on new session.")
   383  		}
   384  		response := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendTimestampElement(nil, "operationTime", want.T, want.I))
   385  		Operation{Client: sess}.updateOperationTime(response)
   386  		got := sess.OperationTime
   387  		if got.T != want.T || got.I != want.I {
   388  			t.Errorf("OperationTimes do not match. got %v; want %v", got, want)
   389  		}
   390  
   391  		response = bsoncore.BuildDocumentFromElements(nil)
   392  		Operation{Client: sess}.updateOperationTime(response)
   393  		got = sess.OperationTime
   394  		if got.T != want.T || got.I != want.I {
   395  			t.Errorf("OperationTimes do not match. got %v; want %v", got, want)
   396  		}
   397  
   398  		Operation{}.updateOperationTime(response) // should do nothing
   399  	})
   400  	t.Run("createReadPref", func(t *testing.T) {
   401  		rpWithTags := bsoncore.BuildDocumentFromElements(nil,
   402  			bsoncore.AppendStringElement(nil, "mode", "secondaryPreferred"),
   403  			bsoncore.BuildArrayElement(nil, "tags",
   404  				bsoncore.Value{Type: bsontype.EmbeddedDocument,
   405  					Data: bsoncore.BuildDocumentFromElements(nil,
   406  						bsoncore.AppendStringElement(nil, "disk", "ssd"),
   407  						bsoncore.AppendStringElement(nil, "use", "reporting"),
   408  					),
   409  				},
   410  			),
   411  		)
   412  		rpWithMaxStaleness := bsoncore.BuildDocumentFromElements(nil,
   413  			bsoncore.AppendStringElement(nil, "mode", "secondaryPreferred"),
   414  			bsoncore.AppendInt32Element(nil, "maxStalenessSeconds", 25),
   415  		)
   416  		// Hedged read preference: {mode: "secondaryPreferred", hedge: {enabled: true}}
   417  		rpWithHedge := bsoncore.BuildDocumentFromElements(nil,
   418  			bsoncore.AppendStringElement(nil, "mode", "secondaryPreferred"),
   419  			bsoncore.AppendDocumentElement(nil, "hedge", bsoncore.BuildDocumentFromElements(nil,
   420  				bsoncore.AppendBooleanElement(nil, "enabled", true),
   421  			)),
   422  		)
   423  		rpWithAllOptions := bsoncore.BuildDocumentFromElements(nil,
   424  			bsoncore.AppendStringElement(nil, "mode", "secondaryPreferred"),
   425  			bsoncore.BuildArrayElement(nil, "tags",
   426  				bsoncore.Value{Type: bsontype.EmbeddedDocument,
   427  					Data: bsoncore.BuildDocumentFromElements(nil,
   428  						bsoncore.AppendStringElement(nil, "disk", "ssd"),
   429  						bsoncore.AppendStringElement(nil, "use", "reporting"),
   430  					),
   431  				},
   432  			),
   433  			bsoncore.AppendInt32Element(nil, "maxStalenessSeconds", 25),
   434  			bsoncore.AppendDocumentElement(nil, "hedge", bsoncore.BuildDocumentFromElements(nil,
   435  				bsoncore.AppendBooleanElement(nil, "enabled", false),
   436  			)),
   437  		)
   438  
   439  		rpPrimaryPreferred := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "primaryPreferred"))
   440  		rpSecondaryPreferred := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "secondaryPreferred"))
   441  		rpSecondary := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "secondary"))
   442  		rpNearest := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendStringElement(nil, "mode", "nearest"))
   443  
   444  		testCases := []struct {
   445  			name       string
   446  			rp         *readpref.ReadPref
   447  			serverKind description.ServerKind
   448  			topoKind   description.TopologyKind
   449  			opQuery    bool
   450  			want       bsoncore.Document
   451  		}{
   452  			{"nil/single/mongos", nil, description.Mongos, description.Single, false, nil},
   453  			{"nil/single/secondary", nil, description.RSSecondary, description.Single, false, rpPrimaryPreferred},
   454  			{"primary/mongos", readpref.Primary(), description.Mongos, description.Sharded, false, nil},
   455  			{"primary/single", readpref.Primary(), description.RSPrimary, description.Single, false, rpPrimaryPreferred},
   456  			{"primary/primary", readpref.Primary(), description.RSPrimary, description.ReplicaSet, false, nil},
   457  			{"primaryPreferred", readpref.PrimaryPreferred(), description.RSSecondary, description.ReplicaSet, false, rpPrimaryPreferred},
   458  			{"secondaryPreferred/mongos/opquery", readpref.SecondaryPreferred(), description.Mongos, description.Sharded, true, nil},
   459  			{"secondaryPreferred", readpref.SecondaryPreferred(), description.RSSecondary, description.ReplicaSet, false, rpSecondaryPreferred},
   460  			{"secondary", readpref.Secondary(), description.RSSecondary, description.ReplicaSet, false, rpSecondary},
   461  			{"nearest", readpref.Nearest(), description.RSSecondary, description.ReplicaSet, false, rpNearest},
   462  			{
   463  				"secondaryPreferred/withTags",
   464  				readpref.SecondaryPreferred(readpref.WithTags("disk", "ssd", "use", "reporting")),
   465  				description.RSSecondary, description.ReplicaSet, false, rpWithTags,
   466  			},
   467  			// GODRIVER-2205: Ensure empty tag sets are written as an empty document in the read
   468  			// preference document. Empty tag sets match any server and are used as a fallback when
   469  			// no other tag sets match any servers.
   470  			{
   471  				"secondaryPreferred/withTags/emptyTagSet",
   472  				readpref.SecondaryPreferred(readpref.WithTagSets(
   473  					tag.Set{{Name: "disk", Value: "ssd"}},
   474  					tag.Set{})),
   475  				description.RSSecondary,
   476  				description.ReplicaSet,
   477  				false,
   478  				bsoncore.NewDocumentBuilder().
   479  					AppendString("mode", "secondaryPreferred").
   480  					AppendArray("tags", bsoncore.NewArrayBuilder().
   481  						AppendDocument(bsoncore.NewDocumentBuilder().AppendString("disk", "ssd").Build()).
   482  						AppendDocument(bsoncore.NewDocumentBuilder().Build()).
   483  						Build()).
   484  					Build(),
   485  			},
   486  			{
   487  				"secondaryPreferred/withMaxStaleness",
   488  				readpref.SecondaryPreferred(readpref.WithMaxStaleness(25 * time.Second)),
   489  				description.RSSecondary, description.ReplicaSet, false, rpWithMaxStaleness,
   490  			},
   491  			{
   492  				// A read preference document is generated for SecondaryPreferred if the hedge document is non-nil.
   493  				"secondaryPreferred with hedge to mongos using OP_QUERY",
   494  				readpref.SecondaryPreferred(readpref.WithHedgeEnabled(true)),
   495  				description.Mongos,
   496  				description.Sharded,
   497  				true,
   498  				rpWithHedge,
   499  			},
   500  			{
   501  				"secondaryPreferred with all options",
   502  				readpref.SecondaryPreferred(
   503  					readpref.WithTags("disk", "ssd", "use", "reporting"),
   504  					readpref.WithMaxStaleness(25*time.Second),
   505  					readpref.WithHedgeEnabled(false),
   506  				),
   507  				description.RSSecondary,
   508  				description.ReplicaSet,
   509  				false,
   510  				rpWithAllOptions,
   511  			},
   512  		}
   513  
   514  		for _, tc := range testCases {
   515  			tc := tc
   516  			t.Run(tc.name, func(t *testing.T) {
   517  				desc := description.SelectedServer{Kind: tc.topoKind, Server: description.Server{Kind: tc.serverKind}}
   518  				got, err := Operation{ReadPreference: tc.rp}.createReadPref(desc, tc.opQuery)
   519  				if err != nil {
   520  					t.Fatalf("error creating read pref: %v", err)
   521  				}
   522  				if !bytes.Equal(got, tc.want) {
   523  					t.Errorf("Returned documents do not match. got %v; want %v", got, tc.want)
   524  				}
   525  			})
   526  		}
   527  	})
   528  	t.Run("secondaryOK", func(t *testing.T) {
   529  		t.Run("description.SelectedServer", func(t *testing.T) {
   530  			want := wiremessage.SecondaryOK
   531  			desc := description.SelectedServer{
   532  				Kind:   description.Single,
   533  				Server: description.Server{Kind: description.RSSecondary},
   534  			}
   535  			got := Operation{}.secondaryOK(desc)
   536  			if got != want {
   537  				t.Errorf("Did not receive expected query flags. got %v; want %v", got, want)
   538  			}
   539  		})
   540  		t.Run("readPreference", func(t *testing.T) {
   541  			want := wiremessage.SecondaryOK
   542  			got := Operation{ReadPreference: readpref.Secondary()}.secondaryOK(description.SelectedServer{})
   543  			if got != want {
   544  				t.Errorf("Did not receive expected query flags. got %v; want %v", got, want)
   545  			}
   546  		})
   547  		t.Run("not secondaryOK", func(t *testing.T) {
   548  			var want wiremessage.QueryFlag
   549  			got := Operation{}.secondaryOK(description.SelectedServer{})
   550  			if got != want {
   551  				t.Errorf("Did not receive expected query flags. got %v; want %v", got, want)
   552  			}
   553  		})
   554  	})
   555  	t.Run("ExecuteExhaust", func(t *testing.T) {
   556  		t.Run("errors if connection is not streaming", func(t *testing.T) {
   557  			conn := &mockConnection{
   558  				rStreaming: false,
   559  			}
   560  			err := Operation{}.ExecuteExhaust(context.TODO(), conn)
   561  			assert.NotNil(t, err, "expected error, got nil")
   562  		})
   563  	})
   564  	t.Run("exhaustAllowed and moreToCome", func(t *testing.T) {
   565  		// Test the interaction between exhaustAllowed and moreToCome on requests/responses when using the Execute
   566  		// and ExecuteExhaust methods.
   567  
   568  		// Create a server response wire message that has moreToCome=false.
   569  		serverResponseDoc := bsoncore.BuildDocumentFromElements(nil,
   570  			bsoncore.AppendInt32Element(nil, "ok", 1),
   571  		)
   572  		nonStreamingResponse := createExhaustServerResponse(serverResponseDoc, false)
   573  
   574  		// Create a connection that reports that it cannot stream messages.
   575  		conn := &mockConnection{
   576  			rDesc: description.Server{
   577  				WireVersion: &description.VersionRange{
   578  					Max: 6,
   579  				},
   580  			},
   581  			rReadWM:    nonStreamingResponse,
   582  			rCanStream: false,
   583  		}
   584  		op := Operation{
   585  			CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
   586  				return bsoncore.AppendInt32Element(dst, handshake.LegacyHello, 1), nil
   587  			},
   588  			Database:   "admin",
   589  			Deployment: SingleConnectionDeployment{conn},
   590  		}
   591  		err := op.Execute(context.TODO())
   592  		assert.Nil(t, err, "Execute error: %v", err)
   593  
   594  		// The wire message sent to the server should not have exhaustAllowed=true. After execution, the connection
   595  		// should not be in a streaming state.
   596  		assertExhaustAllowedSet(t, conn.pWriteWM, false)
   597  		assert.False(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be false")
   598  
   599  		// Modify the connection to report that it can stream and create a new server response with moreToCome=true.
   600  		streamingResponse := createExhaustServerResponse(serverResponseDoc, true)
   601  		conn.rReadWM = streamingResponse
   602  		conn.rCanStream = true
   603  		err = op.Execute(context.TODO())
   604  		assert.Nil(t, err, "Execute error: %v", err)
   605  		assertExhaustAllowedSet(t, conn.pWriteWM, true)
   606  		assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true")
   607  
   608  		// Reset the server response and go through ExecuteExhaust to mimic streaming the next response. After
   609  		// execution, the connection should still be in a streaming state.
   610  		conn.rReadWM = streamingResponse
   611  		err = op.ExecuteExhaust(context.TODO(), conn)
   612  		assert.Nil(t, err, "ExecuteExhaust error: %v", err)
   613  		assert.True(t, conn.CurrentlyStreaming(), "expected CurrentlyStreaming to be true")
   614  	})
   615  	t.Run("context deadline exceeded not marked as TransientTransactionError", func(t *testing.T) {
   616  		conn := new(mockConnection)
   617  		// Create a context that's already timed out.
   618  		ctx, cancel := context.WithDeadline(context.Background(), time.Unix(893934480, 0))
   619  		defer cancel()
   620  
   621  		op := Operation{
   622  			Database:   "foobar",
   623  			Deployment: SingleConnectionDeployment{C: conn},
   624  			CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
   625  				dst = bsoncore.AppendInt32Element(dst, "ping", 1)
   626  				return dst, nil
   627  			},
   628  		}
   629  
   630  		err := op.Execute(ctx)
   631  		assert.NotNil(t, err, "expected an error from Execute(), got nil")
   632  		// Assert that error is just context deadline exceeded and is therefore not a driver.Error marked
   633  		// with the TransientTransactionError label.
   634  		assert.Equal(t, err, context.DeadlineExceeded, "expected context.DeadlineExceeded error, got %v", err)
   635  	})
   636  	t.Run("canceled context not marked as TransientTransactionError", func(t *testing.T) {
   637  		conn := new(mockConnection)
   638  		// Create a context and cancel it immediately.
   639  		ctx, cancel := context.WithCancel(context.Background())
   640  		cancel()
   641  
   642  		op := Operation{
   643  			Database:   "foobar",
   644  			Deployment: SingleConnectionDeployment{C: conn},
   645  			CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
   646  				dst = bsoncore.AppendInt32Element(dst, "ping", 1)
   647  				return dst, nil
   648  			},
   649  		}
   650  
   651  		err := op.Execute(ctx)
   652  		assert.NotNil(t, err, "expected an error from Execute(), got nil")
   653  		// Assert that error is just context canceled and is therefore not a driver.Error marked with
   654  		// the TransientTransactionError label.
   655  		assert.Equal(t, err, context.Canceled, "expected context.Canceled error, got %v", err)
   656  	})
   657  	t.Run("ErrDeadlineWouldBeExceeded wraps context.DeadlineExceeded", func(t *testing.T) {
   658  		// Create a deployment that returns a server that reports a 90th
   659  		// percentile RTT of 1 minute.
   660  		d := new(mockDeployment)
   661  		d.returns.server = mockServer{
   662  			conn:       new(mockConnection),
   663  			rttMonitor: mockRTTMonitor{p90: 1 * time.Minute},
   664  		}
   665  
   666  		// Create an operation with a Timeout specified to enable CSOT behavior.
   667  		var dur time.Duration
   668  		op := Operation{
   669  			Database:   "foobar",
   670  			Deployment: d,
   671  			CommandFn: func(dst []byte, desc description.SelectedServer) ([]byte, error) {
   672  				return dst, nil
   673  			},
   674  			Timeout: &dur,
   675  		}
   676  
   677  		// Call the operation with a context with a deadline less than the 90th
   678  		// percentile RTT configured above.
   679  		ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
   680  		defer cancel()
   681  		err := op.Execute(ctx)
   682  
   683  		assert.ErrorIs(t, err, ErrDeadlineWouldBeExceeded)
   684  		assert.ErrorIs(t, err, context.DeadlineExceeded)
   685  	})
   686  }
   687  
   688  func createExhaustServerResponse(response bsoncore.Document, moreToCome bool) []byte {
   689  	const psuedoRequestID = 1
   690  	idx, wm := wiremessage.AppendHeaderStart(nil, 0, psuedoRequestID, wiremessage.OpMsg)
   691  	var flags wiremessage.MsgFlag
   692  	if moreToCome {
   693  		flags = wiremessage.MoreToCome
   694  	}
   695  	wm = wiremessage.AppendMsgFlags(wm, flags)
   696  	wm = wiremessage.AppendMsgSectionType(wm, wiremessage.SingleDocument)
   697  	wm = bsoncore.AppendDocument(wm, response)
   698  	return bsoncore.UpdateLength(wm, idx, int32(len(wm)))
   699  }
   700  
   701  func assertExhaustAllowedSet(t *testing.T, wm []byte, expected bool) {
   702  	t.Helper()
   703  	_, _, _, _, wm, ok := wiremessage.ReadHeader(wm)
   704  	if !ok {
   705  		t.Fatal("could not read wm header")
   706  	}
   707  	flags, wm, ok := wiremessage.ReadMsgFlags(wm)
   708  	if !ok {
   709  		t.Fatal("could not read wm flags")
   710  	}
   711  
   712  	actual := flags&wiremessage.ExhaustAllowed > 0
   713  	assert.Equal(t, expected, actual, "expected exhaustAllowed set %v, got %v", expected, actual)
   714  }
   715  
   716  type mockDeployment struct {
   717  	params struct {
   718  		selector description.ServerSelector
   719  	}
   720  	returns struct {
   721  		server Server
   722  		err    error
   723  		retry  bool
   724  		kind   description.TopologyKind
   725  	}
   726  }
   727  
   728  func (m *mockDeployment) SelectServer(_ context.Context, desc description.ServerSelector) (Server, error) {
   729  	m.params.selector = desc
   730  	return m.returns.server, m.returns.err
   731  }
   732  
   733  func (m *mockDeployment) Kind() description.TopologyKind { return m.returns.kind }
   734  
   735  type mockServerSelector struct{}
   736  
   737  func (m *mockServerSelector) SelectServer(description.Topology, []description.Server) ([]description.Server, error) {
   738  	panic("not implemented")
   739  }
   740  
   741  func (m *mockServerSelector) String() string {
   742  	panic("not implemented")
   743  }
   744  
   745  type mockServer struct {
   746  	conn       Connection
   747  	err        error
   748  	rttMonitor RTTMonitor
   749  }
   750  
   751  func (ms mockServer) Connection(context.Context) (Connection, error) { return ms.conn, ms.err }
   752  func (ms mockServer) RTTMonitor() RTTMonitor                         { return ms.rttMonitor }
   753  
   754  type mockRTTMonitor struct {
   755  	ewma  time.Duration
   756  	min   time.Duration
   757  	p90   time.Duration
   758  	stats string
   759  }
   760  
   761  func (mrm mockRTTMonitor) EWMA() time.Duration { return mrm.ewma }
   762  func (mrm mockRTTMonitor) Min() time.Duration  { return mrm.min }
   763  func (mrm mockRTTMonitor) P90() time.Duration  { return mrm.p90 }
   764  func (mrm mockRTTMonitor) Stats() string       { return mrm.stats }
   765  
   766  type mockConnection struct {
   767  	// parameters
   768  	pWriteWM []byte
   769  
   770  	// returns
   771  	rWriteErr     error
   772  	rReadWM       []byte
   773  	rReadErr      error
   774  	rDesc         description.Server
   775  	rCloseErr     error
   776  	rID           string
   777  	rServerConnID *int64
   778  	rAddr         address.Address
   779  	rCanStream    bool
   780  	rStreaming    bool
   781  }
   782  
   783  func (m *mockConnection) Description() description.Server { return m.rDesc }
   784  func (m *mockConnection) Close() error                    { return m.rCloseErr }
   785  func (m *mockConnection) ID() string                      { return m.rID }
   786  func (m *mockConnection) ServerConnectionID() *int64      { return m.rServerConnID }
   787  func (m *mockConnection) Address() address.Address        { return m.rAddr }
   788  func (m *mockConnection) SupportsStreaming() bool         { return m.rCanStream }
   789  func (m *mockConnection) CurrentlyStreaming() bool        { return m.rStreaming }
   790  func (m *mockConnection) SetStreaming(streaming bool)     { m.rStreaming = streaming }
   791  func (m *mockConnection) Stale() bool                     { return false }
   792  
   793  // TODO:(GODRIVER-2824) replace return type with int64.
   794  func (m *mockConnection) DriverConnectionID() uint64 { return 0 }
   795  
   796  func (m *mockConnection) WriteWireMessage(_ context.Context, wm []byte) error {
   797  	m.pWriteWM = wm
   798  	return m.rWriteErr
   799  }
   800  
   801  func (m *mockConnection) ReadWireMessage(_ context.Context) ([]byte, error) {
   802  	return m.rReadWM, m.rReadErr
   803  }
   804  
   805  type retryableError struct {
   806  	error
   807  }
   808  
   809  func (retryableError) Retryable() bool { return true }
   810  
   811  var _ RetryablePoolError = retryableError{}
   812  
   813  // mockRetryServer is used to test retry of connection checkout. Returns a retryable error from
   814  // Connection().
   815  type mockRetryServer struct {
   816  	numCallsToConnection int
   817  }
   818  
   819  // Connection records the number of calls and returns retryable errors until the provided context
   820  // times out or is cancelled, then returns the context error.
   821  func (ms *mockRetryServer) Connection(ctx context.Context) (Connection, error) {
   822  	ms.numCallsToConnection++
   823  
   824  	if ctx.Err() != nil {
   825  		return nil, ctx.Err()
   826  	}
   827  
   828  	time.Sleep(1 * time.Millisecond)
   829  	return nil, retryableError{error: errors.New("test error")}
   830  }
   831  
   832  func (ms *mockRetryServer) RTTMonitor() RTTMonitor {
   833  	return &csot.ZeroRTTMonitor{}
   834  }
   835  
   836  func TestRetry(t *testing.T) {
   837  	t.Run("retries multiple times with RetryContext", func(t *testing.T) {
   838  		d := new(mockDeployment)
   839  		ms := new(mockRetryServer)
   840  		d.returns.server = ms
   841  
   842  		ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
   843  		defer cancel()
   844  
   845  		retry := RetryContext
   846  		err := Operation{
   847  			CommandFn:  func([]byte, description.SelectedServer) ([]byte, error) { return nil, nil },
   848  			Deployment: d,
   849  			Database:   "testing",
   850  			RetryMode:  &retry,
   851  			Type:       Read,
   852  		}.Execute(ctx)
   853  		assert.NotNil(t, err, "expected an error from Execute()")
   854  
   855  		// Expect Connection() to be called at least 3 times. The first call is the initial attempt
   856  		// to run the operation and the second is the retry. The third indicates that we retried
   857  		// more than once, which is the behavior we want to assert.
   858  		assert.True(t,
   859  			ms.numCallsToConnection >= 3,
   860  			"expected Connection() to be called at least 3 times")
   861  
   862  		deadline, _ := ctx.Deadline()
   863  		assert.True(t,
   864  			time.Now().After(deadline),
   865  			"expected operation to complete only after the context deadline is exceeded")
   866  	})
   867  }
   868  
   869  func TestConvertI64PtrToI32Ptr(t *testing.T) {
   870  	t.Parallel()
   871  
   872  	newI64 := func(i64 int64) *int64 { return &i64 }
   873  	newI32 := func(i32 int32) *int32 { return &i32 }
   874  
   875  	tests := []struct {
   876  		name string
   877  		i64  *int64
   878  		want *int32
   879  	}{
   880  		{
   881  			name: "empty",
   882  			want: nil,
   883  		},
   884  		{
   885  			name: "in bounds",
   886  			i64:  newI64(1),
   887  			want: newI32(1),
   888  		},
   889  		{
   890  			name: "out of bounds negative",
   891  			i64:  newI64(math.MinInt32 - 1),
   892  		},
   893  		{
   894  			name: "out of bounds positive",
   895  			i64:  newI64(math.MaxInt32 + 1),
   896  		},
   897  		{
   898  			name: "exact min int32",
   899  			i64:  newI64(math.MinInt32),
   900  			want: newI32(math.MinInt32),
   901  		},
   902  		{
   903  			name: "exact max int32",
   904  			i64:  newI64(math.MaxInt32),
   905  			want: newI32(math.MaxInt32),
   906  		},
   907  	}
   908  
   909  	for _, test := range tests {
   910  		test := test
   911  
   912  		t.Run(test.name, func(t *testing.T) {
   913  			t.Parallel()
   914  
   915  			got := convertInt64PtrToInt32Ptr(test.i64)
   916  			assert.Equal(t, test.want, got)
   917  		})
   918  	}
   919  }
   920  
   921  func TestDecodeOpReply(t *testing.T) {
   922  	t.Parallel()
   923  
   924  	// GODRIVER-2869: Prevent infinite loop caused by malformatted wiremessage with length of 0.
   925  	t.Run("malformatted wiremessage with length of 0", func(t *testing.T) {
   926  		t.Parallel()
   927  
   928  		var wm []byte
   929  		wm = wiremessage.AppendReplyFlags(wm, 0)
   930  		wm = wiremessage.AppendReplyCursorID(wm, int64(0))
   931  		wm = wiremessage.AppendReplyStartingFrom(wm, 0)
   932  		wm = wiremessage.AppendReplyNumberReturned(wm, 0)
   933  		idx, wm := bsoncore.ReserveLength(wm)
   934  		wm = bsoncore.UpdateLength(wm, idx, 0)
   935  		reply := Operation{}.decodeOpReply(wm)
   936  		assert.Equal(t, []bsoncore.Document(nil), reply.documents)
   937  	})
   938  }
   939  
   940  func TestFilterDeprioritizedServers(t *testing.T) {
   941  	t.Parallel()
   942  
   943  	tests := []struct {
   944  		name          string
   945  		deprioritized []description.Server
   946  		candidates    []description.Server
   947  		want          []description.Server
   948  	}{
   949  		{
   950  			name:       "empty",
   951  			candidates: []description.Server{},
   952  			want:       []description.Server{},
   953  		},
   954  		{
   955  			name:       "nil candidates",
   956  			candidates: nil,
   957  			want:       []description.Server{},
   958  		},
   959  		{
   960  			name: "nil deprioritized server list",
   961  			candidates: []description.Server{
   962  				{
   963  					Addr: address.Address("mongodb://localhost:27017"),
   964  				},
   965  			},
   966  			want: []description.Server{
   967  				{
   968  					Addr: address.Address("mongodb://localhost:27017"),
   969  				},
   970  			},
   971  		},
   972  		{
   973  			name: "deprioritize single server candidate list",
   974  			candidates: []description.Server{
   975  				{
   976  					Addr: address.Address("mongodb://localhost:27017"),
   977  				},
   978  			},
   979  			deprioritized: []description.Server{
   980  				{
   981  					Addr: address.Address("mongodb://localhost:27017"),
   982  				},
   983  			},
   984  			want: []description.Server{
   985  				// Since all available servers were deprioritized, then the selector
   986  				// should return all candidates.
   987  				{
   988  					Addr: address.Address("mongodb://localhost:27017"),
   989  				},
   990  			},
   991  		},
   992  		{
   993  			name: "depriotirize one server in multi server candidate list",
   994  			candidates: []description.Server{
   995  				{
   996  					Addr: address.Address("mongodb://localhost:27017"),
   997  				},
   998  				{
   999  					Addr: address.Address("mongodb://localhost:27018"),
  1000  				},
  1001  				{
  1002  					Addr: address.Address("mongodb://localhost:27019"),
  1003  				},
  1004  			},
  1005  			deprioritized: []description.Server{
  1006  				{
  1007  					Addr: address.Address("mongodb://localhost:27017"),
  1008  				},
  1009  			},
  1010  			want: []description.Server{
  1011  				{
  1012  					Addr: address.Address("mongodb://localhost:27018"),
  1013  				},
  1014  				{
  1015  					Addr: address.Address("mongodb://localhost:27019"),
  1016  				},
  1017  			},
  1018  		},
  1019  		{
  1020  			name: "depriotirize multiple servers in multi server candidate list",
  1021  			deprioritized: []description.Server{
  1022  				{
  1023  					Addr: address.Address("mongodb://localhost:27017"),
  1024  				},
  1025  				{
  1026  					Addr: address.Address("mongodb://localhost:27018"),
  1027  				},
  1028  			},
  1029  			candidates: []description.Server{
  1030  				{
  1031  					Addr: address.Address("mongodb://localhost:27017"),
  1032  				},
  1033  				{
  1034  					Addr: address.Address("mongodb://localhost:27018"),
  1035  				},
  1036  				{
  1037  					Addr: address.Address("mongodb://localhost:27019"),
  1038  				},
  1039  			},
  1040  			want: []description.Server{
  1041  				{
  1042  					Addr: address.Address("mongodb://localhost:27019"),
  1043  				},
  1044  			},
  1045  		},
  1046  	}
  1047  
  1048  	for _, tc := range tests {
  1049  		tc := tc // Capture the range variable.
  1050  
  1051  		t.Run(tc.name, func(t *testing.T) {
  1052  			t.Parallel()
  1053  
  1054  			got := filterDeprioritizedServers(tc.candidates, tc.deprioritized)
  1055  			assert.ElementsMatch(t, got, tc.want)
  1056  		})
  1057  	}
  1058  }
  1059  

View as plain text