1
2
3
4
5
6
7 package mongo
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "io"
14 "reflect"
15 "time"
16
17 "go.mongodb.org/mongo-driver/bson"
18 "go.mongodb.org/mongo-driver/bson/bsoncodec"
19 "go.mongodb.org/mongo-driver/bson/bsonrw"
20 "go.mongodb.org/mongo-driver/mongo/options"
21 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
22 "go.mongodb.org/mongo-driver/x/mongo/driver"
23 "go.mongodb.org/mongo-driver/x/mongo/driver/session"
24 )
25
26
27
28
29 type Cursor struct {
30
31
32 Current bson.Raw
33
34 bc batchCursor
35 batch *bsoncore.DocumentSequence
36 batchLength int
37 bsonOpts *options.BSONOptions
38 registry *bsoncodec.Registry
39 clientSession *session.Client
40
41 err error
42 }
43
44 func newCursor(
45 bc batchCursor,
46 bsonOpts *options.BSONOptions,
47 registry *bsoncodec.Registry,
48 ) (*Cursor, error) {
49 return newCursorWithSession(bc, bsonOpts, registry, nil)
50 }
51
52 func newCursorWithSession(
53 bc batchCursor,
54 bsonOpts *options.BSONOptions,
55 registry *bsoncodec.Registry,
56 clientSession *session.Client,
57 ) (*Cursor, error) {
58 if registry == nil {
59 registry = bson.DefaultRegistry
60 }
61 if bc == nil {
62 return nil, errors.New("batch cursor must not be nil")
63 }
64 c := &Cursor{
65 bc: bc,
66 bsonOpts: bsonOpts,
67 registry: registry,
68 clientSession: clientSession,
69 }
70 if bc.ID() == 0 {
71 c.closeImplicitSession()
72 }
73
74
75
76 c.batchLength = c.bc.Batch().DocumentCount()
77 return c, nil
78 }
79
80 func newEmptyCursor() *Cursor {
81 return &Cursor{bc: driver.NewEmptyBatchCursor()}
82 }
83
84
85
86
87
88 func NewCursorFromDocuments(documents []interface{}, err error, registry *bsoncodec.Registry) (*Cursor, error) {
89 if registry == nil {
90 registry = bson.DefaultRegistry
91 }
92
93
94 var docsBytes []byte
95 for _, doc := range documents {
96 switch t := doc.(type) {
97 case nil:
98 return nil, ErrNilDocument
99 case []byte:
100
101 doc = bson.Raw(t)
102 }
103 var marshalErr error
104 docsBytes, marshalErr = bson.MarshalAppendWithRegistry(registry, docsBytes, doc)
105 if marshalErr != nil {
106 return nil, marshalErr
107 }
108 }
109
110 c := &Cursor{
111 bc: driver.NewBatchCursorFromDocuments(docsBytes),
112 registry: registry,
113 err: err,
114 }
115
116
117
118 c.batch = c.bc.Batch()
119 c.batchLength = c.bc.Batch().DocumentCount()
120 return c, nil
121 }
122
123
124 func (c *Cursor) ID() int64 { return c.bc.ID() }
125
126
127
128
129
130
131
132
133 func (c *Cursor) Next(ctx context.Context) bool {
134 return c.next(ctx, false)
135 }
136
137
138
139
140
141
142
143
144
145
146
147
148
149 func (c *Cursor) TryNext(ctx context.Context) bool {
150 return c.next(ctx, true)
151 }
152
153 func (c *Cursor) next(ctx context.Context, nonBlocking bool) bool {
154
155 if c.err != nil {
156 return false
157 }
158
159 if ctx == nil {
160 ctx = context.Background()
161 }
162 doc, err := c.batch.Next()
163 switch {
164 case err == nil:
165
166 c.batchLength--
167 c.Current = bson.Raw(doc)
168 return true
169 case errors.Is(err, io.EOF):
170 default:
171 c.err = err
172 return false
173 }
174
175
176
177 for {
178
179 if !c.bc.Next(ctx) {
180
181 c.err = replaceErrors(c.bc.Err())
182 if c.err != nil {
183 return false
184 }
185
186 if c.bc.ID() == 0 {
187 c.closeImplicitSession()
188 return false
189 }
190
191
192 if nonBlocking {
193 return false
194 }
195 continue
196 }
197
198
199 if c.bc.ID() == 0 {
200 c.closeImplicitSession()
201 }
202
203
204 c.batch = c.bc.Batch()
205 c.batchLength = c.batch.DocumentCount()
206 doc, err = c.batch.Next()
207 switch {
208 case err == nil:
209 c.batchLength--
210 c.Current = bson.Raw(doc)
211 return true
212 case errors.Is(err, io.EOF):
213 default:
214 c.err = err
215 return false
216 }
217 }
218 }
219
220 func getDecoder(
221 data []byte,
222 opts *options.BSONOptions,
223 reg *bsoncodec.Registry,
224 ) (*bson.Decoder, error) {
225 dec, err := bson.NewDecoder(bsonrw.NewBSONDocumentReader(data))
226 if err != nil {
227 return nil, err
228 }
229
230 if opts != nil {
231 if opts.AllowTruncatingDoubles {
232 dec.AllowTruncatingDoubles()
233 }
234 if opts.BinaryAsSlice {
235 dec.BinaryAsSlice()
236 }
237 if opts.DefaultDocumentD {
238 dec.DefaultDocumentD()
239 }
240 if opts.DefaultDocumentM {
241 dec.DefaultDocumentM()
242 }
243 if opts.UseJSONStructTags {
244 dec.UseJSONStructTags()
245 }
246 if opts.UseLocalTimeZone {
247 dec.UseLocalTimeZone()
248 }
249 if opts.ZeroMaps {
250 dec.ZeroMaps()
251 }
252 if opts.ZeroStructs {
253 dec.ZeroStructs()
254 }
255 }
256
257 if reg != nil {
258
259 if err := dec.SetRegistry(reg); err != nil {
260 return nil, err
261 }
262 }
263
264 return dec, nil
265 }
266
267
268
269 func (c *Cursor) Decode(val interface{}) error {
270 dec, err := getDecoder(c.Current, c.bsonOpts, c.registry)
271 if err != nil {
272 return fmt.Errorf("error configuring BSON decoder: %w", err)
273 }
274
275 return dec.Decode(val)
276 }
277
278
279 func (c *Cursor) Err() error { return c.err }
280
281
282
283 func (c *Cursor) Close(ctx context.Context) error {
284 defer c.closeImplicitSession()
285 return replaceErrors(c.bc.Close(ctx))
286 }
287
288
289
290
291
292
293 func (c *Cursor) All(ctx context.Context, results interface{}) error {
294 resultsVal := reflect.ValueOf(results)
295 if resultsVal.Kind() != reflect.Ptr {
296 return fmt.Errorf("results argument must be a pointer to a slice, but was a %s", resultsVal.Kind())
297 }
298
299 sliceVal := resultsVal.Elem()
300 if sliceVal.Kind() == reflect.Interface {
301 sliceVal = sliceVal.Elem()
302 }
303
304 if sliceVal.Kind() != reflect.Slice {
305 return fmt.Errorf("results argument must be a pointer to a slice, but was a pointer to %s", sliceVal.Kind())
306 }
307
308 elementType := sliceVal.Type().Elem()
309 var index int
310 var err error
311
312
313
314
315 defer c.Close(context.Background())
316
317 batch := c.batch
318 for {
319 sliceVal, index, err = c.addFromBatch(sliceVal, elementType, batch, index)
320 if err != nil {
321 return err
322 }
323
324 if !c.bc.Next(ctx) {
325 break
326 }
327
328 batch = c.bc.Batch()
329 }
330
331 if err = replaceErrors(c.bc.Err()); err != nil {
332 return err
333 }
334
335 resultsVal.Elem().Set(sliceVal.Slice(0, index))
336 return nil
337 }
338
339
340
341 func (c *Cursor) RemainingBatchLength() int {
342 return c.batchLength
343 }
344
345
346
347 func (c *Cursor) addFromBatch(sliceVal reflect.Value, elemType reflect.Type, batch *bsoncore.DocumentSequence,
348 index int) (reflect.Value, int, error) {
349
350 docs, err := batch.Documents()
351 if err != nil {
352 return sliceVal, index, err
353 }
354
355 for _, doc := range docs {
356 if sliceVal.Len() == index {
357
358 newElem := reflect.New(elemType)
359 sliceVal = reflect.Append(sliceVal, newElem.Elem())
360 sliceVal = sliceVal.Slice(0, sliceVal.Cap())
361 }
362
363 currElem := sliceVal.Index(index).Addr().Interface()
364 dec, err := getDecoder(doc, c.bsonOpts, c.registry)
365 if err != nil {
366 return sliceVal, index, fmt.Errorf("error configuring BSON decoder: %w", err)
367 }
368 err = dec.Decode(currElem)
369 if err != nil {
370 return sliceVal, index, err
371 }
372
373 index++
374 }
375
376 return sliceVal, index, nil
377 }
378
379 func (c *Cursor) closeImplicitSession() {
380 if c.clientSession != nil && c.clientSession.IsImplicit {
381 c.clientSession.EndSession()
382 }
383 }
384
385
386
387
388
389 func (c *Cursor) SetBatchSize(batchSize int32) {
390 c.bc.SetBatchSize(batchSize)
391 }
392
393
394
395
396
397
398
399 func (c *Cursor) SetMaxTime(dur time.Duration) {
400 c.bc.SetMaxTime(dur)
401 }
402
403
404
405 func (c *Cursor) SetComment(comment interface{}) {
406 c.bc.SetComment(comment)
407 }
408
409
410
411
412
413
414 func BatchCursorFromCursor(c *Cursor) *driver.BatchCursor {
415 bc, _ := c.bc.(*driver.BatchCursor)
416 return bc
417 }
418
View as plain text