...
1 package matchers
2
3 import (
4 "fmt"
5 "reflect"
6 "strings"
7
8 "github.com/onsi/gomega/format"
9 )
10
11
12
13
14 type missingFieldError string
15
16 func (e missingFieldError) Error() string {
17 return string(e)
18 }
19
20 func extractField(actual interface{}, field string, matchername string) (interface{}, error) {
21 fields := strings.SplitN(field, ".", 2)
22 actualValue := reflect.ValueOf(actual)
23
24 if actualValue.Kind() == reflect.Ptr {
25 actualValue = actualValue.Elem()
26 }
27 if actualValue == (reflect.Value{}) {
28 return nil, fmt.Errorf("%s encountered nil while dereferencing a pointer of type %T.", matchername, actual)
29 }
30
31 if actualValue.Kind() != reflect.Struct {
32 return nil, fmt.Errorf("%s encountered:\n%s\nWhich is not a struct.", matchername, format.Object(actual, 1))
33 }
34
35 var extractedValue reflect.Value
36
37 if strings.HasSuffix(fields[0], "()") {
38 extractedValue = actualValue.MethodByName(strings.TrimSuffix(fields[0], "()"))
39 if extractedValue == (reflect.Value{}) && actualValue.CanAddr() {
40 extractedValue = actualValue.Addr().MethodByName(strings.TrimSuffix(fields[0], "()"))
41 }
42 if extractedValue == (reflect.Value{}) {
43 return nil, missingFieldError(fmt.Sprintf("%s could not find method named '%s' in struct of type %T.", matchername, fields[0], actual))
44 }
45 t := extractedValue.Type()
46 if t.NumIn() != 0 || t.NumOut() != 1 {
47 return nil, fmt.Errorf("%s found an invalid method named '%s' in struct of type %T.\nMethods must take no arguments and return exactly one value.", matchername, fields[0], actual)
48 }
49 extractedValue = extractedValue.Call([]reflect.Value{})[0]
50 } else {
51 extractedValue = actualValue.FieldByName(fields[0])
52 if extractedValue == (reflect.Value{}) {
53 return nil, missingFieldError(fmt.Sprintf("%s could not find field named '%s' in struct:\n%s", matchername, fields[0], format.Object(actual, 1)))
54 }
55 }
56
57 if len(fields) == 1 {
58 return extractedValue.Interface(), nil
59 } else {
60 return extractField(extractedValue.Interface(), fields[1], matchername)
61 }
62 }
63
64 type HaveFieldMatcher struct {
65 Field string
66 Expected interface{}
67
68 extractedField interface{}
69 expectedMatcher omegaMatcher
70 }
71
72 func (matcher *HaveFieldMatcher) Match(actual interface{}) (success bool, err error) {
73 matcher.extractedField, err = extractField(actual, matcher.Field, "HaveField")
74 if err != nil {
75 return false, err
76 }
77
78 var isMatcher bool
79 matcher.expectedMatcher, isMatcher = matcher.Expected.(omegaMatcher)
80 if !isMatcher {
81 matcher.expectedMatcher = &EqualMatcher{Expected: matcher.Expected}
82 }
83
84 return matcher.expectedMatcher.Match(matcher.extractedField)
85 }
86
87 func (matcher *HaveFieldMatcher) FailureMessage(actual interface{}) (message string) {
88 message = fmt.Sprintf("Value for field '%s' failed to satisfy matcher.\n", matcher.Field)
89 message += matcher.expectedMatcher.FailureMessage(matcher.extractedField)
90
91 return message
92 }
93
94 func (matcher *HaveFieldMatcher) NegatedFailureMessage(actual interface{}) (message string) {
95 message = fmt.Sprintf("Value for field '%s' satisfied matcher, but should not have.\n", matcher.Field)
96 message += matcher.expectedMatcher.NegatedFailureMessage(matcher.extractedField)
97
98 return message
99 }
100
View as plain text