...

Source file src/go.mongodb.org/mongo-driver/x/mongo/driver/topology/cmap_prose_test.go

Documentation: go.mongodb.org/mongo-driver/x/mongo/driver/topology

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package topology
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"net"
    13  	"testing"
    14  	"time"
    15  
    16  	"go.mongodb.org/mongo-driver/event"
    17  	"go.mongodb.org/mongo-driver/internal/assert"
    18  	"go.mongodb.org/mongo-driver/internal/require"
    19  	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
    20  )
    21  
    22  func TestCMAPProse(t *testing.T) {
    23  	t.Run("created and closed events", func(t *testing.T) {
    24  		created := make(chan *event.PoolEvent, 10)
    25  		closed := make(chan *event.PoolEvent, 10)
    26  		clearEvents := func() {
    27  			for len(created) > 0 {
    28  				<-created
    29  			}
    30  			for len(closed) > 0 {
    31  				<-closed
    32  			}
    33  		}
    34  		monitor := &event.PoolMonitor{
    35  			Event: func(evt *event.PoolEvent) {
    36  				switch evt.Type {
    37  				case event.ConnectionCreated:
    38  					created <- evt
    39  				case event.ConnectionClosed:
    40  					closed <- evt
    41  				}
    42  			},
    43  		}
    44  		getConfig := func() poolConfig {
    45  			return poolConfig{
    46  				PoolMonitor: monitor,
    47  			}
    48  		}
    49  		assertConnectionCounts := func(t *testing.T, p *pool, numCreated, numClosed int) {
    50  			t.Helper()
    51  
    52  			require.Eventuallyf(t,
    53  				func() bool {
    54  					return numCreated == len(created) && numClosed == len(closed)
    55  				},
    56  				1*time.Second,
    57  				10*time.Millisecond,
    58  				"expected %d creation events, got %d; expected %d closed events, got %d",
    59  				numCreated,
    60  				len(created),
    61  				numClosed,
    62  				len(closed))
    63  
    64  			netCount := numCreated - numClosed
    65  			assert.Equal(t, netCount, p.totalConnectionCount(), "expected %d total connections, got %d", netCount,
    66  				p.totalConnectionCount())
    67  		}
    68  
    69  		t.Run("maintain", func(t *testing.T) {
    70  			t.Run("connection error publishes events", func(t *testing.T) {
    71  				// If a connection is created as part of minPoolSize maintenance and errors while connecting, checkOut()
    72  				// should report that error and publish an event.
    73  				// If maintain() creates a connection that encounters an error while connecting,
    74  				// the pool should publish connection created and closed events.
    75  				clearEvents()
    76  
    77  				var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
    78  					return &testNetConn{writeerr: errors.New("write error")}, nil
    79  				}
    80  
    81  				cfg := getConfig()
    82  				cfg.MinPoolSize = 1
    83  				connOpts := []ConnectionOption{
    84  					WithDialer(func(Dialer) Dialer { return dialer }),
    85  					WithHandshaker(func(Handshaker) Handshaker {
    86  						return operation.NewHello()
    87  					}),
    88  				}
    89  				pool := createTestPool(t, cfg, connOpts...)
    90  				defer pool.close(context.Background())
    91  
    92  				// Wait up to 3 seconds for the maintain() goroutine to run and for 1 connection
    93  				// created and 1 connection closed events to be published.
    94  				start := time.Now()
    95  				for len(created) != 1 || len(closed) != 1 {
    96  					if time.Since(start) > 3*time.Second {
    97  						t.Errorf(
    98  							"Expected 1 connection created and 1 connection closed events within 3 seconds. "+
    99  								"Actual created events: %d, actual closed events: %d",
   100  							len(created),
   101  							len(closed))
   102  					}
   103  					time.Sleep(time.Millisecond)
   104  				}
   105  			})
   106  		})
   107  		t.Run("checkOut", func(t *testing.T) {
   108  			t.Run("connection error publishes events", func(t *testing.T) {
   109  				// If checkOut() creates a connection that encounters an error while connecting,
   110  				// the pool should publish connection created and closed events and checkOut should
   111  				// return the error.
   112  				clearEvents()
   113  
   114  				var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
   115  					return &testNetConn{writeerr: errors.New("write error")}, nil
   116  				}
   117  
   118  				cfg := getConfig()
   119  				connOpts := []ConnectionOption{
   120  					WithDialer(func(Dialer) Dialer { return dialer }),
   121  					WithHandshaker(func(Handshaker) Handshaker {
   122  						return operation.NewHello()
   123  					}),
   124  				}
   125  				pool := createTestPool(t, cfg, connOpts...)
   126  				defer pool.close(context.Background())
   127  
   128  				_, err := pool.checkOut(context.Background())
   129  				assert.NotNil(t, err, "expected checkOut() error, got nil")
   130  
   131  				assertConnectionCounts(t, pool, 1, 1)
   132  			})
   133  			t.Run("pool is empty", func(t *testing.T) {
   134  				// If a checkOut() has to create a new connection and that connection encounters an
   135  				// error while connecting, checkOut() should return that error and publish an event.
   136  				clearEvents()
   137  
   138  				var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
   139  					return &testNetConn{writeerr: errors.New("write error")}, nil
   140  				}
   141  
   142  				connOpts := []ConnectionOption{
   143  					WithDialer(func(Dialer) Dialer { return dialer }),
   144  					WithHandshaker(func(Handshaker) Handshaker {
   145  						return operation.NewHello()
   146  					}),
   147  				}
   148  				pool := createTestPool(t, getConfig(), connOpts...)
   149  				defer pool.close(context.Background())
   150  
   151  				_, err := pool.checkOut(context.Background())
   152  				assert.NotNil(t, err, "expected checkOut() error, got nil")
   153  				assertConnectionCounts(t, pool, 1, 1)
   154  			})
   155  		})
   156  		t.Run("checkIn", func(t *testing.T) {
   157  			t.Run("errored connection", func(t *testing.T) {
   158  				// If the connection being returned to the pool encountered a network error, it should be removed from
   159  				// the pool and an event should be published.
   160  				clearEvents()
   161  
   162  				var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
   163  					return &testNetConn{writeerr: errors.New("write error")}, nil
   164  				}
   165  
   166  				// We don't use the WithHandshaker option so the connection won't error during handshaking.
   167  				connOpts := []ConnectionOption{
   168  					WithDialer(func(Dialer) Dialer { return dialer }),
   169  				}
   170  				pool := createTestPool(t, getConfig(), connOpts...)
   171  				defer pool.close(context.Background())
   172  
   173  				conn, err := pool.checkOut(context.Background())
   174  				assert.Nil(t, err, "checkOut() error: %v", err)
   175  
   176  				// Force a network error by writing to the connection.
   177  				err = conn.writeWireMessage(context.Background(), nil)
   178  				assert.NotNil(t, err, "expected writeWireMessage error, got nil")
   179  
   180  				err = pool.checkIn(conn)
   181  				assert.Nil(t, err, "checkIn() error: %v", err)
   182  
   183  				assertConnectionCounts(t, pool, 1, 1)
   184  				evt := <-closed
   185  				assert.Equal(t, event.ReasonError, evt.Reason, "expected reason %q, got %q",
   186  					event.ReasonError, evt.Reason)
   187  			})
   188  		})
   189  		t.Run("close", func(t *testing.T) {
   190  			t.Run("connections returned gracefully", func(t *testing.T) {
   191  				// If all connections are in the pool when close is called, they should be closed gracefully and
   192  				// events should be published.
   193  				clearEvents()
   194  
   195  				numConns := 5
   196  				var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
   197  					return &testNetConn{}, nil
   198  				}
   199  				pool := createTestPool(t, getConfig(), WithDialer(func(Dialer) Dialer { return dialer }))
   200  				defer pool.close(context.Background())
   201  
   202  				conns := checkoutConnections(t, pool, numConns)
   203  				assertConnectionCounts(t, pool, numConns, 0)
   204  
   205  				// Return all connections to the pool and assert that none were closed by checkIn().
   206  				for i, c := range conns {
   207  					err := pool.checkIn(c)
   208  					assert.Nil(t, err, "checkIn() error at index %d: %v", i, err)
   209  				}
   210  				assertConnectionCounts(t, pool, numConns, 0)
   211  
   212  				// Close the pool and assert that a closed event is published for each connection.
   213  				pool.close(context.Background())
   214  				assertConnectionCounts(t, pool, numConns, numConns)
   215  
   216  				for len(closed) > 0 {
   217  					evt := <-closed
   218  					assert.Equal(t, event.ReasonPoolClosed, evt.Reason, "expected reason %q, got %q",
   219  						event.ReasonPoolClosed, evt.Reason)
   220  				}
   221  			})
   222  			t.Run("connections closed forcefully", func(t *testing.T) {
   223  				// If some connections are still checked out when close is called, they should be closed
   224  				// forcefully and events should be published for them.
   225  				clearEvents()
   226  
   227  				numConns := 5
   228  				var dialer DialerFunc = func(context.Context, string, string) (net.Conn, error) {
   229  					return &testNetConn{}, nil
   230  				}
   231  				pool := createTestPool(t, getConfig(), WithDialer(func(Dialer) Dialer { return dialer }))
   232  
   233  				conns := checkoutConnections(t, pool, numConns)
   234  				assertConnectionCounts(t, pool, numConns, 0)
   235  
   236  				// Only return 2 of the connection.
   237  				for i := 0; i < 2; i++ {
   238  					err := pool.checkIn(conns[i])
   239  					assert.Nil(t, err, "checkIn() error at index %d: %v", i, err)
   240  				}
   241  				conns = conns[2:]
   242  				assertConnectionCounts(t, pool, numConns, 0)
   243  
   244  				// Close and assert that events are published for all connections.
   245  				pool.close(context.Background())
   246  				assertConnectionCounts(t, pool, numConns, numConns)
   247  
   248  				// Return the remaining connections and assert that the closed event count does not increase because
   249  				// these connections have already been closed.
   250  				for i, c := range conns {
   251  					err := pool.checkIn(c)
   252  					assert.Nil(t, err, "checkIn() error at index %d: %v", i, err)
   253  				}
   254  				assertConnectionCounts(t, pool, numConns, numConns)
   255  
   256  				// Ensure all closed events have the correct reason.
   257  				for len(closed) > 0 {
   258  					evt := <-closed
   259  					assert.Equal(t, event.ReasonPoolClosed, evt.Reason, "expected reason %q, got %q",
   260  						event.ReasonPoolClosed, evt.Reason)
   261  				}
   262  
   263  			})
   264  		})
   265  	})
   266  }
   267  
   268  func createTestPool(t *testing.T, cfg poolConfig, opts ...ConnectionOption) *pool {
   269  	t.Helper()
   270  
   271  	pool := newPool(cfg, opts...)
   272  	err := pool.ready()
   273  	assert.Nil(t, err, "connect error: %v", err)
   274  	return pool
   275  }
   276  
   277  func checkoutConnections(t *testing.T, p *pool, numConns int) []*connection {
   278  	conns := make([]*connection, 0, numConns)
   279  
   280  	for i := 0; i < numConns; i++ {
   281  		conn, err := p.checkOut(context.Background())
   282  		assert.Nil(t, err, "checkOut() error at index %d: %v", i, err)
   283  		conns = append(conns, conn)
   284  	}
   285  
   286  	return conns
   287  }
   288  

View as plain text