1
2 package cmp
3
4 import (
5 "errors"
6 "fmt"
7 "reflect"
8 "regexp"
9 "strings"
10
11 "github.com/google/go-cmp/cmp"
12 "gotest.tools/v3/internal/format"
13 )
14
15
16
17
18 type Comparison func() Result
19
20
21
22
23
24
25
26 func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
27 return func() (result Result) {
28 defer func() {
29 if panicmsg, handled := handleCmpPanic(recover()); handled {
30 result = ResultFailure(panicmsg)
31 }
32 }()
33 diff := cmp.Diff(x, y, opts...)
34 if diff == "" {
35 return ResultSuccess
36 }
37 return multiLineDiffResult(diff, x, y)
38 }
39 }
40
41 func handleCmpPanic(r interface{}) (string, bool) {
42 if r == nil {
43 return "", false
44 }
45 panicmsg, ok := r.(string)
46 if !ok {
47 panic(r)
48 }
49 switch {
50 case strings.HasPrefix(panicmsg, "cannot handle unexported field"):
51 return panicmsg, true
52 }
53 panic(r)
54 }
55
56 func toResult(success bool, msg string) Result {
57 if success {
58 return ResultSuccess
59 }
60 return ResultFailure(msg)
61 }
62
63
64
65 type RegexOrPattern interface{}
66
67
68
69
70
71
72
73
74 func Regexp(re RegexOrPattern, v string) Comparison {
75 match := func(re *regexp.Regexp) Result {
76 return toResult(
77 re.MatchString(v),
78 fmt.Sprintf("value %q does not match regexp %q", v, re.String()))
79 }
80
81 return func() Result {
82 switch regex := re.(type) {
83 case *regexp.Regexp:
84 return match(regex)
85 case string:
86 re, err := regexp.Compile(regex)
87 if err != nil {
88 return ResultFailure(err.Error())
89 }
90 return match(re)
91 default:
92 return ResultFailure(fmt.Sprintf("invalid type %T for regex pattern", regex))
93 }
94 }
95 }
96
97
98 func Equal(x, y interface{}) Comparison {
99 return func() Result {
100 switch {
101 case x == y:
102 return ResultSuccess
103 case isMultiLineStringCompare(x, y):
104 diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
105 return multiLineDiffResult(diff, x, y)
106 }
107 return ResultFailureTemplate(`
108 {{- printf "%v" .Data.x}} (
109 {{- with callArg 0 }}{{ formatNode . }} {{end -}}
110 {{- printf "%T" .Data.x -}}
111 ) != {{ printf "%v" .Data.y}} (
112 {{- with callArg 1 }}{{ formatNode . }} {{end -}}
113 {{- printf "%T" .Data.y -}}
114 )`,
115 map[string]interface{}{"x": x, "y": y})
116 }
117 }
118
119 func isMultiLineStringCompare(x, y interface{}) bool {
120 strX, ok := x.(string)
121 if !ok {
122 return false
123 }
124 strY, ok := y.(string)
125 if !ok {
126 return false
127 }
128 return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
129 }
130
131 func multiLineDiffResult(diff string, x, y interface{}) Result {
132 return ResultFailureTemplate(`
133 --- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
134 +++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
135 {{ .Data.diff }}`,
136 map[string]interface{}{"diff": diff, "x": x, "y": y})
137 }
138
139
140 func Len(seq interface{}, expected int) Comparison {
141 return func() (result Result) {
142 defer func() {
143 if e := recover(); e != nil {
144 result = ResultFailure(fmt.Sprintf("type %T does not have a length", seq))
145 }
146 }()
147 value := reflect.ValueOf(seq)
148 length := value.Len()
149 if length == expected {
150 return ResultSuccess
151 }
152 msg := fmt.Sprintf("expected %s (length %d) to have length %d", seq, length, expected)
153 return ResultFailure(msg)
154 }
155 }
156
157
158
159
160
161
162
163
164
165 func Contains(collection interface{}, item interface{}) Comparison {
166 return func() Result {
167 colValue := reflect.ValueOf(collection)
168 if !colValue.IsValid() {
169 return ResultFailure("nil does not contain items")
170 }
171 msg := fmt.Sprintf("%v does not contain %v", collection, item)
172
173 itemValue := reflect.ValueOf(item)
174 switch colValue.Type().Kind() {
175 case reflect.String:
176 if itemValue.Type().Kind() != reflect.String {
177 return ResultFailure("string may only contain strings")
178 }
179 return toResult(
180 strings.Contains(colValue.String(), itemValue.String()),
181 fmt.Sprintf("string %q does not contain %q", collection, item))
182
183 case reflect.Map:
184 if itemValue.Type() != colValue.Type().Key() {
185 return ResultFailure(fmt.Sprintf(
186 "%v can not contain a %v key", colValue.Type(), itemValue.Type()))
187 }
188 return toResult(colValue.MapIndex(itemValue).IsValid(), msg)
189
190 case reflect.Slice, reflect.Array:
191 for i := 0; i < colValue.Len(); i++ {
192 if reflect.DeepEqual(colValue.Index(i).Interface(), item) {
193 return ResultSuccess
194 }
195 }
196 return ResultFailure(msg)
197 default:
198 return ResultFailure(fmt.Sprintf("type %T does not contain items", collection))
199 }
200 }
201 }
202
203
204 func Panics(f func()) Comparison {
205 return func() (result Result) {
206 defer func() {
207 if err := recover(); err != nil {
208 result = ResultSuccess
209 }
210 }()
211 f()
212 return ResultFailure("did not panic")
213 }
214 }
215
216
217
218 func Error(err error, message string) Comparison {
219 return func() Result {
220 switch {
221 case err == nil:
222 return ResultFailure("expected an error, got nil")
223 case err.Error() != message:
224 return ResultFailure(fmt.Sprintf(
225 "expected error %q, got %s", message, formatErrorMessage(err)))
226 }
227 return ResultSuccess
228 }
229 }
230
231
232
233 func ErrorContains(err error, substring string) Comparison {
234 return func() Result {
235 switch {
236 case err == nil:
237 return ResultFailure("expected an error, got nil")
238 case !strings.Contains(err.Error(), substring):
239 return ResultFailure(fmt.Sprintf(
240 "expected error to contain %q, got %s", substring, formatErrorMessage(err)))
241 }
242 return ResultSuccess
243 }
244 }
245
246 type causer interface {
247 Cause() error
248 }
249
250 func formatErrorMessage(err error) string {
251
252 if _, ok := err.(causer); ok {
253 return fmt.Sprintf("%q\n%+v", err, err)
254 }
255
256 return fmt.Sprintf("%q", err)
257 }
258
259
260
261
262
263 func Nil(obj interface{}) Comparison {
264 msgFunc := func(value reflect.Value) string {
265 return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type())
266 }
267 return isNil(obj, msgFunc)
268 }
269
270 func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
271 return func() Result {
272 if obj == nil {
273 return ResultSuccess
274 }
275 value := reflect.ValueOf(obj)
276 kind := value.Type().Kind()
277 if kind >= reflect.Chan && kind <= reflect.Slice {
278 if value.IsNil() {
279 return ResultSuccess
280 }
281 return ResultFailure(msgFunc(value))
282 }
283
284 return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", value, value.Type()))
285 }
286 }
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310 func ErrorType(err error, expected interface{}) Comparison {
311 return func() Result {
312 switch expectedType := expected.(type) {
313 case func(error) bool:
314 return cmpErrorTypeFunc(err, expectedType)
315 case reflect.Type:
316 if expectedType.Kind() == reflect.Interface {
317 return cmpErrorTypeImplementsType(err, expectedType)
318 }
319 return cmpErrorTypeEqualType(err, expectedType)
320 case nil:
321 return ResultFailure("invalid type for expected: nil")
322 }
323
324 expectedType := reflect.TypeOf(expected)
325 switch {
326 case expectedType.Kind() == reflect.Struct, isPtrToStruct(expectedType):
327 return cmpErrorTypeEqualType(err, expectedType)
328 case isPtrToInterface(expectedType):
329 return cmpErrorTypeImplementsType(err, expectedType.Elem())
330 }
331 return ResultFailure(fmt.Sprintf("invalid type for expected: %T", expected))
332 }
333 }
334
335 func cmpErrorTypeFunc(err error, f func(error) bool) Result {
336 if f(err) {
337 return ResultSuccess
338 }
339 actual := "nil"
340 if err != nil {
341 actual = fmt.Sprintf("%s (%T)", err, err)
342 }
343 return ResultFailureTemplate(`error is {{ .Data.actual }}
344 {{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`,
345 map[string]interface{}{"actual": actual})
346 }
347
348 func cmpErrorTypeEqualType(err error, expectedType reflect.Type) Result {
349 if err == nil {
350 return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
351 }
352 errValue := reflect.ValueOf(err)
353 if errValue.Type() == expectedType {
354 return ResultSuccess
355 }
356 return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
357 }
358
359 func cmpErrorTypeImplementsType(err error, expectedType reflect.Type) Result {
360 if err == nil {
361 return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
362 }
363 errValue := reflect.ValueOf(err)
364 if errValue.Type().Implements(expectedType) {
365 return ResultSuccess
366 }
367 return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
368 }
369
370 func isPtrToInterface(typ reflect.Type) bool {
371 return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface
372 }
373
374 func isPtrToStruct(typ reflect.Type) bool {
375 return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct
376 }
377
378 var (
379 stdlibErrorNewType = reflect.TypeOf(errors.New(""))
380 stdlibFmtErrorType = reflect.TypeOf(fmt.Errorf("%w", fmt.Errorf("")))
381 )
382
383
384
385 func ErrorIs(actual error, expected error) Comparison {
386 return func() Result {
387 if errors.Is(actual, expected) {
388 return ResultSuccess
389 }
390
391
392
393
394 return ResultFailureTemplate(`error is
395 {{- if not .Data.a }} nil,{{ else }}
396 {{- printf " \"%v\"" .Data.a }}
397 {{- if notStdlibErrorType .Data.a }} ({{ printf "%T" .Data.a }}){{ end }},
398 {{- end }} not {{ printf "\"%v\"" .Data.x }} (
399 {{- with callArg 1 }}{{ formatNode . }}{{ end }}
400 {{- if notStdlibErrorType .Data.x }}{{ printf " %T" .Data.x }}{{ end }})`,
401 map[string]interface{}{"a": actual, "x": expected})
402 }
403 }
404
View as plain text