1 package mock
2
3 import (
4 "errors"
5 "fmt"
6 "path"
7 "reflect"
8 "regexp"
9 "runtime"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/davecgh/go-spew/spew"
15 "github.com/pmezard/go-difflib/difflib"
16 "github.com/stretchr/objx"
17
18 "github.com/stretchr/testify/assert"
19 )
20
21
22 var gccgoRE = regexp.MustCompile(`\.pN\d+_`)
23
24
25 type TestingT interface {
26 Logf(format string, args ...interface{})
27 Errorf(format string, args ...interface{})
28 FailNow()
29 }
30
31
34
35
36
37 type Call struct {
38 Parent *Mock
39
40
41 Method string
42
43
44 Arguments Arguments
45
46
47
48 ReturnArguments Arguments
49
50
51 callerInfo []string
52
53
54
55 Repeatability int
56
57
58 totalCalls int
59
60
61 optional bool
62
63
64
65 WaitFor <-chan time.Time
66
67 waitTime time.Duration
68
69
70
71
72 RunFn func(Arguments)
73
74
75
76
77 PanicMsg *string
78
79
80 requires []*Call
81 }
82
83 func newCall(parent *Mock, methodName string, callerInfo []string, methodArguments ...interface{}) *Call {
84 return &Call{
85 Parent: parent,
86 Method: methodName,
87 Arguments: methodArguments,
88 ReturnArguments: make([]interface{}, 0),
89 callerInfo: callerInfo,
90 Repeatability: 0,
91 WaitFor: nil,
92 RunFn: nil,
93 PanicMsg: nil,
94 }
95 }
96
97 func (c *Call) lock() {
98 c.Parent.mutex.Lock()
99 }
100
101 func (c *Call) unlock() {
102 c.Parent.mutex.Unlock()
103 }
104
105
106
107
108 func (c *Call) Return(returnArguments ...interface{}) *Call {
109 c.lock()
110 defer c.unlock()
111
112 c.ReturnArguments = returnArguments
113
114 return c
115 }
116
117
118
119
120 func (c *Call) Panic(msg string) *Call {
121 c.lock()
122 defer c.unlock()
123
124 c.PanicMsg = &msg
125
126 return c
127 }
128
129
130
131
132 func (c *Call) Once() *Call {
133 return c.Times(1)
134 }
135
136
137
138
139 func (c *Call) Twice() *Call {
140 return c.Times(2)
141 }
142
143
144
145
146
147 func (c *Call) Times(i int) *Call {
148 c.lock()
149 defer c.unlock()
150 c.Repeatability = i
151 return c
152 }
153
154
155
156
157
158 func (c *Call) WaitUntil(w <-chan time.Time) *Call {
159 c.lock()
160 defer c.unlock()
161 c.WaitFor = w
162 return c
163 }
164
165
166
167
168 func (c *Call) After(d time.Duration) *Call {
169 c.lock()
170 defer c.unlock()
171 c.waitTime = d
172 return c
173 }
174
175
176
177
178
179
180
181
182
183 func (c *Call) Run(fn func(args Arguments)) *Call {
184 c.lock()
185 defer c.unlock()
186 c.RunFn = fn
187 return c
188 }
189
190
191
192 func (c *Call) Maybe() *Call {
193 c.lock()
194 defer c.unlock()
195 c.optional = true
196 return c
197 }
198
199
200
201
202
203
204
205
206
207 func (c *Call) On(methodName string, arguments ...interface{}) *Call {
208 return c.Parent.On(methodName, arguments...)
209 }
210
211
212
213
214 func (c *Call) Unset() *Call {
215 var unlockOnce sync.Once
216
217 for _, arg := range c.Arguments {
218 if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
219 panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
220 }
221 }
222
223 c.lock()
224 defer unlockOnce.Do(c.unlock)
225
226 foundMatchingCall := false
227
228
229 var index int
230 for _, call := range c.Parent.ExpectedCalls {
231 if call.Method == c.Method {
232 _, diffCount := call.Arguments.Diff(c.Arguments)
233 if diffCount == 0 {
234 foundMatchingCall = true
235
236 continue
237 }
238 }
239 c.Parent.ExpectedCalls[index] = call
240 index++
241 }
242
243 c.Parent.ExpectedCalls = c.Parent.ExpectedCalls[:index]
244
245 if !foundMatchingCall {
246 unlockOnce.Do(c.unlock)
247 c.Parent.fail("\n\nmock: Could not find expected call\n-----------------------------\n\n%s\n\n",
248 callString(c.Method, c.Arguments, true),
249 )
250 }
251
252 return c
253 }
254
255
256
257
258
259
260
261
262 func (c *Call) NotBefore(calls ...*Call) *Call {
263 c.lock()
264 defer c.unlock()
265
266 for _, call := range calls {
267 if call.Parent == nil {
268 panic("not before calls must be created with Mock.On()")
269 }
270 }
271
272 c.requires = append(c.requires, calls...)
273 return c
274 }
275
276
277
278
279 type Mock struct {
280
281
282 ExpectedCalls []*Call
283
284
285 Calls []Call
286
287
288
289 test TestingT
290
291
292
293 testData objx.Map
294
295 mutex sync.Mutex
296 }
297
298
299
300
301
302 func (m *Mock) String() string {
303 return fmt.Sprintf("%[1]T<%[1]p>", m)
304 }
305
306
307
308 func (m *Mock) TestData() objx.Map {
309 if m.testData == nil {
310 m.testData = make(objx.Map)
311 }
312
313 return m.testData
314 }
315
316
319
320
321 func (m *Mock) Test(t TestingT) {
322 m.mutex.Lock()
323 defer m.mutex.Unlock()
324 m.test = t
325 }
326
327
328
329
330 func (m *Mock) fail(format string, args ...interface{}) {
331 m.mutex.Lock()
332 defer m.mutex.Unlock()
333
334 if m.test == nil {
335 panic(fmt.Sprintf(format, args...))
336 }
337 m.test.Errorf(format, args...)
338 m.test.FailNow()
339 }
340
341
342
343
344
345 func (m *Mock) On(methodName string, arguments ...interface{}) *Call {
346 for _, arg := range arguments {
347 if v := reflect.ValueOf(arg); v.Kind() == reflect.Func {
348 panic(fmt.Sprintf("cannot use Func in expectations. Use mock.AnythingOfType(\"%T\")", arg))
349 }
350 }
351
352 m.mutex.Lock()
353 defer m.mutex.Unlock()
354 c := newCall(m, methodName, assert.CallerInfo(), arguments...)
355 m.ExpectedCalls = append(m.ExpectedCalls, c)
356 return c
357 }
358
359
360
361
362
363 func (m *Mock) findExpectedCall(method string, arguments ...interface{}) (int, *Call) {
364 var expectedCall *Call
365
366 for i, call := range m.ExpectedCalls {
367 if call.Method == method {
368 _, diffCount := call.Arguments.Diff(arguments)
369 if diffCount == 0 {
370 expectedCall = call
371 if call.Repeatability > -1 {
372 return i, call
373 }
374 }
375 }
376 }
377
378 return -1, expectedCall
379 }
380
381 type matchCandidate struct {
382 call *Call
383 mismatch string
384 diffCount int
385 }
386
387 func (c matchCandidate) isBetterMatchThan(other matchCandidate) bool {
388 if c.call == nil {
389 return false
390 }
391 if other.call == nil {
392 return true
393 }
394
395 if c.diffCount > other.diffCount {
396 return false
397 }
398 if c.diffCount < other.diffCount {
399 return true
400 }
401
402 if c.call.Repeatability > 0 && other.call.Repeatability <= 0 {
403 return true
404 }
405 return false
406 }
407
408 func (m *Mock) findClosestCall(method string, arguments ...interface{}) (*Call, string) {
409 var bestMatch matchCandidate
410
411 for _, call := range m.expectedCalls() {
412 if call.Method == method {
413
414 errInfo, tempDiffCount := call.Arguments.Diff(arguments)
415 tempCandidate := matchCandidate{
416 call: call,
417 mismatch: errInfo,
418 diffCount: tempDiffCount,
419 }
420 if tempCandidate.isBetterMatchThan(bestMatch) {
421 bestMatch = tempCandidate
422 }
423 }
424 }
425
426 return bestMatch.call, bestMatch.mismatch
427 }
428
429 func callString(method string, arguments Arguments, includeArgumentValues bool) string {
430 var argValsString string
431 if includeArgumentValues {
432 var argVals []string
433 for argIndex, arg := range arguments {
434 if _, ok := arg.(*FunctionalOptionsArgument); ok {
435 argVals = append(argVals, fmt.Sprintf("%d: %s", argIndex, arg))
436 continue
437 }
438 argVals = append(argVals, fmt.Sprintf("%d: %#v", argIndex, arg))
439 }
440 argValsString = fmt.Sprintf("\n\t\t%s", strings.Join(argVals, "\n\t\t"))
441 }
442
443 return fmt.Sprintf("%s(%s)%s", method, arguments.String(), argValsString)
444 }
445
446
447
448
449
450 func (m *Mock) Called(arguments ...interface{}) Arguments {
451
452 pc, _, _, ok := runtime.Caller(1)
453 if !ok {
454 panic("Couldn't get the caller information")
455 }
456 functionPath := runtime.FuncForPC(pc).Name()
457
458
459
460
461 if gccgoRE.MatchString(functionPath) {
462 functionPath = gccgoRE.Split(functionPath, -1)[0]
463 }
464 parts := strings.Split(functionPath, ".")
465 functionName := parts[len(parts)-1]
466 return m.MethodCalled(functionName, arguments...)
467 }
468
469
470
471
472
473 func (m *Mock) MethodCalled(methodName string, arguments ...interface{}) Arguments {
474 m.mutex.Lock()
475
476 found, call := m.findExpectedCall(methodName, arguments...)
477
478 if found < 0 {
479
480 if call != nil {
481 m.mutex.Unlock()
482 m.fail("\nassert: mock: The method has been called over %d times.\n\tEither do one more Mock.On(\"%s\").Return(...), or remove extra call.\n\tThis call was unexpected:\n\t\t%s\n\tat: %s", call.totalCalls, methodName, callString(methodName, arguments, true), assert.CallerInfo())
483 }
484
485
486
487
488
489
490 closestCall, mismatch := m.findClosestCall(methodName, arguments...)
491 m.mutex.Unlock()
492
493 if closestCall != nil {
494 m.fail("\n\nmock: Unexpected Method Call\n-----------------------------\n\n%s\n\nThe closest call I have is: \n\n%s\n\n%s\nDiff: %s",
495 callString(methodName, arguments, true),
496 callString(methodName, closestCall.Arguments, true),
497 diffArguments(closestCall.Arguments, arguments),
498 strings.TrimSpace(mismatch),
499 )
500 } else {
501 m.fail("\nassert: mock: I don't know what to return because the method call was unexpected.\n\tEither do Mock.On(\"%s\").Return(...) first, or remove the %s() call.\n\tThis method was unexpected:\n\t\t%s\n\tat: %s", methodName, methodName, callString(methodName, arguments, true), assert.CallerInfo())
502 }
503 }
504
505 for _, requirement := range call.requires {
506 if satisfied, _ := requirement.Parent.checkExpectation(requirement); !satisfied {
507 m.mutex.Unlock()
508 m.fail("mock: Unexpected Method Call\n-----------------------------\n\n%s\n\nMust not be called before%s:\n\n%s",
509 callString(call.Method, call.Arguments, true),
510 func() (s string) {
511 if requirement.totalCalls > 0 {
512 s = " another call of"
513 }
514 if call.Parent != requirement.Parent {
515 s += " method from another mock instance"
516 }
517 return
518 }(),
519 callString(requirement.Method, requirement.Arguments, true),
520 )
521 }
522 }
523
524 if call.Repeatability == 1 {
525 call.Repeatability = -1
526 } else if call.Repeatability > 1 {
527 call.Repeatability--
528 }
529 call.totalCalls++
530
531
532 m.Calls = append(m.Calls, *newCall(m, methodName, assert.CallerInfo(), arguments...))
533 m.mutex.Unlock()
534
535
536 if call.WaitFor != nil {
537 <-call.WaitFor
538 } else {
539 time.Sleep(call.waitTime)
540 }
541
542 m.mutex.Lock()
543 panicMsg := call.PanicMsg
544 m.mutex.Unlock()
545 if panicMsg != nil {
546 panic(*panicMsg)
547 }
548
549 m.mutex.Lock()
550 runFn := call.RunFn
551 m.mutex.Unlock()
552
553 if runFn != nil {
554 runFn(arguments)
555 }
556
557 m.mutex.Lock()
558 returnArgs := call.ReturnArguments
559 m.mutex.Unlock()
560
561 return returnArgs
562 }
563
564
567
568 type assertExpectationiser interface {
569 AssertExpectations(TestingT) bool
570 }
571
572
573
574
575
576 func AssertExpectationsForObjects(t TestingT, testObjects ...interface{}) bool {
577 if h, ok := t.(tHelper); ok {
578 h.Helper()
579 }
580 for _, obj := range testObjects {
581 if m, ok := obj.(*Mock); ok {
582 t.Logf("Deprecated mock.AssertExpectationsForObjects(myMock.Mock) use mock.AssertExpectationsForObjects(myMock)")
583 obj = m
584 }
585 m := obj.(assertExpectationiser)
586 if !m.AssertExpectations(t) {
587 t.Logf("Expectations didn't match for Mock: %+v", reflect.TypeOf(m))
588 return false
589 }
590 }
591 return true
592 }
593
594
595
596 func (m *Mock) AssertExpectations(t TestingT) bool {
597 if s, ok := t.(interface{ Skipped() bool }); ok && s.Skipped() {
598 return true
599 }
600 if h, ok := t.(tHelper); ok {
601 h.Helper()
602 }
603
604 m.mutex.Lock()
605 defer m.mutex.Unlock()
606 var failedExpectations int
607
608
609 expectedCalls := m.expectedCalls()
610 for _, expectedCall := range expectedCalls {
611 satisfied, reason := m.checkExpectation(expectedCall)
612 if !satisfied {
613 failedExpectations++
614 t.Logf(reason)
615 }
616 }
617
618 if failedExpectations != 0 {
619 t.Errorf("FAIL: %d out of %d expectation(s) were met.\n\tThe code you are testing needs to make %d more call(s).\n\tat: %s", len(expectedCalls)-failedExpectations, len(expectedCalls), failedExpectations, assert.CallerInfo())
620 }
621
622 return failedExpectations == 0
623 }
624
625 func (m *Mock) checkExpectation(call *Call) (bool, string) {
626 if !call.optional && !m.methodWasCalled(call.Method, call.Arguments) && call.totalCalls == 0 {
627 return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
628 }
629 if call.Repeatability > 0 {
630 return false, fmt.Sprintf("FAIL:\t%s(%s)\n\t\tat: %s", call.Method, call.Arguments.String(), call.callerInfo)
631 }
632 return true, fmt.Sprintf("PASS:\t%s(%s)", call.Method, call.Arguments.String())
633 }
634
635
636 func (m *Mock) AssertNumberOfCalls(t TestingT, methodName string, expectedCalls int) bool {
637 if h, ok := t.(tHelper); ok {
638 h.Helper()
639 }
640 m.mutex.Lock()
641 defer m.mutex.Unlock()
642 var actualCalls int
643 for _, call := range m.calls() {
644 if call.Method == methodName {
645 actualCalls++
646 }
647 }
648 return assert.Equal(t, expectedCalls, actualCalls, fmt.Sprintf("Expected number of calls (%d) does not match the actual number of calls (%d).", expectedCalls, actualCalls))
649 }
650
651
652
653 func (m *Mock) AssertCalled(t TestingT, methodName string, arguments ...interface{}) bool {
654 if h, ok := t.(tHelper); ok {
655 h.Helper()
656 }
657 m.mutex.Lock()
658 defer m.mutex.Unlock()
659 if !m.methodWasCalled(methodName, arguments) {
660 var calledWithArgs []string
661 for _, call := range m.calls() {
662 calledWithArgs = append(calledWithArgs, fmt.Sprintf("%v", call.Arguments))
663 }
664 if len(calledWithArgs) == 0 {
665 return assert.Fail(t, "Should have called with given arguments",
666 fmt.Sprintf("Expected %q to have been called with:\n%v\nbut no actual calls happened", methodName, arguments))
667 }
668 return assert.Fail(t, "Should have called with given arguments",
669 fmt.Sprintf("Expected %q to have been called with:\n%v\nbut actual calls were:\n %v", methodName, arguments, strings.Join(calledWithArgs, "\n")))
670 }
671 return true
672 }
673
674
675
676 func (m *Mock) AssertNotCalled(t TestingT, methodName string, arguments ...interface{}) bool {
677 if h, ok := t.(tHelper); ok {
678 h.Helper()
679 }
680 m.mutex.Lock()
681 defer m.mutex.Unlock()
682 if m.methodWasCalled(methodName, arguments) {
683 return assert.Fail(t, "Should not have called with given arguments",
684 fmt.Sprintf("Expected %q to not have been called with:\n%v\nbut actually it was.", methodName, arguments))
685 }
686 return true
687 }
688
689
690
691 func (m *Mock) IsMethodCallable(t TestingT, methodName string, arguments ...interface{}) bool {
692 if h, ok := t.(tHelper); ok {
693 h.Helper()
694 }
695 m.mutex.Lock()
696 defer m.mutex.Unlock()
697
698 for _, v := range m.ExpectedCalls {
699 if v.Method != methodName {
700 continue
701 }
702 if len(arguments) != len(v.Arguments) {
703 continue
704 }
705 if v.Repeatability < v.totalCalls {
706 continue
707 }
708 if isArgsEqual(v.Arguments, arguments) {
709 return true
710 }
711 }
712 return false
713 }
714
715
716 func isArgsEqual(expected Arguments, args []interface{}) bool {
717 if len(expected) != len(args) {
718 return false
719 }
720 for i, v := range args {
721 if !reflect.DeepEqual(expected[i], v) {
722 return false
723 }
724 }
725 return true
726 }
727
728 func (m *Mock) methodWasCalled(methodName string, expected []interface{}) bool {
729 for _, call := range m.calls() {
730 if call.Method == methodName {
731
732 _, differences := Arguments(expected).Diff(call.Arguments)
733
734 if differences == 0 {
735
736 return true
737 }
738
739 }
740 }
741
742 return false
743 }
744
745 func (m *Mock) expectedCalls() []*Call {
746 return append([]*Call{}, m.ExpectedCalls...)
747 }
748
749 func (m *Mock) calls() []Call {
750 return append([]Call{}, m.Calls...)
751 }
752
753
756
757
758 type Arguments []interface{}
759
760 const (
761
762
763 Anything = "mock.Anything"
764 )
765
766
767
768
769
770 type AnythingOfTypeArgument = anythingOfTypeArgument
771
772
773
774 type anythingOfTypeArgument string
775
776
777
778
779
780
781
782
783
784 func AnythingOfType(t string) AnythingOfTypeArgument {
785 return anythingOfTypeArgument(t)
786 }
787
788
789
790
791 type IsTypeArgument struct {
792 t reflect.Type
793 }
794
795
796
797
798
799
800
801 func IsType(t interface{}) *IsTypeArgument {
802 return &IsTypeArgument{t: reflect.TypeOf(t)}
803 }
804
805
806
807 type FunctionalOptionsArgument struct {
808 value interface{}
809 }
810
811
812 func (f *FunctionalOptionsArgument) String() string {
813 var name string
814 tValue := reflect.ValueOf(f.value)
815 if tValue.Len() > 0 {
816 name = "[]" + reflect.TypeOf(tValue.Index(0).Interface()).String()
817 }
818
819 return strings.Replace(fmt.Sprintf("%#v", f.value), "[]interface {}", name, 1)
820 }
821
822
823
824
825
826
827 func FunctionalOptions(value ...interface{}) *FunctionalOptionsArgument {
828 return &FunctionalOptionsArgument{
829 value: value,
830 }
831 }
832
833
834
835 type argumentMatcher struct {
836
837 fn reflect.Value
838 }
839
840 func (f argumentMatcher) Matches(argument interface{}) bool {
841 expectType := f.fn.Type().In(0)
842 expectTypeNilSupported := false
843 switch expectType.Kind() {
844 case reflect.Interface, reflect.Chan, reflect.Func, reflect.Map, reflect.Slice, reflect.Ptr:
845 expectTypeNilSupported = true
846 }
847
848 argType := reflect.TypeOf(argument)
849 var arg reflect.Value
850 if argType == nil {
851 arg = reflect.New(expectType).Elem()
852 } else {
853 arg = reflect.ValueOf(argument)
854 }
855
856 if argType == nil && !expectTypeNilSupported {
857 panic(errors.New("attempting to call matcher with nil for non-nil expected type"))
858 }
859 if argType == nil || argType.AssignableTo(expectType) {
860 result := f.fn.Call([]reflect.Value{arg})
861 return result[0].Bool()
862 }
863 return false
864 }
865
866 func (f argumentMatcher) String() string {
867 return fmt.Sprintf("func(%s) bool", f.fn.Type().In(0).String())
868 }
869
870
871
872
873
874
875
876
877
878
879
880
881 func MatchedBy(fn interface{}) argumentMatcher {
882 fnType := reflect.TypeOf(fn)
883
884 if fnType.Kind() != reflect.Func {
885 panic(fmt.Sprintf("assert: arguments: %s is not a func", fn))
886 }
887 if fnType.NumIn() != 1 {
888 panic(fmt.Sprintf("assert: arguments: %s does not take exactly one argument", fn))
889 }
890 if fnType.NumOut() != 1 || fnType.Out(0).Kind() != reflect.Bool {
891 panic(fmt.Sprintf("assert: arguments: %s does not return a bool", fn))
892 }
893
894 return argumentMatcher{fn: reflect.ValueOf(fn)}
895 }
896
897
898 func (args Arguments) Get(index int) interface{} {
899 if index+1 > len(args) {
900 panic(fmt.Sprintf("assert: arguments: Cannot call Get(%d) because there are %d argument(s).", index, len(args)))
901 }
902 return args[index]
903 }
904
905
906 func (args Arguments) Is(objects ...interface{}) bool {
907 for i, obj := range args {
908 if obj != objects[i] {
909 return false
910 }
911 }
912 return true
913 }
914
915
916
917
918
919 func (args Arguments) Diff(objects []interface{}) (string, int) {
920
921
922 output := "\n"
923 var differences int
924
925 maxArgCount := len(args)
926 if len(objects) > maxArgCount {
927 maxArgCount = len(objects)
928 }
929
930 for i := 0; i < maxArgCount; i++ {
931 var actual, expected interface{}
932 var actualFmt, expectedFmt string
933
934 if len(objects) <= i {
935 actual = "(Missing)"
936 actualFmt = "(Missing)"
937 } else {
938 actual = objects[i]
939 actualFmt = fmt.Sprintf("(%[1]T=%[1]v)", actual)
940 }
941
942 if len(args) <= i {
943 expected = "(Missing)"
944 expectedFmt = "(Missing)"
945 } else {
946 expected = args[i]
947 expectedFmt = fmt.Sprintf("(%[1]T=%[1]v)", expected)
948 }
949
950 if matcher, ok := expected.(argumentMatcher); ok {
951 var matches bool
952 func() {
953 defer func() {
954 if r := recover(); r != nil {
955 actualFmt = fmt.Sprintf("panic in argument matcher: %v", r)
956 }
957 }()
958 matches = matcher.Matches(actual)
959 }()
960 if matches {
961 output = fmt.Sprintf("%s\t%d: PASS: %s matched by %s\n", output, i, actualFmt, matcher)
962 } else {
963 differences++
964 output = fmt.Sprintf("%s\t%d: FAIL: %s not matched by %s\n", output, i, actualFmt, matcher)
965 }
966 } else {
967 switch expected := expected.(type) {
968 case anythingOfTypeArgument:
969
970 if reflect.TypeOf(actual).Name() != string(expected) && reflect.TypeOf(actual).String() != string(expected) {
971
972 differences++
973 output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected, reflect.TypeOf(actual).Name(), actualFmt)
974 }
975 case *IsTypeArgument:
976 actualT := reflect.TypeOf(actual)
977 if actualT != expected.t {
978 differences++
979 output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, expected.t.Name(), actualT.Name(), actualFmt)
980 }
981 case *FunctionalOptionsArgument:
982 t := expected.value
983
984 var name string
985 tValue := reflect.ValueOf(t)
986 if tValue.Len() > 0 {
987 name = "[]" + reflect.TypeOf(tValue.Index(0).Interface()).String()
988 }
989
990 tName := reflect.TypeOf(t).Name()
991 if name != reflect.TypeOf(actual).String() && tValue.Len() != 0 {
992 differences++
993 output = fmt.Sprintf("%s\t%d: FAIL: type %s != type %s - %s\n", output, i, tName, reflect.TypeOf(actual).Name(), actualFmt)
994 } else {
995 if ef, af := assertOpts(t, actual); ef == "" && af == "" {
996
997 output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, tName, tName)
998 } else {
999
1000 differences++
1001 output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, af, ef)
1002 }
1003 }
1004
1005 default:
1006 if assert.ObjectsAreEqual(expected, Anything) || assert.ObjectsAreEqual(actual, Anything) || assert.ObjectsAreEqual(actual, expected) {
1007
1008 output = fmt.Sprintf("%s\t%d: PASS: %s == %s\n", output, i, actualFmt, expectedFmt)
1009 } else {
1010
1011 differences++
1012 output = fmt.Sprintf("%s\t%d: FAIL: %s != %s\n", output, i, actualFmt, expectedFmt)
1013 }
1014 }
1015 }
1016
1017 }
1018
1019 if differences == 0 {
1020 return "No differences.", differences
1021 }
1022
1023 return output, differences
1024 }
1025
1026
1027
1028 func (args Arguments) Assert(t TestingT, objects ...interface{}) bool {
1029 if h, ok := t.(tHelper); ok {
1030 h.Helper()
1031 }
1032
1033
1034 diff, diffCount := args.Diff(objects)
1035
1036 if diffCount == 0 {
1037 return true
1038 }
1039
1040
1041 t.Logf(diff)
1042 t.Errorf("%sArguments do not match.", assert.CallerInfo())
1043
1044 return false
1045 }
1046
1047
1048
1049
1050
1051
1052 func (args Arguments) String(indexOrNil ...int) string {
1053 if len(indexOrNil) == 0 {
1054
1055 var argsStr []string
1056 for _, arg := range args {
1057 argsStr = append(argsStr, fmt.Sprintf("%T", arg))
1058 }
1059 return strings.Join(argsStr, ",")
1060 } else if len(indexOrNil) == 1 {
1061
1062 index := indexOrNil[0]
1063 var s string
1064 var ok bool
1065 if s, ok = args.Get(index).(string); !ok {
1066 panic(fmt.Sprintf("assert: arguments: String(%d) failed because object wasn't correct type: %s", index, args.Get(index)))
1067 }
1068 return s
1069 }
1070
1071 panic(fmt.Sprintf("assert: arguments: Wrong number of arguments passed to String. Must be 0 or 1, not %d", len(indexOrNil)))
1072 }
1073
1074
1075
1076 func (args Arguments) Int(index int) int {
1077 var s int
1078 var ok bool
1079 if s, ok = args.Get(index).(int); !ok {
1080 panic(fmt.Sprintf("assert: arguments: Int(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
1081 }
1082 return s
1083 }
1084
1085
1086
1087 func (args Arguments) Error(index int) error {
1088 obj := args.Get(index)
1089 var s error
1090 var ok bool
1091 if obj == nil {
1092 return nil
1093 }
1094 if s, ok = obj.(error); !ok {
1095 panic(fmt.Sprintf("assert: arguments: Error(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
1096 }
1097 return s
1098 }
1099
1100
1101
1102 func (args Arguments) Bool(index int) bool {
1103 var s bool
1104 var ok bool
1105 if s, ok = args.Get(index).(bool); !ok {
1106 panic(fmt.Sprintf("assert: arguments: Bool(%d) failed because object wasn't correct type: %v", index, args.Get(index)))
1107 }
1108 return s
1109 }
1110
1111 func typeAndKind(v interface{}) (reflect.Type, reflect.Kind) {
1112 t := reflect.TypeOf(v)
1113 k := t.Kind()
1114
1115 if k == reflect.Ptr {
1116 t = t.Elem()
1117 k = t.Kind()
1118 }
1119 return t, k
1120 }
1121
1122 func diffArguments(expected Arguments, actual Arguments) string {
1123 if len(expected) != len(actual) {
1124 return fmt.Sprintf("Provided %v arguments, mocked for %v arguments", len(expected), len(actual))
1125 }
1126
1127 for x := range expected {
1128 if diffString := diff(expected[x], actual[x]); diffString != "" {
1129 return fmt.Sprintf("Difference found in argument %v:\n\n%s", x, diffString)
1130 }
1131 }
1132
1133 return ""
1134 }
1135
1136
1137
1138 func diff(expected interface{}, actual interface{}) string {
1139 if expected == nil || actual == nil {
1140 return ""
1141 }
1142
1143 et, ek := typeAndKind(expected)
1144 at, _ := typeAndKind(actual)
1145
1146 if et != at {
1147 return ""
1148 }
1149
1150 if ek != reflect.Struct && ek != reflect.Map && ek != reflect.Slice && ek != reflect.Array {
1151 return ""
1152 }
1153
1154 e := spewConfig.Sdump(expected)
1155 a := spewConfig.Sdump(actual)
1156
1157 diff, _ := difflib.GetUnifiedDiffString(difflib.UnifiedDiff{
1158 A: difflib.SplitLines(e),
1159 B: difflib.SplitLines(a),
1160 FromFile: "Expected",
1161 FromDate: "",
1162 ToFile: "Actual",
1163 ToDate: "",
1164 Context: 1,
1165 })
1166
1167 return diff
1168 }
1169
1170 var spewConfig = spew.ConfigState{
1171 Indent: " ",
1172 DisablePointerAddresses: true,
1173 DisableCapacities: true,
1174 SortKeys: true,
1175 }
1176
1177 type tHelper interface {
1178 Helper()
1179 }
1180
1181 func assertOpts(expected, actual interface{}) (expectedFmt, actualFmt string) {
1182 expectedOpts := reflect.ValueOf(expected)
1183 actualOpts := reflect.ValueOf(actual)
1184 var expectedNames []string
1185 for i := 0; i < expectedOpts.Len(); i++ {
1186 expectedNames = append(expectedNames, funcName(expectedOpts.Index(i).Interface()))
1187 }
1188 var actualNames []string
1189 for i := 0; i < actualOpts.Len(); i++ {
1190 actualNames = append(actualNames, funcName(actualOpts.Index(i).Interface()))
1191 }
1192 if !assert.ObjectsAreEqual(expectedNames, actualNames) {
1193 expectedFmt = fmt.Sprintf("%v", expectedNames)
1194 actualFmt = fmt.Sprintf("%v", actualNames)
1195 return
1196 }
1197
1198 for i := 0; i < expectedOpts.Len(); i++ {
1199 expectedOpt := expectedOpts.Index(i).Interface()
1200 actualOpt := actualOpts.Index(i).Interface()
1201
1202 expectedFunc := expectedNames[i]
1203 actualFunc := actualNames[i]
1204 if expectedFunc != actualFunc {
1205 expectedFmt = expectedFunc
1206 actualFmt = actualFunc
1207 return
1208 }
1209
1210 ot := reflect.TypeOf(expectedOpt)
1211 var expectedValues []reflect.Value
1212 var actualValues []reflect.Value
1213 if ot.NumIn() == 0 {
1214 return
1215 }
1216
1217 for i := 0; i < ot.NumIn(); i++ {
1218 vt := ot.In(i).Elem()
1219 expectedValues = append(expectedValues, reflect.New(vt))
1220 actualValues = append(actualValues, reflect.New(vt))
1221 }
1222
1223 reflect.ValueOf(expectedOpt).Call(expectedValues)
1224 reflect.ValueOf(actualOpt).Call(actualValues)
1225
1226 for i := 0; i < ot.NumIn(); i++ {
1227 if !assert.ObjectsAreEqual(expectedValues[i].Interface(), actualValues[i].Interface()) {
1228 expectedFmt = fmt.Sprintf("%s %+v", expectedNames[i], expectedValues[i].Interface())
1229 actualFmt = fmt.Sprintf("%s %+v", expectedNames[i], actualValues[i].Interface())
1230 return
1231 }
1232 }
1233 }
1234
1235 return "", ""
1236 }
1237
1238 func funcName(opt interface{}) string {
1239 n := runtime.FuncForPC(reflect.ValueOf(opt).Pointer()).Name()
1240 return strings.TrimSuffix(path.Base(n), path.Ext(n))
1241 }
1242
View as plain text