...

Source file src/go.mongodb.org/mongo-driver/mongo/mongo.go

Documentation: go.mongodb.org/mongo-driver/mongo

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package mongo // import "go.mongodb.org/mongo-driver/mongo"
     8  
     9  import (
    10  	"bytes"
    11  	"context"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"net"
    16  	"reflect"
    17  	"strconv"
    18  	"strings"
    19  
    20  	"go.mongodb.org/mongo-driver/internal/codecutil"
    21  	"go.mongodb.org/mongo-driver/mongo/options"
    22  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    23  
    24  	"go.mongodb.org/mongo-driver/bson"
    25  	"go.mongodb.org/mongo-driver/bson/bsoncodec"
    26  	"go.mongodb.org/mongo-driver/bson/bsonrw"
    27  	"go.mongodb.org/mongo-driver/bson/bsontype"
    28  	"go.mongodb.org/mongo-driver/bson/primitive"
    29  )
    30  
    31  // Dialer is used to make network connections.
    32  type Dialer interface {
    33  	DialContext(ctx context.Context, network, address string) (net.Conn, error)
    34  }
    35  
    36  // BSONAppender is an interface implemented by types that can marshal a
    37  // provided type into BSON bytes and append those bytes to the provided []byte.
    38  // The AppendBSON can return a non-nil error and non-nil []byte. The AppendBSON
    39  // method may also write incomplete BSON to the []byte.
    40  //
    41  // Deprecated: BSONAppender is unused and will be removed in Go Driver 2.0.
    42  type BSONAppender interface {
    43  	AppendBSON([]byte, interface{}) ([]byte, error)
    44  }
    45  
    46  // BSONAppenderFunc is an adapter function that allows any function that
    47  // satisfies the AppendBSON method signature to be used where a BSONAppender is
    48  // used.
    49  //
    50  // Deprecated: BSONAppenderFunc is unused and will be removed in Go Driver 2.0.
    51  type BSONAppenderFunc func([]byte, interface{}) ([]byte, error)
    52  
    53  // AppendBSON implements the BSONAppender interface
    54  //
    55  // Deprecated: BSONAppenderFunc is unused and will be removed in Go Driver 2.0.
    56  func (baf BSONAppenderFunc) AppendBSON(dst []byte, val interface{}) ([]byte, error) {
    57  	return baf(dst, val)
    58  }
    59  
    60  // MarshalError is returned when attempting to marshal a value into a document
    61  // results in an error.
    62  type MarshalError struct {
    63  	Value interface{}
    64  	Err   error
    65  }
    66  
    67  // Error implements the error interface.
    68  func (me MarshalError) Error() string {
    69  	return fmt.Sprintf("cannot marshal type %s to a BSON Document: %v", reflect.TypeOf(me.Value), me.Err)
    70  }
    71  
    72  // Pipeline is a type that makes creating aggregation pipelines easier. It is a
    73  // helper and is intended for serializing to BSON.
    74  //
    75  // Example usage:
    76  //
    77  //	mongo.Pipeline{
    78  //		{{"$group", bson.D{{"_id", "$state"}, {"totalPop", bson.D{{"$sum", "$pop"}}}}}},
    79  //		{{"$match", bson.D{{"totalPop", bson.D{{"$gte", 10*1000*1000}}}}}},
    80  //	}
    81  type Pipeline []bson.D
    82  
    83  // bvwPool is a pool of BSON value writers. BSON value writers
    84  var bvwPool = bsonrw.NewBSONValueWriterPool()
    85  
    86  // getEncoder takes a writer, BSON options, and a BSON registry and returns a properly configured
    87  // bson.Encoder that writes to the given writer.
    88  func getEncoder(
    89  	w io.Writer,
    90  	opts *options.BSONOptions,
    91  	reg *bsoncodec.Registry,
    92  ) (*bson.Encoder, error) {
    93  	vw := bvwPool.Get(w)
    94  	enc, err := bson.NewEncoder(vw)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	if opts != nil {
   100  		if opts.ErrorOnInlineDuplicates {
   101  			enc.ErrorOnInlineDuplicates()
   102  		}
   103  		if opts.IntMinSize {
   104  			enc.IntMinSize()
   105  		}
   106  		if opts.NilByteSliceAsEmpty {
   107  			enc.NilByteSliceAsEmpty()
   108  		}
   109  		if opts.NilMapAsEmpty {
   110  			enc.NilMapAsEmpty()
   111  		}
   112  		if opts.NilSliceAsEmpty {
   113  			enc.NilSliceAsEmpty()
   114  		}
   115  		if opts.OmitZeroStruct {
   116  			enc.OmitZeroStruct()
   117  		}
   118  		if opts.StringifyMapKeysWithFmt {
   119  			enc.StringifyMapKeysWithFmt()
   120  		}
   121  		if opts.UseJSONStructTags {
   122  			enc.UseJSONStructTags()
   123  		}
   124  	}
   125  
   126  	if reg != nil {
   127  		// TODO:(GODRIVER-2719): Remove error handling.
   128  		if err := enc.SetRegistry(reg); err != nil {
   129  			return nil, err
   130  		}
   131  	}
   132  
   133  	return enc, nil
   134  }
   135  
   136  // newEncoderFn will return a function for constructing an encoder based on the
   137  // provided codec options.
   138  func newEncoderFn(opts *options.BSONOptions, registry *bsoncodec.Registry) codecutil.EncoderFn {
   139  	return func(w io.Writer) (*bson.Encoder, error) {
   140  		return getEncoder(w, opts, registry)
   141  	}
   142  }
   143  
   144  // marshal marshals the given value as a BSON document. Byte slices are always converted to a
   145  // bson.Raw before marshaling.
   146  //
   147  // If bsonOpts and registry are specified, the encoder is configured with the requested behaviors.
   148  // If they are nil, the default behaviors are used.
   149  func marshal(
   150  	val interface{},
   151  	bsonOpts *options.BSONOptions,
   152  	registry *bsoncodec.Registry,
   153  ) (bsoncore.Document, error) {
   154  	if registry == nil {
   155  		registry = bson.DefaultRegistry
   156  	}
   157  	if val == nil {
   158  		return nil, ErrNilDocument
   159  	}
   160  	if bs, ok := val.([]byte); ok {
   161  		// Slight optimization so we'll just use MarshalBSON and not go through the codec machinery.
   162  		val = bson.Raw(bs)
   163  	}
   164  
   165  	buf := new(bytes.Buffer)
   166  	enc, err := getEncoder(buf, bsonOpts, registry)
   167  	if err != nil {
   168  		return nil, fmt.Errorf("error configuring BSON encoder: %w", err)
   169  	}
   170  
   171  	err = enc.Encode(val)
   172  	if err != nil {
   173  		return nil, MarshalError{Value: val, Err: err}
   174  	}
   175  
   176  	return buf.Bytes(), nil
   177  }
   178  
   179  // ensureID inserts the given ObjectID as an element named "_id" at the
   180  // beginning of the given BSON document if there is not an "_id" already.
   181  // If the given ObjectID is primitive.NilObjectID, a new object ID will be
   182  // generated with time.Now().
   183  //
   184  // If there is already an element named "_id", the document is not modified. It
   185  // returns the resulting document and the decoded Go value of the "_id" element.
   186  func ensureID(
   187  	doc bsoncore.Document,
   188  	oid primitive.ObjectID,
   189  	bsonOpts *options.BSONOptions,
   190  	reg *bsoncodec.Registry,
   191  ) (bsoncore.Document, interface{}, error) {
   192  	if reg == nil {
   193  		reg = bson.DefaultRegistry
   194  	}
   195  
   196  	// Try to find the "_id" element. If it exists, try to unmarshal just the
   197  	// "_id" field as an interface{} and return it along with the unmodified
   198  	// BSON document.
   199  	if _, err := doc.LookupErr("_id"); err == nil {
   200  		var id struct {
   201  			ID interface{} `bson:"_id"`
   202  		}
   203  		dec, err := getDecoder(doc, bsonOpts, reg)
   204  		if err != nil {
   205  			return nil, nil, fmt.Errorf("error configuring BSON decoder: %w", err)
   206  		}
   207  		err = dec.Decode(&id)
   208  		if err != nil {
   209  			return nil, nil, fmt.Errorf("error unmarshaling BSON document: %w", err)
   210  		}
   211  
   212  		return doc, id.ID, nil
   213  	}
   214  
   215  	// We couldn't find an "_id" element, so add one with the value of the
   216  	// provided ObjectID.
   217  
   218  	olddoc := doc
   219  
   220  	// Reserve an extra 17 bytes for the "_id" field we're about to add:
   221  	// type (1) + "_id" (3) + terminator (1) + object ID (12)
   222  	const extraSpace = 17
   223  	doc = make(bsoncore.Document, 0, len(olddoc)+extraSpace)
   224  	_, doc = bsoncore.ReserveLength(doc)
   225  	if oid.IsZero() {
   226  		oid = primitive.NewObjectID()
   227  	}
   228  	doc = bsoncore.AppendObjectIDElement(doc, "_id", oid)
   229  
   230  	// Remove and re-write the BSON document length header.
   231  	const int32Len = 4
   232  	doc = append(doc, olddoc[int32Len:]...)
   233  	doc = bsoncore.UpdateLength(doc, 0, int32(len(doc)))
   234  
   235  	return doc, oid, nil
   236  }
   237  
   238  func ensureDollarKey(doc bsoncore.Document) error {
   239  	firstElem, err := doc.IndexErr(0)
   240  	if err != nil {
   241  		return errors.New("update document must have at least one element")
   242  	}
   243  
   244  	if !strings.HasPrefix(firstElem.Key(), "$") {
   245  		return errors.New("update document must contain key beginning with '$'")
   246  	}
   247  	return nil
   248  }
   249  
   250  func ensureNoDollarKey(doc bsoncore.Document) error {
   251  	if elem, err := doc.IndexErr(0); err == nil && strings.HasPrefix(elem.Key(), "$") {
   252  		return errors.New("replacement document cannot contain keys beginning with '$'")
   253  	}
   254  
   255  	return nil
   256  }
   257  
   258  func marshalAggregatePipeline(
   259  	pipeline interface{},
   260  	bsonOpts *options.BSONOptions,
   261  	registry *bsoncodec.Registry,
   262  ) (bsoncore.Document, bool, error) {
   263  	switch t := pipeline.(type) {
   264  	case bsoncodec.ValueMarshaler:
   265  		btype, val, err := t.MarshalBSONValue()
   266  		if err != nil {
   267  			return nil, false, err
   268  		}
   269  		if btype != bsontype.Array {
   270  			return nil, false, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v", btype, bsontype.Array)
   271  		}
   272  
   273  		var hasOutputStage bool
   274  		pipelineDoc := bsoncore.Document(val)
   275  		values, _ := pipelineDoc.Values()
   276  		if pipelineLen := len(values); pipelineLen > 0 {
   277  			if finalDoc, ok := values[pipelineLen-1].DocumentOK(); ok {
   278  				if elem, err := finalDoc.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") {
   279  					hasOutputStage = true
   280  				}
   281  			}
   282  		}
   283  
   284  		return pipelineDoc, hasOutputStage, nil
   285  	default:
   286  		val := reflect.ValueOf(t)
   287  		if !val.IsValid() || (val.Kind() != reflect.Slice && val.Kind() != reflect.Array) {
   288  			return nil, false, fmt.Errorf("can only marshal slices and arrays into aggregation pipelines, but got %v", val.Kind())
   289  		}
   290  
   291  		var hasOutputStage bool
   292  		valLen := val.Len()
   293  
   294  		switch t := pipeline.(type) {
   295  		// Explicitly forbid non-empty pipelines that are semantically single documents
   296  		// and are implemented as slices.
   297  		case bson.D, bson.Raw, bsoncore.Document:
   298  			if valLen > 0 {
   299  				return nil, false,
   300  					fmt.Errorf("%T is not an allowed pipeline type as it represents a single document. Use bson.A or mongo.Pipeline instead", t)
   301  			}
   302  		// bsoncore.Arrays do not need to be marshaled. Only check validity and presence of output stage.
   303  		case bsoncore.Array:
   304  			if err := t.Validate(); err != nil {
   305  				return nil, false, err
   306  			}
   307  
   308  			values, err := t.Values()
   309  			if err != nil {
   310  				return nil, false, err
   311  			}
   312  
   313  			numVals := len(values)
   314  			if numVals == 0 {
   315  				return bsoncore.Document(t), false, nil
   316  			}
   317  
   318  			// If not empty, check if first value of the last stage is $out or $merge.
   319  			if lastStage, ok := values[numVals-1].DocumentOK(); ok {
   320  				if elem, err := lastStage.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") {
   321  					hasOutputStage = true
   322  				}
   323  			}
   324  			return bsoncore.Document(t), hasOutputStage, nil
   325  		}
   326  
   327  		aidx, arr := bsoncore.AppendArrayStart(nil)
   328  		for idx := 0; idx < valLen; idx++ {
   329  			doc, err := marshal(val.Index(idx).Interface(), bsonOpts, registry)
   330  			if err != nil {
   331  				return nil, false, err
   332  			}
   333  
   334  			if idx == valLen-1 {
   335  				if elem, err := doc.IndexErr(0); err == nil && (elem.Key() == "$out" || elem.Key() == "$merge") {
   336  					hasOutputStage = true
   337  				}
   338  			}
   339  			arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc)
   340  		}
   341  		arr, _ = bsoncore.AppendArrayEnd(arr, aidx)
   342  		return arr, hasOutputStage, nil
   343  	}
   344  }
   345  
   346  func marshalUpdateValue(
   347  	update interface{},
   348  	bsonOpts *options.BSONOptions,
   349  	registry *bsoncodec.Registry,
   350  	dollarKeysAllowed bool,
   351  ) (bsoncore.Value, error) {
   352  	documentCheckerFunc := ensureDollarKey
   353  	if !dollarKeysAllowed {
   354  		documentCheckerFunc = ensureNoDollarKey
   355  	}
   356  
   357  	var u bsoncore.Value
   358  	var err error
   359  	switch t := update.(type) {
   360  	case nil:
   361  		return u, ErrNilDocument
   362  	case primitive.D:
   363  		u.Type = bsontype.EmbeddedDocument
   364  		u.Data, err = marshal(update, bsonOpts, registry)
   365  		if err != nil {
   366  			return u, err
   367  		}
   368  
   369  		return u, documentCheckerFunc(u.Data)
   370  	case bson.Raw:
   371  		u.Type = bsontype.EmbeddedDocument
   372  		u.Data = t
   373  		return u, documentCheckerFunc(u.Data)
   374  	case bsoncore.Document:
   375  		u.Type = bsontype.EmbeddedDocument
   376  		u.Data = t
   377  		return u, documentCheckerFunc(u.Data)
   378  	case []byte:
   379  		u.Type = bsontype.EmbeddedDocument
   380  		u.Data = t
   381  		return u, documentCheckerFunc(u.Data)
   382  	case bsoncodec.Marshaler:
   383  		u.Type = bsontype.EmbeddedDocument
   384  		u.Data, err = t.MarshalBSON()
   385  		if err != nil {
   386  			return u, err
   387  		}
   388  
   389  		return u, documentCheckerFunc(u.Data)
   390  	case bsoncodec.ValueMarshaler:
   391  		u.Type, u.Data, err = t.MarshalBSONValue()
   392  		if err != nil {
   393  			return u, err
   394  		}
   395  		if u.Type != bsontype.Array && u.Type != bsontype.EmbeddedDocument {
   396  			return u, fmt.Errorf("ValueMarshaler returned a %v, but was expecting %v or %v", u.Type, bsontype.Array, bsontype.EmbeddedDocument)
   397  		}
   398  		return u, err
   399  	default:
   400  		val := reflect.ValueOf(t)
   401  		if !val.IsValid() {
   402  			return u, fmt.Errorf("can only marshal slices and arrays into update pipelines, but got %v", val.Kind())
   403  		}
   404  		if val.Kind() != reflect.Slice && val.Kind() != reflect.Array {
   405  			u.Type = bsontype.EmbeddedDocument
   406  			u.Data, err = marshal(update, bsonOpts, registry)
   407  			if err != nil {
   408  				return u, err
   409  			}
   410  
   411  			return u, documentCheckerFunc(u.Data)
   412  		}
   413  
   414  		u.Type = bsontype.Array
   415  		aidx, arr := bsoncore.AppendArrayStart(nil)
   416  		valLen := val.Len()
   417  		for idx := 0; idx < valLen; idx++ {
   418  			doc, err := marshal(val.Index(idx).Interface(), bsonOpts, registry)
   419  			if err != nil {
   420  				return u, err
   421  			}
   422  
   423  			if err := documentCheckerFunc(doc); err != nil {
   424  				return u, err
   425  			}
   426  
   427  			arr = bsoncore.AppendDocumentElement(arr, strconv.Itoa(idx), doc)
   428  		}
   429  		u.Data, _ = bsoncore.AppendArrayEnd(arr, aidx)
   430  		return u, err
   431  	}
   432  }
   433  
   434  func marshalValue(
   435  	val interface{},
   436  	bsonOpts *options.BSONOptions,
   437  	registry *bsoncodec.Registry,
   438  ) (bsoncore.Value, error) {
   439  	return codecutil.MarshalValue(val, newEncoderFn(bsonOpts, registry))
   440  }
   441  
   442  // Build the aggregation pipeline for the CountDocument command.
   443  func countDocumentsAggregatePipeline(
   444  	filter interface{},
   445  	encOpts *options.BSONOptions,
   446  	registry *bsoncodec.Registry,
   447  	opts *options.CountOptions,
   448  ) (bsoncore.Document, error) {
   449  	filterDoc, err := marshal(filter, encOpts, registry)
   450  	if err != nil {
   451  		return nil, err
   452  	}
   453  
   454  	aidx, arr := bsoncore.AppendArrayStart(nil)
   455  	didx, arr := bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(0))
   456  	arr = bsoncore.AppendDocumentElement(arr, "$match", filterDoc)
   457  	arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
   458  
   459  	index := 1
   460  	if opts != nil {
   461  		if opts.Skip != nil {
   462  			didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
   463  			arr = bsoncore.AppendInt64Element(arr, "$skip", *opts.Skip)
   464  			arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
   465  			index++
   466  		}
   467  		if opts.Limit != nil {
   468  			didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
   469  			arr = bsoncore.AppendInt64Element(arr, "$limit", *opts.Limit)
   470  			arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
   471  			index++
   472  		}
   473  	}
   474  
   475  	didx, arr = bsoncore.AppendDocumentElementStart(arr, strconv.Itoa(index))
   476  	iidx, arr := bsoncore.AppendDocumentElementStart(arr, "$group")
   477  	arr = bsoncore.AppendInt32Element(arr, "_id", 1)
   478  	iiidx, arr := bsoncore.AppendDocumentElementStart(arr, "n")
   479  	arr = bsoncore.AppendInt32Element(arr, "$sum", 1)
   480  	arr, _ = bsoncore.AppendDocumentEnd(arr, iiidx)
   481  	arr, _ = bsoncore.AppendDocumentEnd(arr, iidx)
   482  	arr, _ = bsoncore.AppendDocumentEnd(arr, didx)
   483  
   484  	return bsoncore.AppendArrayEnd(arr, aidx)
   485  }
   486  

View as plain text