...

Source file src/go.mongodb.org/mongo-driver/mongo/integration/unified_runner_events_helper_test.go

Documentation: go.mongodb.org/mongo-driver/mongo/integration

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package integration
     8  
     9  import (
    10  	"context"
    11  	"sync"
    12  	"time"
    13  
    14  	"go.mongodb.org/mongo-driver/event"
    15  	"go.mongodb.org/mongo-driver/internal/assert"
    16  	"go.mongodb.org/mongo-driver/mongo/address"
    17  	"go.mongodb.org/mongo-driver/mongo/description"
    18  	"go.mongodb.org/mongo-driver/mongo/integration/mtest"
    19  	"go.mongodb.org/mongo-driver/mongo/readpref"
    20  	"go.mongodb.org/mongo-driver/x/mongo/driver/topology"
    21  )
    22  
    23  // Helper functions for the operations in the unified spec test runner that require assertions about SDAM and connection
    24  // pool events.
    25  
    26  var (
    27  	poolEventTypesMap = map[string]string{
    28  		"PoolClearedEvent": event.PoolCleared,
    29  	}
    30  	defaultCallbackTimeout = 10 * time.Second
    31  )
    32  
    33  // unifiedRunnerEventMonitor monitors connection pool-related events.
    34  type unifiedRunnerEventMonitor struct {
    35  	poolEventCount               map[string]int
    36  	poolEventCountLock           sync.Mutex
    37  	sdamMonitor                  *event.ServerMonitor
    38  	serverMarkedUnknownCount     int
    39  	serverMarkedUnknownCountLock sync.Mutex
    40  }
    41  
    42  func newUnifiedRunnerEventMonitor() *unifiedRunnerEventMonitor {
    43  	urem := unifiedRunnerEventMonitor{
    44  		poolEventCount: make(map[string]int),
    45  	}
    46  	urem.sdamMonitor = &event.ServerMonitor{
    47  		ServerDescriptionChanged: (func(e *event.ServerDescriptionChangedEvent) {
    48  			urem.serverMarkedUnknownCountLock.Lock()
    49  			defer urem.serverMarkedUnknownCountLock.Unlock()
    50  
    51  			// Spec tests only ever handle ServerMarkedUnknown ServerDescriptionChangedEvents
    52  			// for the time being.
    53  			if e.NewDescription.Kind == description.Unknown {
    54  				urem.serverMarkedUnknownCount++
    55  			}
    56  		}),
    57  	}
    58  	return &urem
    59  }
    60  
    61  // handlePoolEvent can be used as the event handler for a connection pool monitor.
    62  func (u *unifiedRunnerEventMonitor) handlePoolEvent(evt *event.PoolEvent) {
    63  	u.poolEventCountLock.Lock()
    64  	defer u.poolEventCountLock.Unlock()
    65  
    66  	u.poolEventCount[evt.Type]++
    67  }
    68  
    69  // getPoolEventCount returns the number of pool events of the given type, or 0 if no events were recorded.
    70  func (u *unifiedRunnerEventMonitor) getPoolEventCount(eventType string) int {
    71  	u.poolEventCountLock.Lock()
    72  	defer u.poolEventCountLock.Unlock()
    73  
    74  	mappedType := poolEventTypesMap[eventType]
    75  	return u.poolEventCount[mappedType]
    76  }
    77  
    78  // getServerMarkedUnknownEvent returns the number of ServerMarkedUnknownEvents, or 0 if none were recorded.
    79  func (u *unifiedRunnerEventMonitor) getServerMarkedUnknownCount() int {
    80  	u.serverMarkedUnknownCountLock.Lock()
    81  	defer u.serverMarkedUnknownCountLock.Unlock()
    82  
    83  	return u.serverMarkedUnknownCount
    84  }
    85  
    86  func waitForEvent(mt *mtest.T, test *testCase, op *operation) {
    87  	eventType := op.Arguments.Lookup("event").StringValue()
    88  	expectedCount := int(op.Arguments.Lookup("count").Int32())
    89  
    90  	callback := func(ctx context.Context) {
    91  		for {
    92  			// Stop loop if callback has been canceled.
    93  			select {
    94  			case <-ctx.Done():
    95  				return
    96  			default:
    97  			}
    98  
    99  			var count int
   100  			// Spec tests only ever wait for ServerMarkedUnknown SDAM events for the time being.
   101  			if eventType == "ServerMarkedUnknownEvent" {
   102  				count = test.monitor.getServerMarkedUnknownCount()
   103  			} else {
   104  				count = test.monitor.getPoolEventCount(eventType)
   105  			}
   106  
   107  			if count >= expectedCount {
   108  				return
   109  			}
   110  			time.Sleep(100 * time.Millisecond)
   111  		}
   112  	}
   113  
   114  	assert.Soon(mt, callback, defaultCallbackTimeout)
   115  }
   116  
   117  func assertEventCount(mt *mtest.T, testCase *testCase, op *operation) {
   118  	eventType := op.Arguments.Lookup("event").StringValue()
   119  	expectedCount := int(op.Arguments.Lookup("count").Int32())
   120  
   121  	var gotCount int
   122  	// Spec tests only ever assert ServerMarkedUnknown SDAM events for the time being.
   123  	if eventType == "ServerMarkedUnknownEvent" {
   124  		gotCount = testCase.monitor.getServerMarkedUnknownCount()
   125  	} else {
   126  		gotCount = testCase.monitor.getPoolEventCount(eventType)
   127  	}
   128  	assert.Equal(mt, expectedCount, gotCount, "expected count %d for event %s, got %d", expectedCount, eventType,
   129  		gotCount)
   130  }
   131  
   132  func recordPrimary(mt *mtest.T, testCase *testCase) {
   133  	testCase.recordedPrimary = getPrimaryAddress(mt, testCase.testTopology, true)
   134  }
   135  
   136  func waitForPrimaryChange(mt *mtest.T, testCase *testCase, op *operation) {
   137  	callback := func(ctx context.Context) {
   138  		for {
   139  			// Stop loop if callback has been canceled.
   140  			select {
   141  			case <-ctx.Done():
   142  				return
   143  			default:
   144  			}
   145  
   146  			if getPrimaryAddress(mt, testCase.testTopology, false) != testCase.recordedPrimary {
   147  				return
   148  			}
   149  		}
   150  	}
   151  
   152  	timeout := convertValueToMilliseconds(mt, op.Arguments.Lookup("timeoutMS"))
   153  	assert.Soon(mt, callback, timeout)
   154  }
   155  
   156  // getPrimaryAddress returns the address of the current primary. If failFast is true, the server selection fast path
   157  // is used and the function will fail if the fast path doesn't return a server.
   158  func getPrimaryAddress(mt *mtest.T, topo *topology.Topology, failFast bool) address.Address {
   159  	mt.Helper()
   160  
   161  	ctx, cancel := context.WithCancel(context.Background())
   162  	defer cancel()
   163  	if failFast {
   164  		cancel()
   165  	}
   166  
   167  	primary, err := topo.SelectServer(ctx, description.ReadPrefSelector(readpref.Primary()))
   168  	assert.Nil(mt, err, "SelectServer error: %v", err)
   169  	return primary.(*topology.SelectedServer).Description().Addr
   170  }
   171  

View as plain text