...
1
2
3
4
5
6
7 package mongo
8
9 import (
10 "context"
11 "errors"
12 "time"
13
14 "go.mongodb.org/mongo-driver/bson"
15 "go.mongodb.org/mongo-driver/bson/primitive"
16 "go.mongodb.org/mongo-driver/mongo/description"
17 "go.mongodb.org/mongo-driver/mongo/options"
18 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
19 "go.mongodb.org/mongo-driver/x/mongo/driver"
20 "go.mongodb.org/mongo-driver/x/mongo/driver/operation"
21 "go.mongodb.org/mongo-driver/x/mongo/driver/session"
22 )
23
24
25
26 var ErrWrongClient = errors.New("session was not created by this client")
27
28 var withTransactionTimeout = 120 * time.Second
29
30
31
32
33
34
35
36
37
38 type SessionContext interface {
39 context.Context
40 Session
41 }
42
43 type sessionContext struct {
44 context.Context
45 Session
46 }
47
48 type sessionKey struct {
49 }
50
51
52 func NewSessionContext(ctx context.Context, sess Session) SessionContext {
53 return &sessionContext{
54 Context: context.WithValue(ctx, sessionKey{}, sess),
55 Session: sess,
56 }
57 }
58
59
60
61
62 func SessionFromContext(ctx context.Context) Session {
63 val := ctx.Value(sessionKey{})
64 if val == nil {
65 return nil
66 }
67
68 sess, ok := val.(Session)
69 if !ok {
70 return nil
71 }
72
73 return sess
74 }
75
76
77
78
79
80
81
82
83
84
85
86 type Session interface {
87
88
89
90 StartTransaction(...*options.TransactionOptions) error
91
92
93
94
95 AbortTransaction(context.Context) error
96
97
98
99
100 CommitTransaction(context.Context) error
101
102
103
104
105
106
107
108
109
110
111
112
113 WithTransaction(ctx context.Context, fn func(ctx SessionContext) (interface{}, error),
114 opts ...*options.TransactionOptions) (interface{}, error)
115
116
117 EndSession(context.Context)
118
119
120 ClusterTime() bson.Raw
121
122
123 OperationTime() *primitive.Timestamp
124
125
126 Client() *Client
127
128
129
130 ID() bson.Raw
131
132
133
134 AdvanceClusterTime(bson.Raw) error
135
136
137
138 AdvanceOperationTime(*primitive.Timestamp) error
139
140 session()
141 }
142
143
144
145
146
147 type XSession interface {
148 ClientSession() *session.Client
149 }
150
151
152 type sessionImpl struct {
153 clientSession *session.Client
154 client *Client
155 deployment driver.Deployment
156 didCommitAfterStart bool
157 }
158
159 var _ Session = &sessionImpl{}
160 var _ XSession = &sessionImpl{}
161
162
163 func (s *sessionImpl) ClientSession() *session.Client {
164 return s.clientSession
165 }
166
167
168 func (s *sessionImpl) ID() bson.Raw {
169 return bson.Raw(s.clientSession.SessionID)
170 }
171
172
173 func (s *sessionImpl) EndSession(ctx context.Context) {
174 if s.clientSession.TransactionInProgress() {
175
176 _ = s.AbortTransaction(ctx)
177 }
178 s.clientSession.EndSession()
179 }
180
181
182 func (s *sessionImpl) WithTransaction(ctx context.Context, fn func(ctx SessionContext) (interface{}, error),
183 opts ...*options.TransactionOptions) (interface{}, error) {
184 timeout := time.NewTimer(withTransactionTimeout)
185 defer timeout.Stop()
186 var err error
187 for {
188 err = s.StartTransaction(opts...)
189 if err != nil {
190 return nil, err
191 }
192
193 res, err := fn(NewSessionContext(ctx, s))
194 if err != nil {
195 if s.clientSession.TransactionRunning() {
196
197
198 _ = s.AbortTransaction(newBackgroundContext(ctx))
199 }
200
201 select {
202 case <-timeout.C:
203 return nil, err
204 default:
205 }
206
207 if errorHasLabel(err, driver.TransientTransactionError) {
208 continue
209 }
210 return res, err
211 }
212
213
214
215 err = s.clientSession.CheckAbortTransaction()
216 if err != nil {
217 return res, nil
218 }
219
220
221
222
223
224
225
226
227 if ctx.Err() != nil {
228
229
230 _ = s.AbortTransaction(newBackgroundContext(ctx))
231 return nil, ctx.Err()
232 }
233
234 CommitLoop:
235 for {
236 err = s.CommitTransaction(newBackgroundContext(ctx))
237
238 if err == nil {
239 return res, nil
240 }
241
242 select {
243 case <-timeout.C:
244 return res, err
245 default:
246 }
247
248 if cerr, ok := err.(CommandError); ok {
249 if cerr.HasErrorLabel(driver.UnknownTransactionCommitResult) && !cerr.IsMaxTimeMSExpiredError() {
250 continue
251 }
252 if cerr.HasErrorLabel(driver.TransientTransactionError) {
253 break CommitLoop
254 }
255 }
256 return res, err
257 }
258 }
259 }
260
261
262 func (s *sessionImpl) StartTransaction(opts ...*options.TransactionOptions) error {
263 err := s.clientSession.CheckStartTransaction()
264 if err != nil {
265 return err
266 }
267
268 s.didCommitAfterStart = false
269
270 topts := options.MergeTransactionOptions(opts...)
271 coreOpts := &session.TransactionOptions{
272 ReadConcern: topts.ReadConcern,
273 ReadPreference: topts.ReadPreference,
274 WriteConcern: topts.WriteConcern,
275 MaxCommitTime: topts.MaxCommitTime,
276 }
277
278 return s.clientSession.StartTransaction(coreOpts)
279 }
280
281
282 func (s *sessionImpl) AbortTransaction(ctx context.Context) error {
283 err := s.clientSession.CheckAbortTransaction()
284 if err != nil {
285 return err
286 }
287
288
289 if s.clientSession.TransactionStarting() || s.didCommitAfterStart {
290 return s.clientSession.AbortTransaction()
291 }
292
293 selector := makePinnedSelector(s.clientSession, description.WriteSelector())
294
295 s.clientSession.Aborting = true
296 _ = operation.NewAbortTransaction().Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").
297 Deployment(s.deployment).WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).
298 Retry(driver.RetryOncePerCommand).CommandMonitor(s.client.monitor).
299 RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).ServerAPI(s.client.serverAPI).Execute(ctx)
300
301 s.clientSession.Aborting = false
302 _ = s.clientSession.AbortTransaction()
303
304 return nil
305 }
306
307
308 func (s *sessionImpl) CommitTransaction(ctx context.Context) error {
309 err := s.clientSession.CheckCommitTransaction()
310 if err != nil {
311 return err
312 }
313
314
315 if s.clientSession.TransactionStarting() || s.didCommitAfterStart {
316 s.didCommitAfterStart = true
317 return s.clientSession.CommitTransaction()
318 }
319
320 if s.clientSession.TransactionCommitted() {
321 s.clientSession.RetryingCommit = true
322 }
323
324 selector := makePinnedSelector(s.clientSession, description.WriteSelector())
325
326 s.clientSession.Committing = true
327 op := operation.NewCommitTransaction().
328 Session(s.clientSession).ClusterClock(s.client.clock).Database("admin").Deployment(s.deployment).
329 WriteConcern(s.clientSession.CurrentWc).ServerSelector(selector).Retry(driver.RetryOncePerCommand).
330 CommandMonitor(s.client.monitor).RecoveryToken(bsoncore.Document(s.clientSession.RecoveryToken)).
331 ServerAPI(s.client.serverAPI).MaxTime(s.clientSession.CurrentMct)
332
333 err = op.Execute(ctx)
334
335
336 if IsTimeout(err) {
337 return replaceErrors(err)
338 }
339 s.clientSession.Committing = false
340 commitErr := s.clientSession.CommitTransaction()
341
342
343 s.clientSession.UpdateCommitTransactionWriteConcern()
344
345 if err != nil {
346 return replaceErrors(err)
347 }
348 return commitErr
349 }
350
351
352 func (s *sessionImpl) ClusterTime() bson.Raw {
353 return s.clientSession.ClusterTime
354 }
355
356
357 func (s *sessionImpl) AdvanceClusterTime(d bson.Raw) error {
358 return s.clientSession.AdvanceClusterTime(d)
359 }
360
361
362 func (s *sessionImpl) OperationTime() *primitive.Timestamp {
363 return s.clientSession.OperationTime
364 }
365
366
367 func (s *sessionImpl) AdvanceOperationTime(ts *primitive.Timestamp) error {
368 return s.clientSession.AdvanceOperationTime(ts)
369 }
370
371
372 func (s *sessionImpl) Client() *Client {
373 return s.client
374 }
375
376
377 func (*sessionImpl) session() {
378 }
379
380
381
382 func sessionFromContext(ctx context.Context) *session.Client {
383 s := ctx.Value(sessionKey{})
384 if ses, ok := s.(*sessionImpl); ses != nil && ok {
385 return ses.clientSession
386 }
387
388 return nil
389 }
390
View as plain text