...
1
2
3
4
5
6
7 package session
8
9 import (
10 "context"
11 "errors"
12 "time"
13
14 "go.mongodb.org/mongo-driver/bson"
15 "go.mongodb.org/mongo-driver/bson/primitive"
16 "go.mongodb.org/mongo-driver/internal/uuid"
17 "go.mongodb.org/mongo-driver/mongo/address"
18 "go.mongodb.org/mongo-driver/mongo/description"
19 "go.mongodb.org/mongo-driver/mongo/readconcern"
20 "go.mongodb.org/mongo-driver/mongo/readpref"
21 "go.mongodb.org/mongo-driver/mongo/writeconcern"
22 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
23 )
24
25
26 var ErrSessionEnded = errors.New("ended session was used")
27
28
29 var ErrNoTransactStarted = errors.New("no transaction started")
30
31
32 var ErrTransactInProgress = errors.New("transaction already in progress")
33
34
35 var ErrAbortAfterCommit = errors.New("cannot call abortTransaction after calling commitTransaction")
36
37
38 var ErrAbortTwice = errors.New("cannot call abortTransaction twice")
39
40
41 var ErrCommitAfterAbort = errors.New("cannot call commitTransaction after calling abortTransaction")
42
43
44 var ErrUnackWCUnsupported = errors.New("transactions do not support unacknowledged write concerns")
45
46
47 var ErrSnapshotTransaction = errors.New("transactions are not supported in snapshot sessions")
48
49
50 type TransactionState uint8
51
52
53 const (
54 None TransactionState = iota
55 Starting
56 InProgress
57 Committed
58 Aborted
59 )
60
61
62 func (s TransactionState) String() string {
63 switch s {
64 case None:
65 return "none"
66 case Starting:
67 return "starting"
68 case InProgress:
69 return "in progress"
70 case Committed:
71 return "committed"
72 case Aborted:
73 return "aborted"
74 default:
75 return "unknown"
76 }
77 }
78
79
80
81
82 type LoadBalancedTransactionConnection interface {
83
84 WriteWireMessage(context.Context, []byte) error
85 ReadWireMessage(ctx context.Context) ([]byte, error)
86 Description() description.Server
87 Close() error
88 ID() string
89 ServerConnectionID() *int64
90 DriverConnectionID() uint64
91 Address() address.Address
92 Stale() bool
93
94
95 PinToCursor() error
96 PinToTransaction() error
97 UnpinFromCursor() error
98 UnpinFromTransaction() error
99 }
100
101
102 type Client struct {
103 *Server
104 ClientID uuid.UUID
105 ClusterTime bson.Raw
106 Consistent bool
107 OperationTime *primitive.Timestamp
108 IsImplicit bool
109 Terminated bool
110 RetryingCommit bool
111 Committing bool
112 Aborting bool
113 RetryWrite bool
114 RetryRead bool
115 Snapshot bool
116
117
118
119 CurrentRc *readconcern.ReadConcern
120 CurrentRp *readpref.ReadPref
121 CurrentWc *writeconcern.WriteConcern
122 CurrentMct *time.Duration
123
124
125 transactionRc *readconcern.ReadConcern
126 transactionRp *readpref.ReadPref
127 transactionWc *writeconcern.WriteConcern
128 transactionMaxCommitTime *time.Duration
129
130 pool *Pool
131 TransactionState TransactionState
132 PinnedServer *description.Server
133 RecoveryToken bson.Raw
134 PinnedConnection LoadBalancedTransactionConnection
135 SnapshotTime *primitive.Timestamp
136 }
137
138 func getClusterTime(clusterTime bson.Raw) (uint32, uint32) {
139 if clusterTime == nil {
140 return 0, 0
141 }
142
143 clusterTimeVal, err := clusterTime.LookupErr("$clusterTime")
144 if err != nil {
145 return 0, 0
146 }
147
148 timestampVal, err := bson.Raw(clusterTimeVal.Value).LookupErr("clusterTime")
149 if err != nil {
150 return 0, 0
151 }
152
153 return timestampVal.Timestamp()
154 }
155
156
157 func MaxClusterTime(ct1, ct2 bson.Raw) bson.Raw {
158 epoch1, ord1 := getClusterTime(ct1)
159 epoch2, ord2 := getClusterTime(ct2)
160
161 if epoch1 > epoch2 {
162 return ct1
163 } else if epoch1 < epoch2 {
164 return ct2
165 } else if ord1 > ord2 {
166 return ct1
167 } else if ord1 < ord2 {
168 return ct2
169 }
170
171 return ct1
172 }
173
174
175 func NewImplicitClientSession(pool *Pool, clientID uuid.UUID) *Client {
176
177
178
179
180 return &Client{
181 pool: pool,
182 ClientID: clientID,
183 IsImplicit: true,
184 }
185 }
186
187
188 func NewClientSession(pool *Pool, clientID uuid.UUID, opts ...*ClientOptions) (*Client, error) {
189 c := &Client{
190 pool: pool,
191 ClientID: clientID,
192 }
193
194 mergedOpts := mergeClientOptions(opts...)
195 if mergedOpts.DefaultReadPreference != nil {
196 c.transactionRp = mergedOpts.DefaultReadPreference
197 }
198 if mergedOpts.DefaultReadConcern != nil {
199 c.transactionRc = mergedOpts.DefaultReadConcern
200 }
201 if mergedOpts.DefaultWriteConcern != nil {
202 c.transactionWc = mergedOpts.DefaultWriteConcern
203 }
204 if mergedOpts.DefaultMaxCommitTime != nil {
205 c.transactionMaxCommitTime = mergedOpts.DefaultMaxCommitTime
206 }
207 if mergedOpts.Snapshot != nil {
208 c.Snapshot = *mergedOpts.Snapshot
209 }
210
211
212
213
214 c.Consistent = !c.Snapshot
215 if mergedOpts.CausalConsistency != nil {
216 c.Consistent = *mergedOpts.CausalConsistency
217 }
218
219 if c.Consistent && c.Snapshot {
220 return nil, errors.New("causal consistency and snapshot cannot both be set for a session")
221 }
222
223 if err := c.SetServer(); err != nil {
224 return nil, err
225 }
226
227 return c, nil
228 }
229
230
231 func (c *Client) SetServer() error {
232 var err error
233 c.Server, err = c.pool.GetSession()
234 return err
235 }
236
237
238 func (c *Client) AdvanceClusterTime(clusterTime bson.Raw) error {
239 if c.Terminated {
240 return ErrSessionEnded
241 }
242 c.ClusterTime = MaxClusterTime(c.ClusterTime, clusterTime)
243 return nil
244 }
245
246
247 func (c *Client) AdvanceOperationTime(opTime *primitive.Timestamp) error {
248 if c.Terminated {
249 return ErrSessionEnded
250 }
251
252 if c.OperationTime == nil {
253 c.OperationTime = opTime
254 return nil
255 }
256
257 if opTime.T > c.OperationTime.T {
258 c.OperationTime = opTime
259 } else if (opTime.T == c.OperationTime.T) && (opTime.I > c.OperationTime.I) {
260 c.OperationTime = opTime
261 }
262
263 return nil
264 }
265
266
267
268
269 func (c *Client) UpdateUseTime() error {
270 if c.Terminated {
271 return ErrSessionEnded
272 }
273 c.updateUseTime()
274 return nil
275 }
276
277
278 func (c *Client) UpdateRecoveryToken(response bson.Raw) {
279 if c == nil {
280 return
281 }
282
283 token, err := response.LookupErr("recoveryToken")
284 if err != nil {
285 return
286 }
287
288 c.RecoveryToken = token.Document()
289 }
290
291
292 func (c *Client) UpdateSnapshotTime(response bsoncore.Document) {
293 if c == nil {
294 return
295 }
296
297 subDoc := response
298 if cur, ok := response.Lookup("cursor").DocumentOK(); ok {
299 subDoc = cur
300 }
301
302 ssTimeElem, err := subDoc.LookupErr("atClusterTime")
303 if err != nil {
304
305 return
306 }
307
308 t, i := ssTimeElem.Timestamp()
309 c.SnapshotTime = &primitive.Timestamp{
310 T: t,
311 I: i,
312 }
313 }
314
315
316 func (c *Client) ClearPinnedResources() error {
317 if c == nil {
318 return nil
319 }
320
321 c.PinnedServer = nil
322 if c.PinnedConnection != nil {
323 if err := c.PinnedConnection.UnpinFromTransaction(); err != nil {
324 return err
325 }
326 if err := c.PinnedConnection.Close(); err != nil {
327 return err
328 }
329 }
330 c.PinnedConnection = nil
331 return nil
332 }
333
334
335
336
337 func (c *Client) unpinConnection() error {
338 if c == nil || c.PinnedConnection == nil {
339 return nil
340 }
341
342 err := c.PinnedConnection.UnpinFromTransaction()
343 closeErr := c.PinnedConnection.Close()
344 if err == nil && closeErr != nil {
345 err = closeErr
346 }
347 c.PinnedConnection = nil
348 return err
349 }
350
351
352 func (c *Client) EndSession() {
353 if c.Terminated {
354 return
355 }
356 c.Terminated = true
357
358
359
360
361
362 _ = c.unpinConnection()
363 c.pool.ReturnSession(c.Server)
364 }
365
366
367 func (c *Client) TransactionInProgress() bool {
368 return c.TransactionState == InProgress
369 }
370
371
372 func (c *Client) TransactionStarting() bool {
373 return c.TransactionState == Starting
374 }
375
376
377
378 func (c *Client) TransactionRunning() bool {
379 return c != nil && (c.TransactionState == Starting || c.TransactionState == InProgress)
380 }
381
382
383 func (c *Client) TransactionCommitted() bool {
384 return c.TransactionState == Committed
385 }
386
387
388
389 func (c *Client) CheckStartTransaction() error {
390 if c.TransactionState == InProgress || c.TransactionState == Starting {
391 return ErrTransactInProgress
392 }
393 if c.Snapshot {
394 return ErrSnapshotTransaction
395 }
396 return nil
397 }
398
399
400
401 func (c *Client) StartTransaction(opts *TransactionOptions) error {
402 err := c.CheckStartTransaction()
403 if err != nil {
404 return err
405 }
406
407 c.IncrementTxnNumber()
408 c.RetryingCommit = false
409
410 if opts != nil {
411 c.CurrentRc = opts.ReadConcern
412 c.CurrentRp = opts.ReadPreference
413 c.CurrentWc = opts.WriteConcern
414 c.CurrentMct = opts.MaxCommitTime
415 }
416
417 if c.CurrentRc == nil {
418 c.CurrentRc = c.transactionRc
419 }
420
421 if c.CurrentRp == nil {
422 c.CurrentRp = c.transactionRp
423 }
424
425 if c.CurrentWc == nil {
426 c.CurrentWc = c.transactionWc
427 }
428
429 if c.CurrentMct == nil {
430 c.CurrentMct = c.transactionMaxCommitTime
431 }
432
433 if !writeconcern.AckWrite(c.CurrentWc) {
434 _ = c.clearTransactionOpts()
435 return ErrUnackWCUnsupported
436 }
437
438 c.TransactionState = Starting
439 return c.ClearPinnedResources()
440 }
441
442
443
444 func (c *Client) CheckCommitTransaction() error {
445 if c.TransactionState == None {
446 return ErrNoTransactStarted
447 } else if c.TransactionState == Aborted {
448 return ErrCommitAfterAbort
449 }
450 return nil
451 }
452
453
454
455 func (c *Client) CommitTransaction() error {
456 err := c.CheckCommitTransaction()
457 if err != nil {
458 return err
459 }
460 c.TransactionState = Committed
461 return nil
462 }
463
464
465
466
467 func (c *Client) UpdateCommitTransactionWriteConcern() {
468 wc := c.CurrentWc
469 timeout := 10 * time.Second
470 if wc != nil && wc.GetWTimeout() != 0 {
471 timeout = wc.GetWTimeout()
472 }
473 c.CurrentWc = wc.WithOptions(writeconcern.WMajority(), writeconcern.WTimeout(timeout))
474 }
475
476
477
478 func (c *Client) CheckAbortTransaction() error {
479 if c.TransactionState == None {
480 return ErrNoTransactStarted
481 } else if c.TransactionState == Committed {
482 return ErrAbortAfterCommit
483 } else if c.TransactionState == Aborted {
484 return ErrAbortTwice
485 }
486 return nil
487 }
488
489
490
491 func (c *Client) AbortTransaction() error {
492 err := c.CheckAbortTransaction()
493 if err != nil {
494 return err
495 }
496 c.TransactionState = Aborted
497 return c.clearTransactionOpts()
498 }
499
500
501
502 func (c *Client) StartCommand() error {
503 if c == nil {
504 return nil
505 }
506
507
508
509 if !c.TransactionRunning() && !c.Committing && !c.Aborting {
510 return c.ClearPinnedResources()
511 }
512 return nil
513 }
514
515
516
517 func (c *Client) ApplyCommand(desc description.Server) error {
518 if c.Committing {
519
520 return nil
521 }
522 if c.TransactionState == Starting {
523 c.TransactionState = InProgress
524
525 if desc.Kind == description.Mongos {
526 c.PinnedServer = &desc
527 }
528 } else if c.TransactionState == Committed || c.TransactionState == Aborted {
529 c.TransactionState = None
530 return c.clearTransactionOpts()
531 }
532
533 return nil
534 }
535
536 func (c *Client) clearTransactionOpts() error {
537 c.RetryingCommit = false
538 c.Aborting = false
539 c.Committing = false
540 c.CurrentWc = nil
541 c.CurrentRp = nil
542 c.CurrentRc = nil
543 c.RecoveryToken = nil
544
545 return c.ClearPinnedResources()
546 }
547
View as plain text