1
2
3
4
5
6
7 package topology
8
9 import (
10 "context"
11 "errors"
12 "net"
13 "testing"
14 "time"
15
16 "go.mongodb.org/mongo-driver/event"
17 "go.mongodb.org/mongo-driver/internal/assert"
18 "go.mongodb.org/mongo-driver/internal/require"
19 "go.mongodb.org/mongo-driver/x/mongo/driver/operation"
20 )
21
22 func TestCMAPProse(t *testing.T) {
23 t.Run("created and closed events", func(t *testing.T) {
24 created := make(chan *event.PoolEvent, 10)
25 closed := make(chan *event.PoolEvent, 10)
26 clearEvents := func() {
27 for len(created) > 0 {
28 <-created
29 }
30 for len(closed) > 0 {
31 <-closed
32 }
33 }
34 monitor := &event.PoolMonitor{
35 Event: func(evt *event.PoolEvent) {
36 switch evt.Type {
37 case event.ConnectionCreated:
38 created <- evt
39 case event.ConnectionClosed:
40 closed <- evt
41 }
42 },
43 }
44 getConfig := func() poolConfig {
45 return poolConfig{
46 PoolMonitor: monitor,
47 }
48 }
49 assertConnectionCounts := func(t *testing.T, p *pool, numCreated, numClosed int) {
50 t.Helper()
51
52 require.Eventuallyf(t,
53 func() bool {
54 return numCreated == len(created) && numClosed == len(closed)
55 },
56 1*time.Second,
57 10*time.Millisecond,
58 "expected %d creation events, got %d; expected %d closed events, got %d",
59 numCreated,
60 len(created),
61 numClosed,
62 len(closed))
63
64 netCount := numCreated - numClosed
65 assert.Equal(t, netCount, p.totalConnectionCount(), "expected %d total connections, got %d", netCount,
66 p.totalConnectionCount())
67 }
68
69 t.Run("maintain", func(t *testing.T) {
70 t.Run("connection error publishes events", func(t *testing.T) {
71
72
73
74
75 clearEvents()
76
77 var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
78 return &testNetConn{writeerr: errors.New("write error")}, nil
79 }
80
81 cfg := getConfig()
82 cfg.MinPoolSize = 1
83 connOpts := []ConnectionOption{
84 WithDialer(func(Dialer) Dialer { return dialer }),
85 WithHandshaker(func(Handshaker) Handshaker {
86 return operation.NewHello()
87 }),
88 }
89 pool := createTestPool(t, cfg, connOpts...)
90 defer pool.close(context.Background())
91
92
93
94 start := time.Now()
95 for len(created) != 1 || len(closed) != 1 {
96 if time.Since(start) > 3*time.Second {
97 t.Errorf(
98 "Expected 1 connection created and 1 connection closed events within 3 seconds. "+
99 "Actual created events: %d, actual closed events: %d",
100 len(created),
101 len(closed))
102 }
103 time.Sleep(time.Millisecond)
104 }
105 })
106 })
107 t.Run("checkOut", func(t *testing.T) {
108 t.Run("connection error publishes events", func(t *testing.T) {
109
110
111
112 clearEvents()
113
114 var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
115 return &testNetConn{writeerr: errors.New("write error")}, nil
116 }
117
118 cfg := getConfig()
119 connOpts := []ConnectionOption{
120 WithDialer(func(Dialer) Dialer { return dialer }),
121 WithHandshaker(func(Handshaker) Handshaker {
122 return operation.NewHello()
123 }),
124 }
125 pool := createTestPool(t, cfg, connOpts...)
126 defer pool.close(context.Background())
127
128 _, err := pool.checkOut(context.Background())
129 assert.NotNil(t, err, "expected checkOut() error, got nil")
130
131 assertConnectionCounts(t, pool, 1, 1)
132 })
133 t.Run("pool is empty", func(t *testing.T) {
134
135
136 clearEvents()
137
138 var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
139 return &testNetConn{writeerr: errors.New("write error")}, nil
140 }
141
142 connOpts := []ConnectionOption{
143 WithDialer(func(Dialer) Dialer { return dialer }),
144 WithHandshaker(func(Handshaker) Handshaker {
145 return operation.NewHello()
146 }),
147 }
148 pool := createTestPool(t, getConfig(), connOpts...)
149 defer pool.close(context.Background())
150
151 _, err := pool.checkOut(context.Background())
152 assert.NotNil(t, err, "expected checkOut() error, got nil")
153 assertConnectionCounts(t, pool, 1, 1)
154 })
155 })
156 t.Run("checkIn", func(t *testing.T) {
157 t.Run("errored connection", func(t *testing.T) {
158
159
160 clearEvents()
161
162 var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
163 return &testNetConn{writeerr: errors.New("write error")}, nil
164 }
165
166
167 connOpts := []ConnectionOption{
168 WithDialer(func(Dialer) Dialer { return dialer }),
169 }
170 pool := createTestPool(t, getConfig(), connOpts...)
171 defer pool.close(context.Background())
172
173 conn, err := pool.checkOut(context.Background())
174 assert.Nil(t, err, "checkOut() error: %v", err)
175
176
177 err = conn.writeWireMessage(context.Background(), nil)
178 assert.NotNil(t, err, "expected writeWireMessage error, got nil")
179
180 err = pool.checkIn(conn)
181 assert.Nil(t, err, "checkIn() error: %v", err)
182
183 assertConnectionCounts(t, pool, 1, 1)
184 evt := <-closed
185 assert.Equal(t, event.ReasonError, evt.Reason, "expected reason %q, got %q",
186 event.ReasonError, evt.Reason)
187 })
188 })
189 t.Run("close", func(t *testing.T) {
190 t.Run("connections returned gracefully", func(t *testing.T) {
191
192
193 clearEvents()
194
195 numConns := 5
196 var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
197 return &testNetConn{}, nil
198 }
199 pool := createTestPool(t, getConfig(), WithDialer(func(Dialer) Dialer { return dialer }))
200 defer pool.close(context.Background())
201
202 conns := checkoutConnections(t, pool, numConns)
203 assertConnectionCounts(t, pool, numConns, 0)
204
205
206 for i, c := range conns {
207 err := pool.checkIn(c)
208 assert.Nil(t, err, "checkIn() error at index %d: %v", i, err)
209 }
210 assertConnectionCounts(t, pool, numConns, 0)
211
212
213 pool.close(context.Background())
214 assertConnectionCounts(t, pool, numConns, numConns)
215
216 for len(closed) > 0 {
217 evt := <-closed
218 assert.Equal(t, event.ReasonPoolClosed, evt.Reason, "expected reason %q, got %q",
219 event.ReasonPoolClosed, evt.Reason)
220 }
221 })
222 t.Run("connections closed forcefully", func(t *testing.T) {
223
224
225 clearEvents()
226
227 numConns := 5
228 var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
229 return &testNetConn{}, nil
230 }
231 pool := createTestPool(t, getConfig(), WithDialer(func(Dialer) Dialer { return dialer }))
232
233 conns := checkoutConnections(t, pool, numConns)
234 assertConnectionCounts(t, pool, numConns, 0)
235
236
237 for i := 0; i < 2; i++ {
238 err := pool.checkIn(conns[i])
239 assert.Nil(t, err, "checkIn() error at index %d: %v", i, err)
240 }
241 conns = conns[2:]
242 assertConnectionCounts(t, pool, numConns, 0)
243
244
245 pool.close(context.Background())
246 assertConnectionCounts(t, pool, numConns, numConns)
247
248
249
250 for i, c := range conns {
251 err := pool.checkIn(c)
252 assert.Nil(t, err, "checkIn() error at index %d: %v", i, err)
253 }
254 assertConnectionCounts(t, pool, numConns, numConns)
255
256
257 for len(closed) > 0 {
258 evt := <-closed
259 assert.Equal(t, event.ReasonPoolClosed, evt.Reason, "expected reason %q, got %q",
260 event.ReasonPoolClosed, evt.Reason)
261 }
262
263 })
264 })
265 })
266 }
267
268 func createTestPool(t *testing.T, cfg poolConfig, opts ...ConnectionOption) *pool {
269 t.Helper()
270
271 pool := newPool(cfg, opts...)
272 err := pool.ready()
273 assert.Nil(t, err, "connect error: %v", err)
274 return pool
275 }
276
277 func checkoutConnections(t *testing.T, p *pool, numConns int) []*connection {
278 conns := make([]*connection, 0, numConns)
279
280 for i := 0; i < numConns; i++ {
281 conn, err := p.checkOut(context.Background())
282 assert.Nil(t, err, "checkOut() error at index %d: %v", i, err)
283 conns = append(conns, conn)
284 }
285
286 return conns
287 }
288
View as plain text