1
2 package assert
3
4 import (
5 "fmt"
6 "go/ast"
7 "go/token"
8 "reflect"
9
10 "gotest.tools/v3/assert/cmp"
11 "gotest.tools/v3/internal/format"
12 "gotest.tools/v3/internal/source"
13 )
14
15
16 type LogT interface {
17 Log(args ...interface{})
18 }
19
20 type helperT interface {
21 Helper()
22 }
23
24 const failureMessage = "assertion failed: "
25
26
27 func Eval(
28 t LogT,
29 argSelector argSelector,
30 comparison interface{},
31 msgAndArgs ...interface{},
32 ) bool {
33 if ht, ok := t.(helperT); ok {
34 ht.Helper()
35 }
36 var success bool
37 switch check := comparison.(type) {
38 case bool:
39 if check {
40 return true
41 }
42 logFailureFromBool(t, msgAndArgs...)
43
44
45 case func() (success bool, message string):
46 success = runCompareFunc(t, check, msgAndArgs...)
47
48 case nil:
49 return true
50
51 case error:
52 msg := failureMsgFromError(check)
53 t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
54
55 case cmp.Comparison:
56 success = RunComparison(t, argSelector, check, msgAndArgs...)
57
58 case func() cmp.Result:
59 success = RunComparison(t, argSelector, check, msgAndArgs...)
60
61 default:
62 t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
63 }
64 return success
65 }
66
67 func runCompareFunc(
68 t LogT,
69 f func() (success bool, message string),
70 msgAndArgs ...interface{},
71 ) bool {
72 if ht, ok := t.(helperT); ok {
73 ht.Helper()
74 }
75 if success, message := f(); !success {
76 t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
77 return false
78 }
79 return true
80 }
81
82 func logFailureFromBool(t LogT, msgAndArgs ...interface{}) {
83 if ht, ok := t.(helperT); ok {
84 ht.Helper()
85 }
86 const stackIndex = 3
87 args, err := source.CallExprArgs(stackIndex)
88 if err != nil {
89 t.Log(err.Error())
90 }
91
92 var msg string
93 const comparisonArgIndex = 1
94 if len(args) <= comparisonArgIndex {
95 msg = "but assert failed to find the expression to print"
96 } else {
97 msg, err = boolFailureMessage(args[comparisonArgIndex])
98 if err != nil {
99 t.Log(err.Error())
100 msg = "expression is false"
101 }
102 }
103
104 t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
105 }
106
107 func failureMsgFromError(err error) string {
108
109 v := reflect.ValueOf(err)
110 if v.Kind() == reflect.Ptr && v.IsNil() {
111 return fmt.Sprintf("error is not nil: error has type %T", err)
112 }
113 return "error is not nil: " + err.Error()
114 }
115
116 func boolFailureMessage(expr ast.Expr) (string, error) {
117 if binaryExpr, ok := expr.(*ast.BinaryExpr); ok {
118 x, err := source.FormatNode(binaryExpr.X)
119 if err != nil {
120 return "", err
121 }
122 y, err := source.FormatNode(binaryExpr.Y)
123 if err != nil {
124 return "", err
125 }
126
127 switch binaryExpr.Op {
128 case token.NEQ:
129 return x + " is " + y, nil
130 case token.EQL:
131 return x + " is not " + y, nil
132 case token.GTR:
133 return x + " is <= " + y, nil
134 case token.LSS:
135 return x + " is >= " + y, nil
136 case token.GEQ:
137 return x + " is less than " + y, nil
138 case token.LEQ:
139 return x + " is greater than " + y, nil
140 }
141 }
142
143 if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
144 x, err := source.FormatNode(unaryExpr.X)
145 if err != nil {
146 return "", err
147 }
148 return x + " is true", nil
149 }
150
151 if ident, ok := expr.(*ast.Ident); ok {
152 return ident.Name + " is false", nil
153 }
154
155 formatted, err := source.FormatNode(expr)
156 if err != nil {
157 return "", err
158 }
159 return "expression is false: " + formatted, nil
160 }
161
View as plain text