1
2
3
4
5
6
7 package session
8
9 import (
10 "bytes"
11 "errors"
12 "testing"
13
14 "go.mongodb.org/mongo-driver/bson/primitive"
15 "go.mongodb.org/mongo-driver/internal/assert"
16 "go.mongodb.org/mongo-driver/internal/require"
17 "go.mongodb.org/mongo-driver/internal/uuid"
18 "go.mongodb.org/mongo-driver/mongo/description"
19 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
20 )
21
22 var consistent = true
23 var sessionOpts = &ClientOptions{
24 CausalConsistency: &consistent,
25 }
26
27 func compareOperationTimes(t *testing.T, expected *primitive.Timestamp, actual *primitive.Timestamp) {
28 if expected.T != actual.T {
29 t.Fatalf("T value mismatch; expected %d got %d", expected.T, actual.T)
30 }
31
32 if expected.I != actual.I {
33 t.Fatalf("I value mismatch; expected %d got %d", expected.I, actual.I)
34 }
35 }
36
37 func TestClientSession(t *testing.T) {
38 var clusterTime1 = bsoncore.BuildDocument(nil, bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocument(nil, bsoncore.AppendTimestampElement(nil, "clusterTime", 10, 5))))
39 var clusterTime2 = bsoncore.BuildDocument(nil, bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocument(nil, bsoncore.AppendTimestampElement(nil, "clusterTime", 5, 5))))
40 var clusterTime3 = bsoncore.BuildDocument(nil, bsoncore.AppendDocumentElement(nil, "$clusterTime", bsoncore.BuildDocument(nil, bsoncore.AppendTimestampElement(nil, "clusterTime", 5, 0))))
41
42 t.Run("TestMaxClusterTime", func(t *testing.T) {
43 maxTime := MaxClusterTime(clusterTime1, clusterTime2)
44 if !bytes.Equal(maxTime, clusterTime1) {
45 t.Errorf("Wrong max time")
46 }
47
48 maxTime = MaxClusterTime(clusterTime3, clusterTime2)
49 if !bytes.Equal(maxTime, clusterTime2) {
50 t.Errorf("Wrong max time")
51 }
52 })
53
54 t.Run("TestAdvanceClusterTime", func(t *testing.T) {
55 id, _ := uuid.New()
56 sess, err := NewClientSession(&Pool{}, id, sessionOpts)
57 require.Nil(t, err, "Unexpected error")
58 err = sess.AdvanceClusterTime(clusterTime2)
59 require.Nil(t, err, "Unexpected error")
60 if !bytes.Equal(sess.ClusterTime, clusterTime2) {
61 t.Errorf("Session cluster time incorrect, expected %v, received %v", clusterTime2, sess.ClusterTime)
62 }
63 err = sess.AdvanceClusterTime(clusterTime3)
64 require.Nil(t, err, "Unexpected error")
65 if !bytes.Equal(sess.ClusterTime, clusterTime2) {
66 t.Errorf("Session cluster time incorrect, expected %v, received %v", clusterTime2, sess.ClusterTime)
67 }
68 err = sess.AdvanceClusterTime(clusterTime1)
69 require.Nil(t, err, "Unexpected error")
70 if !bytes.Equal(sess.ClusterTime, clusterTime1) {
71 t.Errorf("Session cluster time incorrect, expected %v, received %v", clusterTime1, sess.ClusterTime)
72 }
73 sess.EndSession()
74 })
75
76 t.Run("TestEndSession", func(t *testing.T) {
77 id, _ := uuid.New()
78 sess, err := NewClientSession(&Pool{}, id, sessionOpts)
79 require.Nil(t, err, "Unexpected error")
80 sess.EndSession()
81 err = sess.UpdateUseTime()
82 require.NotNil(t, err, "Expected error, received nil")
83 })
84
85 t.Run("TestAdvanceOperationTime", func(t *testing.T) {
86 id, _ := uuid.New()
87 sess, err := NewClientSession(&Pool{}, id, sessionOpts)
88 require.Nil(t, err, "Unexpected error")
89
90 optime1 := &primitive.Timestamp{
91 T: 1,
92 I: 0,
93 }
94 err = sess.AdvanceOperationTime(optime1)
95 assert.Nil(t, err, "error updating first operation time: %s", err)
96 compareOperationTimes(t, optime1, sess.OperationTime)
97
98 optime2 := &primitive.Timestamp{
99 T: 2,
100 I: 0,
101 }
102 err = sess.AdvanceOperationTime(optime2)
103 assert.Nil(t, err, "error updating second operation time: %s", err)
104 compareOperationTimes(t, optime2, sess.OperationTime)
105
106 optime3 := &primitive.Timestamp{
107 T: 2,
108 I: 1,
109 }
110 err = sess.AdvanceOperationTime(optime3)
111 assert.Nil(t, err, "error updating third operation time: %s", err)
112 compareOperationTimes(t, optime3, sess.OperationTime)
113
114 err = sess.AdvanceOperationTime(&primitive.Timestamp{
115 T: 1,
116 I: 10,
117 })
118 assert.Nil(t, err, "error updating fourth operation time: %s", err)
119 compareOperationTimes(t, optime3, sess.OperationTime)
120 sess.EndSession()
121 })
122
123 t.Run("TestTransactionState", func(t *testing.T) {
124 id, _ := uuid.New()
125 sess, err := NewClientSession(&Pool{}, id, nil)
126 require.Nil(t, err, "Unexpected error")
127
128 err = sess.CommitTransaction()
129 if !errors.Is(err, ErrNoTransactStarted) {
130 t.Errorf("expected error, got %v", err)
131 }
132
133 err = sess.AbortTransaction()
134 if !errors.Is(err, ErrNoTransactStarted) {
135 t.Errorf("expected error, got %v", err)
136 }
137
138 if sess.TransactionState != None {
139 t.Errorf("incorrect session state, expected None, received %v", sess.TransactionState)
140 }
141
142 err = sess.StartTransaction(nil)
143 require.Nil(t, err, "error starting transaction: %s", err)
144 if sess.TransactionState != Starting {
145 t.Errorf("incorrect session state, expected Starting, received %v", sess.TransactionState)
146 }
147
148 err = sess.StartTransaction(nil)
149 if !errors.Is(err, ErrTransactInProgress) {
150 t.Errorf("expected error, got %v", err)
151 }
152
153 err = sess.ApplyCommand(description.Server{Kind: description.Standalone})
154 assert.Nil(t, err, "ApplyCommand error: %v", err)
155 if sess.TransactionState != InProgress {
156 t.Errorf("incorrect session state, expected InProgress, received %v", sess.TransactionState)
157 }
158
159 err = sess.StartTransaction(nil)
160 if !errors.Is(err, ErrTransactInProgress) {
161 t.Errorf("expected error, got %v", err)
162 }
163
164 err = sess.CommitTransaction()
165 require.Nil(t, err, "error committing transaction: %s", err)
166 if sess.TransactionState != Committed {
167 t.Errorf("incorrect session state, expected Committed, received %v", sess.TransactionState)
168 }
169
170 err = sess.AbortTransaction()
171 if !errors.Is(err, ErrAbortAfterCommit) {
172 t.Errorf("expected error, got %v", err)
173 }
174
175 err = sess.StartTransaction(nil)
176 require.Nil(t, err, "error starting transaction: %s", err)
177 if sess.TransactionState != Starting {
178 t.Errorf("incorrect session state, expected Starting, received %v", sess.TransactionState)
179 }
180
181 err = sess.AbortTransaction()
182 require.Nil(t, err, "error aborting transaction: %s", err)
183 if sess.TransactionState != Aborted {
184 t.Errorf("incorrect session state, expected Aborted, received %v", sess.TransactionState)
185 }
186
187 err = sess.AbortTransaction()
188 if !errors.Is(err, ErrAbortTwice) {
189 t.Errorf("expected error, got %v", err)
190 }
191
192 err = sess.CommitTransaction()
193 if !errors.Is(err, ErrCommitAfterAbort) {
194 t.Errorf("expected error, got %v", err)
195 }
196 })
197
198 t.Run("causal consistency and snapshot", func(t *testing.T) {
199 falseVal := false
200 trueVal := true
201
202
203
204 testCases := []struct {
205 description string
206 consistent *bool
207 snapshot *bool
208 expectedConsistent bool
209 expectedSnapshot bool
210 }{
211 {
212 "both unset",
213 nil,
214 nil,
215 true,
216 false,
217 },
218 {
219 "both false",
220 &falseVal,
221 &falseVal,
222 false,
223 false,
224 },
225 {
226 "cc unset snapshot true",
227 nil,
228 &trueVal,
229 false,
230 true,
231 },
232 {
233 "cc unset snapshot false",
234 nil,
235 &falseVal,
236 true,
237 false,
238 },
239 {
240 "cc true snapshot unset",
241 &trueVal,
242 nil,
243 true,
244 false,
245 },
246 {
247 "cc false snapshot unset",
248 &falseVal,
249 nil,
250 false,
251 false,
252 },
253 {
254 "cc false snapshot true",
255 &falseVal,
256 &trueVal,
257 false,
258 true,
259 },
260 {
261 "cc true snapshot false",
262 &trueVal,
263 &falseVal,
264 true,
265 false,
266 },
267 }
268
269 for _, tc := range testCases {
270 t.Run(tc.description, func(t *testing.T) {
271 sessOpts := &ClientOptions{
272 CausalConsistency: tc.consistent,
273 Snapshot: tc.snapshot,
274 }
275
276 id, _ := uuid.New()
277 sess, err := NewClientSession(&Pool{}, id, sessOpts)
278 require.Nil(t, err, "unexpected NewClientSession error %v", err)
279
280 require.Equal(t, tc.expectedConsistent, sess.Consistent,
281 "expected Consistent to be %v, got %v", tc.expectedConsistent, sess.Consistent)
282 require.Equal(t, tc.expectedSnapshot, sess.Snapshot,
283 "expected Snapshot to be %v, got %v", tc.expectedSnapshot, sess.Snapshot)
284 })
285 }
286 })
287 }
288
289 func TestImplicitClientSession(t *testing.T) {
290 t.Parallel()
291
292 t.Run("causal consistency is false", func(t *testing.T) {
293 t.Parallel()
294
295 id, err := uuid.New()
296 require.NoError(t, err)
297
298 c := NewImplicitClientSession(&Pool{}, id)
299 assert.False(t, c.Consistent, "expected causal consistency to be false for implicit sessions")
300 })
301 }
302
View as plain text