1
2
3
4
5
6
7
8
9
10
11
12
13
14
15 package gomock
16
17 import (
18 "context"
19 "fmt"
20 "reflect"
21 "runtime"
22 "sync"
23 )
24
25
26
27 type TestReporter interface {
28 Errorf(format string, args ...interface{})
29 Fatalf(format string, args ...interface{})
30 }
31
32
33
34 type TestHelper interface {
35 TestReporter
36 Helper()
37 }
38
39
40
41
42
43
44 type cleanuper interface {
45 Cleanup(func())
46 }
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72 type Controller struct {
73
74
75
76
77
78 T TestHelper
79 mu sync.Mutex
80 expectedCalls *callSet
81 finished bool
82 }
83
84
85
86
87
88
89 func NewController(t TestReporter) *Controller {
90 h, ok := t.(TestHelper)
91 if !ok {
92 h = &nopTestHelper{t}
93 }
94 ctrl := &Controller{
95 T: h,
96 expectedCalls: newCallSet(),
97 }
98 if c, ok := isCleanuper(ctrl.T); ok {
99 c.Cleanup(func() {
100 ctrl.T.Helper()
101 ctrl.finish(true, nil)
102 })
103 }
104
105 return ctrl
106 }
107
108 type cancelReporter struct {
109 t TestHelper
110 cancel func()
111 }
112
113 func (r *cancelReporter) Errorf(format string, args ...interface{}) {
114 r.t.Errorf(format, args...)
115 }
116 func (r *cancelReporter) Fatalf(format string, args ...interface{}) {
117 defer r.cancel()
118 r.t.Fatalf(format, args...)
119 }
120
121 func (r *cancelReporter) Helper() {
122 r.t.Helper()
123 }
124
125
126
127 func WithContext(ctx context.Context, t TestReporter) (*Controller, context.Context) {
128 h, ok := t.(TestHelper)
129 if !ok {
130 h = &nopTestHelper{t: t}
131 }
132
133 ctx, cancel := context.WithCancel(ctx)
134 return NewController(&cancelReporter{t: h, cancel: cancel}), ctx
135 }
136
137 type nopTestHelper struct {
138 t TestReporter
139 }
140
141 func (h *nopTestHelper) Errorf(format string, args ...interface{}) {
142 h.t.Errorf(format, args...)
143 }
144 func (h *nopTestHelper) Fatalf(format string, args ...interface{}) {
145 h.t.Fatalf(format, args...)
146 }
147
148 func (h nopTestHelper) Helper() {}
149
150
151 func (ctrl *Controller) RecordCall(receiver interface{}, method string, args ...interface{}) *Call {
152 ctrl.T.Helper()
153
154 recv := reflect.ValueOf(receiver)
155 for i := 0; i < recv.Type().NumMethod(); i++ {
156 if recv.Type().Method(i).Name == method {
157 return ctrl.RecordCallWithMethodType(receiver, method, recv.Method(i).Type(), args...)
158 }
159 }
160 ctrl.T.Fatalf("gomock: failed finding method %s on %T", method, receiver)
161 panic("unreachable")
162 }
163
164
165 func (ctrl *Controller) RecordCallWithMethodType(receiver interface{}, method string, methodType reflect.Type, args ...interface{}) *Call {
166 ctrl.T.Helper()
167
168 call := newCall(ctrl.T, receiver, method, methodType, args...)
169
170 ctrl.mu.Lock()
171 defer ctrl.mu.Unlock()
172 ctrl.expectedCalls.Add(call)
173
174 return call
175 }
176
177
178 func (ctrl *Controller) Call(receiver interface{}, method string, args ...interface{}) []interface{} {
179 ctrl.T.Helper()
180
181
182 actions := func() []func([]interface{}) []interface{} {
183 ctrl.T.Helper()
184 ctrl.mu.Lock()
185 defer ctrl.mu.Unlock()
186
187 expected, err := ctrl.expectedCalls.FindMatch(receiver, method, args)
188 if err != nil {
189
190
191
192 origin := callerInfo(3)
193 ctrl.T.Fatalf("Unexpected call to %T.%v(%v) at %s because: %s", receiver, method, args, origin, err)
194 }
195
196
197
198
199 preReqCalls := expected.dropPrereqs()
200 for _, preReqCall := range preReqCalls {
201 ctrl.expectedCalls.Remove(preReqCall)
202 }
203
204 actions := expected.call()
205 if expected.exhausted() {
206 ctrl.expectedCalls.Remove(expected)
207 }
208 return actions
209 }()
210
211 var rets []interface{}
212 for _, action := range actions {
213 if r := action(args); r != nil {
214 rets = r
215 }
216 }
217
218 return rets
219 }
220
221
222
223
224
225
226
227 func (ctrl *Controller) Finish() {
228
229
230 err := recover()
231 ctrl.finish(false, err)
232 }
233
234 func (ctrl *Controller) finish(cleanup bool, panicErr interface{}) {
235 ctrl.T.Helper()
236
237 ctrl.mu.Lock()
238 defer ctrl.mu.Unlock()
239
240 if ctrl.finished {
241 if _, ok := isCleanuper(ctrl.T); !ok {
242 ctrl.T.Fatalf("Controller.Finish was called more than once. It has to be called exactly once.")
243 }
244 return
245 }
246 ctrl.finished = true
247
248
249 if panicErr != nil {
250 panic(panicErr)
251 }
252
253
254 failures := ctrl.expectedCalls.Failures()
255 for _, call := range failures {
256 ctrl.T.Errorf("missing call(s) to %v", call)
257 }
258 if len(failures) != 0 {
259 if !cleanup {
260 ctrl.T.Fatalf("aborting test due to missing call(s)")
261 return
262 }
263 ctrl.T.Errorf("aborting test due to missing call(s)")
264 }
265 }
266
267
268
269 func callerInfo(skip int) string {
270 if _, file, line, ok := runtime.Caller(skip + 1); ok {
271 return fmt.Sprintf("%s:%d", file, line)
272 }
273 return "unknown file"
274 }
275
276
277 func isCleanuper(t TestReporter) (cleanuper, bool) {
278 tr := unwrapTestReporter(t)
279 c, ok := tr.(cleanuper)
280 return c, ok
281 }
282
283
284 func unwrapTestReporter(t TestReporter) TestReporter {
285 tr := t
286 switch nt := t.(type) {
287 case *cancelReporter:
288 tr = nt.t
289 if h, check := tr.(*nopTestHelper); check {
290 tr = h.t
291 }
292 case *nopTestHelper:
293 tr = nt.t
294 default:
295
296 }
297 return tr
298 }
299
View as plain text