...

Source file src/github.com/golang/mock/gomock/controller.go

Documentation: github.com/golang/mock/gomock

     1  // Copyright 2010 Google Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package gomock
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"reflect"
    21  	"runtime"
    22  	"sync"
    23  )
    24  
    25  // A TestReporter is something that can be used to report test failures.  It
    26  // is satisfied by the standard library's *testing.T.
    27  type TestReporter interface {
    28  	Errorf(format string, args ...interface{})
    29  	Fatalf(format string, args ...interface{})
    30  }
    31  
    32  // TestHelper is a TestReporter that has the Helper method.  It is satisfied
    33  // by the standard library's *testing.T.
    34  type TestHelper interface {
    35  	TestReporter
    36  	Helper()
    37  }
    38  
    39  // cleanuper is used to check if TestHelper also has the `Cleanup` method. A
    40  // common pattern is to pass in a `*testing.T` to
    41  // `NewController(t TestReporter)`. In Go 1.14+, `*testing.T` has a cleanup
    42  // method. This can be utilized to call `Finish()` so the caller of this library
    43  // does not have to.
    44  type cleanuper interface {
    45  	Cleanup(func())
    46  }
    47  
    48  // A Controller represents the top-level control of a mock ecosystem.  It
    49  // defines the scope and lifetime of mock objects, as well as their
    50  // expectations.  It is safe to call Controller's methods from multiple
    51  // goroutines. Each test should create a new Controller and invoke Finish via
    52  // defer.
    53  //
    54  //   func TestFoo(t *testing.T) {
    55  //     ctrl := gomock.NewController(t)
    56  //     defer ctrl.Finish()
    57  //     // ..
    58  //   }
    59  //
    60  //   func TestBar(t *testing.T) {
    61  //     t.Run("Sub-Test-1", st) {
    62  //       ctrl := gomock.NewController(st)
    63  //       defer ctrl.Finish()
    64  //       // ..
    65  //     })
    66  //     t.Run("Sub-Test-2", st) {
    67  //       ctrl := gomock.NewController(st)
    68  //       defer ctrl.Finish()
    69  //       // ..
    70  //     })
    71  //   })
    72  type Controller struct {
    73  	// T should only be called within a generated mock. It is not intended to
    74  	// be used in user code and may be changed in future versions. T is the
    75  	// TestReporter passed in when creating the Controller via NewController.
    76  	// If the TestReporter does not implement a TestHelper it will be wrapped
    77  	// with a nopTestHelper.
    78  	T             TestHelper
    79  	mu            sync.Mutex
    80  	expectedCalls *callSet
    81  	finished      bool
    82  }
    83  
    84  // NewController returns a new Controller. It is the preferred way to create a
    85  // Controller.
    86  //
    87  // New in go1.14+, if you are passing a *testing.T into this function you no
    88  // longer need to call ctrl.Finish() in your test methods.
    89  func NewController(t TestReporter) *Controller {
    90  	h, ok := t.(TestHelper)
    91  	if !ok {
    92  		h = &nopTestHelper{t}
    93  	}
    94  	ctrl := &Controller{
    95  		T:             h,
    96  		expectedCalls: newCallSet(),
    97  	}
    98  	if c, ok := isCleanuper(ctrl.T); ok {
    99  		c.Cleanup(func() {
   100  			ctrl.T.Helper()
   101  			ctrl.finish(true, nil)
   102  		})
   103  	}
   104  
   105  	return ctrl
   106  }
   107  
   108  type cancelReporter struct {
   109  	t      TestHelper
   110  	cancel func()
   111  }
   112  
   113  func (r *cancelReporter) Errorf(format string, args ...interface{}) {
   114  	r.t.Errorf(format, args...)
   115  }
   116  func (r *cancelReporter) Fatalf(format string, args ...interface{}) {
   117  	defer r.cancel()
   118  	r.t.Fatalf(format, args...)
   119  }
   120  
   121  func (r *cancelReporter) Helper() {
   122  	r.t.Helper()
   123  }
   124  
   125  // WithContext returns a new Controller and a Context, which is cancelled on any
   126  // fatal failure.
   127  func WithContext(ctx context.Context, t TestReporter) (*Controller, context.Context) {
   128  	h, ok := t.(TestHelper)
   129  	if !ok {
   130  		h = &nopTestHelper{t: t}
   131  	}
   132  
   133  	ctx, cancel := context.WithCancel(ctx)
   134  	return NewController(&cancelReporter{t: h, cancel: cancel}), ctx
   135  }
   136  
   137  type nopTestHelper struct {
   138  	t TestReporter
   139  }
   140  
   141  func (h *nopTestHelper) Errorf(format string, args ...interface{}) {
   142  	h.t.Errorf(format, args...)
   143  }
   144  func (h *nopTestHelper) Fatalf(format string, args ...interface{}) {
   145  	h.t.Fatalf(format, args...)
   146  }
   147  
   148  func (h nopTestHelper) Helper() {}
   149  
   150  // RecordCall is called by a mock. It should not be called by user code.
   151  func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...interface{}) *Call {
   152  	ctrl.T.Helper()
   153  
   154  	recv := reflect.ValueOf(receiver)
   155  	for i := 0; i < recv.Type().NumMethod(); i++ {
   156  		if recv.Type().Method(i).Name == method {
   157  			return ctrl.RecordCallWithMethodType(receiver, method, recv.Method(i).Type(), args...)
   158  		}
   159  	}
   160  	ctrl.T.Fatalf("gomock: failed finding method %s on %T", method, receiver)
   161  	panic("unreachable")
   162  }
   163  
   164  // RecordCallWithMethodType is called by a mock. It should not be called by user code.
   165  func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
   166  	ctrl.T.Helper()
   167  
   168  	call := newCall(ctrl.T, receiver, method, methodType, args...)
   169  
   170  	ctrl.mu.Lock()
   171  	defer ctrl.mu.Unlock()
   172  	ctrl.expectedCalls.Add(call)
   173  
   174  	return call
   175  }
   176  
   177  // Call is called by a mock. It should not be called by user code.
   178  func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} {
   179  	ctrl.T.Helper()
   180  
   181  	// Nest this code so we can use defer to make sure the lock is released.
   182  	actions := func() []func([]interface{}) []interface{} {
   183  		ctrl.T.Helper()
   184  		ctrl.mu.Lock()
   185  		defer ctrl.mu.Unlock()
   186  
   187  		expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
   188  		if err != nil {
   189  			// callerInfo's skip should be updated if the number of calls between the user's test
   190  			// and this line changes, i.e. this code is wrapped in another anonymous function.
   191  			// 0 is us, 1 is controller.Call(), 2 is the generated mock, and 3 is the user's test.
   192  			origin := callerInfo(3)
   193  			ctrl.T.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
   194  		}
   195  
   196  		// Two things happen here:
   197  		// * the matching call no longer needs to check prerequite calls,
   198  		// * and the prerequite calls are no longer expected, so remove them.
   199  		preReqCalls := expected.dropPrereqs()
   200  		for _, preReqCall := range preReqCalls {
   201  			ctrl.expectedCalls.Remove(preReqCall)
   202  		}
   203  
   204  		actions := expected.call()
   205  		if expected.exhausted() {
   206  			ctrl.expectedCalls.Remove(expected)
   207  		}
   208  		return actions
   209  	}()
   210  
   211  	var rets []interface{}
   212  	for _, action := range actions {
   213  		if r := action(args); r != nil {
   214  			rets = r
   215  		}
   216  	}
   217  
   218  	return rets
   219  }
   220  
   221  // Finish checks to see if all the methods that were expected to be called
   222  // were called. It should be invoked for each Controller. It is not idempotent
   223  // and therefore can only be invoked once.
   224  //
   225  // New in go1.14+, if you are passing a *testing.T into NewController function you no
   226  // longer need to call ctrl.Finish() in your test methods.
   227  func (ctrl *Controller) Finish() {
   228  	// If we're currently panicking, probably because this is a deferred call.
   229  	// This must be recovered in the deferred function.
   230  	err := recover()
   231  	ctrl.finish(false, err)
   232  }
   233  
   234  func (ctrl *Controller) finish(cleanup bool, panicErr interface{}) {
   235  	ctrl.T.Helper()
   236  
   237  	ctrl.mu.Lock()
   238  	defer ctrl.mu.Unlock()
   239  
   240  	if ctrl.finished {
   241  		if _, ok := isCleanuper(ctrl.T); !ok {
   242  			ctrl.T.Fatalf("Controller.Finish was called more than once. It has to be called exactly once.")
   243  		}
   244  		return
   245  	}
   246  	ctrl.finished = true
   247  
   248  	// Short-circuit, pass through the panic.
   249  	if panicErr != nil {
   250  		panic(panicErr)
   251  	}
   252  
   253  	// Check that all remaining expected calls are satisfied.
   254  	failures := ctrl.expectedCalls.Failures()
   255  	for _, call := range failures {
   256  		ctrl.T.Errorf("missing call(s) to %v", call)
   257  	}
   258  	if len(failures) != 0 {
   259  		if !cleanup {
   260  			ctrl.T.Fatalf("aborting test due to missing call(s)")
   261  			return
   262  		}
   263  		ctrl.T.Errorf("aborting test due to missing call(s)")
   264  	}
   265  }
   266  
   267  // callerInfo returns the file:line of the call site. skip is the number
   268  // of stack frames to skip when reporting. 0 is callerInfo's call site.
   269  func callerInfo(skip int) string {
   270  	if _, file, line, ok := runtime.Caller(skip + 1); ok {
   271  		return fmt.Sprintf("%s:%d", file, line)
   272  	}
   273  	return "unknown file"
   274  }
   275  
   276  // isCleanuper checks it if t's base TestReporter has a Cleanup method.
   277  func isCleanuper(t TestReporter) (cleanuper, bool) {
   278  	tr := unwrapTestReporter(t)
   279  	c, ok := tr.(cleanuper)
   280  	return c, ok
   281  }
   282  
   283  // unwrapTestReporter unwraps TestReporter to the base implementation.
   284  func unwrapTestReporter(t TestReporter) TestReporter {
   285  	tr := t
   286  	switch nt := t.(type) {
   287  	case *cancelReporter:
   288  		tr = nt.t
   289  		if h, check := tr.(*nopTestHelper); check {
   290  			tr = h.t
   291  		}
   292  	case *nopTestHelper:
   293  		tr = nt.t
   294  	default:
   295  		// not wrapped
   296  	}
   297  	return tr
   298  }
   299  

View as plain text