...

Source file src/go.mongodb.org/mongo-driver/mongo/integration/mtest/proxy_dialer.go

Documentation: go.mongodb.org/mongo-driver/mongo/integration/mtest

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package mtest
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"net"
    14  	"sync"
    15  	"time"
    16  
    17  	"go.mongodb.org/mongo-driver/mongo/options"
    18  	"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
    19  )
    20  
    21  // ProxyMessage represents a sent/received pair of parsed wire messages.
    22  type ProxyMessage struct {
    23  	ServerAddress string
    24  	CommandName   string
    25  	Sent          *SentMessage
    26  	Received      *ReceivedMessage
    27  }
    28  
    29  // proxyDialer is a ContextDialer implementation that wraps a net.Dialer and records the messages sent and received
    30  // using connections created through it.
    31  type proxyDialer struct {
    32  	*net.Dialer
    33  	sync.Mutex
    34  
    35  	messages []*ProxyMessage
    36  	// sentMap temporarily stores the message sent to the server using the requestID so it can map requests to their
    37  	// responses.
    38  	sentMap sync.Map
    39  	// addressTranslations maps dialed addresses to the remote addresses reported by the created connections if they
    40  	// differ. This can happen if a connection is dialed to a host name, in which case the reported remote address will
    41  	// be the resolved IP address.
    42  	addressTranslations sync.Map
    43  }
    44  
    45  var _ options.ContextDialer = (*proxyDialer)(nil)
    46  
    47  func newProxyDialer() *proxyDialer {
    48  	return &proxyDialer{
    49  		Dialer: &net.Dialer{Timeout: 30 * time.Second},
    50  	}
    51  }
    52  
    53  func newProxyErrorWithWireMsg(wm []byte, err error) error {
    54  	return fmt.Errorf("proxy error for wiremessage %v: %w", wm, err)
    55  }
    56  
    57  // DialContext creates a new proxyConnection.
    58  func (p *proxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
    59  	netConn, err := p.Dialer.DialContext(ctx, network, address)
    60  	if err != nil {
    61  		return netConn, err
    62  	}
    63  
    64  	// If the connection's remote address does not match the dialed address, store it in the translations map for
    65  	// future look-up. Use the remote address as they key because that's what we'll have access to in the connection's
    66  	// Read/Write functions.
    67  	if remoteAddress := netConn.RemoteAddr().String(); remoteAddress != address {
    68  		p.addressTranslations.Store(remoteAddress, address)
    69  	}
    70  
    71  	proxy := &proxyConn{
    72  		Conn:   netConn,
    73  		dialer: p,
    74  	}
    75  	return proxy, nil
    76  }
    77  
    78  func (p *proxyDialer) storeSentMessage(wm []byte) error {
    79  	p.Lock()
    80  	defer p.Unlock()
    81  
    82  	// Create a copy of the wire message so it can be parsed/stored and will not be affected if the wm slice is
    83  	// changed by the driver.
    84  	wmCopy := copyBytes(wm)
    85  	parsed, err := parseSentMessage(wmCopy)
    86  	if err != nil {
    87  		return err
    88  	}
    89  	p.sentMap.Store(parsed.RequestID, parsed)
    90  	return nil
    91  }
    92  
    93  func (p *proxyDialer) storeReceivedMessage(wm []byte, addr string) error {
    94  	p.Lock()
    95  	defer p.Unlock()
    96  
    97  	serverAddress := addr
    98  	if translated, ok := p.addressTranslations.Load(addr); ok {
    99  		serverAddress = translated.(string)
   100  	}
   101  
   102  	// Create a copy of the wire message so it can be parsed/stored and will not be affected if the wm slice is
   103  	// changed by the driver. Parse the incoming message and get the corresponding outgoing message.
   104  	wmCopy := copyBytes(wm)
   105  	parsed, err := parseReceivedMessage(wmCopy)
   106  	if err != nil {
   107  		return err
   108  	}
   109  	mapValue, ok := p.sentMap.Load(parsed.ResponseTo)
   110  	if !ok {
   111  		return errors.New("no sent message found")
   112  	}
   113  	sent := mapValue.(*SentMessage)
   114  	p.sentMap.Delete(parsed.ResponseTo)
   115  
   116  	// Store the parsed message pair.
   117  	msgPair := &ProxyMessage{
   118  		// The command name is always the first key in the command document.
   119  		CommandName:   sent.Command.Index(0).Key(),
   120  		ServerAddress: serverAddress,
   121  		Sent:          sent,
   122  		Received:      parsed,
   123  	}
   124  	p.messages = append(p.messages, msgPair)
   125  	return nil
   126  }
   127  
   128  // Messages returns a slice of proxied messages. This slice is a copy of the messages proxied so far and will not be
   129  // updated for messages proxied after this call.
   130  func (p *proxyDialer) Messages() []*ProxyMessage {
   131  	p.Lock()
   132  	defer p.Unlock()
   133  
   134  	copiedMessages := make([]*ProxyMessage, len(p.messages))
   135  	copy(copiedMessages, p.messages)
   136  	return copiedMessages
   137  }
   138  
   139  // proxyConn is a net.Conn that wraps a network connection. All messages sent/received through a proxyConn are stored
   140  // in the associated proxyDialer and are forwarded over the wrapped connection. Errors encountered when parsing and
   141  // storing wire messages are wrapped to add context, while errors returned from the underlying network connection are
   142  // forwarded without wrapping.
   143  type proxyConn struct {
   144  	net.Conn
   145  	dialer *proxyDialer
   146  }
   147  
   148  // Write stores the given message in the proxyDialer associated with this connection and forwards the message to the
   149  // server.
   150  func (pc *proxyConn) Write(wm []byte) (n int, err error) {
   151  	if err := pc.dialer.storeSentMessage(wm); err != nil {
   152  		wrapped := fmt.Errorf("error storing sent message: %w", err)
   153  		return 0, newProxyErrorWithWireMsg(wm, wrapped)
   154  	}
   155  
   156  	return pc.Conn.Write(wm)
   157  }
   158  
   159  // Read reads the message from the server into the given buffer and stores the read message in the proxyDialer
   160  // associated with this connection.
   161  func (pc *proxyConn) Read(buffer []byte) (int, error) {
   162  	n, err := pc.Conn.Read(buffer)
   163  	if err != nil {
   164  		return n, err
   165  	}
   166  
   167  	// The driver reads wire messages in two phases: a four-byte read to get the length of the incoming wire message
   168  	// and a (length-4) byte read to get the message itself. There's nothing to be stored during the initial four-byte
   169  	// read because we can calculate the length from the rest of the message.
   170  	if len(buffer) == 4 {
   171  		return 4, nil
   172  	}
   173  
   174  	// The buffer contains the entire wire message except for the length bytes. Re-create the full message by appending
   175  	// buffer to the end of a four-byte slice and using UpdateLength to set the length bytes.
   176  	idx, wm := bsoncore.ReserveLength(nil)
   177  	wm = append(wm, buffer...)
   178  	wm = bsoncore.UpdateLength(wm, idx, int32(len(wm[idx:])))
   179  
   180  	if err := pc.dialer.storeReceivedMessage(wm, pc.RemoteAddr().String()); err != nil {
   181  		wrapped := fmt.Errorf("error storing received message: %w", err)
   182  		return 0, newProxyErrorWithWireMsg(wm, wrapped)
   183  	}
   184  
   185  	return n, nil
   186  }
   187  

View as plain text