...

Source file src/go.mongodb.org/mongo-driver/mongo/bulk_write.go

Documentation: go.mongodb.org/mongo-driver/mongo

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package mongo
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  
    13  	"go.mongodb.org/mongo-driver/bson/bsoncodec"
    14  	"go.mongodb.org/mongo-driver/bson/primitive"
    15  	"go.mongodb.org/mongo-driver/mongo/description"
    16  	"go.mongodb.org/mongo-driver/mongo/options"
    17  	"go.mongodb.org/mongo-driver/mongo/writeconcern"
    18  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    19  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    20  	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
    21  	"go.mongodb.org/mongo-driver/x/mongo/driver/session"
    22  )
    23  
    24  type bulkWriteBatch struct {
    25  	models   []WriteModel
    26  	canRetry bool
    27  	indexes  []int
    28  }
    29  
    30  // bulkWrite performs a bulkwrite operation
    31  type bulkWrite struct {
    32  	comment                  interface{}
    33  	ordered                  *bool
    34  	bypassDocumentValidation *bool
    35  	models                   []WriteModel
    36  	session                  *session.Client
    37  	collection               *Collection
    38  	selector                 description.ServerSelector
    39  	writeConcern             *writeconcern.WriteConcern
    40  	result                   BulkWriteResult
    41  	let                      interface{}
    42  }
    43  
    44  func (bw *bulkWrite) execute(ctx context.Context) error {
    45  	ordered := true
    46  	if bw.ordered != nil {
    47  		ordered = *bw.ordered
    48  	}
    49  
    50  	batches := createBatches(bw.models, ordered)
    51  	bw.result = BulkWriteResult{
    52  		UpsertedIDs: make(map[int64]interface{}),
    53  	}
    54  
    55  	bwErr := BulkWriteException{
    56  		WriteErrors: make([]BulkWriteError, 0),
    57  	}
    58  
    59  	var lastErr error
    60  	continueOnError := !ordered
    61  	for _, batch := range batches {
    62  		if len(batch.models) == 0 {
    63  			continue
    64  		}
    65  
    66  		batchRes, batchErr, err := bw.runBatch(ctx, batch)
    67  
    68  		bw.mergeResults(batchRes)
    69  
    70  		bwErr.WriteConcernError = batchErr.WriteConcernError
    71  		bwErr.Labels = append(bwErr.Labels, batchErr.Labels...)
    72  
    73  		bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...)
    74  
    75  		commandErrorOccurred := err != nil && !errors.Is(err, driver.ErrUnacknowledgedWrite)
    76  		writeErrorOccurred := len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil
    77  		if !continueOnError && (commandErrorOccurred || writeErrorOccurred) {
    78  			if err != nil {
    79  				return err
    80  			}
    81  
    82  			return bwErr
    83  		}
    84  
    85  		if err != nil {
    86  			lastErr = err
    87  		}
    88  	}
    89  
    90  	bw.result.MatchedCount -= bw.result.UpsertedCount
    91  	if lastErr != nil {
    92  		_, lastErr = processWriteError(lastErr)
    93  		return lastErr
    94  	}
    95  	if len(bwErr.WriteErrors) > 0 || bwErr.WriteConcernError != nil {
    96  		return bwErr
    97  	}
    98  	return nil
    99  }
   100  
   101  func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWriteResult, BulkWriteException, error) {
   102  	batchRes := BulkWriteResult{
   103  		UpsertedIDs: make(map[int64]interface{}),
   104  	}
   105  	batchErr := BulkWriteException{}
   106  
   107  	var writeErrors []driver.WriteError
   108  	switch batch.models[0].(type) {
   109  	case *InsertOneModel:
   110  		res, err := bw.runInsert(ctx, batch)
   111  		if err != nil {
   112  			var writeErr driver.WriteCommandError
   113  			if !errors.As(err, &writeErr) {
   114  				return BulkWriteResult{}, batchErr, err
   115  			}
   116  			writeErrors = writeErr.WriteErrors
   117  			batchErr.Labels = writeErr.Labels
   118  			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
   119  		}
   120  		batchRes.InsertedCount = res.N
   121  	case *DeleteOneModel, *DeleteManyModel:
   122  		res, err := bw.runDelete(ctx, batch)
   123  		if err != nil {
   124  			var writeErr driver.WriteCommandError
   125  			if !errors.As(err, &writeErr) {
   126  				return BulkWriteResult{}, batchErr, err
   127  			}
   128  			writeErrors = writeErr.WriteErrors
   129  			batchErr.Labels = writeErr.Labels
   130  			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
   131  		}
   132  		batchRes.DeletedCount = res.N
   133  	case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel:
   134  		res, err := bw.runUpdate(ctx, batch)
   135  		if err != nil {
   136  			var writeErr driver.WriteCommandError
   137  			if !errors.As(err, &writeErr) {
   138  				return BulkWriteResult{}, batchErr, err
   139  			}
   140  			writeErrors = writeErr.WriteErrors
   141  			batchErr.Labels = writeErr.Labels
   142  			batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
   143  		}
   144  		batchRes.MatchedCount = res.N
   145  		batchRes.ModifiedCount = res.NModified
   146  		batchRes.UpsertedCount = int64(len(res.Upserted))
   147  		for _, upsert := range res.Upserted {
   148  			batchRes.UpsertedIDs[int64(batch.indexes[upsert.Index])] = upsert.ID
   149  		}
   150  	}
   151  
   152  	batchErr.WriteErrors = make([]BulkWriteError, 0, len(writeErrors))
   153  	convWriteErrors := writeErrorsFromDriverWriteErrors(writeErrors)
   154  	for _, we := range convWriteErrors {
   155  		request := batch.models[we.Index]
   156  		we.Index = batch.indexes[we.Index]
   157  		batchErr.WriteErrors = append(batchErr.WriteErrors, BulkWriteError{
   158  			WriteError: we,
   159  			Request:    request,
   160  		})
   161  	}
   162  	return batchRes, batchErr, nil
   163  }
   164  
   165  func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (operation.InsertResult, error) {
   166  	docs := make([]bsoncore.Document, len(batch.models))
   167  	var i int
   168  	for _, model := range batch.models {
   169  		converted := model.(*InsertOneModel)
   170  		doc, err := marshal(converted.Document, bw.collection.bsonOpts, bw.collection.registry)
   171  		if err != nil {
   172  			return operation.InsertResult{}, err
   173  		}
   174  		doc, _, err = ensureID(doc, primitive.NilObjectID, bw.collection.bsonOpts, bw.collection.registry)
   175  		if err != nil {
   176  			return operation.InsertResult{}, err
   177  		}
   178  
   179  		docs[i] = doc
   180  		i++
   181  	}
   182  
   183  	op := operation.NewInsert(docs...).
   184  		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
   185  		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
   186  		Database(bw.collection.db.name).Collection(bw.collection.name).
   187  		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).
   188  		ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout).
   189  		Logger(bw.collection.client.logger)
   190  	if bw.comment != nil {
   191  		comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry)
   192  		if err != nil {
   193  			return op.Result(), err
   194  		}
   195  		op.Comment(comment)
   196  	}
   197  	if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
   198  		op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
   199  	}
   200  	if bw.ordered != nil {
   201  		op = op.Ordered(*bw.ordered)
   202  	}
   203  
   204  	retry := driver.RetryNone
   205  	if bw.collection.client.retryWrites && batch.canRetry {
   206  		retry = driver.RetryOncePerCommand
   207  	}
   208  	op = op.Retry(retry)
   209  
   210  	err := op.Execute(ctx)
   211  
   212  	return op.Result(), err
   213  }
   214  
   215  func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (operation.DeleteResult, error) {
   216  	docs := make([]bsoncore.Document, len(batch.models))
   217  	var i int
   218  	var hasHint bool
   219  
   220  	for _, model := range batch.models {
   221  		var doc bsoncore.Document
   222  		var err error
   223  
   224  		switch converted := model.(type) {
   225  		case *DeleteOneModel:
   226  			doc, err = createDeleteDoc(
   227  				converted.Filter,
   228  				converted.Collation,
   229  				converted.Hint,
   230  				true,
   231  				bw.collection.bsonOpts,
   232  				bw.collection.registry)
   233  			hasHint = hasHint || (converted.Hint != nil)
   234  		case *DeleteManyModel:
   235  			doc, err = createDeleteDoc(
   236  				converted.Filter,
   237  				converted.Collation,
   238  				converted.Hint,
   239  				false,
   240  				bw.collection.bsonOpts,
   241  				bw.collection.registry)
   242  			hasHint = hasHint || (converted.Hint != nil)
   243  		}
   244  
   245  		if err != nil {
   246  			return operation.DeleteResult{}, err
   247  		}
   248  
   249  		docs[i] = doc
   250  		i++
   251  	}
   252  
   253  	op := operation.NewDelete(docs...).
   254  		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
   255  		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
   256  		Database(bw.collection.db.name).Collection(bw.collection.name).
   257  		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint).
   258  		ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout).
   259  		Logger(bw.collection.client.logger)
   260  	if bw.comment != nil {
   261  		comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry)
   262  		if err != nil {
   263  			return op.Result(), err
   264  		}
   265  		op.Comment(comment)
   266  	}
   267  	if bw.let != nil {
   268  		let, err := marshal(bw.let, bw.collection.bsonOpts, bw.collection.registry)
   269  		if err != nil {
   270  			return operation.DeleteResult{}, err
   271  		}
   272  		op = op.Let(let)
   273  	}
   274  	if bw.ordered != nil {
   275  		op = op.Ordered(*bw.ordered)
   276  	}
   277  	retry := driver.RetryNone
   278  	if bw.collection.client.retryWrites && batch.canRetry {
   279  		retry = driver.RetryOncePerCommand
   280  	}
   281  	op = op.Retry(retry)
   282  
   283  	err := op.Execute(ctx)
   284  
   285  	return op.Result(), err
   286  }
   287  
   288  func createDeleteDoc(
   289  	filter interface{},
   290  	collation *options.Collation,
   291  	hint interface{},
   292  	deleteOne bool,
   293  	bsonOpts *options.BSONOptions,
   294  	registry *bsoncodec.Registry,
   295  ) (bsoncore.Document, error) {
   296  	f, err := marshal(filter, bsonOpts, registry)
   297  	if err != nil {
   298  		return nil, err
   299  	}
   300  
   301  	var limit int32
   302  	if deleteOne {
   303  		limit = 1
   304  	}
   305  	didx, doc := bsoncore.AppendDocumentStart(nil)
   306  	doc = bsoncore.AppendDocumentElement(doc, "q", f)
   307  	doc = bsoncore.AppendInt32Element(doc, "limit", limit)
   308  	if collation != nil {
   309  		doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument())
   310  	}
   311  	if hint != nil {
   312  		if isUnorderedMap(hint) {
   313  			return nil, ErrMapForOrderedArgument{"hint"}
   314  		}
   315  		hintVal, err := marshalValue(hint, bsonOpts, registry)
   316  		if err != nil {
   317  			return nil, err
   318  		}
   319  		doc = bsoncore.AppendValueElement(doc, "hint", hintVal)
   320  	}
   321  	doc, _ = bsoncore.AppendDocumentEnd(doc, didx)
   322  
   323  	return doc, nil
   324  }
   325  
   326  func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (operation.UpdateResult, error) {
   327  	docs := make([]bsoncore.Document, len(batch.models))
   328  	var hasHint bool
   329  	var hasArrayFilters bool
   330  	for i, model := range batch.models {
   331  		var doc bsoncore.Document
   332  		var err error
   333  
   334  		switch converted := model.(type) {
   335  		case *ReplaceOneModel:
   336  			doc, err = createUpdateDoc(
   337  				converted.Filter,
   338  				converted.Replacement,
   339  				converted.Hint,
   340  				nil,
   341  				converted.Collation,
   342  				converted.Upsert,
   343  				false,
   344  				false,
   345  				bw.collection.bsonOpts,
   346  				bw.collection.registry)
   347  			hasHint = hasHint || (converted.Hint != nil)
   348  		case *UpdateOneModel:
   349  			doc, err = createUpdateDoc(
   350  				converted.Filter,
   351  				converted.Update,
   352  				converted.Hint,
   353  				converted.ArrayFilters,
   354  				converted.Collation,
   355  				converted.Upsert,
   356  				false,
   357  				true,
   358  				bw.collection.bsonOpts,
   359  				bw.collection.registry)
   360  			hasHint = hasHint || (converted.Hint != nil)
   361  			hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
   362  		case *UpdateManyModel:
   363  			doc, err = createUpdateDoc(
   364  				converted.Filter,
   365  				converted.Update,
   366  				converted.Hint,
   367  				converted.ArrayFilters,
   368  				converted.Collation,
   369  				converted.Upsert,
   370  				true,
   371  				true,
   372  				bw.collection.bsonOpts,
   373  				bw.collection.registry)
   374  			hasHint = hasHint || (converted.Hint != nil)
   375  			hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
   376  		}
   377  		if err != nil {
   378  			return operation.UpdateResult{}, err
   379  		}
   380  
   381  		docs[i] = doc
   382  	}
   383  
   384  	op := operation.NewUpdate(docs...).
   385  		Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
   386  		ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
   387  		Database(bw.collection.db.name).Collection(bw.collection.name).
   388  		Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint).
   389  		ArrayFilters(hasArrayFilters).ServerAPI(bw.collection.client.serverAPI).
   390  		Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger)
   391  	if bw.comment != nil {
   392  		comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry)
   393  		if err != nil {
   394  			return op.Result(), err
   395  		}
   396  		op.Comment(comment)
   397  	}
   398  	if bw.let != nil {
   399  		let, err := marshal(bw.let, bw.collection.bsonOpts, bw.collection.registry)
   400  		if err != nil {
   401  			return operation.UpdateResult{}, err
   402  		}
   403  		op = op.Let(let)
   404  	}
   405  	if bw.ordered != nil {
   406  		op = op.Ordered(*bw.ordered)
   407  	}
   408  	if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
   409  		op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
   410  	}
   411  	retry := driver.RetryNone
   412  	if bw.collection.client.retryWrites && batch.canRetry {
   413  		retry = driver.RetryOncePerCommand
   414  	}
   415  	op = op.Retry(retry)
   416  
   417  	err := op.Execute(ctx)
   418  
   419  	return op.Result(), err
   420  }
   421  
   422  func createUpdateDoc(
   423  	filter interface{},
   424  	update interface{},
   425  	hint interface{},
   426  	arrayFilters *options.ArrayFilters,
   427  	collation *options.Collation,
   428  	upsert *bool,
   429  	multi bool,
   430  	checkDollarKey bool,
   431  	bsonOpts *options.BSONOptions,
   432  	registry *bsoncodec.Registry,
   433  ) (bsoncore.Document, error) {
   434  	f, err := marshal(filter, bsonOpts, registry)
   435  	if err != nil {
   436  		return nil, err
   437  	}
   438  
   439  	uidx, updateDoc := bsoncore.AppendDocumentStart(nil)
   440  	updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f)
   441  
   442  	u, err := marshalUpdateValue(update, bsonOpts, registry, checkDollarKey)
   443  	if err != nil {
   444  		return nil, err
   445  	}
   446  
   447  	updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u)
   448  
   449  	if multi {
   450  		updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", multi)
   451  	}
   452  
   453  	if arrayFilters != nil {
   454  		reg := registry
   455  		if arrayFilters.Registry != nil {
   456  			reg = arrayFilters.Registry
   457  		}
   458  		arr, err := marshalValue(arrayFilters.Filters, bsonOpts, reg)
   459  		if err != nil {
   460  			return nil, err
   461  		}
   462  		updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr.Data)
   463  	}
   464  
   465  	if collation != nil {
   466  		updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(collation.ToDocument()))
   467  	}
   468  
   469  	if upsert != nil {
   470  		updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *upsert)
   471  	}
   472  
   473  	if hint != nil {
   474  		if isUnorderedMap(hint) {
   475  			return nil, ErrMapForOrderedArgument{"hint"}
   476  		}
   477  		hintVal, err := marshalValue(hint, bsonOpts, registry)
   478  		if err != nil {
   479  			return nil, err
   480  		}
   481  		updateDoc = bsoncore.AppendValueElement(updateDoc, "hint", hintVal)
   482  	}
   483  
   484  	updateDoc, _ = bsoncore.AppendDocumentEnd(updateDoc, uidx)
   485  	return updateDoc, nil
   486  }
   487  
   488  func createBatches(models []WriteModel, ordered bool) []bulkWriteBatch {
   489  	if ordered {
   490  		return createOrderedBatches(models)
   491  	}
   492  
   493  	batches := make([]bulkWriteBatch, 5)
   494  	batches[insertCommand].canRetry = true
   495  	batches[deleteOneCommand].canRetry = true
   496  	batches[updateOneCommand].canRetry = true
   497  
   498  	// TODO(GODRIVER-1157): fix batching once operation retryability is fixed
   499  	for i, model := range models {
   500  		switch model.(type) {
   501  		case *InsertOneModel:
   502  			batches[insertCommand].models = append(batches[insertCommand].models, model)
   503  			batches[insertCommand].indexes = append(batches[insertCommand].indexes, i)
   504  		case *DeleteOneModel:
   505  			batches[deleteOneCommand].models = append(batches[deleteOneCommand].models, model)
   506  			batches[deleteOneCommand].indexes = append(batches[deleteOneCommand].indexes, i)
   507  		case *DeleteManyModel:
   508  			batches[deleteManyCommand].models = append(batches[deleteManyCommand].models, model)
   509  			batches[deleteManyCommand].indexes = append(batches[deleteManyCommand].indexes, i)
   510  		case *ReplaceOneModel, *UpdateOneModel:
   511  			batches[updateOneCommand].models = append(batches[updateOneCommand].models, model)
   512  			batches[updateOneCommand].indexes = append(batches[updateOneCommand].indexes, i)
   513  		case *UpdateManyModel:
   514  			batches[updateManyCommand].models = append(batches[updateManyCommand].models, model)
   515  			batches[updateManyCommand].indexes = append(batches[updateManyCommand].indexes, i)
   516  		}
   517  	}
   518  
   519  	return batches
   520  }
   521  
   522  func createOrderedBatches(models []WriteModel) []bulkWriteBatch {
   523  	var batches []bulkWriteBatch
   524  	var prevKind writeCommandKind = -1
   525  	i := -1 // batch index
   526  
   527  	for ind, model := range models {
   528  		var createNewBatch bool
   529  		var canRetry bool
   530  		var newKind writeCommandKind
   531  
   532  		// TODO(GODRIVER-1157): fix batching once operation retryability is fixed
   533  		switch model.(type) {
   534  		case *InsertOneModel:
   535  			createNewBatch = prevKind != insertCommand
   536  			canRetry = true
   537  			newKind = insertCommand
   538  		case *DeleteOneModel:
   539  			createNewBatch = prevKind != deleteOneCommand
   540  			canRetry = true
   541  			newKind = deleteOneCommand
   542  		case *DeleteManyModel:
   543  			createNewBatch = prevKind != deleteManyCommand
   544  			newKind = deleteManyCommand
   545  		case *ReplaceOneModel, *UpdateOneModel:
   546  			createNewBatch = prevKind != updateOneCommand
   547  			canRetry = true
   548  			newKind = updateOneCommand
   549  		case *UpdateManyModel:
   550  			createNewBatch = prevKind != updateManyCommand
   551  			newKind = updateManyCommand
   552  		}
   553  
   554  		if createNewBatch {
   555  			batches = append(batches, bulkWriteBatch{
   556  				models:   []WriteModel{model},
   557  				canRetry: canRetry,
   558  				indexes:  []int{ind},
   559  			})
   560  			i++
   561  		} else {
   562  			batches[i].models = append(batches[i].models, model)
   563  			if !canRetry {
   564  				batches[i].canRetry = false // don't make it true if it was already false
   565  			}
   566  			batches[i].indexes = append(batches[i].indexes, ind)
   567  		}
   568  
   569  		prevKind = newKind
   570  	}
   571  
   572  	return batches
   573  }
   574  
   575  func (bw *bulkWrite) mergeResults(newResult BulkWriteResult) {
   576  	bw.result.InsertedCount += newResult.InsertedCount
   577  	bw.result.MatchedCount += newResult.MatchedCount
   578  	bw.result.ModifiedCount += newResult.ModifiedCount
   579  	bw.result.DeletedCount += newResult.DeletedCount
   580  	bw.result.UpsertedCount += newResult.UpsertedCount
   581  
   582  	for index, upsertID := range newResult.UpsertedIDs {
   583  		bw.result.UpsertedIDs[index] = upsertID
   584  	}
   585  }
   586  
   587  // WriteCommandKind is the type of command represented by a Write
   588  type writeCommandKind int8
   589  
   590  // These constants represent the valid types of write commands.
   591  const (
   592  	insertCommand writeCommandKind = iota
   593  	updateOneCommand
   594  	updateManyCommand
   595  	deleteOneCommand
   596  	deleteManyCommand
   597  )
   598  

View as plain text