...
1 package assert
2
3 import (
4 "errors"
5 "fmt"
6 "go/ast"
7
8 "gotest.tools/v3/assert/cmp"
9 "gotest.tools/v3/internal/format"
10 "gotest.tools/v3/internal/source"
11 )
12
13
14
15 func RunComparison(
16 t LogT,
17 argSelector argSelector,
18 f cmp.Comparison,
19 msgAndArgs ...interface{},
20 ) bool {
21 if ht, ok := t.(helperT); ok {
22 ht.Helper()
23 }
24 result := f()
25 if result.Success() {
26 return true
27 }
28
29 if source.IsUpdate() {
30 if updater, ok := result.(updateExpected); ok {
31 const stackIndex = 3
32 err := updater.UpdatedExpected(stackIndex)
33 switch {
34 case err == nil:
35 return true
36 case errors.Is(err, source.ErrNotFound):
37
38 default:
39 t.Log("failed to update source", err)
40 return false
41 }
42 }
43 }
44
45 var message string
46 switch typed := result.(type) {
47 case resultWithComparisonArgs:
48 const stackIndex = 3
49 args, err := source.CallExprArgs(stackIndex)
50 if err != nil {
51 t.Log(err.Error())
52 }
53 message = typed.FailureMessage(filterPrintableExpr(argSelector(args)))
54 case resultBasic:
55 message = typed.FailureMessage()
56 default:
57 message = fmt.Sprintf("comparison returned invalid Result type: %T", result)
58 }
59
60 t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
61 return false
62 }
63
64 type resultWithComparisonArgs interface {
65 FailureMessage(args []ast.Expr) string
66 }
67
68 type resultBasic interface {
69 FailureMessage() string
70 }
71
72 type updateExpected interface {
73 UpdatedExpected(stackIndex int) error
74 }
75
76
77
78
79
80
81
82
83
84 func filterPrintableExpr(args []ast.Expr) []ast.Expr {
85 result := make([]ast.Expr, len(args))
86 for i, arg := range args {
87 if isShortPrintableExpr(arg) {
88 result[i] = arg
89 continue
90 }
91
92 if starExpr, ok := arg.(*ast.StarExpr); ok {
93 result[i] = starExpr.X
94 continue
95 }
96 }
97 return result
98 }
99
100 func isShortPrintableExpr(expr ast.Expr) bool {
101 switch expr.(type) {
102 case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr:
103 return true
104 case *ast.BinaryExpr, *ast.UnaryExpr:
105 return true
106 default:
107
108 return false
109 }
110 }
111
112 type argSelector func([]ast.Expr) []ast.Expr
113
114
115
116 func ArgsAfterT(args []ast.Expr) []ast.Expr {
117 if len(args) < 1 {
118 return nil
119 }
120 return args[1:]
121 }
122
123
124
125
126 func ArgsFromComparisonCall(args []ast.Expr) []ast.Expr {
127 if len(args) <= 1 {
128 return nil
129 }
130 if callExpr, ok := args[1].(*ast.CallExpr); ok {
131 return callExpr.Args
132 }
133 return nil
134 }
135
136
137
138 func ArgsAtZeroIndex(args []ast.Expr) []ast.Expr {
139 if len(args) == 0 {
140 return nil
141 }
142 if callExpr, ok := args[0].(*ast.CallExpr); ok {
143 return callExpr.Args
144 }
145 return nil
146 }
147
View as plain text