...

Source file src/github.com/stretchr/testify/mock/mock.go

Documentation: github.com/stretchr/testify/mock

     1  package mock
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"path"
     7  	"reflect"
     8  	"regexp"
     9  	"runtime"
    10  	"strings"
    11  	"sync"
    12  	"time"
    13  
    14  	"github.com/davecgh/go-spew/spew"
    15  	"github.com/pmezard/go-difflib/difflib"
    16  	"github.com/stretchr/objx"
    17  
    18  	"github.com/stretchr/testify/assert"
    19  )
    20  
    21  // regex for GCCGO functions
    22  var gccgoRE = regexp.MustCompile(`\.pN\d+_`)
    23  
    24  // TestingT is an interface wrapper around *testing.T
    25  type TestingT interface {
    26  	Logf(format string, args ...interface{})
    27  	Errorf(format string, args ...interface{})
    28  	FailNow()
    29  }
    30  
    31  /*
    32  	Call
    33  */
    34  
    35  // Call represents a method call and is used for setting expectations,
    36  // as well as recording activity.
    37  type Call struct {
    38  	Parent *Mock
    39  
    40  	// The name of the method that was or will be called.
    41  	Method string
    42  
    43  	// Holds the arguments of the method.
    44  	Arguments Arguments
    45  
    46  	// Holds the arguments that should be returned when
    47  	// this method is called.
    48  	ReturnArguments Arguments
    49  
    50  	// Holds the caller info for the On() call
    51  	callerInfo []string
    52  
    53  	// The number of times to return the return arguments when setting
    54  	// expectations. 0 means to always return the value.
    55  	Repeatability int
    56  
    57  	// Amount of times this call has been called
    58  	totalCalls int
    59  
    60  	// Call to this method can be optional
    61  	optional bool
    62  
    63  	// Holds a channel that will be used to block the Return until it either
    64  	// receives a message or is closed. nil means it returns immediately.
    65  	WaitFor <-chan time.Time
    66  
    67  	waitTime time.Duration
    68  
    69  	// Holds a handler used to manipulate arguments content that are passed by
    70  	// reference. It's useful when mocking methods such as unmarshalers or
    71  	// decoders.
    72  	RunFn func(Arguments)
    73  
    74  	// PanicMsg holds msg to be used to mock panic on the function call
    75  	//  if the PanicMsg is set to a non nil string the function call will panic
    76  	// irrespective of other settings
    77  	PanicMsg *string
    78  
    79  	// Calls which must be satisfied before this call can be
    80  	requires []*Call
    81  }
    82  
    83  func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call {
    84  	return &Call{
    85  		Parent:          parent,
    86  		Method:          methodName,
    87  		Arguments:       methodArguments,
    88  		ReturnArguments: make([]interface{}, 0),
    89  		callerInfo:      callerInfo,
    90  		Repeatability:   0,
    91  		WaitFor:         nil,
    92  		RunFn:           nil,
    93  		PanicMsg:        nil,
    94  	}
    95  }
    96  
    97  func (c *Call) lock() {
    98  	c.Parent.mutex.Lock()
    99  }
   100  
   101  func (c *Call) unlock() {
   102  	c.Parent.mutex.Unlock()
   103  }
   104  
   105  // Return specifies the return arguments for the expectation.
   106  //
   107  //	Mock.On("DoSomething").Return(errors.New("failed"))
   108  func (c *Call) Return(returnArguments ...interface{}) *Call {
   109  	c.lock()
   110  	defer c.unlock()
   111  
   112  	c.ReturnArguments = returnArguments
   113  
   114  	return c
   115  }
   116  
   117  // Panic specifies if the function call should fail and the panic message
   118  //
   119  //	Mock.On("DoSomething").Panic("test panic")
   120  func (c *Call) Panic(msg string) *Call {
   121  	c.lock()
   122  	defer c.unlock()
   123  
   124  	c.PanicMsg = &msg
   125  
   126  	return c
   127  }
   128  
   129  // Once indicates that the mock should only return the value once.
   130  //
   131  //	Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Once()
   132  func (c *Call) Once() *Call {
   133  	return c.Times(1)
   134  }
   135  
   136  // Twice indicates that the mock should only return the value twice.
   137  //
   138  //	Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Twice()
   139  func (c *Call) Twice() *Call {
   140  	return c.Times(2)
   141  }
   142  
   143  // Times indicates that the mock should only return the indicated number
   144  // of times.
   145  //
   146  //	Mock.On("MyMethod", arg1, arg2).Return(returnArg1, returnArg2).Times(5)
   147  func (c *Call) Times(i int) *Call {
   148  	c.lock()
   149  	defer c.unlock()
   150  	c.Repeatability = i
   151  	return c
   152  }
   153  
   154  // WaitUntil sets the channel that will block the mock's return until its closed
   155  // or a message is received.
   156  //
   157  //	Mock.On("MyMethod", arg1, arg2).WaitUntil(time.After(time.Second))
   158  func (c *Call) WaitUntil(w <-chan time.Time) *Call {
   159  	c.lock()
   160  	defer c.unlock()
   161  	c.WaitFor = w
   162  	return c
   163  }
   164  
   165  // After sets how long to block until the call returns
   166  //
   167  //	Mock.On("MyMethod", arg1, arg2).After(time.Second)
   168  func (c *Call) After(d time.Duration) *Call {
   169  	c.lock()
   170  	defer c.unlock()
   171  	c.waitTime = d
   172  	return c
   173  }
   174  
   175  // Run sets a handler to be called before returning. It can be used when
   176  // mocking a method (such as an unmarshaler) that takes a pointer to a struct and
   177  // sets properties in such struct
   178  //
   179  //	Mock.On("Unmarshal", AnythingOfType("*map[string]interface{}")).Return().Run(func(args Arguments) {
   180  //		arg := args.Get(0).(*map[string]interface{})
   181  //		arg["foo"] = "bar"
   182  //	})
   183  func (c *Call) Run(fn func(args Arguments)) *Call {
   184  	c.lock()
   185  	defer c.unlock()
   186  	c.RunFn = fn
   187  	return c
   188  }
   189  
   190  // Maybe allows the method call to be optional. Not calling an optional method
   191  // will not cause an error while asserting expectations
   192  func (c *Call) Maybe() *Call {
   193  	c.lock()
   194  	defer c.unlock()
   195  	c.optional = true
   196  	return c
   197  }
   198  
   199  // On chains a new expectation description onto the mocked interface. This
   200  // allows syntax like.
   201  //
   202  //	Mock.
   203  //	   On("MyMethod", 1).Return(nil).
   204  //	   On("MyOtherMethod", 'a', 'b', 'c').Return(errors.New("Some Error"))
   205  //
   206  //go:noinline
   207  func (c *Call) On(methodName string, arguments ...interface{}) *Call {
   208  	return c.Parent.On(methodName, arguments...)
   209  }
   210  
   211  // Unset removes a mock handler from being called.
   212  //
   213  //	test.On("func", mock.Anything).Unset()
   214  func (c *Call) Unset() *Call {
   215  	var unlockOnce sync.Once
   216  
   217  	for _, arg := range c.Arguments {
   218  		if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
   219  			panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
   220  		}
   221  	}
   222  
   223  	c.lock()
   224  	defer unlockOnce.Do(c.unlock)
   225  
   226  	foundMatchingCall := false
   227  
   228  	// in-place filter slice for calls to be removed - iterate from 0'th to last skipping unnecessary ones
   229  	var index int // write index
   230  	for _, call := range c.Parent.ExpectedCalls {
   231  		if call.Method == c.Method {
   232  			_, diffCount := call.Arguments.Diff(c.Arguments)
   233  			if diffCount == 0 {
   234  				foundMatchingCall = true
   235  				// Remove from ExpectedCalls - just skip it
   236  				continue
   237  			}
   238  		}
   239  		c.Parent.ExpectedCalls[index] = call
   240  		index++
   241  	}
   242  	// trim slice up to last copied index
   243  	c.Parent.ExpectedCalls = c.Parent.ExpectedCalls[:index]
   244  
   245  	if !foundMatchingCall {
   246  		unlockOnce.Do(c.unlock)
   247  		c.Parent.fail("\n\nmock: Could not find expected call\n-----------------------------\n\n%s\n\n",
   248  			callString(c.Method, c.Arguments, true),
   249  		)
   250  	}
   251  
   252  	return c
   253  }
   254  
   255  // NotBefore indicates that the mock should only be called after the referenced
   256  // calls have been called as expected. The referenced calls may be from the
   257  // same mock instance and/or other mock instances.
   258  //
   259  //	Mock.On("Do").Return(nil).Notbefore(
   260  //	    Mock.On("Init").Return(nil)
   261  //	)
   262  func (c *Call) NotBefore(calls ...*Call) *Call {
   263  	c.lock()
   264  	defer c.unlock()
   265  
   266  	for _, call := range calls {
   267  		if call.Parent == nil {
   268  			panic("not before calls must be created with Mock.On()")
   269  		}
   270  	}
   271  
   272  	c.requires = append(c.requires, calls...)
   273  	return c
   274  }
   275  
   276  // Mock is the workhorse used to track activity on another object.
   277  // For an example of its usage, refer to the "Example Usage" section at the top
   278  // of this document.
   279  type Mock struct {
   280  	// Represents the calls that are expected of
   281  	// an object.
   282  	ExpectedCalls []*Call
   283  
   284  	// Holds the calls that were made to this mocked object.
   285  	Calls []Call
   286  
   287  	// test is An optional variable that holds the test struct, to be used when an
   288  	// invalid mock call was made.
   289  	test TestingT
   290  
   291  	// TestData holds any data that might be useful for testing.  Testify ignores
   292  	// this data completely allowing you to do whatever you like with it.
   293  	testData objx.Map
   294  
   295  	mutex sync.Mutex
   296  }
   297  
   298  // String provides a %v format string for Mock.
   299  // Note: this is used implicitly by Arguments.Diff if a Mock is passed.
   300  // It exists because go's default %v formatting traverses the struct
   301  // without acquiring the mutex, which is detected by go test -race.
   302  func (m *Mock) String() string {
   303  	return fmt.Sprintf("%[1]T<%[1]p>", m)
   304  }
   305  
   306  // TestData holds any data that might be useful for testing.  Testify ignores
   307  // this data completely allowing you to do whatever you like with it.
   308  func (m *Mock) TestData() objx.Map {
   309  	if m.testData == nil {
   310  		m.testData = make(objx.Map)
   311  	}
   312  
   313  	return m.testData
   314  }
   315  
   316  /*
   317  	Setting expectations
   318  */
   319  
   320  // Test sets the test struct variable of the mock object
   321  func (m *Mock) Test(t TestingT) {
   322  	m.mutex.Lock()
   323  	defer m.mutex.Unlock()
   324  	m.test = t
   325  }
   326  
   327  // fail fails the current test with the given formatted format and args.
   328  // In case that a test was defined, it uses the test APIs for failing a test,
   329  // otherwise it uses panic.
   330  func (m *Mock) fail(format string, args ...interface{}) {
   331  	m.mutex.Lock()
   332  	defer m.mutex.Unlock()
   333  
   334  	if m.test == nil {
   335  		panic(fmt.Sprintf(format, args...))
   336  	}
   337  	m.test.Errorf(format, args...)
   338  	m.test.FailNow()
   339  }
   340  
   341  // On starts a description of an expectation of the specified method
   342  // being called.
   343  //
   344  //	Mock.On("MyMethod", arg1, arg2)
   345  func (m *Mock) On(methodName string, arguments ...interface{}) *Call {
   346  	for _, arg := range arguments {
   347  		if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
   348  			panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
   349  		}
   350  	}
   351  
   352  	m.mutex.Lock()
   353  	defer m.mutex.Unlock()
   354  	c := newCall(m, methodName, assert.CallerInfo(), arguments...)
   355  	m.ExpectedCalls = append(m.ExpectedCalls, c)
   356  	return c
   357  }
   358  
   359  // /*
   360  // 	Recording and responding to activity
   361  // */
   362  
   363  func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) {
   364  	var expectedCall *Call
   365  
   366  	for i, call := range m.ExpectedCalls {
   367  		if call.Method == method {
   368  			_, diffCount := call.Arguments.Diff(arguments)
   369  			if diffCount == 0 {
   370  				expectedCall = call
   371  				if call.Repeatability > -1 {
   372  					return i, call
   373  				}
   374  			}
   375  		}
   376  	}
   377  
   378  	return -1, expectedCall
   379  }
   380  
   381  type matchCandidate struct {
   382  	call      *Call
   383  	mismatch  string
   384  	diffCount int
   385  }
   386  
   387  func (c matchCandidate) isBetterMatchThan(other matchCandidate) bool {
   388  	if c.call == nil {
   389  		return false
   390  	}
   391  	if other.call == nil {
   392  		return true
   393  	}
   394  
   395  	if c.diffCount > other.diffCount {
   396  		return false
   397  	}
   398  	if c.diffCount < other.diffCount {
   399  		return true
   400  	}
   401  
   402  	if c.call.Repeatability > 0 && other.call.Repeatability <= 0 {
   403  		return true
   404  	}
   405  	return false
   406  }
   407  
   408  func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, string) {
   409  	var bestMatch matchCandidate
   410  
   411  	for _, call := range m.expectedCalls() {
   412  		if call.Method == method {
   413  
   414  			errInfo, tempDiffCount := call.Arguments.Diff(arguments)
   415  			tempCandidate := matchCandidate{
   416  				call:      call,
   417  				mismatch:  errInfo,
   418  				diffCount: tempDiffCount,
   419  			}
   420  			if tempCandidate.isBetterMatchThan(bestMatch) {
   421  				bestMatch = tempCandidate
   422  			}
   423  		}
   424  	}
   425  
   426  	return bestMatch.call, bestMatch.mismatch
   427  }
   428  
   429  func callString(method string, arguments Arguments, includeArgumentValues bool) string {
   430  	var argValsString string
   431  	if includeArgumentValues {
   432  		var argVals []string
   433  		for argIndex, arg := range arguments {
   434  			if _, ok := arg.(*FunctionalOptionsArgument); ok {
   435  				argVals = append(argVals, fmt.Sprintf("%d: %s", argIndex, arg))
   436  				continue
   437  			}
   438  			argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg))
   439  		}
   440  		argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t"))
   441  	}
   442  
   443  	return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString)
   444  }
   445  
   446  // Called tells the mock object that a method has been called, and gets an array
   447  // of arguments to return.  Panics if the call is unexpected (i.e. not preceded by
   448  // appropriate .On .Return() calls)
   449  // If Call.WaitFor is set, blocks until the channel is closed or receives a message.
   450  func (m *Mock) Called(arguments ...interface{}) Arguments {
   451  	// get the calling function's name
   452  	pc, _, _, ok := runtime.Caller(1)
   453  	if !ok {
   454  		panic("Couldn't get the caller information")
   455  	}
   456  	functionPath := runtime.FuncForPC(pc).Name()
   457  	// Next four lines are required to use GCCGO function naming conventions.
   458  	// For Ex:  github_com_docker_libkv_store_mock.WatchTree.pN39_github_com_docker_libkv_store_mock.Mock
   459  	// uses interface information unlike golang github.com/docker/libkv/store/mock.(*Mock).WatchTree
   460  	// With GCCGO we need to remove interface information starting from pN<dd>.
   461  	if gccgoRE.MatchString(functionPath) {
   462  		functionPath = gccgoRE.Split(functionPath, -1)[0]
   463  	}
   464  	parts := strings.Split(functionPath, ".")
   465  	functionName := parts[len(parts)-1]
   466  	return m.MethodCalled(functionName, arguments...)
   467  }
   468  
   469  // MethodCalled tells the mock object that the given method has been called, and gets
   470  // an array of arguments to return. Panics if the call is unexpected (i.e. not preceded
   471  // by appropriate .On .Return() calls)
   472  // If Call.WaitFor is set, blocks until the channel is closed or receives a message.
   473  func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments {
   474  	m.mutex.Lock()
   475  	// TODO: could combine expected and closes in single loop
   476  	found, call := m.findExpectedCall(methodName, arguments...)
   477  
   478  	if found < 0 {
   479  		// expected call found, but it has already been called with repeatable times
   480  		if call != nil {
   481  			m.mutex.Unlock()
   482  			m.fail("\nassert: mock: The method has been called over %d times.\n\tEither do one more Mock.On(\"%s\").Return(...), or remove extra call.\n\tThis call was unexpected:\n\t\t%s\n\tat: %s", call.totalCalls, methodName, callString(methodName, arguments, true), assert.CallerInfo())
   483  		}
   484  		// we have to fail here - because we don't know what to do
   485  		// as the return arguments.  This is because:
   486  		//
   487  		//   a) this is a totally unexpected call to this method,
   488  		//   b) the arguments are not what was expected, or
   489  		//   c) the developer has forgotten to add an accompanying On...Return pair.
   490  		closestCall, mismatch := m.findClosestCall(methodName, arguments...)
   491  		m.mutex.Unlock()
   492  
   493  		if closestCall != nil {
   494  			m.fail("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\nDiff: %s",
   495  				callString(methodName, arguments, true),
   496  				callString(methodName, closestCall.Arguments, true),
   497  				diffArguments(closestCall.Arguments, arguments),
   498  				strings.TrimSpace(mismatch),
   499  			)
   500  		} else {
   501  			m.fail("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", methodName, methodName, callString(methodName, arguments, true), assert.CallerInfo())
   502  		}
   503  	}
   504  
   505  	for _, requirement := range call.requires {
   506  		if satisfied, _ := requirement.Parent.checkExpectation(requirement); !satisfied {
   507  			m.mutex.Unlock()
   508  			m.fail("mock: Unexpected Method Call\n-----------------------------\n\n%s\n\nMust not be called before%s:\n\n%s",
   509  				callString(call.Method, call.Arguments, true),
   510  				func() (s string) {
   511  					if requirement.totalCalls > 0 {
   512  						s = " another call of"
   513  					}
   514  					if call.Parent != requirement.Parent {
   515  						s += " method from another mock instance"
   516  					}
   517  					return
   518  				}(),
   519  				callString(requirement.Method, requirement.Arguments, true),
   520  			)
   521  		}
   522  	}
   523  
   524  	if call.Repeatability == 1 {
   525  		call.Repeatability = -1
   526  	} else if call.Repeatability > 1 {
   527  		call.Repeatability--
   528  	}
   529  	call.totalCalls++
   530  
   531  	// add the call
   532  	m.Calls = append(m.Calls, *newCall(m, methodName, assert.CallerInfo(), arguments...))
   533  	m.mutex.Unlock()
   534  
   535  	// block if specified
   536  	if call.WaitFor != nil {
   537  		<-call.WaitFor
   538  	} else {
   539  		time.Sleep(call.waitTime)
   540  	}
   541  
   542  	m.mutex.Lock()
   543  	panicMsg := call.PanicMsg
   544  	m.mutex.Unlock()
   545  	if panicMsg != nil {
   546  		panic(*panicMsg)
   547  	}
   548  
   549  	m.mutex.Lock()
   550  	runFn := call.RunFn
   551  	m.mutex.Unlock()
   552  
   553  	if runFn != nil {
   554  		runFn(arguments)
   555  	}
   556  
   557  	m.mutex.Lock()
   558  	returnArgs := call.ReturnArguments
   559  	m.mutex.Unlock()
   560  
   561  	return returnArgs
   562  }
   563  
   564  /*
   565  	Assertions
   566  */
   567  
   568  type assertExpectationiser interface {
   569  	AssertExpectations(TestingT) bool
   570  }
   571  
   572  // AssertExpectationsForObjects asserts that everything specified with On and Return
   573  // of the specified objects was in fact called as expected.
   574  //
   575  // Calls may have occurred in any order.
   576  func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
   577  	if h, ok := t.(tHelper); ok {
   578  		h.Helper()
   579  	}
   580  	for _, obj := range testObjects {
   581  		if m, ok := obj.(*Mock); ok {
   582  			t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)")
   583  			obj = m
   584  		}
   585  		m := obj.(assertExpectationiser)
   586  		if !m.AssertExpectations(t) {
   587  			t.Logf("Expectations didn't match for Mock: %+v", reflect.TypeOf(m))
   588  			return false
   589  		}
   590  	}
   591  	return true
   592  }
   593  
   594  // AssertExpectations asserts that everything specified with On and Return was
   595  // in fact called as expected.  Calls may have occurred in any order.
   596  func (m *Mock) AssertExpectations(t TestingT) bool {
   597  	if s, ok := t.(interface{ Skipped() bool }); ok && s.Skipped() {
   598  		return true
   599  	}
   600  	if h, ok := t.(tHelper); ok {
   601  		h.Helper()
   602  	}
   603  
   604  	m.mutex.Lock()
   605  	defer m.mutex.Unlock()
   606  	var failedExpectations int
   607  
   608  	// iterate through each expectation
   609  	expectedCalls := m.expectedCalls()
   610  	for _, expectedCall := range expectedCalls {
   611  		satisfied, reason := m.checkExpectation(expectedCall)
   612  		if !satisfied {
   613  			failedExpectations++
   614  			t.Logf(reason)
   615  		}
   616  	}
   617  
   618  	if failedExpectations != 0 {
   619  		t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo())
   620  	}
   621  
   622  	return failedExpectations == 0
   623  }
   624  
   625  func (m *Mock) checkExpectation(call *Call) (bool, string) {
   626  	if !call.optional && !m.methodWasCalled(call.Method, call.Arguments) && call.totalCalls == 0 {
   627  		return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
   628  	}
   629  	if call.Repeatability > 0 {
   630  		return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
   631  	}
   632  	return true, fmt.Sprintf("PASS:\t%s(%s)", call.Method, call.Arguments.String())
   633  }
   634  
   635  // AssertNumberOfCalls asserts that the method was called expectedCalls times.
   636  func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool {
   637  	if h, ok := t.(tHelper); ok {
   638  		h.Helper()
   639  	}
   640  	m.mutex.Lock()
   641  	defer m.mutex.Unlock()
   642  	var actualCalls int
   643  	for _, call := range m.calls() {
   644  		if call.Method == methodName {
   645  			actualCalls++
   646  		}
   647  	}
   648  	return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls))
   649  }
   650  
   651  // AssertCalled asserts that the method was called.
   652  // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
   653  func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool {
   654  	if h, ok := t.(tHelper); ok {
   655  		h.Helper()
   656  	}
   657  	m.mutex.Lock()
   658  	defer m.mutex.Unlock()
   659  	if !m.methodWasCalled(methodName, arguments) {
   660  		var calledWithArgs []string
   661  		for _, call := range m.calls() {
   662  			calledWithArgs = append(calledWithArgs, fmt.Sprintf("%v", call.Arguments))
   663  		}
   664  		if len(calledWithArgs) == 0 {
   665  			return assert.Fail(t, "Should have called with given arguments",
   666  				fmt.Sprintf("Expected %q to have been called with:\n%v\nbut no actual calls happened", methodName, arguments))
   667  		}
   668  		return assert.Fail(t, "Should have called with given arguments",
   669  			fmt.Sprintf("Expected %q to have been called with:\n%v\nbut actual calls were:\n        %v", methodName, arguments, strings.Join(calledWithArgs, "\n")))
   670  	}
   671  	return true
   672  }
   673  
   674  // AssertNotCalled asserts that the method was not called.
   675  // It can produce a false result when an argument is a pointer type and the underlying value changed after calling the mocked method.
   676  func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool {
   677  	if h, ok := t.(tHelper); ok {
   678  		h.Helper()
   679  	}
   680  	m.mutex.Lock()
   681  	defer m.mutex.Unlock()
   682  	if m.methodWasCalled(methodName, arguments) {
   683  		return assert.Fail(t, "Should not have called with given arguments",
   684  			fmt.Sprintf("Expected %q to not have been called with:\n%v\nbut actually it was.", methodName, arguments))
   685  	}
   686  	return true
   687  }
   688  
   689  // IsMethodCallable checking that the method can be called
   690  // If the method was called more than `Repeatability` return false
   691  func (m *Mock) IsMethodCallable(t TestingT, methodName string, arguments ...interface{}) bool {
   692  	if h, ok := t.(tHelper); ok {
   693  		h.Helper()
   694  	}
   695  	m.mutex.Lock()
   696  	defer m.mutex.Unlock()
   697  
   698  	for _, v := range m.ExpectedCalls {
   699  		if v.Method != methodName {
   700  			continue
   701  		}
   702  		if len(arguments) != len(v.Arguments) {
   703  			continue
   704  		}
   705  		if v.Repeatability < v.totalCalls {
   706  			continue
   707  		}
   708  		if isArgsEqual(v.Arguments, arguments) {
   709  			return true
   710  		}
   711  	}
   712  	return false
   713  }
   714  
   715  // isArgsEqual compares arguments
   716  func isArgsEqual(expected Arguments, args []interface{}) bool {
   717  	if len(expected) != len(args) {
   718  		return false
   719  	}
   720  	for i, v := range args {
   721  		if !reflect.DeepEqual(expected[i], v) {
   722  			return false
   723  		}
   724  	}
   725  	return true
   726  }
   727  
   728  func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
   729  	for _, call := range m.calls() {
   730  		if call.Method == methodName {
   731  
   732  			_, differences := Arguments(expected).Diff(call.Arguments)
   733  
   734  			if differences == 0 {
   735  				// found the expected call
   736  				return true
   737  			}
   738  
   739  		}
   740  	}
   741  	// we didn't find the expected call
   742  	return false
   743  }
   744  
   745  func (m *Mock) expectedCalls() []*Call {
   746  	return append([]*Call{}, m.ExpectedCalls...)
   747  }
   748  
   749  func (m *Mock) calls() []Call {
   750  	return append([]Call{}, m.Calls...)
   751  }
   752  
   753  /*
   754  	Arguments
   755  */
   756  
   757  // Arguments holds an array of method arguments or return values.
   758  type Arguments []interface{}
   759  
   760  const (
   761  	// Anything is used in Diff and Assert when the argument being tested
   762  	// shouldn't be taken into consideration.
   763  	Anything = "mock.Anything"
   764  )
   765  
   766  // AnythingOfTypeArgument contains the type of an argument
   767  // for use when type checking.  Used in Diff and Assert.
   768  //
   769  // Deprecated: this is an implementation detail that must not be used. Use [AnythingOfType] instead.
   770  type AnythingOfTypeArgument = anythingOfTypeArgument
   771  
   772  // anythingOfTypeArgument is a string that contains the type of an argument
   773  // for use when type checking.  Used in Diff and Assert.
   774  type anythingOfTypeArgument string
   775  
   776  // AnythingOfType returns a special value containing the
   777  // name of the type to check for. The type name will be matched against the type name returned by [reflect.Type.String].
   778  //
   779  // Used in Diff and Assert.
   780  //
   781  // For example:
   782  //
   783  //	Assert(t, AnythingOfType("string"), AnythingOfType("int"))
   784  func AnythingOfType(t string) AnythingOfTypeArgument {
   785  	return anythingOfTypeArgument(t)
   786  }
   787  
   788  // IsTypeArgument is a struct that contains the type of an argument
   789  // for use when type checking.  This is an alternative to AnythingOfType.
   790  // Used in Diff and Assert.
   791  type IsTypeArgument struct {
   792  	t reflect.Type
   793  }
   794  
   795  // IsType returns an IsTypeArgument object containing the type to check for.
   796  // You can provide a zero-value of the type to check.  This is an
   797  // alternative to AnythingOfType.  Used in Diff and Assert.
   798  //
   799  // For example:
   800  // Assert(t, IsType(""), IsType(0))
   801  func IsType(t interface{}) *IsTypeArgument {
   802  	return &IsTypeArgument{t: reflect.TypeOf(t)}
   803  }
   804  
   805  // FunctionalOptionsArgument is a struct that contains the type and value of an functional option argument
   806  // for use when type checking.
   807  type FunctionalOptionsArgument struct {
   808  	value interface{}
   809  }
   810  
   811  // String returns the string representation of FunctionalOptionsArgument
   812  func (f *FunctionalOptionsArgument) String() string {
   813  	var name string
   814  	tValue := reflect.ValueOf(f.value)
   815  	if tValue.Len() > 0 {
   816  		name = "[]" + reflect.TypeOf(tValue.Index(0).Interface()).String()
   817  	}
   818  
   819  	return strings.Replace(fmt.Sprintf("%#v", f.value), "[]interface {}", name, 1)
   820  }
   821  
   822  // FunctionalOptions returns an FunctionalOptionsArgument object containing the functional option type
   823  // and the values to check of
   824  //
   825  // For example:
   826  // Assert(t, FunctionalOptions("[]foo.FunctionalOption", foo.Opt1(), foo.Opt2()))
   827  func FunctionalOptions(value ...interface{}) *FunctionalOptionsArgument {
   828  	return &FunctionalOptionsArgument{
   829  		value: value,
   830  	}
   831  }
   832  
   833  // argumentMatcher performs custom argument matching, returning whether or
   834  // not the argument is matched by the expectation fixture function.
   835  type argumentMatcher struct {
   836  	// fn is a function which accepts one argument, and returns a bool.
   837  	fn reflect.Value
   838  }
   839  
   840  func (f argumentMatcher) Matches(argument interface{}) bool {
   841  	expectType := f.fn.Type().In(0)
   842  	expectTypeNilSupported := false
   843  	switch expectType.Kind() {
   844  	case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Ptr:
   845  		expectTypeNilSupported = true
   846  	}
   847  
   848  	argType := reflect.TypeOf(argument)
   849  	var arg reflect.Value
   850  	if argType == nil {
   851  		arg = reflect.New(expectType).Elem()
   852  	} else {
   853  		arg = reflect.ValueOf(argument)
   854  	}
   855  
   856  	if argType == nil && !expectTypeNilSupported {
   857  		panic(errors.New("attempting to call matcher with nil for non-nil expected type"))
   858  	}
   859  	if argType == nil || argType.AssignableTo(expectType) {
   860  		result := f.fn.Call([]reflect.Value{arg})
   861  		return result[0].Bool()
   862  	}
   863  	return false
   864  }
   865  
   866  func (f argumentMatcher) String() string {
   867  	return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).String())
   868  }
   869  
   870  // MatchedBy can be used to match a mock call based on only certain properties
   871  // from a complex struct or some calculation. It takes a function that will be
   872  // evaluated with the called argument and will return true when there's a match
   873  // and false otherwise.
   874  //
   875  // Example:
   876  // m.On("Do", MatchedBy(func(req *http.Request) bool { return req.Host == "example.com" }))
   877  //
   878  // |fn|, must be a function accepting a single argument (of the expected type)
   879  // which returns a bool. If |fn| doesn't match the required signature,
   880  // MatchedBy() panics.
   881  func MatchedBy(fn interface{}) argumentMatcher {
   882  	fnType := reflect.TypeOf(fn)
   883  
   884  	if fnType.Kind() != reflect.Func {
   885  		panic(fmt.Sprintf("assert: arguments: %s is not a func", fn))
   886  	}
   887  	if fnType.NumIn() != 1 {
   888  		panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn))
   889  	}
   890  	if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool {
   891  		panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn))
   892  	}
   893  
   894  	return argumentMatcher{fn: reflect.ValueOf(fn)}
   895  }
   896  
   897  // Get Returns the argument at the specified index.
   898  func (args Arguments) Get(index int) interface{} {
   899  	if index+1 > len(args) {
   900  		panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args)))
   901  	}
   902  	return args[index]
   903  }
   904  
   905  // Is gets whether the objects match the arguments specified.
   906  func (args Arguments) Is(objects ...interface{}) bool {
   907  	for i, obj := range args {
   908  		if obj != objects[i] {
   909  			return false
   910  		}
   911  	}
   912  	return true
   913  }
   914  
   915  // Diff gets a string describing the differences between the arguments
   916  // and the specified objects.
   917  //
   918  // Returns the diff string and number of differences found.
   919  func (args Arguments) Diff(objects []interface{}) (string, int) {
   920  	// TODO: could return string as error and nil for No difference
   921  
   922  	output := "\n"
   923  	var differences int
   924  
   925  	maxArgCount := len(args)
   926  	if len(objects) > maxArgCount {
   927  		maxArgCount = len(objects)
   928  	}
   929  
   930  	for i := 0; i < maxArgCount; i++ {
   931  		var actual, expected interface{}
   932  		var actualFmt, expectedFmt string
   933  
   934  		if len(objects) <= i {
   935  			actual = "(Missing)"
   936  			actualFmt = "(Missing)"
   937  		} else {
   938  			actual = objects[i]
   939  			actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
   940  		}
   941  
   942  		if len(args) <= i {
   943  			expected = "(Missing)"
   944  			expectedFmt = "(Missing)"
   945  		} else {
   946  			expected = args[i]
   947  			expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
   948  		}
   949  
   950  		if matcher, ok := expected.(argumentMatcher); ok {
   951  			var matches bool
   952  			func() {
   953  				defer func() {
   954  					if r := recover(); r != nil {
   955  						actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
   956  					}
   957  				}()
   958  				matches = matcher.Matches(actual)
   959  			}()
   960  			if matches {
   961  				output = fmt.Sprintf("%s\t%d: PASS:  %s matched by %s\n", output, i, actualFmt, matcher)
   962  			} else {
   963  				differences++
   964  				output = fmt.Sprintf("%s\t%d: FAIL:  %s not matched by %s\n", output, i, actualFmt, matcher)
   965  			}
   966  		} else {
   967  			switch expected := expected.(type) {
   968  			case anythingOfTypeArgument:
   969  				// type checking
   970  				if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) {
   971  					// not match
   972  					differences++
   973  					output = fmt.Sprintf("%s\t%d: FAIL:  type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
   974  				}
   975  			case *IsTypeArgument:
   976  				actualT := reflect.TypeOf(actual)
   977  				if actualT != expected.t {
   978  					differences++
   979  					output = fmt.Sprintf("%s\t%d: FAIL:  type %s != type %s - %s\n", output, i, expected.t.Name(), actualT.Name(), actualFmt)
   980  				}
   981  			case *FunctionalOptionsArgument:
   982  				t := expected.value
   983  
   984  				var name string
   985  				tValue := reflect.ValueOf(t)
   986  				if tValue.Len() > 0 {
   987  					name = "[]" + reflect.TypeOf(tValue.Index(0).Interface()).String()
   988  				}
   989  
   990  				tName := reflect.TypeOf(t).Name()
   991  				if name != reflect.TypeOf(actual).String() && tValue.Len() != 0 {
   992  					differences++
   993  					output = fmt.Sprintf("%s\t%d: FAIL:  type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt)
   994  				} else {
   995  					if ef, af := assertOpts(t, actual); ef == "" && af == "" {
   996  						// match
   997  						output = fmt.Sprintf("%s\t%d: PASS:  %s == %s\n", output, i, tName, tName)
   998  					} else {
   999  						// not match
  1000  						differences++
  1001  						output = fmt.Sprintf("%s\t%d: FAIL:  %s != %s\n", output, i, af, ef)
  1002  					}
  1003  				}
  1004  
  1005  			default:
  1006  				if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
  1007  					// match
  1008  					output = fmt.Sprintf("%s\t%d: PASS:  %s == %s\n", output, i, actualFmt, expectedFmt)
  1009  				} else {
  1010  					// not match
  1011  					differences++
  1012  					output = fmt.Sprintf("%s\t%d: FAIL:  %s != %s\n", output, i, actualFmt, expectedFmt)
  1013  				}
  1014  			}
  1015  		}
  1016  
  1017  	}
  1018  
  1019  	if differences == 0 {
  1020  		return "No differences.", differences
  1021  	}
  1022  
  1023  	return output, differences
  1024  }
  1025  
  1026  // Assert compares the arguments with the specified objects and fails if
  1027  // they do not exactly match.
  1028  func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
  1029  	if h, ok := t.(tHelper); ok {
  1030  		h.Helper()
  1031  	}
  1032  
  1033  	// get the differences
  1034  	diff, diffCount := args.Diff(objects)
  1035  
  1036  	if diffCount == 0 {
  1037  		return true
  1038  	}
  1039  
  1040  	// there are differences... report them...
  1041  	t.Logf(diff)
  1042  	t.Errorf("%sArguments do not match.", assert.CallerInfo())
  1043  
  1044  	return false
  1045  }
  1046  
  1047  // String gets the argument at the specified index. Panics if there is no argument, or
  1048  // if the argument is of the wrong type.
  1049  //
  1050  // If no index is provided, String() returns a complete string representation
  1051  // of the arguments.
  1052  func (args Arguments) String(indexOrNil ...int) string {
  1053  	if len(indexOrNil) == 0 {
  1054  		// normal String() method - return a string representation of the args
  1055  		var argsStr []string
  1056  		for _, arg := range args {
  1057  			argsStr = append(argsStr, fmt.Sprintf("%T", arg)) // handles nil nicely
  1058  		}
  1059  		return strings.Join(argsStr, ",")
  1060  	} else if len(indexOrNil) == 1 {
  1061  		// Index has been specified - get the argument at that index
  1062  		index := indexOrNil[0]
  1063  		var s string
  1064  		var ok bool
  1065  		if s, ok = args.Get(index).(string); !ok {
  1066  			panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
  1067  		}
  1068  		return s
  1069  	}
  1070  
  1071  	panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String.  Must be 0 or 1, not %d", len(indexOrNil)))
  1072  }
  1073  
  1074  // Int gets the argument at the specified index. Panics if there is no argument, or
  1075  // if the argument is of the wrong type.
  1076  func (args Arguments) Int(index int) int {
  1077  	var s int
  1078  	var ok bool
  1079  	if s, ok = args.Get(index).(int); !ok {
  1080  		panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
  1081  	}
  1082  	return s
  1083  }
  1084  
  1085  // Error gets the argument at the specified index. Panics if there is no argument, or
  1086  // if the argument is of the wrong type.
  1087  func (args Arguments) Error(index int) error {
  1088  	obj := args.Get(index)
  1089  	var s error
  1090  	var ok bool
  1091  	if obj == nil {
  1092  		return nil
  1093  	}
  1094  	if s, ok = obj.(error); !ok {
  1095  		panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
  1096  	}
  1097  	return s
  1098  }
  1099  
  1100  // Bool gets the argument at the specified index. Panics if there is no argument, or
  1101  // if the argument is of the wrong type.
  1102  func (args Arguments) Bool(index int) bool {
  1103  	var s bool
  1104  	var ok bool
  1105  	if s, ok = args.Get(index).(bool); !ok {
  1106  		panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
  1107  	}
  1108  	return s
  1109  }
  1110  
  1111  func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
  1112  	t := reflect.TypeOf(v)
  1113  	k := t.Kind()
  1114  
  1115  	if k == reflect.Ptr {
  1116  		t = t.Elem()
  1117  		k = t.Kind()
  1118  	}
  1119  	return t, k
  1120  }
  1121  
  1122  func diffArguments(expected Arguments, actual Arguments) string {
  1123  	if len(expected) != len(actual) {
  1124  		return fmt.Sprintf("Provided %v arguments, mocked for %v arguments", len(expected), len(actual))
  1125  	}
  1126  
  1127  	for x := range expected {
  1128  		if diffString := diff(expected[x], actual[x]); diffString != "" {
  1129  			return fmt.Sprintf("Difference found in argument %v:\n\n%s", x, diffString)
  1130  		}
  1131  	}
  1132  
  1133  	return ""
  1134  }
  1135  
  1136  // diff returns a diff of both values as long as both are of the same type and
  1137  // are a struct, map, slice or array. Otherwise it returns an empty string.
  1138  func diff(expected interface{}, actual interface{}) string {
  1139  	if expected == nil || actual == nil {
  1140  		return ""
  1141  	}
  1142  
  1143  	et, ek := typeAndKind(expected)
  1144  	at, _ := typeAndKind(actual)
  1145  
  1146  	if et != at {
  1147  		return ""
  1148  	}
  1149  
  1150  	if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array {
  1151  		return ""
  1152  	}
  1153  
  1154  	e := spewConfig.Sdump(expected)
  1155  	a := spewConfig.Sdump(actual)
  1156  
  1157  	diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
  1158  		A:        difflib.SplitLines(e),
  1159  		B:        difflib.SplitLines(a),
  1160  		FromFile: "Expected",
  1161  		FromDate: "",
  1162  		ToFile:   "Actual",
  1163  		ToDate:   "",
  1164  		Context:  1,
  1165  	})
  1166  
  1167  	return diff
  1168  }
  1169  
  1170  var spewConfig = spew.ConfigState{
  1171  	Indent:                  " ",
  1172  	DisablePointerAddresses: true,
  1173  	DisableCapacities:       true,
  1174  	SortKeys:                true,
  1175  }
  1176  
  1177  type tHelper interface {
  1178  	Helper()
  1179  }
  1180  
  1181  func assertOpts(expected, actual interface{}) (expectedFmt, actualFmt string) {
  1182  	expectedOpts := reflect.ValueOf(expected)
  1183  	actualOpts := reflect.ValueOf(actual)
  1184  	var expectedNames []string
  1185  	for i := 0; i < expectedOpts.Len(); i++ {
  1186  		expectedNames = append(expectedNames, funcName(expectedOpts.Index(i).Interface()))
  1187  	}
  1188  	var actualNames []string
  1189  	for i := 0; i < actualOpts.Len(); i++ {
  1190  		actualNames = append(actualNames, funcName(actualOpts.Index(i).Interface()))
  1191  	}
  1192  	if !assert.ObjectsAreEqual(expectedNames, actualNames) {
  1193  		expectedFmt = fmt.Sprintf("%v", expectedNames)
  1194  		actualFmt = fmt.Sprintf("%v", actualNames)
  1195  		return
  1196  	}
  1197  
  1198  	for i := 0; i < expectedOpts.Len(); i++ {
  1199  		expectedOpt := expectedOpts.Index(i).Interface()
  1200  		actualOpt := actualOpts.Index(i).Interface()
  1201  
  1202  		expectedFunc := expectedNames[i]
  1203  		actualFunc := actualNames[i]
  1204  		if expectedFunc != actualFunc {
  1205  			expectedFmt = expectedFunc
  1206  			actualFmt = actualFunc
  1207  			return
  1208  		}
  1209  
  1210  		ot := reflect.TypeOf(expectedOpt)
  1211  		var expectedValues []reflect.Value
  1212  		var actualValues []reflect.Value
  1213  		if ot.NumIn() == 0 {
  1214  			return
  1215  		}
  1216  
  1217  		for i := 0; i < ot.NumIn(); i++ {
  1218  			vt := ot.In(i).Elem()
  1219  			expectedValues = append(expectedValues, reflect.New(vt))
  1220  			actualValues = append(actualValues, reflect.New(vt))
  1221  		}
  1222  
  1223  		reflect.ValueOf(expectedOpt).Call(expectedValues)
  1224  		reflect.ValueOf(actualOpt).Call(actualValues)
  1225  
  1226  		for i := 0; i < ot.NumIn(); i++ {
  1227  			if !assert.ObjectsAreEqual(expectedValues[i].Interface(), actualValues[i].Interface()) {
  1228  				expectedFmt = fmt.Sprintf("%s %+v", expectedNames[i], expectedValues[i].Interface())
  1229  				actualFmt = fmt.Sprintf("%s %+v", expectedNames[i], actualValues[i].Interface())
  1230  				return
  1231  			}
  1232  		}
  1233  	}
  1234  
  1235  	return "", ""
  1236  }
  1237  
  1238  func funcName(opt interface{}) string {
  1239  	n := runtime.FuncForPC(reflect.ValueOf(opt).Pointer()).Name()
  1240  	return strings.TrimSuffix(path.Base(n), path.Ext(n))
  1241  }
  1242  

View as plain text