1
2
3
4
5
6
7 package mongo
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "math"
14 "os"
15 "strconv"
16 "strings"
17 "testing"
18 "time"
19
20 "go.mongodb.org/mongo-driver/bson"
21 "go.mongodb.org/mongo-driver/event"
22 "go.mongodb.org/mongo-driver/internal/assert"
23 "go.mongodb.org/mongo-driver/internal/integtest"
24 "go.mongodb.org/mongo-driver/mongo/description"
25 "go.mongodb.org/mongo-driver/mongo/options"
26 "go.mongodb.org/mongo-driver/mongo/readpref"
27 "go.mongodb.org/mongo-driver/mongo/writeconcern"
28 "go.mongodb.org/mongo-driver/x/mongo/driver"
29 "go.mongodb.org/mongo-driver/x/mongo/driver/topology"
30 )
31
32 var (
33 connsCheckedOut int
34 errorInterrupted int32 = 11601
35 )
36
37 func TestConvenientTransactions(t *testing.T) {
38 if testing.Short() {
39 t.Skip("skipping integration test in short mode")
40 }
41 if os.Getenv("DOCKER_RUNNING") != "" {
42 t.Skip("skipping test in docker environment")
43 }
44
45 client := setupConvenientTransactions(t)
46 db := client.Database("TestConvenientTransactions")
47 dbAdmin := client.Database("admin")
48
49 defer func() {
50 sessions := client.NumberSessionsInProgress()
51 conns := connsCheckedOut
52
53 err := dbAdmin.RunCommand(bgCtx, bson.D{
54 {"killAllSessions", bson.A{}},
55 }).Err()
56 if err != nil {
57 if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
58 t.Fatalf("killAllSessions error: %v", err)
59 }
60 }
61
62 _ = db.Drop(bgCtx)
63 _ = client.Disconnect(bgCtx)
64
65 assert.Equal(t, 0, sessions, "%v sessions checked out", sessions)
66 assert.Equal(t, 0, conns, "%v connections checked out", conns)
67 }()
68
69 t.Run("callback raises custom error", func(t *testing.T) {
70 coll := db.Collection(t.Name())
71 _, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
72 assert.Nil(t, err, "InsertOne error: %v", err)
73
74 sess, err := client.StartSession()
75 assert.Nil(t, err, "StartSession error: %v", err)
76 defer sess.EndSession(context.Background())
77
78 testErr := errors.New("test error")
79 _, err = sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) {
80 return nil, testErr
81 })
82 assert.Equal(t, testErr, err, "expected error %v, got %v", testErr, err)
83 })
84 t.Run("callback returns value", func(t *testing.T) {
85 coll := db.Collection(t.Name())
86 _, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
87 assert.Nil(t, err, "InsertOne error: %v", err)
88
89 sess, err := client.StartSession()
90 assert.Nil(t, err, "StartSession error: %v", err)
91 defer sess.EndSession(context.Background())
92
93 res, err := sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) {
94 return false, nil
95 })
96 assert.Nil(t, err, "WithTransaction error: %v", err)
97 resBool, ok := res.(bool)
98 assert.True(t, ok, "expected result type %T, got %T", false, res)
99 assert.False(t, resBool, "expected result false, got %v", resBool)
100 })
101 t.Run("retry timeout enforced", func(t *testing.T) {
102 withTransactionTimeout = time.Second
103
104 coll := db.Collection(t.Name())
105 _, err := coll.InsertOne(bgCtx, bson.D{{"x", 1}})
106 assert.Nil(t, err, "InsertOne error: %v", err)
107
108 t.Run("transient transaction error", func(t *testing.T) {
109 sess, err := client.StartSession()
110 assert.Nil(t, err, "StartSession error: %v", err)
111 defer sess.EndSession(context.Background())
112
113 _, err = sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) {
114 return nil, CommandError{Name: "test Error", Labels: []string{driver.TransientTransactionError}}
115 })
116 assert.NotNil(t, err, "expected WithTransaction error, got nil")
117 cmdErr, ok := err.(CommandError)
118 assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
119 assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
120 "expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
121 })
122 t.Run("unknown transaction commit result", func(t *testing.T) {
123
124 failpoint := bson.D{{"configureFailPoint", "failCommand"},
125 {"mode", "alwaysOn"},
126 {"data", bson.D{
127 {"failCommands", bson.A{"commitTransaction"}},
128 {"closeConnection", true},
129 }},
130 }
131 err = dbAdmin.RunCommand(bgCtx, failpoint).Err()
132 assert.Nil(t, err, "error setting failpoint: %v", err)
133 defer func() {
134 err = dbAdmin.RunCommand(bgCtx, bson.D{
135 {"configureFailPoint", "failCommand"},
136 {"mode", "off"},
137 }).Err()
138 assert.Nil(t, err, "error turning off failpoint: %v", err)
139 }()
140
141 sess, err := client.StartSession()
142 assert.Nil(t, err, "StartSession error: %v", err)
143 defer sess.EndSession(context.Background())
144
145 _, err = sess.WithTransaction(context.Background(), func(ctx SessionContext) (interface{}, error) {
146 _, err := coll.InsertOne(ctx, bson.D{{"x", 1}})
147 return nil, err
148 })
149 assert.NotNil(t, err, "expected WithTransaction error, got nil")
150 cmdErr, ok := err.(CommandError)
151 assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
152 assert.True(t, cmdErr.HasErrorLabel(driver.UnknownTransactionCommitResult),
153 "expected error with label %v, got %v", driver.UnknownTransactionCommitResult, cmdErr)
154 })
155 t.Run("commit transient transaction error", func(t *testing.T) {
156
157 failpoint := bson.D{{"configureFailPoint", "failCommand"},
158 {"mode", "alwaysOn"},
159 {"data", bson.D{
160 {"failCommands", bson.A{"commitTransaction"}},
161 {"errorCode", 251},
162 }},
163 }
164 err = dbAdmin.RunCommand(bgCtx, failpoint).Err()
165 assert.Nil(t, err, "error setting failpoint: %v", err)
166 defer func() {
167 err = dbAdmin.RunCommand(bgCtx, bson.D{
168 {"configureFailPoint", "failCommand"},
169 {"mode", "off"},
170 }).Err()
171 assert.Nil(t, err, "error turning off failpoint: %v", err)
172 }()
173
174 sess, err := client.StartSession()
175 assert.Nil(t, err, "StartSession error: %v", err)
176 defer sess.EndSession(context.Background())
177
178 _, err = sess.WithTransaction(context.Background(), func(ctx SessionContext) (interface{}, error) {
179 _, err := coll.InsertOne(ctx, bson.D{{"x", 1}})
180 return nil, err
181 })
182 assert.NotNil(t, err, "expected WithTransaction error, got nil")
183 cmdErr, ok := err.(CommandError)
184 assert.True(t, ok, "expected error type %T, got %T", CommandError{}, err)
185 assert.True(t, cmdErr.HasErrorLabel(driver.TransientTransactionError),
186 "expected error with label %v, got %v", driver.TransientTransactionError, cmdErr)
187 })
188 })
189 t.Run("abortTransaction does not time out", func(t *testing.T) {
190
191
192 var abortStarted []*event.CommandStartedEvent
193 var abortSucceeded []*event.CommandSucceededEvent
194 var abortFailed []*event.CommandFailedEvent
195 var abortCtx context.Context
196 monitor := &event.CommandMonitor{
197 Started: func(ctx context.Context, evt *event.CommandStartedEvent) {
198 if evt.CommandName == "abortTransaction" {
199 abortStarted = append(abortStarted, evt)
200 if abortCtx == nil {
201 abortCtx = ctx
202 }
203 }
204 },
205 Succeeded: func(_ context.Context, evt *event.CommandSucceededEvent) {
206 if evt.CommandName == "abortTransaction" {
207 abortSucceeded = append(abortSucceeded, evt)
208 }
209 },
210 Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
211 if evt.CommandName == "abortTransaction" {
212 abortFailed = append(abortFailed, evt)
213 }
214 },
215 }
216
217
218
219
220 client := setupConvenientTransactions(t, options.Client().SetMonitor(monitor))
221 db := client.Database("foo")
222 coll := db.Collection("bar")
223 err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
224 assert.Nil(t, err, "error creating collection on server: %v\n", err)
225
226 sess, err := client.StartSession()
227 assert.Nil(t, err, "StartSession error: %v", err)
228 defer func() {
229 sess.EndSession(bgCtx)
230 _ = coll.Drop(bgCtx)
231 _ = client.Disconnect(bgCtx)
232 }()
233
234
235 type ctxKey struct{}
236 ctx, cancel := context.WithCancel(context.WithValue(context.Background(), ctxKey{}, "foobar"))
237 defer cancel()
238
239
240
241
242 callbackErr := errors.New("error")
243 callback := func(sc SessionContext) (interface{}, error) {
244 _, err = coll.InsertOne(sc, bson.D{{"x", 1}})
245 if err != nil {
246 return nil, err
247 }
248
249 cancel()
250 return nil, callbackErr
251 }
252
253 _, err = sess.WithTransaction(ctx, callback)
254 assert.Equal(t, callbackErr, err, "expected WithTransaction error %v, got %v", callbackErr, err)
255
256
257 assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted))
258 assert.Equal(t, 1, len(abortSucceeded), "expected 1 abortTransaction succeeded event, got %d",
259 len(abortSucceeded))
260 assert.Equal(t, 0, len(abortFailed), "expected 0 abortTransaction failed event, got %d", len(abortFailed))
261
262
263
264 ctxValue, ok := abortCtx.Value(ctxKey{}).(string)
265 assert.True(t, ok, "expected context for abortTransaction to contain ctxKey")
266 assert.Equal(t, "foobar", ctxValue, "expected value for ctxKey to be 'world', got %s", ctxValue)
267 })
268 t.Run("commitTransaction timeout allows abortTransaction", func(t *testing.T) {
269
270 var abortStarted []*event.CommandStartedEvent
271 var abortSucceeded []*event.CommandSucceededEvent
272 var abortFailed []*event.CommandFailedEvent
273 monitor := &event.CommandMonitor{
274 Started: func(ctx context.Context, evt *event.CommandStartedEvent) {
275 if evt.CommandName == "abortTransaction" {
276 abortStarted = append(abortStarted, evt)
277 }
278 },
279 Succeeded: func(_ context.Context, evt *event.CommandSucceededEvent) {
280 if evt.CommandName == "abortTransaction" {
281 abortSucceeded = append(abortSucceeded, evt)
282 }
283 },
284 Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
285 if evt.CommandName == "abortTransaction" {
286 abortFailed = append(abortFailed, evt)
287 }
288 },
289 }
290
291
292
293
294 client := setupConvenientTransactions(t, options.Client().SetMonitor(monitor))
295 db := client.Database("foo")
296 coll := db.Collection("test")
297 defer func() {
298 _ = coll.Drop(bgCtx)
299 }()
300
301 err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
302 assert.Nil(t, err, "error creating collection on server: %v", err)
303
304
305 session, err := client.StartSession()
306 defer session.EndSession(bgCtx)
307 assert.Nil(t, err, "StartSession error: %v", err)
308
309 _ = WithSession(bgCtx, session, func(sessionContext SessionContext) error {
310
311 err = session.StartTransaction()
312 assert.Nil(t, err, "StartTransaction error: %v", err)
313
314
315 _, err := coll.InsertOne(sessionContext, bson.D{{"val", 17}})
316 assert.Nil(t, err, "InsertOne error: %v", err)
317
318
319 commitTimeoutCtx, commitCancel := context.WithTimeout(sessionContext, 0)
320 defer commitCancel()
321
322
323 commitErr := session.CommitTransaction(commitTimeoutCtx)
324 assert.True(t, IsTimeout(commitErr),
325 "expected timeout error error; got %v", commitErr)
326
327
328 clientSession := session.(XSession).ClientSession()
329 assert.False(t, clientSession.TransactionCommitted(), "expected session state to not be Committed")
330
331
332 abortErr := session.AbortTransaction(context.Background())
333 assert.Nil(t, abortErr, "AbortTransaction error: %v", abortErr)
334
335
336 assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted))
337 assert.Equal(t, 1, len(abortSucceeded), "expected 1 abortTransaction succeeded event, got %d",
338 len(abortSucceeded))
339 assert.Equal(t, 0, len(abortFailed), "expected 0 abortTransaction failed events, got %d", len(abortFailed))
340
341 return nil
342 })
343 })
344 t.Run("context error before commitTransaction does not retry and aborts", func(t *testing.T) {
345 withTransactionTimeout = 2 * time.Second
346
347
348 var abortStarted []*event.CommandStartedEvent
349 var abortSucceeded []*event.CommandSucceededEvent
350 var abortFailed []*event.CommandFailedEvent
351 monitor := &event.CommandMonitor{
352 Started: func(ctx context.Context, evt *event.CommandStartedEvent) {
353 if evt.CommandName == "abortTransaction" {
354 abortStarted = append(abortStarted, evt)
355 }
356 },
357 Succeeded: func(_ context.Context, evt *event.CommandSucceededEvent) {
358 if evt.CommandName == "abortTransaction" {
359 abortSucceeded = append(abortSucceeded, evt)
360 }
361 },
362 Failed: func(_ context.Context, evt *event.CommandFailedEvent) {
363 if evt.CommandName == "abortTransaction" {
364 abortFailed = append(abortFailed, evt)
365 }
366 },
367 }
368
369
370
371
372 client := setupConvenientTransactions(t, options.Client().SetMonitor(monitor))
373 db := client.Database("foo")
374 coll := db.Collection("test")
375
376
377 err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
378 assert.Nil(t, err, "error creating collection on server: %v", err)
379 defer func() {
380 _ = coll.Drop(bgCtx)
381 }()
382
383
384 sess, err := client.StartSession()
385 assert.Nil(t, err, "StartSession error: %v", err)
386 defer sess.EndSession(context.Background())
387
388
389 defer func() {
390 err := dbAdmin.RunCommand(bgCtx, bson.D{
391 {"killAllSessions", bson.A{}},
392 }).Err()
393 if err != nil {
394 if ce, ok := err.(CommandError); !ok || ce.Code != errorInterrupted {
395 t.Fatalf("killAllSessions error: %v", err)
396 }
397 }
398 }()
399
400
401
402 callback := func(ctx context.Context) {
403 transactionCtx, cancel := context.WithCancel(ctx)
404
405 _, _ = sess.WithTransaction(transactionCtx, func(ctx SessionContext) (interface{}, error) {
406 _, err := coll.InsertOne(ctx, bson.M{"x": 1})
407 assert.Nil(t, err, "InsertOne error: %v", err)
408 cancel()
409 return nil, nil
410 })
411 }
412
413
414 assert.Soon(t, callback, 500*time.Millisecond)
415
416
417 assert.Equal(t, 1, len(abortStarted), "expected 1 abortTransaction started event, got %d", len(abortStarted))
418 assert.Equal(t, 1, len(abortSucceeded), "expected 1 abortTransaction succeeded event, got %d",
419 len(abortSucceeded))
420 assert.Equal(t, 0, len(abortFailed), "expected 0 abortTransaction failed events, got %d", len(abortFailed))
421 })
422 t.Run("wrapped transient transaction error retried", func(t *testing.T) {
423 sess, err := client.StartSession()
424 assert.Nil(t, err, "StartSession error: %v", err)
425 defer sess.EndSession(context.Background())
426
427
428 returnError := true
429 res, err := sess.WithTransaction(context.Background(), func(SessionContext) (interface{}, error) {
430 if returnError {
431 returnError = false
432 return nil, fmt.Errorf("%w",
433 CommandError{
434 Name: "test Error",
435 Labels: []string{driver.TransientTransactionError},
436 },
437 )
438 }
439 return false, nil
440 })
441 assert.Nil(t, err, "WithTransaction error: %v", err)
442 resBool, ok := res.(bool)
443 assert.True(t, ok, "expected result type %T, got %T", false, res)
444 assert.False(t, resBool, "expected result false, got %v", resBool)
445 })
446 t.Run("expired context before callback does not retry", func(t *testing.T) {
447 withTransactionTimeout = 2 * time.Second
448
449 coll := db.Collection("test")
450
451
452 err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
453 assert.Nil(t, err, "error creating collection on server: %v", err)
454 defer func() {
455 _ = coll.Drop(bgCtx)
456 }()
457
458 sess, err := client.StartSession()
459 assert.Nil(t, err, "StartSession error: %v", err)
460 defer sess.EndSession(context.Background())
461
462 callback := func(ctx context.Context) {
463
464 withTransactionContext, cancel := context.WithTimeout(ctx, time.Nanosecond)
465 defer cancel()
466
467 _, _ = sess.WithTransaction(withTransactionContext, func(ctx SessionContext) (interface{}, error) {
468 _, err := coll.InsertOne(ctx, bson.D{{}})
469 return nil, err
470 })
471 }
472
473
474 assert.Soon(t, callback, 500*time.Millisecond)
475 })
476 t.Run("canceled context before callback does not retry", func(t *testing.T) {
477 withTransactionTimeout = 2 * time.Second
478
479 coll := db.Collection("test")
480
481
482 err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
483 assert.Nil(t, err, "error creating collection on server: %v", err)
484 defer func() {
485 _ = coll.Drop(bgCtx)
486 }()
487
488 sess, err := client.StartSession()
489 assert.Nil(t, err, "StartSession error: %v", err)
490 defer sess.EndSession(context.Background())
491
492 callback := func(ctx context.Context) {
493
494 withTransactionContext, cancel := context.WithTimeout(ctx, 2*time.Second)
495 cancel()
496
497 _, _ = sess.WithTransaction(withTransactionContext, func(ctx SessionContext) (interface{}, error) {
498 _, err := coll.InsertOne(ctx, bson.D{{}})
499 return nil, err
500 })
501 }
502
503
504 assert.Soon(t, callback, 500*time.Millisecond)
505 })
506 t.Run("slow operation in callback retries", func(t *testing.T) {
507 withTransactionTimeout = 2 * time.Second
508
509 coll := db.Collection("test")
510
511
512 err := db.RunCommand(bgCtx, bson.D{{"create", coll.Name()}}).Err()
513 assert.Nil(t, err, "error creating collection on server: %v", err)
514 defer func() {
515 _ = coll.Drop(bgCtx)
516 }()
517
518
519 failpoint := bson.D{{"configureFailPoint", "failCommand"},
520 {"mode", bson.D{
521 {"times", 1},
522 }},
523 {"data", bson.D{
524 {"failCommands", bson.A{"insert"}},
525 {"blockConnection", true},
526 {"blockTimeMS", 500},
527 }},
528 }
529 err = dbAdmin.RunCommand(bgCtx, failpoint).Err()
530 assert.Nil(t, err, "error setting failpoint: %v", err)
531 defer func() {
532 err = dbAdmin.RunCommand(bgCtx, bson.D{
533 {"configureFailPoint", "failCommand"},
534 {"mode", "off"},
535 }).Err()
536 assert.Nil(t, err, "error turning off failpoint: %v", err)
537 }()
538
539 sess, err := client.StartSession()
540 assert.Nil(t, err, "StartSession error: %v", err)
541 defer sess.EndSession(context.Background())
542
543 callback := func(ctx context.Context) {
544 _, err = sess.WithTransaction(ctx, func(ctx SessionContext) (interface{}, error) {
545
546
547 c, cancel := context.WithTimeout(ctx, 300*time.Millisecond)
548 defer cancel()
549
550 _, err := coll.InsertOne(c, bson.D{{}})
551 return nil, err
552 })
553 assert.Nil(t, err, "WithTransaction error: %v", err)
554 }
555
556
557 assert.Soon(t, callback, 2*time.Second)
558 })
559 }
560
561 func setupConvenientTransactions(t *testing.T, extraClientOpts ...*options.ClientOptions) *Client {
562 cs := integtest.ConnString(t)
563 poolMonitor := &event.PoolMonitor{
564 Event: func(evt *event.PoolEvent) {
565 switch evt.Type {
566 case event.GetSucceeded:
567 connsCheckedOut++
568 case event.ConnectionReturned:
569 connsCheckedOut--
570 }
571 },
572 }
573
574 baseClientOpts := options.Client().
575 ApplyURI(cs.Original).
576 SetReadPreference(readpref.Primary()).
577 SetWriteConcern(writeconcern.New(writeconcern.WMajority())).
578 SetPoolMonitor(poolMonitor)
579 integtest.AddTestServerAPIVersion(baseClientOpts)
580 fullClientOpts := []*options.ClientOptions{baseClientOpts}
581 fullClientOpts = append(fullClientOpts, extraClientOpts...)
582
583 client, err := Connect(bgCtx, fullClientOpts...)
584 assert.Nil(t, err, "Connect error: %v", err)
585
586 version, err := getServerVersion(client.Database("admin"))
587 assert.Nil(t, err, "getServerVersion error: %v", err)
588 topoKind := client.deployment.(*topology.Topology).Kind()
589 if compareVersions(version, "4.1") < 0 || topoKind == description.Single {
590 t.Skip("skipping standalones and versions < 4.1")
591 }
592
593 if topoKind != description.Sharded {
594 return client
595 }
596
597
598 _ = client.Disconnect(bgCtx)
599 fullClientOpts = append(fullClientOpts, options.Client().SetHosts([]string{cs.Hosts[0]}))
600 client, err = Connect(bgCtx, fullClientOpts...)
601 assert.Nil(t, err, "Connect error: %v", err)
602 return client
603 }
604
605 func getServerVersion(db *Database) (string, error) {
606 serverStatus, err := db.RunCommand(
607 context.Background(),
608 bson.D{{"serverStatus", 1}},
609 ).Raw()
610 if err != nil {
611 return "", err
612 }
613
614 version, err := serverStatus.LookupErr("version")
615 if err != nil {
616 return "", err
617 }
618
619 return version.StringValue(), nil
620 }
621
622
623
624
625
626
627
628 func compareVersions(v1 string, v2 string) int {
629 n1 := strings.Split(v1, ".")
630 n2 := strings.Split(v2, ".")
631
632 for i := 0; i < int(math.Min(float64(len(n1)), float64(len(n2)))); i++ {
633 i1, err := strconv.Atoi(n1[i])
634 if err != nil {
635 return 1
636 }
637
638 i2, err := strconv.Atoi(n2[i])
639 if err != nil {
640 return -1
641 }
642
643 difference := i1 - i2
644 if difference != 0 {
645 return difference
646 }
647 }
648
649 return 0
650 }
651
View as plain text