1
2
3
4
5
6
7 package mongo
8
9 import (
10 "context"
11 "errors"
12
13 "go.mongodb.org/mongo-driver/bson/bsoncodec"
14 "go.mongodb.org/mongo-driver/bson/primitive"
15 "go.mongodb.org/mongo-driver/mongo/description"
16 "go.mongodb.org/mongo-driver/mongo/options"
17 "go.mongodb.org/mongo-driver/mongo/writeconcern"
18 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
19 "go.mongodb.org/mongo-driver/x/mongo/driver"
20 "go.mongodb.org/mongo-driver/x/mongo/driver/operation"
21 "go.mongodb.org/mongo-driver/x/mongo/driver/session"
22 )
23
24 type bulkWriteBatch struct {
25 models []WriteModel
26 canRetry bool
27 indexes []int
28 }
29
30
31 type bulkWrite struct {
32 comment interface{}
33 ordered *bool
34 bypassDocumentValidation *bool
35 models []WriteModel
36 session *session.Client
37 collection *Collection
38 selector description.ServerSelector
39 writeConcern *writeconcern.WriteConcern
40 result BulkWriteResult
41 let interface{}
42 }
43
44 func (bw *bulkWrite) execute(ctx context.Context) error {
45 ordered := true
46 if bw.ordered != nil {
47 ordered = *bw.ordered
48 }
49
50 batches := createBatches(bw.models, ordered)
51 bw.result = BulkWriteResult{
52 UpsertedIDs: make(map[int64]interface{}),
53 }
54
55 bwErr := BulkWriteException{
56 WriteErrors: make([]BulkWriteError, 0),
57 }
58
59 var lastErr error
60 continueOnError := !ordered
61 for _, batch := range batches {
62 if len(batch.models) == 0 {
63 continue
64 }
65
66 batchRes, batchErr, err := bw.runBatch(ctx, batch)
67
68 bw.mergeResults(batchRes)
69
70 bwErr.WriteConcernError = batchErr.WriteConcernError
71 bwErr.Labels = append(bwErr.Labels, batchErr.Labels...)
72
73 bwErr.WriteErrors = append(bwErr.WriteErrors, batchErr.WriteErrors...)
74
75 commandErrorOccurred := err != nil && !errors.Is(err, driver.ErrUnacknowledgedWrite)
76 writeErrorOccurred := len(batchErr.WriteErrors) > 0 || batchErr.WriteConcernError != nil
77 if !continueOnError && (commandErrorOccurred || writeErrorOccurred) {
78 if err != nil {
79 return err
80 }
81
82 return bwErr
83 }
84
85 if err != nil {
86 lastErr = err
87 }
88 }
89
90 bw.result.MatchedCount -= bw.result.UpsertedCount
91 if lastErr != nil {
92 _, lastErr = processWriteError(lastErr)
93 return lastErr
94 }
95 if len(bwErr.WriteErrors) > 0 || bwErr.WriteConcernError != nil {
96 return bwErr
97 }
98 return nil
99 }
100
101 func (bw *bulkWrite) runBatch(ctx context.Context, batch bulkWriteBatch) (BulkWriteResult, BulkWriteException, error) {
102 batchRes := BulkWriteResult{
103 UpsertedIDs: make(map[int64]interface{}),
104 }
105 batchErr := BulkWriteException{}
106
107 var writeErrors []driver.WriteError
108 switch batch.models[0].(type) {
109 case *InsertOneModel:
110 res, err := bw.runInsert(ctx, batch)
111 if err != nil {
112 var writeErr driver.WriteCommandError
113 if !errors.As(err, &writeErr) {
114 return BulkWriteResult{}, batchErr, err
115 }
116 writeErrors = writeErr.WriteErrors
117 batchErr.Labels = writeErr.Labels
118 batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
119 }
120 batchRes.InsertedCount = res.N
121 case *DeleteOneModel, *DeleteManyModel:
122 res, err := bw.runDelete(ctx, batch)
123 if err != nil {
124 var writeErr driver.WriteCommandError
125 if !errors.As(err, &writeErr) {
126 return BulkWriteResult{}, batchErr, err
127 }
128 writeErrors = writeErr.WriteErrors
129 batchErr.Labels = writeErr.Labels
130 batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
131 }
132 batchRes.DeletedCount = res.N
133 case *ReplaceOneModel, *UpdateOneModel, *UpdateManyModel:
134 res, err := bw.runUpdate(ctx, batch)
135 if err != nil {
136 var writeErr driver.WriteCommandError
137 if !errors.As(err, &writeErr) {
138 return BulkWriteResult{}, batchErr, err
139 }
140 writeErrors = writeErr.WriteErrors
141 batchErr.Labels = writeErr.Labels
142 batchErr.WriteConcernError = convertDriverWriteConcernError(writeErr.WriteConcernError)
143 }
144 batchRes.MatchedCount = res.N
145 batchRes.ModifiedCount = res.NModified
146 batchRes.UpsertedCount = int64(len(res.Upserted))
147 for _, upsert := range res.Upserted {
148 batchRes.UpsertedIDs[int64(batch.indexes[upsert.Index])] = upsert.ID
149 }
150 }
151
152 batchErr.WriteErrors = make([]BulkWriteError, 0, len(writeErrors))
153 convWriteErrors := writeErrorsFromDriverWriteErrors(writeErrors)
154 for _, we := range convWriteErrors {
155 request := batch.models[we.Index]
156 we.Index = batch.indexes[we.Index]
157 batchErr.WriteErrors = append(batchErr.WriteErrors, BulkWriteError{
158 WriteError: we,
159 Request: request,
160 })
161 }
162 return batchRes, batchErr, nil
163 }
164
165 func (bw *bulkWrite) runInsert(ctx context.Context, batch bulkWriteBatch) (operation.InsertResult, error) {
166 docs := make([]bsoncore.Document, len(batch.models))
167 var i int
168 for _, model := range batch.models {
169 converted := model.(*InsertOneModel)
170 doc, err := marshal(converted.Document, bw.collection.bsonOpts, bw.collection.registry)
171 if err != nil {
172 return operation.InsertResult{}, err
173 }
174 doc, _, err = ensureID(doc, primitive.NilObjectID, bw.collection.bsonOpts, bw.collection.registry)
175 if err != nil {
176 return operation.InsertResult{}, err
177 }
178
179 docs[i] = doc
180 i++
181 }
182
183 op := operation.NewInsert(docs...).
184 Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
185 ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
186 Database(bw.collection.db.name).Collection(bw.collection.name).
187 Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).
188 ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout).
189 Logger(bw.collection.client.logger)
190 if bw.comment != nil {
191 comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry)
192 if err != nil {
193 return op.Result(), err
194 }
195 op.Comment(comment)
196 }
197 if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
198 op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
199 }
200 if bw.ordered != nil {
201 op = op.Ordered(*bw.ordered)
202 }
203
204 retry := driver.RetryNone
205 if bw.collection.client.retryWrites && batch.canRetry {
206 retry = driver.RetryOncePerCommand
207 }
208 op = op.Retry(retry)
209
210 err := op.Execute(ctx)
211
212 return op.Result(), err
213 }
214
215 func (bw *bulkWrite) runDelete(ctx context.Context, batch bulkWriteBatch) (operation.DeleteResult, error) {
216 docs := make([]bsoncore.Document, len(batch.models))
217 var i int
218 var hasHint bool
219
220 for _, model := range batch.models {
221 var doc bsoncore.Document
222 var err error
223
224 switch converted := model.(type) {
225 case *DeleteOneModel:
226 doc, err = createDeleteDoc(
227 converted.Filter,
228 converted.Collation,
229 converted.Hint,
230 true,
231 bw.collection.bsonOpts,
232 bw.collection.registry)
233 hasHint = hasHint || (converted.Hint != nil)
234 case *DeleteManyModel:
235 doc, err = createDeleteDoc(
236 converted.Filter,
237 converted.Collation,
238 converted.Hint,
239 false,
240 bw.collection.bsonOpts,
241 bw.collection.registry)
242 hasHint = hasHint || (converted.Hint != nil)
243 }
244
245 if err != nil {
246 return operation.DeleteResult{}, err
247 }
248
249 docs[i] = doc
250 i++
251 }
252
253 op := operation.NewDelete(docs...).
254 Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
255 ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
256 Database(bw.collection.db.name).Collection(bw.collection.name).
257 Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint).
258 ServerAPI(bw.collection.client.serverAPI).Timeout(bw.collection.client.timeout).
259 Logger(bw.collection.client.logger)
260 if bw.comment != nil {
261 comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry)
262 if err != nil {
263 return op.Result(), err
264 }
265 op.Comment(comment)
266 }
267 if bw.let != nil {
268 let, err := marshal(bw.let, bw.collection.bsonOpts, bw.collection.registry)
269 if err != nil {
270 return operation.DeleteResult{}, err
271 }
272 op = op.Let(let)
273 }
274 if bw.ordered != nil {
275 op = op.Ordered(*bw.ordered)
276 }
277 retry := driver.RetryNone
278 if bw.collection.client.retryWrites && batch.canRetry {
279 retry = driver.RetryOncePerCommand
280 }
281 op = op.Retry(retry)
282
283 err := op.Execute(ctx)
284
285 return op.Result(), err
286 }
287
288 func createDeleteDoc(
289 filter interface{},
290 collation *options.Collation,
291 hint interface{},
292 deleteOne bool,
293 bsonOpts *options.BSONOptions,
294 registry *bsoncodec.Registry,
295 ) (bsoncore.Document, error) {
296 f, err := marshal(filter, bsonOpts, registry)
297 if err != nil {
298 return nil, err
299 }
300
301 var limit int32
302 if deleteOne {
303 limit = 1
304 }
305 didx, doc := bsoncore.AppendDocumentStart(nil)
306 doc = bsoncore.AppendDocumentElement(doc, "q", f)
307 doc = bsoncore.AppendInt32Element(doc, "limit", limit)
308 if collation != nil {
309 doc = bsoncore.AppendDocumentElement(doc, "collation", collation.ToDocument())
310 }
311 if hint != nil {
312 if isUnorderedMap(hint) {
313 return nil, ErrMapForOrderedArgument{"hint"}
314 }
315 hintVal, err := marshalValue(hint, bsonOpts, registry)
316 if err != nil {
317 return nil, err
318 }
319 doc = bsoncore.AppendValueElement(doc, "hint", hintVal)
320 }
321 doc, _ = bsoncore.AppendDocumentEnd(doc, didx)
322
323 return doc, nil
324 }
325
326 func (bw *bulkWrite) runUpdate(ctx context.Context, batch bulkWriteBatch) (operation.UpdateResult, error) {
327 docs := make([]bsoncore.Document, len(batch.models))
328 var hasHint bool
329 var hasArrayFilters bool
330 for i, model := range batch.models {
331 var doc bsoncore.Document
332 var err error
333
334 switch converted := model.(type) {
335 case *ReplaceOneModel:
336 doc, err = createUpdateDoc(
337 converted.Filter,
338 converted.Replacement,
339 converted.Hint,
340 nil,
341 converted.Collation,
342 converted.Upsert,
343 false,
344 false,
345 bw.collection.bsonOpts,
346 bw.collection.registry)
347 hasHint = hasHint || (converted.Hint != nil)
348 case *UpdateOneModel:
349 doc, err = createUpdateDoc(
350 converted.Filter,
351 converted.Update,
352 converted.Hint,
353 converted.ArrayFilters,
354 converted.Collation,
355 converted.Upsert,
356 false,
357 true,
358 bw.collection.bsonOpts,
359 bw.collection.registry)
360 hasHint = hasHint || (converted.Hint != nil)
361 hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
362 case *UpdateManyModel:
363 doc, err = createUpdateDoc(
364 converted.Filter,
365 converted.Update,
366 converted.Hint,
367 converted.ArrayFilters,
368 converted.Collation,
369 converted.Upsert,
370 true,
371 true,
372 bw.collection.bsonOpts,
373 bw.collection.registry)
374 hasHint = hasHint || (converted.Hint != nil)
375 hasArrayFilters = hasArrayFilters || (converted.ArrayFilters != nil)
376 }
377 if err != nil {
378 return operation.UpdateResult{}, err
379 }
380
381 docs[i] = doc
382 }
383
384 op := operation.NewUpdate(docs...).
385 Session(bw.session).WriteConcern(bw.writeConcern).CommandMonitor(bw.collection.client.monitor).
386 ServerSelector(bw.selector).ClusterClock(bw.collection.client.clock).
387 Database(bw.collection.db.name).Collection(bw.collection.name).
388 Deployment(bw.collection.client.deployment).Crypt(bw.collection.client.cryptFLE).Hint(hasHint).
389 ArrayFilters(hasArrayFilters).ServerAPI(bw.collection.client.serverAPI).
390 Timeout(bw.collection.client.timeout).Logger(bw.collection.client.logger)
391 if bw.comment != nil {
392 comment, err := marshalValue(bw.comment, bw.collection.bsonOpts, bw.collection.registry)
393 if err != nil {
394 return op.Result(), err
395 }
396 op.Comment(comment)
397 }
398 if bw.let != nil {
399 let, err := marshal(bw.let, bw.collection.bsonOpts, bw.collection.registry)
400 if err != nil {
401 return operation.UpdateResult{}, err
402 }
403 op = op.Let(let)
404 }
405 if bw.ordered != nil {
406 op = op.Ordered(*bw.ordered)
407 }
408 if bw.bypassDocumentValidation != nil && *bw.bypassDocumentValidation {
409 op = op.BypassDocumentValidation(*bw.bypassDocumentValidation)
410 }
411 retry := driver.RetryNone
412 if bw.collection.client.retryWrites && batch.canRetry {
413 retry = driver.RetryOncePerCommand
414 }
415 op = op.Retry(retry)
416
417 err := op.Execute(ctx)
418
419 return op.Result(), err
420 }
421
422 func createUpdateDoc(
423 filter interface{},
424 update interface{},
425 hint interface{},
426 arrayFilters *options.ArrayFilters,
427 collation *options.Collation,
428 upsert *bool,
429 multi bool,
430 checkDollarKey bool,
431 bsonOpts *options.BSONOptions,
432 registry *bsoncodec.Registry,
433 ) (bsoncore.Document, error) {
434 f, err := marshal(filter, bsonOpts, registry)
435 if err != nil {
436 return nil, err
437 }
438
439 uidx, updateDoc := bsoncore.AppendDocumentStart(nil)
440 updateDoc = bsoncore.AppendDocumentElement(updateDoc, "q", f)
441
442 u, err := marshalUpdateValue(update, bsonOpts, registry, checkDollarKey)
443 if err != nil {
444 return nil, err
445 }
446
447 updateDoc = bsoncore.AppendValueElement(updateDoc, "u", u)
448
449 if multi {
450 updateDoc = bsoncore.AppendBooleanElement(updateDoc, "multi", multi)
451 }
452
453 if arrayFilters != nil {
454 reg := registry
455 if arrayFilters.Registry != nil {
456 reg = arrayFilters.Registry
457 }
458 arr, err := marshalValue(arrayFilters.Filters, bsonOpts, reg)
459 if err != nil {
460 return nil, err
461 }
462 updateDoc = bsoncore.AppendArrayElement(updateDoc, "arrayFilters", arr.Data)
463 }
464
465 if collation != nil {
466 updateDoc = bsoncore.AppendDocumentElement(updateDoc, "collation", bsoncore.Document(collation.ToDocument()))
467 }
468
469 if upsert != nil {
470 updateDoc = bsoncore.AppendBooleanElement(updateDoc, "upsert", *upsert)
471 }
472
473 if hint != nil {
474 if isUnorderedMap(hint) {
475 return nil, ErrMapForOrderedArgument{"hint"}
476 }
477 hintVal, err := marshalValue(hint, bsonOpts, registry)
478 if err != nil {
479 return nil, err
480 }
481 updateDoc = bsoncore.AppendValueElement(updateDoc, "hint", hintVal)
482 }
483
484 updateDoc, _ = bsoncore.AppendDocumentEnd(updateDoc, uidx)
485 return updateDoc, nil
486 }
487
488 func createBatches(models []WriteModel, ordered bool) []bulkWriteBatch {
489 if ordered {
490 return createOrderedBatches(models)
491 }
492
493 batches := make([]bulkWriteBatch, 5)
494 batches[insertCommand].canRetry = true
495 batches[deleteOneCommand].canRetry = true
496 batches[updateOneCommand].canRetry = true
497
498
499 for i, model := range models {
500 switch model.(type) {
501 case *InsertOneModel:
502 batches[insertCommand].models = append(batches[insertCommand].models, model)
503 batches[insertCommand].indexes = append(batches[insertCommand].indexes, i)
504 case *DeleteOneModel:
505 batches[deleteOneCommand].models = append(batches[deleteOneCommand].models, model)
506 batches[deleteOneCommand].indexes = append(batches[deleteOneCommand].indexes, i)
507 case *DeleteManyModel:
508 batches[deleteManyCommand].models = append(batches[deleteManyCommand].models, model)
509 batches[deleteManyCommand].indexes = append(batches[deleteManyCommand].indexes, i)
510 case *ReplaceOneModel, *UpdateOneModel:
511 batches[updateOneCommand].models = append(batches[updateOneCommand].models, model)
512 batches[updateOneCommand].indexes = append(batches[updateOneCommand].indexes, i)
513 case *UpdateManyModel:
514 batches[updateManyCommand].models = append(batches[updateManyCommand].models, model)
515 batches[updateManyCommand].indexes = append(batches[updateManyCommand].indexes, i)
516 }
517 }
518
519 return batches
520 }
521
522 func createOrderedBatches(models []WriteModel) []bulkWriteBatch {
523 var batches []bulkWriteBatch
524 var prevKind writeCommandKind = -1
525 i := -1
526
527 for ind, model := range models {
528 var createNewBatch bool
529 var canRetry bool
530 var newKind writeCommandKind
531
532
533 switch model.(type) {
534 case *InsertOneModel:
535 createNewBatch = prevKind != insertCommand
536 canRetry = true
537 newKind = insertCommand
538 case *DeleteOneModel:
539 createNewBatch = prevKind != deleteOneCommand
540 canRetry = true
541 newKind = deleteOneCommand
542 case *DeleteManyModel:
543 createNewBatch = prevKind != deleteManyCommand
544 newKind = deleteManyCommand
545 case *ReplaceOneModel, *UpdateOneModel:
546 createNewBatch = prevKind != updateOneCommand
547 canRetry = true
548 newKind = updateOneCommand
549 case *UpdateManyModel:
550 createNewBatch = prevKind != updateManyCommand
551 newKind = updateManyCommand
552 }
553
554 if createNewBatch {
555 batches = append(batches, bulkWriteBatch{
556 models: []WriteModel{model},
557 canRetry: canRetry,
558 indexes: []int{ind},
559 })
560 i++
561 } else {
562 batches[i].models = append(batches[i].models, model)
563 if !canRetry {
564 batches[i].canRetry = false
565 }
566 batches[i].indexes = append(batches[i].indexes, ind)
567 }
568
569 prevKind = newKind
570 }
571
572 return batches
573 }
574
575 func (bw *bulkWrite) mergeResults(newResult BulkWriteResult) {
576 bw.result.InsertedCount += newResult.InsertedCount
577 bw.result.MatchedCount += newResult.MatchedCount
578 bw.result.ModifiedCount += newResult.ModifiedCount
579 bw.result.DeletedCount += newResult.DeletedCount
580 bw.result.UpsertedCount += newResult.UpsertedCount
581
582 for index, upsertID := range newResult.UpsertedIDs {
583 bw.result.UpsertedIDs[index] = upsertID
584 }
585 }
586
587
588 type writeCommandKind int8
589
590
591 const (
592 insertCommand writeCommandKind = iota
593 updateOneCommand
594 updateManyCommand
595 deleteOneCommand
596 deleteManyCommand
597 )
598
View as plain text