1
2
3
4
5
6
7 package topology
8
9 import (
10 "context"
11 "crypto/tls"
12 "errors"
13 "fmt"
14 "io"
15 "net"
16 "strings"
17 "sync"
18 "sync/atomic"
19 "time"
20
21 "go.mongodb.org/mongo-driver/internal/csot"
22 "go.mongodb.org/mongo-driver/mongo/address"
23 "go.mongodb.org/mongo-driver/mongo/description"
24 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
25 "go.mongodb.org/mongo-driver/x/mongo/driver"
26 "go.mongodb.org/mongo-driver/x/mongo/driver/ocsp"
27 "go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"
28 )
29
30
31 const (
32 connDisconnected int64 = iota
33 connConnected
34 connInitialized
35 )
36
37 var globalConnectionID uint64 = 1
38
39 var (
40 defaultMaxMessageSize uint32 = 48000000
41 errResponseTooLarge = errors.New("length of read message too large")
42 errLoadBalancedStateMismatch = errors.New("driver attempted to initialize in load balancing mode, but the server does not support this mode")
43 )
44
45 func nextConnectionID() uint64 { return atomic.AddUint64(&globalConnectionID, 1) }
46
47 type connection struct {
48
49
50
51 state int64
52
53 id string
54 nc net.Conn
55 addr address.Address
56 idleTimeout time.Duration
57 idleDeadline atomic.Value
58 readTimeout time.Duration
59 writeTimeout time.Duration
60 desc description.Server
61 helloRTT time.Duration
62 compressor wiremessage.CompressorID
63 zliblevel int
64 zstdLevel int
65 connectDone chan struct{}
66 config *connectionConfig
67 cancelConnectContext context.CancelFunc
68 connectContextMade chan struct{}
69 canStream bool
70 currentlyStreaming bool
71 connectContextMutex sync.Mutex
72 cancellationListener cancellationListener
73 serverConnectionID *int64
74
75
76 pool *pool
77
78
79 driverConnectionID uint64
80 generation uint64
81
82
83
84 awaitingResponse bool
85 }
86
87
88 func newConnection(addr address.Address, opts ...ConnectionOption) *connection {
89 cfg := newConnectionConfig(opts...)
90
91 id := fmt.Sprintf("%s[-%d]", addr, nextConnectionID())
92
93 c := &connection{
94 id: id,
95 addr: addr,
96 idleTimeout: cfg.idleTimeout,
97 readTimeout: cfg.readTimeout,
98 writeTimeout: cfg.writeTimeout,
99 connectDone: make(chan struct{}),
100 config: cfg,
101 connectContextMade: make(chan struct{}),
102 cancellationListener: newCancellListener(),
103 }
104
105
106 if !c.config.loadBalanced {
107 c.setGenerationNumber()
108 }
109 atomic.StoreInt64(&c.state, connInitialized)
110
111 return c
112 }
113
114
115
116 func (c *connection) DriverConnectionID() uint64 {
117 return c.driverConnectionID
118 }
119
120
121
122 func (c *connection) setGenerationNumber() {
123 if c.config.getGenerationFn != nil {
124 c.generation = c.config.getGenerationFn(c.desc.ServiceID)
125 }
126 }
127
128
129
130 func (c *connection) hasGenerationNumber() bool {
131 if !c.config.loadBalanced {
132
133 return true
134 }
135
136
137
138 return c.desc.LoadBalanced()
139 }
140
141
142
143
144 func (c *connection) connect(ctx context.Context) (err error) {
145 if !atomic.CompareAndSwapInt64(&c.state, connInitialized, connConnected) {
146 return nil
147 }
148
149 defer close(c.connectDone)
150
151
152
153 defer func() {
154 if err != nil {
155 atomic.StoreInt64(&c.state, connDisconnected)
156
157 if c.nc != nil {
158 _ = c.nc.Close()
159 }
160 }
161 }()
162
163
164
165
166
167
168
169
170
171
172
173 c.connectContextMutex.Lock()
174 var handshakeCtx context.Context
175 handshakeCtx, c.cancelConnectContext = context.WithCancel(ctx)
176 c.connectContextMutex.Unlock()
177
178 dialCtx := handshakeCtx
179 var dialCancel context.CancelFunc
180 if c.config.connectTimeout != 0 {
181 dialCtx, dialCancel = context.WithTimeout(handshakeCtx, c.config.connectTimeout)
182 defer dialCancel()
183 }
184
185 defer func() {
186 var cancelFn context.CancelFunc
187
188 c.connectContextMutex.Lock()
189 cancelFn = c.cancelConnectContext
190 c.cancelConnectContext = nil
191 c.connectContextMutex.Unlock()
192
193 if cancelFn != nil {
194 cancelFn()
195 }
196 }()
197
198 close(c.connectContextMade)
199
200
201 tempNc, err := c.config.dialer.DialContext(dialCtx, c.addr.Network(), c.addr.String())
202 if err != nil {
203 return ConnectionError{Wrapped: err, init: true}
204 }
205 c.nc = tempNc
206
207 if c.config.tlsConfig != nil {
208 tlsConfig := c.config.tlsConfig.Clone()
209
210
211
212 ocspOpts := &ocsp.VerifyOptions{
213 Cache: c.config.ocspCache,
214 DisableEndpointChecking: c.config.disableOCSPEndpointCheck,
215 HTTPClient: c.config.httpClient,
216 }
217 tlsNc, err := configureTLS(dialCtx, c.config.tlsConnectionSource, c.nc, c.addr, tlsConfig, ocspOpts)
218 if err != nil {
219 return ConnectionError{Wrapped: err, init: true}
220 }
221 c.nc = tlsNc
222 }
223
224
225 handshaker := c.config.handshaker
226 if handshaker == nil {
227 return nil
228 }
229
230 var handshakeInfo driver.HandshakeInformation
231 handshakeStartTime := time.Now()
232 handshakeConn := initConnection{c}
233 handshakeInfo, err = handshaker.GetHandshakeInformation(handshakeCtx, c.addr, handshakeConn)
234 if err == nil {
235
236
237 c.desc = handshakeInfo.Description
238 c.serverConnectionID = handshakeInfo.ServerConnectionID
239 c.helloRTT = time.Since(handshakeStartTime)
240
241
242
243 if c.config.loadBalanced && c.desc.ServiceID == nil {
244 err = errLoadBalancedStateMismatch
245 }
246 }
247 if err == nil {
248
249
250
251 if c.config.loadBalanced {
252 c.setGenerationNumber()
253 }
254
255
256
257 err = handshaker.FinishHandshake(handshakeCtx, handshakeConn)
258 }
259
260
261 if err != nil {
262 return ConnectionError{Wrapped: err, init: true}
263 }
264
265 if len(c.desc.Compression) > 0 {
266 clientMethodLoop:
267 for _, method := range c.config.compressors {
268 for _, serverMethod := range c.desc.Compression {
269 if method != serverMethod {
270 continue
271 }
272
273 switch strings.ToLower(method) {
274 case "snappy":
275 c.compressor = wiremessage.CompressorSnappy
276 case "zlib":
277 c.compressor = wiremessage.CompressorZLib
278 c.zliblevel = wiremessage.DefaultZlibLevel
279 if c.config.zlibLevel != nil {
280 c.zliblevel = *c.config.zlibLevel
281 }
282 case "zstd":
283 c.compressor = wiremessage.CompressorZstd
284 c.zstdLevel = wiremessage.DefaultZstdLevel
285 if c.config.zstdLevel != nil {
286 c.zstdLevel = *c.config.zstdLevel
287 }
288 }
289 break clientMethodLoop
290 }
291 }
292 }
293 return nil
294 }
295
296 func (c *connection) wait() {
297 if c.connectDone != nil {
298 <-c.connectDone
299 }
300 }
301
302 func (c *connection) closeConnectContext() {
303 <-c.connectContextMade
304 var cancelFn context.CancelFunc
305
306 c.connectContextMutex.Lock()
307 cancelFn = c.cancelConnectContext
308 c.cancelConnectContext = nil
309 c.connectContextMutex.Unlock()
310
311 if cancelFn != nil {
312 cancelFn()
313 }
314 }
315
316 func transformNetworkError(ctx context.Context, originalError error, contextDeadlineUsed bool) error {
317 if originalError == nil {
318 return nil
319 }
320
321
322 if errors.Is(ctx.Err(), context.Canceled) {
323 return context.Canceled
324 }
325
326
327
328 if !contextDeadlineUsed {
329 return originalError
330 }
331 if netErr, ok := originalError.(net.Error); ok && netErr.Timeout() {
332 return context.DeadlineExceeded
333 }
334
335 return originalError
336 }
337
338 func (c *connection) cancellationListenerCallback() {
339 _ = c.close()
340 }
341
342 func (c *connection) writeWireMessage(ctx context.Context, wm []byte) error {
343 var err error
344 if atomic.LoadInt64(&c.state) != connConnected {
345 return ConnectionError{
346 ConnectionID: c.id,
347 message: "connection is closed",
348 }
349 }
350
351 var deadline time.Time
352 if c.writeTimeout != 0 {
353 deadline = time.Now().Add(c.writeTimeout)
354 }
355
356 var contextDeadlineUsed bool
357 if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
358 contextDeadlineUsed = true
359 deadline = dl
360 }
361
362 if err := c.nc.SetWriteDeadline(deadline); err != nil {
363 return ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set write deadline"}
364 }
365
366 err = c.write(ctx, wm)
367 if err != nil {
368 c.close()
369 return ConnectionError{
370 ConnectionID: c.id,
371 Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
372 message: "unable to write wire message to network",
373 }
374 }
375
376 return nil
377 }
378
379 func (c *connection) write(ctx context.Context, wm []byte) (err error) {
380 go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
381 defer func() {
382
383
384
385
386
387 if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
388 err = context.Canceled
389 }
390 }()
391
392 _, err = c.nc.Write(wm)
393 return err
394 }
395
396
397 func (c *connection) readWireMessage(ctx context.Context) ([]byte, error) {
398 if atomic.LoadInt64(&c.state) != connConnected {
399 return nil, ConnectionError{
400 ConnectionID: c.id,
401 message: "connection is closed",
402 }
403 }
404
405 var deadline time.Time
406 if c.readTimeout != 0 {
407 deadline = time.Now().Add(c.readTimeout)
408 }
409
410 var contextDeadlineUsed bool
411 if dl, ok := ctx.Deadline(); ok && (deadline.IsZero() || dl.Before(deadline)) {
412 contextDeadlineUsed = true
413 deadline = dl
414 }
415
416 if err := c.nc.SetReadDeadline(deadline); err != nil {
417 return nil, ConnectionError{ConnectionID: c.id, Wrapped: err, message: "failed to set read deadline"}
418 }
419
420 dst, errMsg, err := c.read(ctx)
421 if err != nil {
422 if nerr := net.Error(nil); errors.As(err, &nerr) && nerr.Timeout() && csot.IsTimeoutContext(ctx) {
423
424
425
426
427 c.awaitingResponse = true
428 } else {
429
430
431 c.close()
432 }
433 message := errMsg
434 if errors.Is(err, io.EOF) {
435 message = "socket was unexpectedly closed"
436 }
437 return nil, ConnectionError{
438 ConnectionID: c.id,
439 Wrapped: transformNetworkError(ctx, err, contextDeadlineUsed),
440 message: message,
441 }
442 }
443
444 return dst, nil
445 }
446
447 func (c *connection) read(ctx context.Context) (bytesRead []byte, errMsg string, err error) {
448 go c.cancellationListener.Listen(ctx, c.cancellationListenerCallback)
449 defer func() {
450
451
452
453
454 if aborted := c.cancellationListener.StopListening(); aborted && err == nil {
455 errMsg = "unable to read server response"
456 err = context.Canceled
457 }
458 }()
459
460
461
462 var sizeBuf [4]byte
463
464
465
466
467 _, err = io.ReadFull(c.nc, sizeBuf[:])
468 if err != nil {
469 return nil, "incomplete read of message header", err
470 }
471
472
473 size := (int32(sizeBuf[0])) | (int32(sizeBuf[1]) << 8) | (int32(sizeBuf[2]) << 16) | (int32(sizeBuf[3]) << 24)
474
475
476
477 maxMessageSize := c.desc.MaxMessageSize
478 if maxMessageSize == 0 {
479 maxMessageSize = defaultMaxMessageSize
480 }
481 if uint32(size) > maxMessageSize {
482 return nil, errResponseTooLarge.Error(), errResponseTooLarge
483 }
484
485 dst := make([]byte, size)
486 copy(dst, sizeBuf[:])
487
488 _, err = io.ReadFull(c.nc, dst[4:])
489 if err != nil {
490 return dst, "incomplete read of full message", err
491 }
492
493 return dst, "", nil
494 }
495
496 func (c *connection) close() error {
497
498 if !atomic.CompareAndSwapInt64(&c.state, connConnected, connDisconnected) {
499 return nil
500 }
501
502 var err error
503 if c.nc != nil {
504 err = c.nc.Close()
505 }
506
507 return err
508 }
509
510 func (c *connection) closed() bool {
511 return atomic.LoadInt64(&c.state) == connDisconnected
512 }
513
514 func (c *connection) idleTimeoutExpired() bool {
515 now := time.Now()
516 if c.idleTimeout > 0 {
517 idleDeadline, ok := c.idleDeadline.Load().(time.Time)
518 if ok && now.After(idleDeadline) {
519 return true
520 }
521 }
522
523 return false
524 }
525
526 func (c *connection) bumpIdleDeadline() {
527 if c.idleTimeout > 0 {
528 c.idleDeadline.Store(time.Now().Add(c.idleTimeout))
529 }
530 }
531
532 func (c *connection) setCanStream(canStream bool) {
533 c.canStream = canStream
534 }
535
536 func (c initConnection) supportsStreaming() bool {
537 return c.canStream
538 }
539
540 func (c *connection) setStreaming(streaming bool) {
541 c.currentlyStreaming = streaming
542 }
543
544 func (c *connection) getCurrentlyStreaming() bool {
545 return c.currentlyStreaming
546 }
547
548 func (c *connection) setSocketTimeout(timeout time.Duration) {
549 c.readTimeout = timeout
550 c.writeTimeout = timeout
551 }
552
553 func (c *connection) ID() string {
554 return c.id
555 }
556
557 func (c *connection) ServerConnectionID() *int64 {
558 return c.serverConnectionID
559 }
560
561
562
563
564 type initConnection struct{ *connection }
565
566 var _ driver.Connection = initConnection{}
567 var _ driver.StreamerConnection = initConnection{}
568
569 func (c initConnection) Description() description.Server {
570 if c.connection == nil {
571 return description.Server{}
572 }
573 return c.connection.desc
574 }
575 func (c initConnection) Close() error { return nil }
576 func (c initConnection) ID() string { return c.id }
577 func (c initConnection) Address() address.Address { return c.addr }
578 func (c initConnection) Stale() bool { return false }
579 func (c initConnection) LocalAddress() address.Address {
580 if c.connection == nil || c.nc == nil {
581 return address.Address("0.0.0.0")
582 }
583 return address.Address(c.nc.LocalAddr().String())
584 }
585 func (c initConnection) WriteWireMessage(ctx context.Context, wm []byte) error {
586 return c.writeWireMessage(ctx, wm)
587 }
588 func (c initConnection) ReadWireMessage(ctx context.Context) ([]byte, error) {
589 return c.readWireMessage(ctx)
590 }
591 func (c initConnection) SetStreaming(streaming bool) {
592 c.setStreaming(streaming)
593 }
594 func (c initConnection) CurrentlyStreaming() bool {
595 return c.getCurrentlyStreaming()
596 }
597 func (c initConnection) SupportsStreaming() bool {
598 return c.supportsStreaming()
599 }
600
601
602
603
604 type Connection struct {
605 connection *connection
606 refCount int
607 cleanupPoolFn func()
608
609
610
611 cleanupServerFn func()
612
613 mu sync.RWMutex
614 }
615
616 var _ driver.Connection = (*Connection)(nil)
617 var _ driver.Expirable = (*Connection)(nil)
618 var _ driver.PinnedConnection = (*Connection)(nil)
619
620
621 func (c *Connection) WriteWireMessage(ctx context.Context, wm []byte) error {
622 c.mu.RLock()
623 defer c.mu.RUnlock()
624 if c.connection == nil {
625 return ErrConnectionClosed
626 }
627 return c.connection.writeWireMessage(ctx, wm)
628 }
629
630
631
632 func (c *Connection) ReadWireMessage(ctx context.Context) ([]byte, error) {
633 c.mu.RLock()
634 defer c.mu.RUnlock()
635 if c.connection == nil {
636 return nil, ErrConnectionClosed
637 }
638 return c.connection.readWireMessage(ctx)
639 }
640
641
642
643
644 func (c *Connection) CompressWireMessage(src, dst []byte) ([]byte, error) {
645 c.mu.RLock()
646 defer c.mu.RUnlock()
647 if c.connection == nil {
648 return dst, ErrConnectionClosed
649 }
650 if c.connection.compressor == wiremessage.CompressorNoOp {
651 return append(dst, src...), nil
652 }
653 _, reqid, respto, origcode, rem, ok := wiremessage.ReadHeader(src)
654 if !ok {
655 return dst, errors.New("wiremessage is too short to compress, less than 16 bytes")
656 }
657 idx, dst := wiremessage.AppendHeaderStart(dst, reqid, respto, wiremessage.OpCompressed)
658 dst = wiremessage.AppendCompressedOriginalOpCode(dst, origcode)
659 dst = wiremessage.AppendCompressedUncompressedSize(dst, int32(len(rem)))
660 dst = wiremessage.AppendCompressedCompressorID(dst, c.connection.compressor)
661 opts := driver.CompressionOpts{
662 Compressor: c.connection.compressor,
663 ZlibLevel: c.connection.zliblevel,
664 ZstdLevel: c.connection.zstdLevel,
665 }
666 compressed, err := driver.CompressPayload(rem, opts)
667 if err != nil {
668 return nil, err
669 }
670 dst = wiremessage.AppendCompressedCompressedMessage(dst, compressed)
671 return bsoncore.UpdateLength(dst, idx, int32(len(dst[idx:]))), nil
672 }
673
674
675 func (c *Connection) Description() description.Server {
676 c.mu.RLock()
677 defer c.mu.RUnlock()
678 if c.connection == nil {
679 return description.Server{}
680 }
681 return c.connection.desc
682 }
683
684
685
686 func (c *Connection) Close() error {
687 c.mu.Lock()
688 defer c.mu.Unlock()
689 if c.connection == nil || c.refCount > 0 {
690 return nil
691 }
692
693 return c.cleanupReferences()
694 }
695
696
697 func (c *Connection) Expire() error {
698 c.mu.Lock()
699 defer c.mu.Unlock()
700 if c.connection == nil {
701 return nil
702 }
703
704 _ = c.connection.close()
705 return c.cleanupReferences()
706 }
707
708 func (c *Connection) cleanupReferences() error {
709 err := c.connection.pool.checkIn(c.connection)
710 if c.cleanupPoolFn != nil {
711 c.cleanupPoolFn()
712 c.cleanupPoolFn = nil
713 }
714 if c.cleanupServerFn != nil {
715 c.cleanupServerFn()
716 c.cleanupServerFn = nil
717 }
718 c.connection = nil
719 return err
720 }
721
722
723 func (c *Connection) Alive() bool {
724 return c.connection != nil
725 }
726
727
728 func (c *Connection) ID() string {
729 c.mu.RLock()
730 defer c.mu.RUnlock()
731 if c.connection == nil {
732 return "<closed>"
733 }
734 return c.connection.id
735 }
736
737
738 func (c *Connection) ServerConnectionID() *int64 {
739 if c.connection == nil {
740 return nil
741 }
742 return c.connection.serverConnectionID
743 }
744
745
746 func (c *Connection) Stale() bool {
747 c.mu.RLock()
748 defer c.mu.RUnlock()
749 return c.connection.pool.stale(c.connection)
750 }
751
752
753 func (c *Connection) Address() address.Address {
754 c.mu.RLock()
755 defer c.mu.RUnlock()
756 if c.connection == nil {
757 return address.Address("0.0.0.0")
758 }
759 return c.connection.addr
760 }
761
762
763 func (c *Connection) LocalAddress() address.Address {
764 c.mu.RLock()
765 defer c.mu.RUnlock()
766 if c.connection == nil || c.connection.nc == nil {
767 return address.Address("0.0.0.0")
768 }
769 return address.Address(c.connection.nc.LocalAddr().String())
770 }
771
772
773 func (c *Connection) PinToCursor() error {
774 return c.pin("cursor", c.connection.pool.pinConnectionToCursor, c.connection.pool.unpinConnectionFromCursor)
775 }
776
777
778 func (c *Connection) PinToTransaction() error {
779 return c.pin("transaction", c.connection.pool.pinConnectionToTransaction, c.connection.pool.unpinConnectionFromTransaction)
780 }
781
782 func (c *Connection) pin(reason string, updatePoolFn, cleanupPoolFn func()) error {
783 c.mu.Lock()
784 defer c.mu.Unlock()
785 if c.connection == nil {
786 return fmt.Errorf("attempted to pin a connection for a %s, but the connection has already been returned to the pool", reason)
787 }
788
789
790
791 if c.refCount == 0 {
792 updatePoolFn()
793 c.cleanupPoolFn = cleanupPoolFn
794 }
795 c.refCount++
796 return nil
797 }
798
799
800 func (c *Connection) UnpinFromCursor() error {
801 return c.unpin("cursor")
802 }
803
804
805 func (c *Connection) UnpinFromTransaction() error {
806 return c.unpin("transaction")
807 }
808
809 func (c *Connection) unpin(reason string) error {
810 c.mu.Lock()
811 defer c.mu.Unlock()
812 if c.connection == nil {
813
814 return nil
815 }
816 if c.refCount == 0 {
817 return fmt.Errorf("attempted to unpin a connection from a %s, but the connection is not pinned by any resources", reason)
818 }
819
820 c.refCount--
821 return nil
822 }
823
824
825
826 func (c *Connection) DriverConnectionID() uint64 {
827 return c.connection.DriverConnectionID()
828 }
829
830 func configureTLS(ctx context.Context,
831 tlsConnSource tlsConnectionSource,
832 nc net.Conn,
833 addr address.Address,
834 config *tls.Config,
835 ocspOpts *ocsp.VerifyOptions,
836 ) (net.Conn, error) {
837
838 if config.ServerName == "" {
839 hostname := addr.String()
840 colonPos := strings.LastIndex(hostname, ":")
841 if colonPos == -1 {
842 colonPos = len(hostname)
843 }
844
845 hostname = hostname[:colonPos]
846 config.ServerName = hostname
847 }
848
849 client := tlsConnSource.Client(nc, config)
850 if err := clientHandshake(ctx, client); err != nil {
851 return nil, err
852 }
853
854
855 if !config.InsecureSkipVerify {
856 if ocspErr := ocsp.Verify(ctx, client.ConnectionState(), ocspOpts); ocspErr != nil {
857 return nil, ocspErr
858 }
859 }
860 return client, nil
861 }
862
863
864
865
866
867 type cancellListener struct {
868 aborted bool
869 done chan struct{}
870 }
871
872
873 func newCancellListener() *cancellListener {
874 return &cancellListener{
875 done: make(chan struct{}),
876 }
877 }
878
879
880
881
882
883
884 func (c *cancellListener) Listen(ctx context.Context, abortFn func()) {
885 c.aborted = false
886
887 select {
888 case <-ctx.Done():
889 if errors.Is(ctx.Err(), context.Canceled) {
890 c.aborted = true
891 abortFn()
892 }
893
894 <-c.done
895 case <-c.done:
896 }
897 }
898
899
900
901
902 func (c *cancellListener) StopListening() bool {
903 c.done <- struct{}{}
904 return c.aborted
905 }
906
View as plain text