1
2
3
4
5
6
7 package driver
8
9 import (
10 "bytes"
11 "context"
12 "errors"
13 "fmt"
14 "math"
15 "net"
16 "strconv"
17 "strings"
18 "sync"
19 "time"
20
21 "go.mongodb.org/mongo-driver/bson"
22 "go.mongodb.org/mongo-driver/bson/bsontype"
23 "go.mongodb.org/mongo-driver/bson/primitive"
24 "go.mongodb.org/mongo-driver/event"
25 "go.mongodb.org/mongo-driver/internal/csot"
26 "go.mongodb.org/mongo-driver/internal/driverutil"
27 "go.mongodb.org/mongo-driver/internal/handshake"
28 "go.mongodb.org/mongo-driver/internal/logger"
29 "go.mongodb.org/mongo-driver/mongo/address"
30 "go.mongodb.org/mongo-driver/mongo/description"
31 "go.mongodb.org/mongo-driver/mongo/readconcern"
32 "go.mongodb.org/mongo-driver/mongo/readpref"
33 "go.mongodb.org/mongo-driver/mongo/writeconcern"
34 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
35 "go.mongodb.org/mongo-driver/x/mongo/driver/session"
36 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
37 )
38
39 const defaultLocalThreshold = 15 * time.Millisecond
40
41 var (
42
43 ErrNoDocCommandResponse = errors.New("command returned no documents")
44
45 ErrMultiDocCommandResponse = errors.New("command returned multiple documents")
46
47 ErrReplyDocumentMismatch = errors.New("number of documents returned does not match numberReturned field")
48
49 ErrNonPrimaryReadPref = errors.New("read preference in a transaction must be primary")
50
51 errDatabaseNameEmpty = errors.New("database name cannot be empty")
52 )
53
54 const (
55
56 cryptMaxBsonObjectSize uint32 = 2097152
57
58 cryptMinWireVersion int32 = 8
59
60 readSnapshotMinWireVersion int32 = 13
61 )
62
63
64 type RetryablePoolError interface {
65 Retryable() bool
66 }
67
68
69 type labeledError interface {
70 error
71 HasErrorLabel(string) bool
72 }
73
74
75
76 type InvalidOperationError struct{ MissingField string }
77
78 func (err InvalidOperationError) Error() string {
79 return "the " + err.MissingField + " field must be set on Operation"
80 }
81
82
83
84 type opReply struct {
85 responseFlags wiremessage.ReplyFlag
86 cursorID int64
87 startingFrom int32
88 numReturned int32
89 documents []bsoncore.Document
90 err error
91 }
92
93
94 type startedInformation struct {
95 cmd bsoncore.Document
96 requestID int32
97 cmdName string
98 documentSequenceIncluded bool
99 connID string
100 driverConnectionID uint64
101 serverConnID *int64
102 redacted bool
103 serviceID *primitive.ObjectID
104 serverAddress address.Address
105 }
106
107
108 type finishedInformation struct {
109 cmdName string
110 requestID int32
111 response bsoncore.Document
112 cmdErr error
113 connID string
114 driverConnectionID uint64
115 serverConnID *int64
116 redacted bool
117 serviceID *primitive.ObjectID
118 serverAddress address.Address
119 duration time.Duration
120 }
121
122
123
124
125 func convertInt64PtrToInt32Ptr(i64 *int64) *int32 {
126 if i64 == nil {
127 return nil
128 }
129
130 if *i64 > math.MaxInt32 || *i64 < math.MinInt32 {
131 return nil
132 }
133
134 i32 := int32(*i64)
135 return &i32
136 }
137
138
139
140
141
142
143
144 func (info finishedInformation) success() bool {
145 if _, ok := info.cmdErr.(WriteCommandError); ok {
146 return true
147 }
148
149 return info.cmdErr == nil
150 }
151
152
153 type ResponseInfo struct {
154 ServerResponse bsoncore.Document
155 Server Server
156 Connection Connection
157 ConnectionDescription description.Server
158 CurrentIndex int
159 }
160
161 func redactStartedInformationCmd(op Operation, info startedInformation) bson.Raw {
162 var cmdCopy bson.Raw
163
164
165
166
167 if !info.redacted {
168 cmdCopy = make([]byte, len(info.cmd))
169 copy(cmdCopy, info.cmd)
170
171 if info.documentSequenceIncluded {
172
173 cmdCopy = cmdCopy[:len(info.cmd)-1]
174 cmdCopy = op.addBatchArray(cmdCopy)
175
176
177 cmdCopy, _ = bsoncore.AppendDocumentEnd(cmdCopy, 0)
178 }
179 }
180
181 return cmdCopy
182 }
183
184 func redactFinishedInformationResponse(info finishedInformation) bson.Raw {
185 if !info.redacted {
186 return bson.Raw(info.response)
187 }
188
189 return bson.Raw{}
190 }
191
192
193
194
195
196
197
198
199
200
201
202 type Operation struct {
203
204
205
206
207 CommandFn func(dst []byte, desc description.SelectedServer) ([]byte, error)
208
209
210 Database string
211
212
213
214
215
216 Deployment Deployment
217
218
219
220
221 ProcessResponseFn func(ResponseInfo) error
222
223
224
225
226 Selector description.ServerSelector
227
228
229
230 ReadPreference *readpref.ReadPref
231
232
233
234
235 ReadConcern *readconcern.ReadConcern
236
237
238
239 MinimumReadConcernWireVersion int32
240
241
242
243
244 WriteConcern *writeconcern.WriteConcern
245
246
247
248 MinimumWriteConcernWireVersion int32
249
250
251
252
253
254
255
256 Client *session.Client
257
258
259
260
261 Clock *session.ClusterClock
262
263
264
265
266
267
268 RetryMode *RetryMode
269
270
271
272
273 Type Type
274
275
276
277
278
279 Batches *Batches
280
281
282
283
284 Legacy LegacyOperationKind
285
286
287
288 CommandMonitor *event.CommandMonitor
289
290
291 Crypt Crypt
292
293
294 ServerAPI *ServerAPIOptions
295
296
297
298 IsOutputAggregate bool
299
300
301 MaxTime *time.Duration
302
303
304
305 Timeout *time.Duration
306
307 Logger *logger.Logger
308
309
310
311 Name string
312
313
314
315
316 OmitCSOTMaxTimeMS bool
317
318
319
320
321
322 omitReadPreference bool
323 }
324
325
326 func (op Operation) shouldEncrypt() bool {
327 return op.Crypt != nil && !op.Crypt.BypassAutoEncryption()
328 }
329
330
331
332
333
334
335
336
337 func filterDeprioritizedServers(candidates, deprioritized []description.Server) []description.Server {
338 if len(deprioritized) == 0 {
339 return candidates
340 }
341
342 dpaSet := make(map[address.Address]*description.Server)
343 for i, srv := range deprioritized {
344 dpaSet[srv.Addr] = &deprioritized[i]
345 }
346
347 allowed := []description.Server{}
348
349
350
351 for _, candidate := range candidates {
352 if srv, ok := dpaSet[candidate.Addr]; !ok || !srv.Equal(candidate) {
353 allowed = append(allowed, candidate)
354 }
355 }
356
357
358
359
360 if len(allowed) == 0 {
361 return candidates
362 }
363
364 return allowed
365 }
366
367
368
369
370 type opServerSelector struct {
371 selector description.ServerSelector
372 deprioritizedServers []description.Server
373 }
374
375
376
377 func (oss *opServerSelector) SelectServer(
378 topo description.Topology,
379 candidates []description.Server,
380 ) ([]description.Server, error) {
381 selectedServers, err := oss.selector.SelectServer(topo, candidates)
382 if err != nil {
383 return nil, err
384 }
385
386 filteredServers := filterDeprioritizedServers(selectedServers, oss.deprioritizedServers)
387
388 return filteredServers, nil
389 }
390
391
392 func (op Operation) selectServer(
393 ctx context.Context,
394 requestID int32,
395 deprioritized []description.Server,
396 ) (Server, error) {
397 if err := op.Validate(); err != nil {
398 return nil, err
399 }
400
401 selector := op.Selector
402 if selector == nil {
403 rp := op.ReadPreference
404 if rp == nil {
405 rp = readpref.Primary()
406 }
407 selector = description.CompositeSelector([]description.ServerSelector{
408 description.ReadPrefSelector(rp),
409 description.LatencySelector(defaultLocalThreshold),
410 })
411 }
412
413 oss := &opServerSelector{
414 selector: selector,
415 deprioritizedServers: deprioritized,
416 }
417
418 ctx = logger.WithOperationName(ctx, op.Name)
419 ctx = logger.WithOperationID(ctx, requestID)
420
421 return op.Deployment.SelectServer(ctx, oss)
422 }
423
424
425 func (op Operation) getServerAndConnection(
426 ctx context.Context,
427 requestID int32,
428 deprioritized []description.Server,
429 ) (Server, Connection, error) {
430 server, err := op.selectServer(ctx, requestID, deprioritized)
431 if err != nil {
432 if op.Client != nil &&
433 !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning() {
434 err = Error{
435 Message: err.Error(),
436 Labels: []string{TransientTransactionError},
437 Wrapped: err,
438 }
439 }
440 return nil, nil, err
441 }
442
443
444
445 if op.Client != nil && op.Client.PinnedConnection != nil {
446 return server, op.Client.PinnedConnection, nil
447 }
448
449
450 conn, err := server.Connection(ctx)
451 if err != nil {
452 return nil, nil, err
453 }
454
455
456 if conn.Description().LoadBalanced() && op.Client != nil && op.Client.TransactionStarting() {
457 pinnedConn, ok := conn.(PinnedConnection)
458 if !ok {
459
460 _ = conn.Close()
461 return nil, nil, fmt.Errorf("expected Connection used to start a transaction to be a PinnedConnection, but got %T", conn)
462 }
463 if err := pinnedConn.PinToTransaction(); err != nil {
464
465 _ = conn.Close()
466 return nil, nil, fmt.Errorf("error incrementing connection reference count when starting a transaction: %w", err)
467 }
468 op.Client.PinnedConnection = pinnedConn
469 }
470
471 return server, conn, nil
472 }
473
474
475 func (op Operation) Validate() error {
476 if op.CommandFn == nil {
477 return InvalidOperationError{MissingField: "CommandFn"}
478 }
479 if op.Deployment == nil {
480 return InvalidOperationError{MissingField: "Deployment"}
481 }
482 if op.Database == "" {
483 return errDatabaseNameEmpty
484 }
485 if op.Client != nil && !writeconcern.AckWrite(op.WriteConcern) {
486 return errors.New("session provided for an unacknowledged write")
487 }
488 return nil
489 }
490
491 var memoryPool = sync.Pool{
492 New: func() interface{} {
493
494 b := make([]byte, 1024)
495
496 return &b
497 },
498 }
499
500
501 func (op Operation) Execute(ctx context.Context) error {
502 err := op.Validate()
503 if err != nil {
504 return err
505 }
506
507
508
509 if op.Timeout != nil && !csot.IsTimeoutContext(ctx) {
510 newCtx, cancelFunc := csot.MakeTimeoutContext(ctx, *op.Timeout)
511
512 ctx = newCtx
513
514 defer cancelFunc()
515 }
516
517 if op.Client != nil {
518 if err := op.Client.StartCommand(); err != nil {
519 return err
520 }
521 }
522
523 var retries int
524 if op.RetryMode != nil {
525 switch op.Type {
526 case Write:
527 if op.Client == nil {
528 break
529 }
530 switch *op.RetryMode {
531 case RetryOnce, RetryOncePerCommand:
532 retries = 1
533 case RetryContext:
534 retries = -1
535 }
536 case Read:
537 switch *op.RetryMode {
538 case RetryOnce, RetryOncePerCommand:
539 retries = 1
540 case RetryContext:
541 retries = -1
542 }
543 }
544 }
545
546
547 retryEnabled := op.RetryMode != nil && op.RetryMode.Enabled()
548 if csot.IsTimeoutContext(ctx) && retryEnabled {
549 retries = -1
550 }
551
552 var srvr Server
553 var conn Connection
554 var res bsoncore.Document
555 var operationErr WriteCommandError
556 var prevErr error
557 var prevIndefiniteErr error
558 batching := op.Batches.Valid()
559 retrySupported := false
560 first := true
561 currIndex := 0
562
563
564
565
566 var deprioritizedServers []description.Server
567
568
569
570 resetForRetry := func(err error) {
571 retries--
572 prevErr = err
573
574
575
576 switch err := err.(type) {
577 case labeledError:
578
579
580
581
582 if prevIndefiniteErr == nil {
583 prevIndefiniteErr = err
584 }
585
586
587
588 if !err.HasErrorLabel(NoWritesPerformed) && err.HasErrorLabel(RetryableWriteError) {
589 prevIndefiniteErr = err
590 }
591 }
592
593
594
595 if conn != nil {
596
597
598 if desc := conn.Description; desc != nil && op.Deployment.Kind() == description.Sharded {
599 deprioritizedServers = []description.Server{conn.Description()}
600 }
601
602 conn.Close()
603 }
604
605
606 srvr = nil
607 conn = nil
608 }
609
610 wm := memoryPool.Get().(*[]byte)
611 defer func() {
612
613
614
615
616
617
618
619
620 if c := cap(*wm); c < 16*1024*1024 && c/2 < len(*wm) {
621 memoryPool.Put(wm)
622 }
623 }()
624 for {
625
626
627
628 if errors.Is(prevErr, context.Canceled) || errors.Is(prevErr, context.DeadlineExceeded) {
629 return prevErr
630 }
631
632 requestID := wiremessage.NextRequestID()
633
634
635 if srvr == nil || conn == nil {
636 srvr, conn, err = op.getServerAndConnection(ctx, requestID, deprioritizedServers)
637 if err != nil {
638
639
640
641 if rerr, ok := err.(RetryablePoolError); ok && rerr.Retryable() && retries != 0 {
642 resetForRetry(err)
643 continue
644 }
645
646
647
648 if prevErr != nil {
649 return prevErr
650 }
651 return err
652 }
653 defer conn.Close()
654
655
656
657
658 if op.Client != nil && op.Client.Server == nil && op.Client.IsImplicit {
659 if op.Client.Terminated {
660 return fmt.Errorf("unexpected nil session for a terminated implicit session")
661 }
662 if err := op.Client.SetServer(); err != nil {
663 return err
664 }
665 }
666 }
667
668
669 if first {
670
671
672
673
674
675
676
677 retrySupported = op.retryable(conn.Description())
678
679
680
681
682
683
684
685 if retrySupported && op.RetryMode != nil && op.Type == Write && op.Client != nil {
686 op.Client.RetryWrite = false
687 if op.RetryMode.Enabled() {
688 op.Client.RetryWrite = true
689 if !op.Client.Committing && !op.Client.Aborting {
690 op.Client.IncrementTxnNumber()
691 }
692 }
693 }
694
695 first = false
696 }
697
698 maxTimeMS, err := op.calculateMaxTimeMS(ctx, srvr.RTTMonitor())
699 if err != nil {
700 return err
701 }
702
703
704
705 if conn.Description().IsCryptd {
706 maxTimeMS = 0
707 }
708
709 desc := description.SelectedServer{Server: conn.Description(), Kind: op.Deployment.Kind()}
710
711 if batching {
712 targetBatchSize := desc.MaxDocumentSize
713 maxDocSize := desc.MaxDocumentSize
714 if op.shouldEncrypt() {
715
716
717
718
719 targetBatchSize = cryptMaxBsonObjectSize
720 }
721
722 err = op.Batches.AdvanceBatch(int(desc.MaxBatchCount), int(targetBatchSize), int(maxDocSize))
723 if err != nil {
724
725 return err
726 }
727 }
728
729 var startedInfo startedInformation
730 *wm, startedInfo, err = op.createWireMessage(ctx, maxTimeMS, (*wm)[:0], desc, conn, requestID)
731
732 if err != nil {
733 return err
734 }
735
736
737 startedInfo.connID = conn.ID()
738 startedInfo.driverConnectionID = conn.DriverConnectionID()
739 startedInfo.cmdName = op.getCommandName(startedInfo.cmd)
740
741
742
743
744
745 if startedInfo.cmdName != op.Name {
746 op.Name = startedInfo.cmdName
747 }
748
749 startedInfo.redacted = op.redactCommand(startedInfo.cmdName, startedInfo.cmd)
750 startedInfo.serviceID = conn.Description().ServiceID
751 startedInfo.serverConnID = conn.ServerConnectionID()
752 startedInfo.serverAddress = conn.Description().Addr
753
754 op.publishStartedEvent(ctx, startedInfo)
755
756
757 moreToCome := wiremessage.IsMsgMoreToCome(*wm)
758
759
760 if compressor, ok := conn.(Compressor); ok && op.canCompress(startedInfo.cmdName) {
761 b := memoryPool.Get().(*[]byte)
762 *b, err = compressor.CompressWireMessage(*wm, (*b)[:0])
763 memoryPool.Put(wm)
764 wm = b
765 if err != nil {
766 return err
767 }
768 }
769
770 finishedInfo := finishedInformation{
771 cmdName: startedInfo.cmdName,
772 driverConnectionID: startedInfo.driverConnectionID,
773 requestID: startedInfo.requestID,
774 connID: startedInfo.connID,
775 serverConnID: startedInfo.serverConnID,
776 redacted: startedInfo.redacted,
777 serviceID: startedInfo.serviceID,
778 serverAddress: desc.Server.Addr,
779 }
780
781 startedTime := time.Now()
782
783
784
785
786 if ctx.Err() != nil {
787 err = ctx.Err()
788 } else if deadline, ok := ctx.Deadline(); ok {
789 if csot.IsTimeoutContext(ctx) && time.Now().Add(srvr.RTTMonitor().P90()).After(deadline) {
790 err = fmt.Errorf(
791 "remaining time %v until context deadline is less than 90th percentile network round-trip time: %w\n%v",
792 time.Until(deadline),
793 ErrDeadlineWouldBeExceeded,
794 srvr.RTTMonitor().Stats())
795 } else if time.Now().Add(srvr.RTTMonitor().Min()).After(deadline) {
796 err = context.DeadlineExceeded
797 }
798 }
799
800 if err == nil {
801
802
803 roundTrip := op.roundTrip
804 if moreToCome {
805 roundTrip = op.moreToComeRoundTrip
806 }
807 res, err = roundTrip(ctx, conn, *wm)
808
809 if ep, ok := srvr.(ErrorProcessor); ok {
810 _ = ep.ProcessError(err, conn)
811 }
812 }
813
814 finishedInfo.response = res
815 finishedInfo.cmdErr = err
816 finishedInfo.duration = time.Since(startedTime)
817
818 op.publishFinishedEvent(ctx, finishedInfo)
819
820
821
822 var prevIndefiniteErrIsSet bool
823
824
825
826 checkError:
827 var perr error
828 switch tt := err.(type) {
829 case WriteCommandError:
830 if e := err.(WriteCommandError); retrySupported && op.Type == Write && e.UnsupportedStorageEngine() {
831 return ErrUnsupportedStorageEngine
832 }
833
834 connDesc := conn.Description()
835 retryableErr := tt.Retryable(connDesc.WireVersion)
836 preRetryWriteLabelVersion := connDesc.WireVersion != nil && connDesc.WireVersion.Max < 9
837 inTransaction := op.Client != nil &&
838 !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning()
839
840
841 if retryableErr && preRetryWriteLabelVersion && retryEnabled && !inTransaction {
842 tt.Labels = append(tt.Labels, RetryableWriteError)
843 }
844
845
846
847
848 if retrySupported && retryableErr && retries != 0 {
849 if op.Client != nil && op.Client.Committing {
850
851 op.Client.UpdateCommitTransactionWriteConcern()
852 op.WriteConcern = op.Client.CurrentWc
853 }
854 resetForRetry(tt)
855 continue
856 }
857
858
859
860
861 if tt.HasErrorLabel(NoWritesPerformed) && !prevIndefiniteErrIsSet {
862 err = prevIndefiniteErr
863 prevIndefiniteErrIsSet = true
864
865 goto checkError
866 }
867
868
869 if op.ProcessResponseFn != nil {
870 info := ResponseInfo{
871 ServerResponse: res,
872 Server: srvr,
873 Connection: conn,
874 ConnectionDescription: desc.Server,
875 CurrentIndex: currIndex,
876 }
877 _ = op.ProcessResponseFn(info)
878 }
879
880 if batching && len(tt.WriteErrors) > 0 && currIndex > 0 {
881 for i := range tt.WriteErrors {
882 tt.WriteErrors[i].Index += int64(currIndex)
883 }
884 }
885
886
887
888 if batching && (op.Batches.Ordered == nil || *op.Batches.Ordered) && len(tt.WriteErrors) > 0 {
889 return tt
890 }
891 if op.Client != nil && op.Client.Committing && tt.WriteConcernError != nil {
892
893 err := Error{
894 Name: tt.WriteConcernError.Name,
895 Code: int32(tt.WriteConcernError.Code),
896 Message: tt.WriteConcernError.Message,
897 Labels: tt.Labels,
898 Raw: tt.Raw,
899 }
900
901
902 if err.Code != unknownReplWriteConcernCode && err.Code != unsatisfiableWriteConcernCode {
903 err.Labels = append(err.Labels, UnknownTransactionCommitResult)
904 }
905 if retryableErr && retryEnabled {
906 err.Labels = append(err.Labels, RetryableWriteError)
907 }
908 return err
909 }
910 operationErr.WriteConcernError = tt.WriteConcernError
911 operationErr.WriteErrors = append(operationErr.WriteErrors, tt.WriteErrors...)
912 operationErr.Labels = tt.Labels
913 operationErr.Raw = tt.Raw
914 case Error:
915 if tt.HasErrorLabel(TransientTransactionError) || tt.HasErrorLabel(UnknownTransactionCommitResult) {
916 if err := op.Client.ClearPinnedResources(); err != nil {
917 return err
918 }
919 }
920
921 if e := err.(Error); retrySupported && op.Type == Write && e.UnsupportedStorageEngine() {
922 return ErrUnsupportedStorageEngine
923 }
924
925 connDesc := conn.Description()
926 var retryableErr bool
927 if op.Type == Write {
928 retryableErr = tt.RetryableWrite(connDesc.WireVersion)
929 preRetryWriteLabelVersion := connDesc.WireVersion != nil && connDesc.WireVersion.Max < 9
930 inTransaction := op.Client != nil &&
931 !(op.Client.Committing || op.Client.Aborting) && op.Client.TransactionRunning()
932
933
934 if retryEnabled && !inTransaction &&
935 (tt.HasErrorLabel(NetworkError) || (retryableErr && preRetryWriteLabelVersion)) {
936 tt.Labels = append(tt.Labels, RetryableWriteError)
937 }
938 } else {
939 retryableErr = tt.RetryableRead()
940 }
941
942
943
944
945 if retrySupported && retryableErr && retries != 0 {
946 if op.Client != nil && op.Client.Committing {
947
948 op.Client.UpdateCommitTransactionWriteConcern()
949 op.WriteConcern = op.Client.CurrentWc
950 }
951 resetForRetry(tt)
952 continue
953 }
954
955
956
957
958 if tt.HasErrorLabel(NoWritesPerformed) && !prevIndefiniteErrIsSet {
959 err = prevIndefiniteErr
960 prevIndefiniteErrIsSet = true
961
962 goto checkError
963 }
964
965
966 if op.ProcessResponseFn != nil {
967 info := ResponseInfo{
968 ServerResponse: res,
969 Server: srvr,
970 Connection: conn,
971 ConnectionDescription: desc.Server,
972 CurrentIndex: currIndex,
973 }
974 _ = op.ProcessResponseFn(info)
975 }
976
977 if op.Client != nil && op.Client.Committing && (retryableErr || tt.Code == 50) {
978
979 tt.Labels = append(tt.Labels, UnknownTransactionCommitResult)
980 }
981 return tt
982 case nil:
983 if moreToCome {
984 return ErrUnacknowledgedWrite
985 }
986 if op.ProcessResponseFn != nil {
987 info := ResponseInfo{
988 ServerResponse: res,
989 Server: srvr,
990 Connection: conn,
991 ConnectionDescription: desc.Server,
992 CurrentIndex: currIndex,
993 }
994 perr = op.ProcessResponseFn(info)
995 }
996 if perr != nil {
997 return perr
998 }
999 default:
1000 if op.ProcessResponseFn != nil {
1001 info := ResponseInfo{
1002 ServerResponse: res,
1003 Server: srvr,
1004 Connection: conn,
1005 ConnectionDescription: desc.Server,
1006 CurrentIndex: currIndex,
1007 }
1008 _ = op.ProcessResponseFn(info)
1009 }
1010 return err
1011 }
1012
1013
1014
1015
1016 if batching && len(op.Batches.Documents) > 0 {
1017
1018
1019
1020
1021 if retrySupported && op.Client != nil && op.RetryMode != nil {
1022 if op.RetryMode.Enabled() {
1023 op.Client.IncrementTxnNumber()
1024 }
1025
1026
1027 if *op.RetryMode == RetryOncePerCommand && !csot.IsTimeoutContext(ctx) {
1028 retries = 1
1029 }
1030 }
1031 currIndex += len(op.Batches.Current)
1032 op.Batches.ClearBatch()
1033 continue
1034 }
1035 break
1036 }
1037 if len(operationErr.WriteErrors) > 0 || operationErr.WriteConcernError != nil {
1038 return operationErr
1039 }
1040 return nil
1041 }
1042
1043
1044
1045 func (op Operation) retryable(desc description.Server) bool {
1046 switch op.Type {
1047 case Write:
1048 if op.Client != nil && (op.Client.Committing || op.Client.Aborting) {
1049 return true
1050 }
1051 if retryWritesSupported(desc) &&
1052 op.Client != nil && !(op.Client.TransactionInProgress() || op.Client.TransactionStarting()) &&
1053 writeconcern.AckWrite(op.WriteConcern) {
1054 return true
1055 }
1056 case Read:
1057 if op.Client != nil && (op.Client.Committing || op.Client.Aborting) {
1058 return true
1059 }
1060 if op.Client == nil || !(op.Client.TransactionInProgress() || op.Client.TransactionStarting()) {
1061 return true
1062 }
1063 }
1064 return false
1065 }
1066
1067
1068
1069 func (op Operation) roundTrip(ctx context.Context, conn Connection, wm []byte) ([]byte, error) {
1070 err := conn.WriteWireMessage(ctx, wm)
1071 if err != nil {
1072 return nil, op.networkError(err)
1073 }
1074 return op.readWireMessage(ctx, conn)
1075 }
1076
1077 func (op Operation) readWireMessage(ctx context.Context, conn Connection) (result []byte, err error) {
1078 wm, err := conn.ReadWireMessage(ctx)
1079 if err != nil {
1080 return nil, op.networkError(err)
1081 }
1082
1083
1084
1085 if streamer, ok := conn.(StreamerConnection); ok {
1086 streamer.SetStreaming(wiremessage.IsMsgMoreToCome(wm))
1087 }
1088
1089 length, _, _, opcode, rem, ok := wiremessage.ReadHeader(wm)
1090 if !ok || len(wm) < int(length) {
1091 return nil, errors.New("malformed wire message: insufficient bytes")
1092 }
1093 if opcode == wiremessage.OpCompressed {
1094 rawsize := length - 16
1095
1096 opcode, rem, err = op.decompressWireMessage(rem[:rawsize])
1097 if err != nil {
1098 return nil, err
1099 }
1100 }
1101
1102
1103 res, err := op.decodeResult(ctx, opcode, rem)
1104
1105
1106 op.updateClusterTimes(res)
1107 op.updateOperationTime(res)
1108 op.Client.UpdateRecoveryToken(bson.Raw(res))
1109
1110
1111 if op.Name == driverutil.FindOp || op.Name == driverutil.AggregateOp || op.Name == driverutil.DistinctOp {
1112 op.Client.UpdateSnapshotTime(res)
1113 }
1114
1115 if err != nil {
1116 return res, err
1117 }
1118
1119
1120 if op.Crypt != nil {
1121 res, err = op.Crypt.Decrypt(ctx, res)
1122 }
1123 return res, err
1124 }
1125
1126
1127
1128
1129 func (op Operation) networkError(err error) error {
1130 if err == nil {
1131 return nil
1132 }
1133
1134 labels := []string{NetworkError}
1135 if op.Client != nil {
1136 op.Client.MarkDirty()
1137 }
1138 if op.Client != nil && op.Client.TransactionRunning() && !op.Client.Committing {
1139 labels = append(labels, TransientTransactionError)
1140 }
1141 if op.Client != nil && op.Client.Committing {
1142 labels = append(labels, UnknownTransactionCommitResult)
1143 }
1144 return Error{Message: err.Error(), Labels: labels, Wrapped: err}
1145 }
1146
1147
1148
1149 func (op *Operation) moreToComeRoundTrip(ctx context.Context, conn Connection, wm []byte) (result []byte, err error) {
1150 err = conn.WriteWireMessage(ctx, wm)
1151 if err != nil {
1152 if op.Client != nil {
1153 op.Client.MarkDirty()
1154 }
1155 err = Error{Message: err.Error(), Labels: []string{TransientTransactionError, NetworkError}, Wrapped: err}
1156 }
1157 return bsoncore.BuildDocument(nil, bsoncore.AppendInt32Element(nil, "ok", 1)), err
1158 }
1159
1160
1161 func (Operation) decompressWireMessage(wm []byte) (wiremessage.OpCode, []byte, error) {
1162
1163 opcode, rem, ok := wiremessage.ReadCompressedOriginalOpCode(wm)
1164 if !ok {
1165 return 0, nil, errors.New("malformed OP_COMPRESSED: missing original opcode")
1166 }
1167 uncompressedSize, rem, ok := wiremessage.ReadCompressedUncompressedSize(rem)
1168 if !ok {
1169 return 0, nil, errors.New("malformed OP_COMPRESSED: missing uncompressed size")
1170 }
1171
1172 compressorID, rem, ok := wiremessage.ReadCompressedCompressorID(rem)
1173 if !ok {
1174 return 0, nil, errors.New("malformed OP_COMPRESSED: missing compressor ID")
1175 }
1176 compressedSize := len(wm) - 9
1177
1178 msg, _, ok := wiremessage.ReadCompressedCompressedMessage(rem, int32(compressedSize))
1179 if !ok {
1180 return 0, nil, errors.New("malformed OP_COMPRESSED: insufficient bytes for compressed wiremessage")
1181 }
1182
1183 opts := CompressionOpts{
1184 Compressor: compressorID,
1185 UncompressedSize: uncompressedSize,
1186 }
1187 uncompressed, err := DecompressPayload(msg, opts)
1188 if err != nil {
1189 return 0, nil, err
1190 }
1191
1192 return opcode, uncompressed, nil
1193 }
1194
1195 func (op Operation) addBatchArray(dst []byte) []byte {
1196 aidx, dst := bsoncore.AppendArrayElementStart(dst, op.Batches.Identifier)
1197 for i, doc := range op.Batches.Current {
1198 dst = bsoncore.AppendDocumentElement(dst, strconv.Itoa(i), doc)
1199 }
1200 dst, _ = bsoncore.AppendArrayEnd(dst, aidx)
1201 return dst
1202 }
1203
1204 func (op Operation) createLegacyHandshakeWireMessage(
1205 maxTimeMS uint64,
1206 dst []byte,
1207 desc description.SelectedServer,
1208 ) ([]byte, startedInformation, error) {
1209 var info startedInformation
1210 flags := op.secondaryOK(desc)
1211 var wmindex int32
1212 info.requestID = wiremessage.NextRequestID()
1213 wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpQuery)
1214 dst = wiremessage.AppendQueryFlags(dst, flags)
1215
1216 dollarCmd := [...]byte{'.', '$', 'c', 'm', 'd'}
1217
1218
1219 dst = append(dst, op.Database...)
1220 dst = append(dst, dollarCmd[:]...)
1221 dst = append(dst, 0x00)
1222 dst = wiremessage.AppendQueryNumberToSkip(dst, 0)
1223 dst = wiremessage.AppendQueryNumberToReturn(dst, -1)
1224
1225 wrapper := int32(-1)
1226 rp, err := op.createReadPref(desc, true)
1227 if err != nil {
1228 return dst, info, err
1229 }
1230 if len(rp) > 0 {
1231 wrapper, dst = bsoncore.AppendDocumentStart(dst)
1232 dst = bsoncore.AppendHeader(dst, bsontype.EmbeddedDocument, "$query")
1233 }
1234 idx, dst := bsoncore.AppendDocumentStart(dst)
1235 dst, err = op.CommandFn(dst, desc)
1236 if err != nil {
1237 return dst, info, err
1238 }
1239
1240 if op.Batches != nil && len(op.Batches.Current) > 0 {
1241 dst = op.addBatchArray(dst)
1242 }
1243
1244 dst, err = op.addReadConcern(dst, desc)
1245 if err != nil {
1246 return dst, info, err
1247 }
1248
1249 dst, err = op.addWriteConcern(dst, desc)
1250 if err != nil {
1251 return dst, info, err
1252 }
1253
1254 dst, err = op.addSession(dst, desc)
1255 if err != nil {
1256 return dst, info, err
1257 }
1258
1259 dst = op.addClusterTime(dst, desc)
1260 dst = op.addServerAPI(dst)
1261
1262
1263 if maxTimeMS > 0 {
1264 dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS))
1265 }
1266
1267 dst, _ = bsoncore.AppendDocumentEnd(dst, idx)
1268
1269 info.cmd = dst[idx:]
1270
1271 if len(rp) > 0 {
1272 var err error
1273 dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp)
1274 dst, err = bsoncore.AppendDocumentEnd(dst, wrapper)
1275 if err != nil {
1276 return dst, info, err
1277 }
1278 }
1279
1280 return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
1281 }
1282
1283 func (op Operation) createMsgWireMessage(
1284 ctx context.Context,
1285 maxTimeMS uint64,
1286 dst []byte,
1287 desc description.SelectedServer,
1288 conn Connection,
1289 requestID int32,
1290 ) ([]byte, startedInformation, error) {
1291 var info startedInformation
1292 var flags wiremessage.MsgFlag
1293 var wmindex int32
1294
1295
1296 if op.WriteConcern != nil && !writeconcern.AckWrite(op.WriteConcern) && (op.Batches == nil || len(op.Batches.Documents) == 0) {
1297 flags = wiremessage.MoreToCome
1298 }
1299
1300
1301 if streamer, ok := conn.(StreamerConnection); ok && streamer.SupportsStreaming() {
1302 flags |= wiremessage.ExhaustAllowed
1303 }
1304
1305 info.requestID = requestID
1306 wmindex, dst = wiremessage.AppendHeaderStart(dst, info.requestID, 0, wiremessage.OpMsg)
1307 dst = wiremessage.AppendMsgFlags(dst, flags)
1308
1309 dst = wiremessage.AppendMsgSectionType(dst, wiremessage.SingleDocument)
1310
1311 idx, dst := bsoncore.AppendDocumentStart(dst)
1312
1313 dst, err := op.addCommandFields(ctx, dst, desc)
1314 if err != nil {
1315 return dst, info, err
1316 }
1317 dst, err = op.addReadConcern(dst, desc)
1318 if err != nil {
1319 return dst, info, err
1320 }
1321 dst, err = op.addWriteConcern(dst, desc)
1322 if err != nil {
1323 return dst, info, err
1324 }
1325 dst, err = op.addSession(dst, desc)
1326 if err != nil {
1327 return dst, info, err
1328 }
1329
1330 dst = op.addClusterTime(dst, desc)
1331 dst = op.addServerAPI(dst)
1332
1333
1334 if maxTimeMS > 0 {
1335 dst = bsoncore.AppendInt64Element(dst, "maxTimeMS", int64(maxTimeMS))
1336 }
1337
1338 dst = bsoncore.AppendStringElement(dst, "$db", op.Database)
1339 rp, err := op.createReadPref(desc, false)
1340 if err != nil {
1341 return dst, info, err
1342 }
1343 if len(rp) > 0 {
1344 dst = bsoncore.AppendDocumentElement(dst, "$readPreference", rp)
1345 }
1346
1347 dst, _ = bsoncore.AppendDocumentEnd(dst, idx)
1348
1349 info.cmd = dst[idx:]
1350
1351
1352
1353 if !op.shouldEncrypt() && op.Batches != nil && len(op.Batches.Current) > 0 {
1354 info.documentSequenceIncluded = true
1355 dst = wiremessage.AppendMsgSectionType(dst, wiremessage.DocumentSequence)
1356 idx, dst = bsoncore.ReserveLength(dst)
1357
1358 dst = append(dst, op.Batches.Identifier...)
1359 dst = append(dst, 0x00)
1360
1361 for _, doc := range op.Batches.Current {
1362 dst = append(dst, doc...)
1363 }
1364
1365 dst = bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:])))
1366 }
1367
1368 return bsoncore.UpdateLength(dst, wmindex, int32(len(dst[wmindex:]))), info, nil
1369 }
1370
1371
1372
1373 func isLegacyHandshake(op Operation, desc description.SelectedServer) bool {
1374 isInitialHandshake := desc.WireVersion == nil || desc.WireVersion.Max == 0
1375
1376 return op.Legacy == LegacyHandshake && isInitialHandshake
1377 }
1378
1379 func (op Operation) createWireMessage(
1380 ctx context.Context,
1381 maxTimeMS uint64,
1382 dst []byte,
1383 desc description.SelectedServer,
1384 conn Connection,
1385 requestID int32,
1386 ) ([]byte, startedInformation, error) {
1387 if isLegacyHandshake(op, desc) {
1388 return op.createLegacyHandshakeWireMessage(maxTimeMS, dst, desc)
1389 }
1390
1391 return op.createMsgWireMessage(ctx, maxTimeMS, dst, desc, conn, requestID)
1392 }
1393
1394
1395
1396 func (op Operation) addCommandFields(ctx context.Context, dst []byte, desc description.SelectedServer) ([]byte, error) {
1397 if !op.shouldEncrypt() {
1398 return op.CommandFn(dst, desc)
1399 }
1400
1401 if desc.WireVersion.Max < cryptMinWireVersion {
1402 return dst, errors.New("auto-encryption requires a MongoDB version of 4.2")
1403 }
1404
1405
1406 cidx, cmdDst := bsoncore.AppendDocumentStart(nil)
1407 var err error
1408 cmdDst, err = op.CommandFn(cmdDst, desc)
1409 if err != nil {
1410 return dst, err
1411 }
1412
1413 if op.Batches != nil && len(op.Batches.Current) > 0 {
1414 cmdDst = op.addBatchArray(cmdDst)
1415 }
1416 cmdDst, _ = bsoncore.AppendDocumentEnd(cmdDst, cidx)
1417
1418
1419 encrypted, err := op.Crypt.Encrypt(ctx, op.Database, cmdDst)
1420 if err != nil {
1421 return dst, err
1422 }
1423
1424 dst = append(dst, encrypted[4:len(encrypted)-1]...)
1425 return dst, nil
1426 }
1427
1428
1429 func (op Operation) addServerAPI(dst []byte) []byte {
1430 sa := op.ServerAPI
1431 if sa == nil {
1432 return dst
1433 }
1434
1435 dst = bsoncore.AppendStringElement(dst, "apiVersion", sa.ServerAPIVersion)
1436 if sa.Strict != nil {
1437 dst = bsoncore.AppendBooleanElement(dst, "apiStrict", *sa.Strict)
1438 }
1439 if sa.DeprecationErrors != nil {
1440 dst = bsoncore.AppendBooleanElement(dst, "apiDeprecationErrors", *sa.DeprecationErrors)
1441 }
1442 return dst
1443 }
1444
1445 func (op Operation) addReadConcern(dst []byte, desc description.SelectedServer) ([]byte, error) {
1446 if op.MinimumReadConcernWireVersion > 0 && (desc.WireVersion == nil || !desc.WireVersion.Includes(op.MinimumReadConcernWireVersion)) {
1447 return dst, nil
1448 }
1449 rc := op.ReadConcern
1450 client := op.Client
1451
1452 if client != nil && client.TransactionStarting() && client.CurrentRc != nil {
1453 rc = client.CurrentRc
1454 }
1455
1456
1457 if rc == nil && client != nil && client.TransactionStarting() && client.Consistent && client.OperationTime != nil {
1458 rc = readconcern.New()
1459 }
1460
1461 if client != nil && client.Snapshot {
1462 if desc.WireVersion.Max < readSnapshotMinWireVersion {
1463 return dst, errors.New("snapshot reads require MongoDB 5.0 or later")
1464 }
1465 rc = readconcern.Snapshot()
1466 }
1467
1468 if rc == nil {
1469 return dst, nil
1470 }
1471
1472 _, data, err := rc.MarshalBSONValue()
1473 if err != nil {
1474 return dst, err
1475 }
1476
1477 if sessionsSupported(desc.WireVersion) && client != nil {
1478 if client.Consistent && client.OperationTime != nil {
1479 data = data[:len(data)-1]
1480 data = bsoncore.AppendTimestampElement(data, "afterClusterTime", client.OperationTime.T, client.OperationTime.I)
1481 data, _ = bsoncore.AppendDocumentEnd(data, 0)
1482 }
1483 if client.Snapshot && client.SnapshotTime != nil {
1484 data = data[:len(data)-1]
1485 data = bsoncore.AppendTimestampElement(data, "atClusterTime", client.SnapshotTime.T, client.SnapshotTime.I)
1486 data, _ = bsoncore.AppendDocumentEnd(data, 0)
1487 }
1488 }
1489
1490 if len(data) == bsoncore.EmptyDocumentLength {
1491 return dst, nil
1492 }
1493 return bsoncore.AppendDocumentElement(dst, "readConcern", data), nil
1494 }
1495
1496 func (op Operation) addWriteConcern(dst []byte, desc description.SelectedServer) ([]byte, error) {
1497 if op.MinimumWriteConcernWireVersion > 0 && (desc.WireVersion == nil || !desc.WireVersion.Includes(op.MinimumWriteConcernWireVersion)) {
1498 return dst, nil
1499 }
1500 wc := op.WriteConcern
1501 if wc == nil {
1502 return dst, nil
1503 }
1504
1505 t, data, err := wc.MarshalBSONValue()
1506 if errors.Is(err, writeconcern.ErrEmptyWriteConcern) {
1507 return dst, nil
1508 }
1509 if err != nil {
1510 return dst, err
1511 }
1512
1513 return append(bsoncore.AppendHeader(dst, t, "writeConcern"), data...), nil
1514 }
1515
1516 func (op Operation) addSession(dst []byte, desc description.SelectedServer) ([]byte, error) {
1517 client := op.Client
1518
1519
1520
1521 if client != nil && !client.IsImplicit && desc.SessionTimeoutMinutesPtr == nil {
1522 return nil, fmt.Errorf("current topology does not support sessions")
1523 }
1524
1525 if client == nil || !sessionsSupported(desc.WireVersion) || desc.SessionTimeoutMinutesPtr == nil {
1526 return dst, nil
1527 }
1528 if err := client.UpdateUseTime(); err != nil {
1529 return dst, err
1530 }
1531 dst = bsoncore.AppendDocumentElement(dst, "lsid", client.SessionID)
1532
1533 var addedTxnNumber bool
1534 if op.Type == Write && client.RetryWrite {
1535 addedTxnNumber = true
1536 dst = bsoncore.AppendInt64Element(dst, "txnNumber", op.Client.TxnNumber)
1537 }
1538 if client.TransactionRunning() || client.RetryingCommit {
1539 if !addedTxnNumber {
1540 dst = bsoncore.AppendInt64Element(dst, "txnNumber", op.Client.TxnNumber)
1541 }
1542 if client.TransactionStarting() {
1543 dst = bsoncore.AppendBooleanElement(dst, "startTransaction", true)
1544 }
1545 dst = bsoncore.AppendBooleanElement(dst, "autocommit", false)
1546 }
1547
1548 return dst, client.ApplyCommand(desc.Server)
1549 }
1550
1551 func (op Operation) addClusterTime(dst []byte, desc description.SelectedServer) []byte {
1552 client, clock := op.Client, op.Clock
1553 if (clock == nil && client == nil) || !sessionsSupported(desc.WireVersion) {
1554 return dst
1555 }
1556 clusterTime := clock.GetClusterTime()
1557 if client != nil {
1558 clusterTime = session.MaxClusterTime(clusterTime, client.ClusterTime)
1559 }
1560 if clusterTime == nil {
1561 return dst
1562 }
1563 val, err := clusterTime.LookupErr("$clusterTime")
1564 if err != nil {
1565 return dst
1566 }
1567 return append(bsoncore.AppendHeader(dst, val.Type, "$clusterTime"), val.Value...)
1568
1569 }
1570
1571
1572
1573
1574
1575
1576 func (op Operation) calculateMaxTimeMS(ctx context.Context, mon RTTMonitor) (uint64, error) {
1577 if csot.IsTimeoutContext(ctx) {
1578 if op.OmitCSOTMaxTimeMS {
1579 return 0, nil
1580 }
1581
1582 if deadline, ok := ctx.Deadline(); ok {
1583 remainingTimeout := time.Until(deadline)
1584 rtt90 := mon.P90()
1585 maxTime := remainingTimeout - rtt90
1586
1587
1588
1589 maxTimeMS := int64((maxTime + (time.Millisecond - 1)) / time.Millisecond)
1590 if maxTimeMS <= 0 {
1591 return 0, fmt.Errorf(
1592 "negative maxTimeMS: remaining time %v until context deadline is less than 90th percentile network round-trip time (%v): %w",
1593 remainingTimeout,
1594 mon.Stats(),
1595 ErrDeadlineWouldBeExceeded)
1596 }
1597
1598
1599
1600
1601
1602
1603 if maxTimeMS > math.MaxInt32 {
1604 return 0, nil
1605 }
1606
1607 return uint64(maxTimeMS), nil
1608 }
1609 } else if op.MaxTime != nil {
1610
1611
1612 if *op.MaxTime < 0 {
1613 return 0, ErrNegativeMaxTime
1614 }
1615
1616
1617 return uint64((*op.MaxTime + (time.Millisecond - 1)) / time.Millisecond), nil
1618 }
1619 return 0, nil
1620 }
1621
1622
1623
1624
1625 func (op Operation) updateClusterTimes(response bsoncore.Document) {
1626
1627 value, err := response.LookupErr("$clusterTime")
1628 if err != nil {
1629
1630 return
1631 }
1632 clusterTime := bsoncore.BuildDocumentFromElements(nil, bsoncore.AppendValueElement(nil, "$clusterTime", value))
1633
1634 sess, clock := op.Client, op.Clock
1635
1636 if sess != nil {
1637 _ = sess.AdvanceClusterTime(bson.Raw(clusterTime))
1638 }
1639
1640 if clock != nil {
1641 clock.AdvanceClusterTime(bson.Raw(clusterTime))
1642 }
1643 }
1644
1645
1646
1647
1648 func (op Operation) updateOperationTime(response bsoncore.Document) {
1649 sess := op.Client
1650 if sess == nil {
1651 return
1652 }
1653
1654 opTimeElem, err := response.LookupErr("operationTime")
1655 if err != nil {
1656
1657 return
1658 }
1659
1660 t, i := opTimeElem.Timestamp()
1661 _ = sess.AdvanceOperationTime(&primitive.Timestamp{
1662 T: t,
1663 I: i,
1664 })
1665 }
1666
1667 func (op Operation) getReadPrefBasedOnTransaction() (*readpref.ReadPref, error) {
1668 if op.Client != nil && op.Client.TransactionRunning() {
1669
1670 rp := op.Client.CurrentRp
1671
1672
1673 if rp != nil && !op.Client.TransactionStarting() && rp.Mode() != readpref.PrimaryMode {
1674 return nil, ErrNonPrimaryReadPref
1675 }
1676 return rp, nil
1677 }
1678 return op.ReadPreference, nil
1679 }
1680
1681
1682
1683
1684 func (op Operation) createReadPref(desc description.SelectedServer, isOpQuery bool) (bsoncore.Document, error) {
1685 if op.omitReadPreference {
1686 return nil, nil
1687 }
1688
1689
1690
1691 if desc.Server.Kind == description.Standalone || (isOpQuery && desc.Server.Kind != description.Mongos) ||
1692 op.Type == Write || (op.IsOutputAggregate && desc.Server.WireVersion.Max < 13) {
1693
1694
1695
1696
1697
1698
1699 return nil, nil
1700 }
1701
1702 idx, doc := bsoncore.AppendDocumentStart(nil)
1703 rp, err := op.getReadPrefBasedOnTransaction()
1704 if err != nil {
1705 return nil, err
1706 }
1707
1708 if rp == nil {
1709 if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
1710 doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred")
1711 doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
1712 return doc, nil
1713 }
1714 return nil, nil
1715 }
1716
1717 switch rp.Mode() {
1718 case readpref.PrimaryMode:
1719 if desc.Server.Kind == description.Mongos {
1720 return nil, nil
1721 }
1722 if desc.Kind == description.Single {
1723 doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred")
1724 doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
1725 return doc, nil
1726 }
1727
1728
1729
1730
1731
1732
1733
1734 return nil, nil
1735 case readpref.PrimaryPreferredMode:
1736 doc = bsoncore.AppendStringElement(doc, "mode", "primaryPreferred")
1737 case readpref.SecondaryPreferredMode:
1738 _, ok := rp.MaxStaleness()
1739 if desc.Server.Kind == description.Mongos && isOpQuery && !ok && len(rp.TagSets()) == 0 && rp.HedgeEnabled() == nil {
1740 return nil, nil
1741 }
1742 doc = bsoncore.AppendStringElement(doc, "mode", "secondaryPreferred")
1743 case readpref.SecondaryMode:
1744 doc = bsoncore.AppendStringElement(doc, "mode", "secondary")
1745 case readpref.NearestMode:
1746 doc = bsoncore.AppendStringElement(doc, "mode", "nearest")
1747 }
1748
1749 sets := make([]bsoncore.Document, 0, len(rp.TagSets()))
1750 for _, ts := range rp.TagSets() {
1751 i, set := bsoncore.AppendDocumentStart(nil)
1752 for _, t := range ts {
1753 set = bsoncore.AppendStringElement(set, t.Name, t.Value)
1754 }
1755 set, _ = bsoncore.AppendDocumentEnd(set, i)
1756 sets = append(sets, set)
1757 }
1758 if len(sets) > 0 {
1759 var aidx int32
1760 aidx, doc = bsoncore.AppendArrayElementStart(doc, "tags")
1761 for i, set := range sets {
1762 doc = bsoncore.AppendDocumentElement(doc, strconv.Itoa(i), set)
1763 }
1764 doc, _ = bsoncore.AppendArrayEnd(doc, aidx)
1765 }
1766
1767 if d, ok := rp.MaxStaleness(); ok {
1768 doc = bsoncore.AppendInt32Element(doc, "maxStalenessSeconds", int32(d.Seconds()))
1769 }
1770
1771 if hedgeEnabled := rp.HedgeEnabled(); hedgeEnabled != nil {
1772 var hedgeIdx int32
1773 hedgeIdx, doc = bsoncore.AppendDocumentElementStart(doc, "hedge")
1774 doc = bsoncore.AppendBooleanElement(doc, "enabled", *hedgeEnabled)
1775 doc, err = bsoncore.AppendDocumentEnd(doc, hedgeIdx)
1776 if err != nil {
1777 return nil, fmt.Errorf("error creating hedge document: %w", err)
1778 }
1779 }
1780
1781 doc, _ = bsoncore.AppendDocumentEnd(doc, idx)
1782 return doc, nil
1783 }
1784
1785 func (op Operation) secondaryOK(desc description.SelectedServer) wiremessage.QueryFlag {
1786 if desc.Kind == description.Single && desc.Server.Kind != description.Mongos {
1787 return wiremessage.SecondaryOK
1788 }
1789
1790 if rp := op.ReadPreference; rp != nil && rp.Mode() != readpref.PrimaryMode {
1791 return wiremessage.SecondaryOK
1792 }
1793
1794 return 0
1795 }
1796
1797 func (Operation) canCompress(cmd string) bool {
1798 if cmd == handshake.LegacyHello || cmd == "hello" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "authenticate" ||
1799 cmd == "createUser" || cmd == "updateUser" || cmd == "copydbSaslStart" || cmd == "copydbgetnonce" || cmd == "copydb" {
1800 return false
1801 }
1802 return true
1803 }
1804
1805
1806
1807
1808 func (Operation) decodeOpReply(wm []byte) opReply {
1809 var reply opReply
1810 var ok bool
1811
1812 reply.responseFlags, wm, ok = wiremessage.ReadReplyFlags(wm)
1813 if !ok {
1814 reply.err = errors.New("malformed OP_REPLY: missing flags")
1815 return reply
1816 }
1817 reply.cursorID, wm, ok = wiremessage.ReadReplyCursorID(wm)
1818 if !ok {
1819 reply.err = errors.New("malformed OP_REPLY: missing cursorID")
1820 return reply
1821 }
1822 reply.startingFrom, wm, ok = wiremessage.ReadReplyStartingFrom(wm)
1823 if !ok {
1824 reply.err = errors.New("malformed OP_REPLY: missing startingFrom")
1825 return reply
1826 }
1827 reply.numReturned, wm, ok = wiremessage.ReadReplyNumberReturned(wm)
1828 if !ok {
1829 reply.err = errors.New("malformed OP_REPLY: missing numberReturned")
1830 return reply
1831 }
1832 reply.documents, _, ok = wiremessage.ReadReplyDocuments(wm)
1833 if !ok {
1834 reply.err = errors.New("malformed OP_REPLY: could not read documents from reply")
1835 }
1836
1837 if reply.responseFlags&wiremessage.QueryFailure == wiremessage.QueryFailure {
1838 reply.err = QueryFailureError{
1839 Message: "command failure",
1840 Response: reply.documents[0],
1841 }
1842 return reply
1843 }
1844 if reply.responseFlags&wiremessage.CursorNotFound == wiremessage.CursorNotFound {
1845 reply.err = ErrCursorNotFound
1846 return reply
1847 }
1848 if reply.numReturned != int32(len(reply.documents)) {
1849 reply.err = ErrReplyDocumentMismatch
1850 return reply
1851 }
1852
1853 return reply
1854 }
1855
1856 func (op Operation) decodeResult(ctx context.Context, opcode wiremessage.OpCode, wm []byte) (bsoncore.Document, error) {
1857 switch opcode {
1858 case wiremessage.OpReply:
1859 reply := op.decodeOpReply(wm)
1860 if reply.err != nil {
1861 return nil, reply.err
1862 }
1863 if reply.numReturned == 0 {
1864 return nil, ErrNoDocCommandResponse
1865 }
1866 if reply.numReturned > 1 {
1867 return nil, ErrMultiDocCommandResponse
1868 }
1869 rdr := reply.documents[0]
1870 if err := rdr.Validate(); err != nil {
1871 return nil, NewCommandResponseError("malformed OP_REPLY: invalid document", err)
1872 }
1873
1874 return rdr, ExtractErrorFromServerResponse(ctx, rdr)
1875 case wiremessage.OpMsg:
1876 _, wm, ok := wiremessage.ReadMsgFlags(wm)
1877 if !ok {
1878 return nil, errors.New("malformed wire message: missing OP_MSG flags")
1879 }
1880
1881 var res bsoncore.Document
1882 for len(wm) > 0 {
1883 var stype wiremessage.SectionType
1884 stype, wm, ok = wiremessage.ReadMsgSectionType(wm)
1885 if !ok {
1886 return nil, errors.New("malformed wire message: insuffienct bytes to read section type")
1887 }
1888
1889 switch stype {
1890 case wiremessage.SingleDocument:
1891 res, wm, ok = wiremessage.ReadMsgSectionSingleDocument(wm)
1892 if !ok {
1893 return nil, errors.New("malformed wire message: insufficient bytes to read single document")
1894 }
1895 case wiremessage.DocumentSequence:
1896
1897 _, _, wm, ok = wiremessage.ReadMsgSectionDocumentSequence(wm)
1898 if !ok {
1899 return nil, errors.New("malformed wire message: insufficient bytes to read document sequence")
1900 }
1901 default:
1902 return nil, fmt.Errorf("malformed wire message: unknown section type %v", stype)
1903 }
1904 }
1905
1906 err := res.Validate()
1907 if err != nil {
1908 return nil, NewCommandResponseError("malformed OP_MSG: invalid document", err)
1909 }
1910
1911 return res, ExtractErrorFromServerResponse(ctx, res)
1912 default:
1913 return nil, fmt.Errorf("cannot decode result from %s", opcode)
1914 }
1915 }
1916
1917
1918 func (op Operation) getCommandName(doc []byte) string {
1919
1920 idx := bytes.IndexByte(doc[5:], 0x00)
1921 return string(doc[5 : idx+5])
1922 }
1923
1924 func (op *Operation) redactCommand(cmd string, doc bsoncore.Document) bool {
1925 if cmd == "authenticate" || cmd == "saslStart" || cmd == "saslContinue" || cmd == "getnonce" || cmd == "createUser" ||
1926 cmd == "updateUser" || cmd == "copydbgetnonce" || cmd == "copydbsaslstart" || cmd == "copydb" {
1927
1928 return true
1929 }
1930 if strings.ToLower(cmd) != handshake.LegacyHelloLowercase && cmd != "hello" {
1931 return false
1932 }
1933
1934
1935 _, err := doc.LookupErr("speculativeAuthenticate")
1936 return err == nil
1937 }
1938
1939
1940 func (op Operation) canLogCommandMessage() bool {
1941 return op.Logger != nil && op.Logger.LevelComponentEnabled(logger.LevelDebug, logger.ComponentCommand)
1942 }
1943
1944 func (op Operation) canPublishStartedEvent() bool {
1945 return op.CommandMonitor != nil && op.CommandMonitor.Started != nil
1946 }
1947
1948
1949
1950
1951 func (op Operation) publishStartedEvent(ctx context.Context, info startedInformation) {
1952
1953 if op.canLogCommandMessage() {
1954 host, port, _ := net.SplitHostPort(info.serverAddress.String())
1955
1956 redactedCmd := redactStartedInformationCmd(op, info).String()
1957 formattedCmd := logger.FormatMessage(redactedCmd, op.Logger.MaxDocumentLength)
1958
1959 op.Logger.Print(logger.LevelDebug,
1960 logger.ComponentCommand,
1961 logger.CommandStarted,
1962 logger.SerializeCommand(logger.Command{
1963 DriverConnectionID: info.driverConnectionID,
1964 Message: logger.CommandStarted,
1965 Name: info.cmdName,
1966 DatabaseName: op.Database,
1967 RequestID: int64(info.requestID),
1968 ServerConnectionID: info.serverConnID,
1969 ServerHost: host,
1970 ServerPort: port,
1971 ServiceID: info.serviceID,
1972 },
1973 logger.KeyCommand, formattedCmd)...)
1974
1975 }
1976
1977 if op.canPublishStartedEvent() {
1978 started := &event.CommandStartedEvent{
1979 Command: redactStartedInformationCmd(op, info),
1980 DatabaseName: op.Database,
1981 CommandName: info.cmdName,
1982 RequestID: int64(info.requestID),
1983 ConnectionID: info.connID,
1984 ServerConnectionID: convertInt64PtrToInt32Ptr(info.serverConnID),
1985 ServerConnectionID64: info.serverConnID,
1986 ServiceID: info.serviceID,
1987 }
1988 op.CommandMonitor.Started(ctx, started)
1989 }
1990 }
1991
1992
1993
1994
1995 func (op Operation) canPublishFinishedEvent(info finishedInformation) bool {
1996 success := info.success()
1997
1998 return op.CommandMonitor != nil &&
1999 (!success || op.CommandMonitor.Succeeded != nil) &&
2000 (success || op.CommandMonitor.Failed != nil)
2001 }
2002
2003
2004
2005 func (op Operation) publishFinishedEvent(ctx context.Context, info finishedInformation) {
2006 if op.canLogCommandMessage() && info.success() {
2007 host, port, _ := net.SplitHostPort(info.serverAddress.String())
2008
2009 redactedReply := redactFinishedInformationResponse(info).String()
2010 formattedReply := logger.FormatMessage(redactedReply, op.Logger.MaxDocumentLength)
2011
2012 op.Logger.Print(logger.LevelDebug,
2013 logger.ComponentCommand,
2014 logger.CommandSucceeded,
2015 logger.SerializeCommand(logger.Command{
2016 DriverConnectionID: info.driverConnectionID,
2017 Message: logger.CommandSucceeded,
2018 Name: info.cmdName,
2019 DatabaseName: op.Database,
2020 RequestID: int64(info.requestID),
2021 ServerConnectionID: info.serverConnID,
2022 ServerHost: host,
2023 ServerPort: port,
2024 ServiceID: info.serviceID,
2025 },
2026 logger.KeyDurationMS, info.duration.Milliseconds(),
2027 logger.KeyReply, formattedReply)...)
2028 }
2029
2030 if op.canLogCommandMessage() && !info.success() {
2031 host, port, _ := net.SplitHostPort(info.serverAddress.String())
2032
2033 formattedReply := logger.FormatMessage(info.cmdErr.Error(), op.Logger.MaxDocumentLength)
2034
2035 op.Logger.Print(logger.LevelDebug,
2036 logger.ComponentCommand,
2037 logger.CommandFailed,
2038 logger.SerializeCommand(logger.Command{
2039 DriverConnectionID: info.driverConnectionID,
2040 Message: logger.CommandFailed,
2041 Name: info.cmdName,
2042 DatabaseName: op.Database,
2043 RequestID: int64(info.requestID),
2044 ServerConnectionID: info.serverConnID,
2045 ServerHost: host,
2046 ServerPort: port,
2047 ServiceID: info.serviceID,
2048 },
2049 logger.KeyDurationMS, info.duration.Milliseconds(),
2050 logger.KeyFailure, formattedReply)...)
2051 }
2052
2053
2054 if !op.canPublishFinishedEvent(info) {
2055 return
2056 }
2057
2058 finished := event.CommandFinishedEvent{
2059 CommandName: info.cmdName,
2060 DatabaseName: op.Database,
2061 RequestID: int64(info.requestID),
2062 ConnectionID: info.connID,
2063 Duration: info.duration,
2064 DurationNanos: info.duration.Nanoseconds(),
2065 ServerConnectionID: convertInt64PtrToInt32Ptr(info.serverConnID),
2066 ServerConnectionID64: info.serverConnID,
2067 ServiceID: info.serviceID,
2068 }
2069
2070 if info.success() {
2071 successEvent := &event.CommandSucceededEvent{
2072 Reply: redactFinishedInformationResponse(info),
2073 CommandFinishedEvent: finished,
2074 }
2075 op.CommandMonitor.Succeeded(ctx, successEvent)
2076
2077 return
2078 }
2079
2080 failedEvent := &event.CommandFailedEvent{
2081 Failure: info.cmdErr.Error(),
2082 CommandFinishedEvent: finished,
2083 }
2084 op.CommandMonitor.Failed(ctx, failedEvent)
2085 }
2086
2087
2088 func sessionsSupported(wireVersion *description.VersionRange) bool {
2089 return wireVersion != nil
2090 }
2091
2092
2093 func retryWritesSupported(s description.Server) bool {
2094 return s.SessionTimeoutMinutesPtr != nil && s.Kind != description.Standalone
2095 }
2096
View as plain text