1
2
3
4
5
6
7 package mtest
8
9 import (
10 "context"
11 "errors"
12 "fmt"
13 "net"
14 "sync"
15 "time"
16
17 "go.mongodb.org/mongo-driver/mongo/options"
18 "go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
19 )
20
21
22 type ProxyMessage struct {
23 ServerAddress string
24 CommandName string
25 Sent *SentMessage
26 Received *ReceivedMessage
27 }
28
29
30
31 type proxyDialer struct {
32 *net.Dialer
33 sync.Mutex
34
35 messages []*ProxyMessage
36
37
38 sentMap sync.Map
39
40
41
42 addressTranslations sync.Map
43 }
44
45 var _ options.ContextDialer = (*proxyDialer)(nil)
46
47 func newProxyDialer() *proxyDialer {
48 return &proxyDialer{
49 Dialer: &net.Dialer{Timeout: 30 * time.Second},
50 }
51 }
52
53 func newProxyErrorWithWireMsg(wm []byte, err error) error {
54 return fmt.Errorf("proxy error for wiremessage %v: %w", wm, err)
55 }
56
57
58 func (p *proxyDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
59 netConn, err := p.Dialer.DialContext(ctx, network, address)
60 if err != nil {
61 return netConn, err
62 }
63
64
65
66
67 if remoteAddress := netConn.RemoteAddr().String(); remoteAddress != address {
68 p.addressTranslations.Store(remoteAddress, address)
69 }
70
71 proxy := &proxyConn{
72 Conn: netConn,
73 dialer: p,
74 }
75 return proxy, nil
76 }
77
78 func (p *proxyDialer) storeSentMessage(wm []byte) error {
79 p.Lock()
80 defer p.Unlock()
81
82
83
84 wmCopy := copyBytes(wm)
85 parsed, err := parseSentMessage(wmCopy)
86 if err != nil {
87 return err
88 }
89 p.sentMap.Store(parsed.RequestID, parsed)
90 return nil
91 }
92
93 func (p *proxyDialer) storeReceivedMessage(wm []byte, addr string) error {
94 p.Lock()
95 defer p.Unlock()
96
97 serverAddress := addr
98 if translated, ok := p.addressTranslations.Load(addr); ok {
99 serverAddress = translated.(string)
100 }
101
102
103
104 wmCopy := copyBytes(wm)
105 parsed, err := parseReceivedMessage(wmCopy)
106 if err != nil {
107 return err
108 }
109 mapValue, ok := p.sentMap.Load(parsed.ResponseTo)
110 if !ok {
111 return errors.New("no sent message found")
112 }
113 sent := mapValue.(*SentMessage)
114 p.sentMap.Delete(parsed.ResponseTo)
115
116
117 msgPair := &ProxyMessage{
118
119 CommandName: sent.Command.Index(0).Key(),
120 ServerAddress: serverAddress,
121 Sent: sent,
122 Received: parsed,
123 }
124 p.messages = append(p.messages, msgPair)
125 return nil
126 }
127
128
129
130 func (p *proxyDialer) Messages() []*ProxyMessage {
131 p.Lock()
132 defer p.Unlock()
133
134 copiedMessages := make([]*ProxyMessage, len(p.messages))
135 copy(copiedMessages, p.messages)
136 return copiedMessages
137 }
138
139
140
141
142
143 type proxyConn struct {
144 net.Conn
145 dialer *proxyDialer
146 }
147
148
149
150 func (pc *proxyConn) Write(wm []byte) (n int, err error) {
151 if err := pc.dialer.storeSentMessage(wm); err != nil {
152 wrapped := fmt.Errorf("error storing sent message: %w", err)
153 return 0, newProxyErrorWithWireMsg(wm, wrapped)
154 }
155
156 return pc.Conn.Write(wm)
157 }
158
159
160
161 func (pc *proxyConn) Read(buffer []byte) (int, error) {
162 n, err := pc.Conn.Read(buffer)
163 if err != nil {
164 return n, err
165 }
166
167
168
169
170 if len(buffer) == 4 {
171 return 4, nil
172 }
173
174
175
176 idx, wm := bsoncore.ReserveLength(nil)
177 wm = append(wm, buffer...)
178 wm = bsoncore.UpdateLength(wm, idx, int32(len(wm[idx:])))
179
180 if err := pc.dialer.storeReceivedMessage(wm, pc.RemoteAddr().String()); err != nil {
181 wrapped := fmt.Errorf("error storing received message: %w", err)
182 return 0, newProxyErrorWithWireMsg(wm, wrapped)
183 }
184
185 return n, nil
186 }
187
View as plain text