1
2
3
4
5
6
7 package mongo
8
9 import (
10 "bytes"
11 "context"
12 "errors"
13 "fmt"
14 "net"
15 "strings"
16
17 "go.mongodb.org/mongo-driver/bson"
18 "go.mongodb.org/mongo-driver/internal/codecutil"
19 "go.mongodb.org/mongo-driver/x/mongo/driver"
20 "go.mongodb.org/mongo-driver/x/mongo/driver/mongocrypt"
21 "go.mongodb.org/mongo-driver/x/mongo/driver/topology"
22 )
23
24
25 var ErrUnacknowledgedWrite = errors.New("unacknowledged write")
26
27
28 var ErrClientDisconnected = errors.New("client is disconnected")
29
30
31 var ErrNilDocument = errors.New("document is nil")
32
33
34 var ErrNilValue = errors.New("value is nil")
35
36
37 var ErrEmptySlice = errors.New("must provide at least one element in input slice")
38
39
40 type ErrMapForOrderedArgument struct {
41 ParamName string
42 }
43
44
45 func (e ErrMapForOrderedArgument) Error() string {
46 return fmt.Sprintf("multi-key map passed in for ordered parameter %v", e.ParamName)
47 }
48
49 func replaceErrors(err error) error {
50
51 if err == nil {
52 return nil
53 }
54
55 if errors.Is(err, topology.ErrTopologyClosed) {
56 return ErrClientDisconnected
57 }
58 if de, ok := err.(driver.Error); ok {
59 return CommandError{
60 Code: de.Code,
61 Message: de.Message,
62 Labels: de.Labels,
63 Name: de.Name,
64 Wrapped: de.Wrapped,
65 Raw: bson.Raw(de.Raw),
66 }
67 }
68 if qe, ok := err.(driver.QueryFailureError); ok {
69
70 ce := CommandError{
71 Name: qe.Message,
72 Wrapped: qe.Wrapped,
73 Raw: bson.Raw(qe.Response),
74 }
75
76 dollarErr, err := qe.Response.LookupErr("$err")
77 if err == nil {
78 ce.Message, _ = dollarErr.StringValueOK()
79 }
80 code, err := qe.Response.LookupErr("code")
81 if err == nil {
82 ce.Code, _ = code.Int32OK()
83 }
84
85 return ce
86 }
87 if me, ok := err.(mongocrypt.Error); ok {
88 return MongocryptError{Code: me.Code, Message: me.Message}
89 }
90
91 if errors.Is(err, codecutil.ErrNilValue) {
92 return ErrNilValue
93 }
94
95 if marshalErr, ok := err.(codecutil.MarshalError); ok {
96 return MarshalError{
97 Value: marshalErr.Value,
98 Err: marshalErr.Err,
99 }
100 }
101
102 return err
103 }
104
105
106 func IsDuplicateKeyError(err error) bool {
107 if se := ServerError(nil); errors.As(err, &se) {
108 return se.HasErrorCode(11000) ||
109 se.HasErrorCode(11001) ||
110
111 se.HasErrorCode(12582) ||
112
113
114 se.HasErrorCodeWithMessage(16460, " E11000 ")
115 }
116 return false
117 }
118
119
120 var timeoutErrs = [...]error{
121 context.DeadlineExceeded,
122 driver.ErrDeadlineWouldBeExceeded,
123 topology.ErrServerSelectionTimeout,
124 }
125
126
127
128 func IsTimeout(err error) bool {
129
130 for _, target := range timeoutErrs {
131 if errors.Is(err, target) {
132 return true
133 }
134 }
135
136
137
138 if errors.As(err, &topology.WaitQueueTimeoutError{}) {
139 return true
140 }
141 if ce := (CommandError{}); errors.As(err, &ce) && ce.IsMaxTimeMSExpiredError() {
142 return true
143 }
144 if we := (WriteException{}); errors.As(err, &we) && we.WriteConcernError != nil && we.WriteConcernError.IsMaxTimeMSExpiredError() {
145 return true
146 }
147 if ne := net.Error(nil); errors.As(err, &ne) {
148 return ne.Timeout()
149 }
150
151 if le := LabeledError(nil); errors.As(err, &le) {
152 if le.HasErrorLabel("NetworkTimeoutError") || le.HasErrorLabel("ExceededTimeLimitError") {
153 return true
154 }
155 }
156
157 return false
158 }
159
160
161 func unwrap(err error) error {
162 u, ok := err.(interface {
163 Unwrap() error
164 })
165 if !ok {
166 return nil
167 }
168 return u.Unwrap()
169 }
170
171
172 func errorHasLabel(err error, label string) bool {
173 for ; err != nil; err = unwrap(err) {
174 if le, ok := err.(LabeledError); ok && le.HasErrorLabel(label) {
175 return true
176 }
177 }
178 return false
179 }
180
181
182 func IsNetworkError(err error) bool {
183 return errorHasLabel(err, "NetworkError")
184 }
185
186
187 type MongocryptError struct {
188 Code int32
189 Message string
190 }
191
192
193 func (m MongocryptError) Error() string {
194 return fmt.Sprintf("mongocrypt error %d: %v", m.Code, m.Message)
195 }
196
197
198
199 type EncryptionKeyVaultError struct {
200 Wrapped error
201 }
202
203
204 func (ekve EncryptionKeyVaultError) Error() string {
205 return fmt.Sprintf("key vault communication error: %v", ekve.Wrapped)
206 }
207
208
209 func (ekve EncryptionKeyVaultError) Unwrap() error {
210 return ekve.Wrapped
211 }
212
213
214 type MongocryptdError struct {
215 Wrapped error
216 }
217
218
219 func (e MongocryptdError) Error() string {
220 return fmt.Sprintf("mongocryptd communication error: %v", e.Wrapped)
221 }
222
223
224 func (e MongocryptdError) Unwrap() error {
225 return e.Wrapped
226 }
227
228
229 type LabeledError interface {
230 error
231
232 HasErrorLabel(string) bool
233 }
234
235
236
237 type ServerError interface {
238 LabeledError
239
240 HasErrorCode(int) bool
241
242 HasErrorMessage(string) bool
243
244 HasErrorCodeWithMessage(int, string) bool
245
246 serverError()
247 }
248
249 var _ ServerError = CommandError{}
250 var _ ServerError = WriteError{}
251 var _ ServerError = WriteException{}
252 var _ ServerError = BulkWriteException{}
253
254
255 type CommandError struct {
256 Code int32
257 Message string
258 Labels []string
259 Name string
260 Wrapped error
261 Raw bson.Raw
262 }
263
264
265 func (e CommandError) Error() string {
266 if e.Name != "" {
267 return fmt.Sprintf("(%v) %v", e.Name, e.Message)
268 }
269 return e.Message
270 }
271
272
273 func (e CommandError) Unwrap() error {
274 return e.Wrapped
275 }
276
277
278 func (e CommandError) HasErrorCode(code int) bool {
279 return int(e.Code) == code
280 }
281
282
283 func (e CommandError) HasErrorLabel(label string) bool {
284 if e.Labels != nil {
285 for _, l := range e.Labels {
286 if l == label {
287 return true
288 }
289 }
290 }
291 return false
292 }
293
294
295 func (e CommandError) HasErrorMessage(message string) bool {
296 return strings.Contains(e.Message, message)
297 }
298
299
300 func (e CommandError) HasErrorCodeWithMessage(code int, message string) bool {
301 return int(e.Code) == code && strings.Contains(e.Message, message)
302 }
303
304
305 func (e CommandError) IsMaxTimeMSExpiredError() bool {
306 return e.Code == 50 || e.Name == "MaxTimeMSExpired"
307 }
308
309
310 func (e CommandError) serverError() {}
311
312
313
314 type WriteError struct {
315
316 Index int
317
318 Code int
319 Message string
320 Details bson.Raw
321
322
323 Raw bson.Raw
324 }
325
326 func (we WriteError) Error() string {
327 msg := we.Message
328 if len(we.Details) > 0 {
329 msg = fmt.Sprintf("%s: %s", msg, we.Details.String())
330 }
331 return msg
332 }
333
334
335 func (we WriteError) HasErrorCode(code int) bool {
336 return we.Code == code
337 }
338
339
340
341 func (we WriteError) HasErrorLabel(string) bool {
342 return false
343 }
344
345
346 func (we WriteError) HasErrorMessage(message string) bool {
347 return strings.Contains(we.Message, message)
348 }
349
350
351 func (we WriteError) HasErrorCodeWithMessage(code int, message string) bool {
352 return we.Code == code && strings.Contains(we.Message, message)
353 }
354
355
356 func (we WriteError) serverError() {}
357
358
359 type WriteErrors []WriteError
360
361
362 func (we WriteErrors) Error() string {
363 errs := make([]error, len(we))
364 for i := 0; i < len(we); i++ {
365 errs[i] = we[i]
366 }
367
368 return "write errors: " + joinBatchErrors(errs)
369 }
370
371 func writeErrorsFromDriverWriteErrors(errs driver.WriteErrors) WriteErrors {
372 wes := make(WriteErrors, 0, len(errs))
373 for _, err := range errs {
374 wes = append(wes, WriteError{
375 Index: int(err.Index),
376 Code: int(err.Code),
377 Message: err.Message,
378 Details: bson.Raw(err.Details),
379 Raw: bson.Raw(err.Raw),
380 })
381 }
382 return wes
383 }
384
385
386
387 type WriteConcernError struct {
388 Name string
389 Code int
390 Message string
391 Details bson.Raw
392 Raw bson.Raw
393 }
394
395
396 func (wce WriteConcernError) Error() string {
397 if wce.Name != "" {
398 return fmt.Sprintf("(%v) %v", wce.Name, wce.Message)
399 }
400 return wce.Message
401 }
402
403
404 func (wce WriteConcernError) IsMaxTimeMSExpiredError() bool {
405 return wce.Code == 50
406 }
407
408
409
410 type WriteException struct {
411
412 WriteConcernError *WriteConcernError
413
414
415 WriteErrors WriteErrors
416
417
418 Labels []string
419
420
421 Raw bson.Raw
422 }
423
424
425 func (mwe WriteException) Error() string {
426 causes := make([]string, 0, 2)
427 if mwe.WriteConcernError != nil {
428 causes = append(causes, "write concern error: "+mwe.WriteConcernError.Error())
429 }
430 if len(mwe.WriteErrors) > 0 {
431
432
433 causes = append(causes, mwe.WriteErrors.Error())
434 }
435
436 message := "write exception: "
437 if len(causes) == 0 {
438 return message + "no causes"
439 }
440 return message + strings.Join(causes, ", ")
441 }
442
443
444 func (mwe WriteException) HasErrorCode(code int) bool {
445 if mwe.WriteConcernError != nil && mwe.WriteConcernError.Code == code {
446 return true
447 }
448 for _, we := range mwe.WriteErrors {
449 if we.Code == code {
450 return true
451 }
452 }
453 return false
454 }
455
456
457 func (mwe WriteException) HasErrorLabel(label string) bool {
458 if mwe.Labels != nil {
459 for _, l := range mwe.Labels {
460 if l == label {
461 return true
462 }
463 }
464 }
465 return false
466 }
467
468
469 func (mwe WriteException) HasErrorMessage(message string) bool {
470 if mwe.WriteConcernError != nil && strings.Contains(mwe.WriteConcernError.Message, message) {
471 return true
472 }
473 for _, we := range mwe.WriteErrors {
474 if strings.Contains(we.Message, message) {
475 return true
476 }
477 }
478 return false
479 }
480
481
482 func (mwe WriteException) HasErrorCodeWithMessage(code int, message string) bool {
483 if mwe.WriteConcernError != nil &&
484 mwe.WriteConcernError.Code == code && strings.Contains(mwe.WriteConcernError.Message, message) {
485 return true
486 }
487 for _, we := range mwe.WriteErrors {
488 if we.Code == code && strings.Contains(we.Message, message) {
489 return true
490 }
491 }
492 return false
493 }
494
495
496 func (mwe WriteException) serverError() {}
497
498 func convertDriverWriteConcernError(wce *driver.WriteConcernError) *WriteConcernError {
499 if wce == nil {
500 return nil
501 }
502
503 return &WriteConcernError{
504 Name: wce.Name,
505 Code: int(wce.Code),
506 Message: wce.Message,
507 Details: bson.Raw(wce.Details),
508 Raw: bson.Raw(wce.Raw),
509 }
510 }
511
512
513
514 type BulkWriteError struct {
515 WriteError
516 Request WriteModel
517 }
518
519
520 func (bwe BulkWriteError) Error() string {
521 return bwe.WriteError.Error()
522 }
523
524
525 type BulkWriteException struct {
526
527 WriteConcernError *WriteConcernError
528
529
530 WriteErrors []BulkWriteError
531
532
533 Labels []string
534 }
535
536
537 func (bwe BulkWriteException) Error() string {
538 causes := make([]string, 0, 2)
539 if bwe.WriteConcernError != nil {
540 causes = append(causes, "write concern error: "+bwe.WriteConcernError.Error())
541 }
542 if len(bwe.WriteErrors) > 0 {
543 errs := make([]error, len(bwe.WriteErrors))
544 for i := 0; i < len(bwe.WriteErrors); i++ {
545 errs[i] = &bwe.WriteErrors[i]
546 }
547 causes = append(causes, "write errors: "+joinBatchErrors(errs))
548 }
549
550 message := "bulk write exception: "
551 if len(causes) == 0 {
552 return message + "no causes"
553 }
554 return "bulk write exception: " + strings.Join(causes, ", ")
555 }
556
557
558 func (bwe BulkWriteException) HasErrorCode(code int) bool {
559 if bwe.WriteConcernError != nil && bwe.WriteConcernError.Code == code {
560 return true
561 }
562 for _, we := range bwe.WriteErrors {
563 if we.Code == code {
564 return true
565 }
566 }
567 return false
568 }
569
570
571 func (bwe BulkWriteException) HasErrorLabel(label string) bool {
572 if bwe.Labels != nil {
573 for _, l := range bwe.Labels {
574 if l == label {
575 return true
576 }
577 }
578 }
579 return false
580 }
581
582
583 func (bwe BulkWriteException) HasErrorMessage(message string) bool {
584 if bwe.WriteConcernError != nil && strings.Contains(bwe.WriteConcernError.Message, message) {
585 return true
586 }
587 for _, we := range bwe.WriteErrors {
588 if strings.Contains(we.Message, message) {
589 return true
590 }
591 }
592 return false
593 }
594
595
596 func (bwe BulkWriteException) HasErrorCodeWithMessage(code int, message string) bool {
597 if bwe.WriteConcernError != nil &&
598 bwe.WriteConcernError.Code == code && strings.Contains(bwe.WriteConcernError.Message, message) {
599 return true
600 }
601 for _, we := range bwe.WriteErrors {
602 if we.Code == code && strings.Contains(we.Message, message) {
603 return true
604 }
605 }
606 return false
607 }
608
609
610 func (bwe BulkWriteException) serverError() {}
611
612
613
614
615
616 type returnResult int
617
618 const (
619 rrNone returnResult = 1 << iota
620 rrOne
621 rrMany
622
623 rrAll returnResult = rrOne | rrMany
624 )
625
626
627
628
629
630
631 func processWriteError(err error) (returnResult, error) {
632 switch {
633 case errors.Is(err, driver.ErrUnacknowledgedWrite):
634 return rrAll, ErrUnacknowledgedWrite
635 case err != nil:
636 switch tt := err.(type) {
637 case driver.WriteCommandError:
638 return rrMany, WriteException{
639 WriteConcernError: convertDriverWriteConcernError(tt.WriteConcernError),
640 WriteErrors: writeErrorsFromDriverWriteErrors(tt.WriteErrors),
641 Labels: tt.Labels,
642 Raw: bson.Raw(tt.Raw),
643 }
644 default:
645 return rrNone, replaceErrors(err)
646 }
647 default:
648 return rrAll, nil
649 }
650 }
651
652
653
654
655 const batchErrorsTargetLength = 2000
656
657
658
659
660
661
662
663
664 func joinBatchErrors(errs []error) string {
665 var buf bytes.Buffer
666 fmt.Fprint(&buf, "[")
667 for idx, err := range errs {
668 if idx != 0 {
669 fmt.Fprint(&buf, ", ")
670 }
671
672
673 if buf.Len() > batchErrorsTargetLength {
674 fmt.Fprintf(&buf, "+%d more errors...", len(errs)-idx)
675 break
676 }
677 fmt.Fprint(&buf, err.Error())
678 }
679 fmt.Fprint(&buf, "]")
680
681 return buf.String()
682 }
683
View as plain text