...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/session/client_session_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/session

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package session
     8  
     9  import (
    10  	"bytes"
    11  	"errors"
    12  	"testing"
    13  
    14  	"go.mongodb.org/mongo-driver/bson/primitive"
    15  	"go.mongodb.org/mongo-driver/internal/assert"
    16  	"go.mongodb.org/mongo-driver/internal/require"
    17  	"go.mongodb.org/mongo-driver/internal/uuid"
    18  	"go.mongodb.org/mongo-driver/mongo/description"
    19  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    20  )
    21  
    22  var consistent = true
    23  var sessionOpts = &ClientOptions{
    24  	CausalConsistency: &consistent,
    25  }
    26  
    27  func compareOperationTimes(t *testing.T, expected *primitive.Timestamp, actual *primitive.Timestamp) {
    28  	if expected.T != actual.T {
    29  		t.Fatalf("T value mismatch; expected %d got %d", expected.T, actual.T)
    30  	}
    31  
    32  	if expected.I != actual.I {
    33  		t.Fatalf("I value mismatch; expected %d got %d", expected.I, actual.I)
    34  	}
    35  }
    36  
    37  func TestClientSession(t *testing.T) {
    38  	var clusterTime1 = bsoncore.BuildDocument(nil, bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocument(nil, bsoncore.AppendTimestampElement(nil, "clusterTime", 10, 5))))
    39  	var clusterTime2 = bsoncore.BuildDocument(nil, bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocument(nil, bsoncore.AppendTimestampElement(nil, "clusterTime", 5, 5))))
    40  	var clusterTime3 = bsoncore.BuildDocument(nil, bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocument(nil, bsoncore.AppendTimestampElement(nil, "clusterTime", 5, 0))))
    41  
    42  	t.Run("TestMaxClusterTime", func(t *testing.T) {
    43  		maxTime := MaxClusterTime(clusterTime1, clusterTime2)
    44  		if !bytes.Equal(maxTime, clusterTime1) {
    45  			t.Errorf("Wrong max time")
    46  		}
    47  
    48  		maxTime = MaxClusterTime(clusterTime3, clusterTime2)
    49  		if !bytes.Equal(maxTime, clusterTime2) {
    50  			t.Errorf("Wrong max time")
    51  		}
    52  	})
    53  
    54  	t.Run("TestAdvanceClusterTime", func(t *testing.T) {
    55  		id, _ := uuid.New()
    56  		sess, err := NewClientSession(&Pool{}, id, sessionOpts)
    57  		require.Nil(t, err, "Unexpected error")
    58  		err = sess.AdvanceClusterTime(clusterTime2)
    59  		require.Nil(t, err, "Unexpected error")
    60  		if !bytes.Equal(sess.ClusterTime, clusterTime2) {
    61  			t.Errorf("Session cluster time incorrect, expected %v, received %v", clusterTime2, sess.ClusterTime)
    62  		}
    63  		err = sess.AdvanceClusterTime(clusterTime3)
    64  		require.Nil(t, err, "Unexpected error")
    65  		if !bytes.Equal(sess.ClusterTime, clusterTime2) {
    66  			t.Errorf("Session cluster time incorrect, expected %v, received %v", clusterTime2, sess.ClusterTime)
    67  		}
    68  		err = sess.AdvanceClusterTime(clusterTime1)
    69  		require.Nil(t, err, "Unexpected error")
    70  		if !bytes.Equal(sess.ClusterTime, clusterTime1) {
    71  			t.Errorf("Session cluster time incorrect, expected %v, received %v", clusterTime1, sess.ClusterTime)
    72  		}
    73  		sess.EndSession()
    74  	})
    75  
    76  	t.Run("TestEndSession", func(t *testing.T) {
    77  		id, _ := uuid.New()
    78  		sess, err := NewClientSession(&Pool{}, id, sessionOpts)
    79  		require.Nil(t, err, "Unexpected error")
    80  		sess.EndSession()
    81  		err = sess.UpdateUseTime()
    82  		require.NotNil(t, err, "Expected error, received nil")
    83  	})
    84  
    85  	t.Run("TestAdvanceOperationTime", func(t *testing.T) {
    86  		id, _ := uuid.New()
    87  		sess, err := NewClientSession(&Pool{}, id, sessionOpts)
    88  		require.Nil(t, err, "Unexpected error")
    89  
    90  		optime1 := &primitive.Timestamp{
    91  			T: 1,
    92  			I: 0,
    93  		}
    94  		err = sess.AdvanceOperationTime(optime1)
    95  		assert.Nil(t, err, "error updating first operation time: %s", err)
    96  		compareOperationTimes(t, optime1, sess.OperationTime)
    97  
    98  		optime2 := &primitive.Timestamp{
    99  			T: 2,
   100  			I: 0,
   101  		}
   102  		err = sess.AdvanceOperationTime(optime2)
   103  		assert.Nil(t, err, "error updating second operation time: %s", err)
   104  		compareOperationTimes(t, optime2, sess.OperationTime)
   105  
   106  		optime3 := &primitive.Timestamp{
   107  			T: 2,
   108  			I: 1,
   109  		}
   110  		err = sess.AdvanceOperationTime(optime3)
   111  		assert.Nil(t, err, "error updating third operation time: %s", err)
   112  		compareOperationTimes(t, optime3, sess.OperationTime)
   113  
   114  		err = sess.AdvanceOperationTime(&primitive.Timestamp{
   115  			T: 1,
   116  			I: 10,
   117  		})
   118  		assert.Nil(t, err, "error updating fourth operation time: %s", err)
   119  		compareOperationTimes(t, optime3, sess.OperationTime)
   120  		sess.EndSession()
   121  	})
   122  
   123  	t.Run("TestTransactionState", func(t *testing.T) {
   124  		id, _ := uuid.New()
   125  		sess, err := NewClientSession(&Pool{}, id, nil)
   126  		require.Nil(t, err, "Unexpected error")
   127  
   128  		err = sess.CommitTransaction()
   129  		if !errors.Is(err, ErrNoTransactStarted) {
   130  			t.Errorf("expected error, got %v", err)
   131  		}
   132  
   133  		err = sess.AbortTransaction()
   134  		if !errors.Is(err, ErrNoTransactStarted) {
   135  			t.Errorf("expected error, got %v", err)
   136  		}
   137  
   138  		if sess.TransactionState != None {
   139  			t.Errorf("incorrect session state, expected None, received %v", sess.TransactionState)
   140  		}
   141  
   142  		err = sess.StartTransaction(nil)
   143  		require.Nil(t, err, "error starting transaction: %s", err)
   144  		if sess.TransactionState != Starting {
   145  			t.Errorf("incorrect session state, expected Starting, received %v", sess.TransactionState)
   146  		}
   147  
   148  		err = sess.StartTransaction(nil)
   149  		if !errors.Is(err, ErrTransactInProgress) {
   150  			t.Errorf("expected error, got %v", err)
   151  		}
   152  
   153  		err = sess.ApplyCommand(description.Server{Kind: description.Standalone})
   154  		assert.Nil(t, err, "ApplyCommand error: %v", err)
   155  		if sess.TransactionState != InProgress {
   156  			t.Errorf("incorrect session state, expected InProgress, received %v", sess.TransactionState)
   157  		}
   158  
   159  		err = sess.StartTransaction(nil)
   160  		if !errors.Is(err, ErrTransactInProgress) {
   161  			t.Errorf("expected error, got %v", err)
   162  		}
   163  
   164  		err = sess.CommitTransaction()
   165  		require.Nil(t, err, "error committing transaction: %s", err)
   166  		if sess.TransactionState != Committed {
   167  			t.Errorf("incorrect session state, expected Committed, received %v", sess.TransactionState)
   168  		}
   169  
   170  		err = sess.AbortTransaction()
   171  		if !errors.Is(err, ErrAbortAfterCommit) {
   172  			t.Errorf("expected error, got %v", err)
   173  		}
   174  
   175  		err = sess.StartTransaction(nil)
   176  		require.Nil(t, err, "error starting transaction: %s", err)
   177  		if sess.TransactionState != Starting {
   178  			t.Errorf("incorrect session state, expected Starting, received %v", sess.TransactionState)
   179  		}
   180  
   181  		err = sess.AbortTransaction()
   182  		require.Nil(t, err, "error aborting transaction: %s", err)
   183  		if sess.TransactionState != Aborted {
   184  			t.Errorf("incorrect session state, expected Aborted, received %v", sess.TransactionState)
   185  		}
   186  
   187  		err = sess.AbortTransaction()
   188  		if !errors.Is(err, ErrAbortTwice) {
   189  			t.Errorf("expected error, got %v", err)
   190  		}
   191  
   192  		err = sess.CommitTransaction()
   193  		if !errors.Is(err, ErrCommitAfterAbort) {
   194  			t.Errorf("expected error, got %v", err)
   195  		}
   196  	})
   197  
   198  	t.Run("causal consistency and snapshot", func(t *testing.T) {
   199  		falseVal := false
   200  		trueVal := true
   201  
   202  		// A test for Consistent and Snapshot both being true and causing an error can be found
   203  		// in TestSessionsProse.
   204  		testCases := []struct {
   205  			description        string
   206  			consistent         *bool
   207  			snapshot           *bool
   208  			expectedConsistent bool
   209  			expectedSnapshot   bool
   210  		}{
   211  			{
   212  				"both unset",
   213  				nil,
   214  				nil,
   215  				true,
   216  				false,
   217  			},
   218  			{
   219  				"both false",
   220  				&falseVal,
   221  				&falseVal,
   222  				false,
   223  				false,
   224  			},
   225  			{
   226  				"cc unset snapshot true",
   227  				nil,
   228  				&trueVal,
   229  				false,
   230  				true,
   231  			},
   232  			{
   233  				"cc unset snapshot false",
   234  				nil,
   235  				&falseVal,
   236  				true,
   237  				false,
   238  			},
   239  			{
   240  				"cc true snapshot unset",
   241  				&trueVal,
   242  				nil,
   243  				true,
   244  				false,
   245  			},
   246  			{
   247  				"cc false snapshot unset",
   248  				&falseVal,
   249  				nil,
   250  				false,
   251  				false,
   252  			},
   253  			{
   254  				"cc false snapshot true",
   255  				&falseVal,
   256  				&trueVal,
   257  				false,
   258  				true,
   259  			},
   260  			{
   261  				"cc true snapshot false",
   262  				&trueVal,
   263  				&falseVal,
   264  				true,
   265  				false,
   266  			},
   267  		}
   268  
   269  		for _, tc := range testCases {
   270  			t.Run(tc.description, func(t *testing.T) {
   271  				sessOpts := &ClientOptions{
   272  					CausalConsistency: tc.consistent,
   273  					Snapshot:          tc.snapshot,
   274  				}
   275  
   276  				id, _ := uuid.New()
   277  				sess, err := NewClientSession(&Pool{}, id, sessOpts)
   278  				require.Nil(t, err, "unexpected NewClientSession error %v", err)
   279  
   280  				require.Equal(t, tc.expectedConsistent, sess.Consistent,
   281  					"expected Consistent to be %v, got %v", tc.expectedConsistent, sess.Consistent)
   282  				require.Equal(t, tc.expectedSnapshot, sess.Snapshot,
   283  					"expected Snapshot to be %v, got %v", tc.expectedSnapshot, sess.Snapshot)
   284  			})
   285  		}
   286  	})
   287  }
   288  
   289  func TestImplicitClientSession(t *testing.T) {
   290  	t.Parallel()
   291  
   292  	t.Run("causal consistency is false", func(t *testing.T) {
   293  		t.Parallel()
   294  
   295  		id, err := uuid.New()
   296  		require.NoError(t, err)
   297  
   298  		c := NewImplicitClientSession(&Pool{}, id)
   299  		assert.False(t, c.Consistent, "expected causal consistency to be false for implicit sessions")
   300  	})
   301  }
   302  

View as plain text