1
2
3
4
5
6
7 package mongo
8
9 import (
10 "bytes"
11 "context"
12 "errors"
13 "fmt"
14 "io"
15 "net"
16 "reflect"
17 "strconv"
18 "strings"
19
20 "go.mongodb.org/mongo-driver/internal/codecutil"
21 "go.mongodb.org/mongo-driver/mongo/options"
22 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
23
24 "go.mongodb.org/mongo-driver/bson"
25 "go.mongodb.org/mongo-driver/bson/bsoncodec"
26 "go.mongodb.org/mongo-driver/bson/bsonrw"
27 "go.mongodb.org/mongo-driver/bson/bsontype"
28 "go.mongodb.org/mongo-driver/bson/primitive"
29 )
30
31
32 type Dialer interface {
33 DialContext(ctx context.Context, network, address string) (net.Conn, error)
34 }
35
36
37
38
39
40
41
42 type BSONAppender interface {
43 AppendBSON([]byte, interface{}) ([]byte, error)
44 }
45
46
47
48
49
50
51 type BSONAppenderFunc func([]byte, interface{}) ([]byte, error)
52
53
54
55
56 func (baf BSONAppenderFunc) AppendBSON(dst []byte, val interface{}) ([]byte, error) {
57 return baf(dst, val)
58 }
59
60
61
62 type MarshalError struct {
63 Value interface{}
64 Err error
65 }
66
67
68 func (me MarshalError) Error() string {
69 return fmt.Sprintf("cannot marshal type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err)
70 }
71
72
73
74
75
76
77
78
79
80
81 type Pipeline []bson.D
82
83
84 var bvwPool = bsonrw.NewBSONValueWriterPool()
85
86
87
88 func getEncoder(
89 w io.Writer,
90 opts *options.BSONOptions,
91 reg *bsoncodec.Registry,
92 ) (*bson.Encoder, error) {
93 vw := bvwPool.Get(w)
94 enc, err := bson.NewEncoder(vw)
95 if err != nil {
96 return nil, err
97 }
98
99 if opts != nil {
100 if opts.ErrorOnInlineDuplicates {
101 enc.ErrorOnInlineDuplicates()
102 }
103 if opts.IntMinSize {
104 enc.IntMinSize()
105 }
106 if opts.NilByteSliceAsEmpty {
107 enc.NilByteSliceAsEmpty()
108 }
109 if opts.NilMapAsEmpty {
110 enc.NilMapAsEmpty()
111 }
112 if opts.NilSliceAsEmpty {
113 enc.NilSliceAsEmpty()
114 }
115 if opts.OmitZeroStruct {
116 enc.OmitZeroStruct()
117 }
118 if opts.StringifyMapKeysWithFmt {
119 enc.StringifyMapKeysWithFmt()
120 }
121 if opts.UseJSONStructTags {
122 enc.UseJSONStructTags()
123 }
124 }
125
126 if reg != nil {
127
128 if err := enc.SetRegistry(reg); err != nil {
129 return nil, err
130 }
131 }
132
133 return enc, nil
134 }
135
136
137
138 func newEncoderFn(opts *options.BSONOptions, registry *bsoncodec.Registry) codecutil.EncoderFn {
139 return func(w io.Writer) (*bson.Encoder, error) {
140 return getEncoder(w, opts, registry)
141 }
142 }
143
144
145
146
147
148
149 func marshal(
150 val interface{},
151 bsonOpts *options.BSONOptions,
152 registry *bsoncodec.Registry,
153 ) (bsoncore.Document, error) {
154 if registry == nil {
155 registry = bson.DefaultRegistry
156 }
157 if val == nil {
158 return nil, ErrNilDocument
159 }
160 if bs, ok := val.([]byte); ok {
161
162 val = bson.Raw(bs)
163 }
164
165 buf := new(bytes.Buffer)
166 enc, err := getEncoder(buf, bsonOpts, registry)
167 if err != nil {
168 return nil, fmt.Errorf("error configuring BSON encoder: %w", err)
169 }
170
171 err = enc.Encode(val)
172 if err != nil {
173 return nil, MarshalError{Value: val, Err: err}
174 }
175
176 return buf.Bytes(), nil
177 }
178
179
180
181
182
183
184
185
186 func ensureID(
187 doc bsoncore.Document,
188 oid primitive.ObjectID,
189 bsonOpts *options.BSONOptions,
190 reg *bsoncodec.Registry,
191 ) (bsoncore.Document, interface{}, error) {
192 if reg == nil {
193 reg = bson.DefaultRegistry
194 }
195
196
197
198
199 if _, err := doc.LookupErr("_id"); err == nil {
200 var id struct {
201 ID interface{} `bson:"_id"`
202 }
203 dec, err := getDecoder(doc, bsonOpts, reg)
204 if err != nil {
205 return nil, nil, fmt.Errorf("error configuring BSON decoder: %w", err)
206 }
207 err = dec.Decode(&id)
208 if err != nil {
209 return nil, nil, fmt.Errorf("error unmarshaling BSON document: %w", err)
210 }
211
212 return doc, id.ID, nil
213 }
214
215
216
217
218 olddoc := doc
219
220
221
222 const extraSpace = 17
223 doc = make(bsoncore.Document, 0, len(olddoc)+extraSpace)
224 _, doc = bsoncore.ReserveLength(doc)
225 if oid.IsZero() {
226 oid = primitive.NewObjectID()
227 }
228 doc = bsoncore.AppendObjectIDElement(doc, "_id", oid)
229
230
231 const int32Len = 4
232 doc = append(doc, olddoc[int32Len:]...)
233 doc = bsoncore.UpdateLength(doc, 0, int32(len(doc)))
234
235 return doc, oid, nil
236 }
237
238 func ensureDollarKey(doc bsoncore.Document) error {
239 firstElem, err := doc.IndexErr(0)
240 if err != nil {
241 return errors.New("update document must have at least one element")
242 }
243
244 if !strings.HasPrefix(firstElem.Key(), "$") {
245 return errors.New("update document must contain key beginning with '$'")
246 }
247 return nil
248 }
249
250 func ensureNoDollarKey(doc bsoncore.Document) error {
251 if elem, err := doc.IndexErr(0); err == nil && strings.HasPrefix(elem.Key(), "$") {
252 return errors.New("replacement document cannot contain keys beginning with '$'")
253 }
254
255 return nil
256 }
257
258 func marshalAggregatePipeline(
259 pipeline interface{},
260 bsonOpts *options.BSONOptions,
261 registry *bsoncodec.Registry,
262 ) (bsoncore.Document, bool, error) {
263 switch t := pipeline.(type) {
264 case bsoncodec.ValueMarshaler:
265 btype, val, err := t.MarshalBSONValue()
266 if err != nil {
267 return nil, false, err
268 }
269 if btype != bsontype.Array {
270 return nil, false, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array)
271 }
272
273 var hasOutputStage bool
274 pipelineDoc := bsoncore.Document(val)
275 values, _ := pipelineDoc.Values()
276 if pipelineLen := len(values); pipelineLen > 0 {
277 if finalDoc, ok := values[pipelineLen-1].DocumentOK(); ok {
278 if elem, err := finalDoc.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") {
279 hasOutputStage = true
280 }
281 }
282 }
283
284 return pipelineDoc, hasOutputStage, nil
285 default:
286 val := reflect.ValueOf(t)
287 if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) {
288 return nil, false, fmt.Errorf("can only marshal slices and arrays into aggregation pipelines, but got %v", val.Kind())
289 }
290
291 var hasOutputStage bool
292 valLen := val.Len()
293
294 switch t := pipeline.(type) {
295
296
297 case bson.D, bson.Raw, bsoncore.Document:
298 if valLen > 0 {
299 return nil, false,
300 fmt.Errorf("%T is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead", t)
301 }
302
303 case bsoncore.Array:
304 if err := t.Validate(); err != nil {
305 return nil, false, err
306 }
307
308 values, err := t.Values()
309 if err != nil {
310 return nil, false, err
311 }
312
313 numVals := len(values)
314 if numVals == 0 {
315 return bsoncore.Document(t), false, nil
316 }
317
318
319 if lastStage, ok := values[numVals-1].DocumentOK(); ok {
320 if elem, err := lastStage.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") {
321 hasOutputStage = true
322 }
323 }
324 return bsoncore.Document(t), hasOutputStage, nil
325 }
326
327 aidx, arr := bsoncore.AppendArrayStart(nil)
328 for idx := 0; idx < valLen; idx++ {
329 doc, err := marshal(val.Index(idx).Interface(), bsonOpts, registry)
330 if err != nil {
331 return nil, false, err
332 }
333
334 if idx == valLen-1 {
335 if elem, err := doc.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") {
336 hasOutputStage = true
337 }
338 }
339 arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc)
340 }
341 arr, _ = bsoncore.AppendArrayEnd(arr, aidx)
342 return arr, hasOutputStage, nil
343 }
344 }
345
346 func marshalUpdateValue(
347 update interface{},
348 bsonOpts *options.BSONOptions,
349 registry *bsoncodec.Registry,
350 dollarKeysAllowed bool,
351 ) (bsoncore.Value, error) {
352 documentCheckerFunc := ensureDollarKey
353 if !dollarKeysAllowed {
354 documentCheckerFunc = ensureNoDollarKey
355 }
356
357 var u bsoncore.Value
358 var err error
359 switch t := update.(type) {
360 case nil:
361 return u, ErrNilDocument
362 case primitive.D:
363 u.Type = bsontype.EmbeddedDocument
364 u.Data, err = marshal(update, bsonOpts, registry)
365 if err != nil {
366 return u, err
367 }
368
369 return u, documentCheckerFunc(u.Data)
370 case bson.Raw:
371 u.Type = bsontype.EmbeddedDocument
372 u.Data = t
373 return u, documentCheckerFunc(u.Data)
374 case bsoncore.Document:
375 u.Type = bsontype.EmbeddedDocument
376 u.Data = t
377 return u, documentCheckerFunc(u.Data)
378 case []byte:
379 u.Type = bsontype.EmbeddedDocument
380 u.Data = t
381 return u, documentCheckerFunc(u.Data)
382 case bsoncodec.Marshaler:
383 u.Type = bsontype.EmbeddedDocument
384 u.Data, err = t.MarshalBSON()
385 if err != nil {
386 return u, err
387 }
388
389 return u, documentCheckerFunc(u.Data)
390 case bsoncodec.ValueMarshaler:
391 u.Type, u.Data, err = t.MarshalBSONValue()
392 if err != nil {
393 return u, err
394 }
395 if u.Type != bsontype.Array && u.Type != bsontype.EmbeddedDocument {
396 return u, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v or %v", u.Type, bsontype.Array, bsontype.EmbeddedDocument)
397 }
398 return u, err
399 default:
400 val := reflect.ValueOf(t)
401 if !val.IsValid() {
402 return u, fmt.Errorf("can only marshal slices and arrays into update pipelines, but got %v", val.Kind())
403 }
404 if val.Kind() != reflect.Slice && val.Kind() != reflect.Array {
405 u.Type = bsontype.EmbeddedDocument
406 u.Data, err = marshal(update, bsonOpts, registry)
407 if err != nil {
408 return u, err
409 }
410
411 return u, documentCheckerFunc(u.Data)
412 }
413
414 u.Type = bsontype.Array
415 aidx, arr := bsoncore.AppendArrayStart(nil)
416 valLen := val.Len()
417 for idx := 0; idx < valLen; idx++ {
418 doc, err := marshal(val.Index(idx).Interface(), bsonOpts, registry)
419 if err != nil {
420 return u, err
421 }
422
423 if err := documentCheckerFunc(doc); err != nil {
424 return u, err
425 }
426
427 arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc)
428 }
429 u.Data, _ = bsoncore.AppendArrayEnd(arr, aidx)
430 return u, err
431 }
432 }
433
434 func marshalValue(
435 val interface{},
436 bsonOpts *options.BSONOptions,
437 registry *bsoncodec.Registry,
438 ) (bsoncore.Value, error) {
439 return codecutil.MarshalValue(val, newEncoderFn(bsonOpts, registry))
440 }
441
442
443 func countDocumentsAggregatePipeline(
444 filter interface{},
445 encOpts *options.BSONOptions,
446 registry *bsoncodec.Registry,
447 opts *options.CountOptions,
448 ) (bsoncore.Document, error) {
449 filterDoc, err := marshal(filter, encOpts, registry)
450 if err != nil {
451 return nil, err
452 }
453
454 aidx, arr := bsoncore.AppendArrayStart(nil)
455 didx, arr := bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(0))
456 arr = bsoncore.AppendDocumentElement(arr, "$match", filterDoc)
457 arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
458
459 index := 1
460 if opts != nil {
461 if opts.Skip != nil {
462 didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
463 arr = bsoncore.AppendInt64Element(arr, "$skip", *opts.Skip)
464 arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
465 index++
466 }
467 if opts.Limit != nil {
468 didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
469 arr = bsoncore.AppendInt64Element(arr, "$limit", *opts.Limit)
470 arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
471 index++
472 }
473 }
474
475 didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
476 iidx, arr := bsoncore.AppendDocumentElementStart(arr, "$group")
477 arr = bsoncore.AppendInt32Element(arr, "_id", 1)
478 iiidx, arr := bsoncore.AppendDocumentElementStart(arr, "n")
479 arr = bsoncore.AppendInt32Element(arr, "$sum", 1)
480 arr, _ = bsoncore.AppendDocumentEnd(arr, iiidx)
481 arr, _ = bsoncore.AppendDocumentEnd(arr, iidx)
482 arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
483
484 return bsoncore.AppendArrayEnd(arr, aidx)
485 }
486
View as plain text