1
2
3 package matchers
4
5 import (
6 "errors"
7 "fmt"
8 "reflect"
9
10 "github.com/onsi/gomega/format"
11 )
12
13 type ContainElementMatcher struct {
14 Element interface{}
15 Result []interface{}
16 }
17
18 func (matcher *ContainElementMatcher) Match(actual interface{}) (success bool, err error) {
19 if !isArrayOrSlice(actual) && !isMap(actual) {
20 return false, fmt.Errorf("ContainElement matcher expects an array/slice/map. Got:\n%s", format.Object(actual, 1))
21 }
22
23 var actualT reflect.Type
24 var result reflect.Value
25 switch l := len(matcher.Result); {
26 case l > 1:
27 return false, errors.New("ContainElement matcher expects at most a single optional pointer to store its findings at")
28 case l == 1:
29 if reflect.ValueOf(matcher.Result[0]).Kind() != reflect.Ptr {
30 return false, fmt.Errorf("ContainElement matcher expects a non-nil pointer to store its findings at. Got\n%s",
31 format.Object(matcher.Result[0], 1))
32 }
33 actualT = reflect.TypeOf(actual)
34 resultReference := matcher.Result[0]
35 result = reflect.ValueOf(resultReference).Elem()
36 switch result.Kind() {
37 case reflect.Array:
38 return false, fmt.Errorf("ContainElement cannot return findings. Need *%s, got *%s",
39 reflect.SliceOf(actualT.Elem()).String(), result.Type().String())
40 case reflect.Slice:
41 if !isArrayOrSlice(actual) {
42 return false, fmt.Errorf("ContainElement cannot return findings. Need *%s, got *%s",
43 reflect.MapOf(actualT.Key(), actualT.Elem()).String(), result.Type().String())
44 }
45 if !actualT.Elem().AssignableTo(result.Type().Elem()) {
46 return false, fmt.Errorf("ContainElement cannot return findings. Need *%s, got *%s",
47 actualT.String(), result.Type().String())
48 }
49 case reflect.Map:
50 if !isMap(actual) {
51 return false, fmt.Errorf("ContainElement cannot return findings. Need *%s, got *%s",
52 actualT.String(), result.Type().String())
53 }
54 if !actualT.AssignableTo(result.Type()) {
55 return false, fmt.Errorf("ContainElement cannot return findings. Need *%s, got *%s",
56 actualT.String(), result.Type().String())
57 }
58 default:
59 if !actualT.Elem().AssignableTo(result.Type()) {
60 return false, fmt.Errorf("ContainElement cannot return findings. Need *%s, got *%s",
61 actualT.Elem().String(), result.Type().String())
62 }
63 }
64 }
65
66 elemMatcher, elementIsMatcher := matcher.Element.(omegaMatcher)
67 if !elementIsMatcher {
68 elemMatcher = &EqualMatcher{Expected: matcher.Element}
69 }
70
71 value := reflect.ValueOf(actual)
72 var valueAt func(int) interface{}
73
74 var getFindings func() reflect.Value
75 var foundAt func(int)
76
77 if isMap(actual) {
78 keys := value.MapKeys()
79 valueAt = func(i int) interface{} {
80 return value.MapIndex(keys[i]).Interface()
81 }
82 if result.Kind() != reflect.Invalid {
83 fm := reflect.MakeMap(actualT)
84 getFindings = func() reflect.Value {
85 return fm
86 }
87 foundAt = func(i int) {
88 fm.SetMapIndex(keys[i], value.MapIndex(keys[i]))
89 }
90 }
91 } else {
92 valueAt = func(i int) interface{} {
93 return value.Index(i).Interface()
94 }
95 if result.Kind() != reflect.Invalid {
96 var f reflect.Value
97 if result.Kind() == reflect.Slice {
98 f = reflect.MakeSlice(result.Type(), 0, 0)
99 } else {
100 f = reflect.MakeSlice(reflect.SliceOf(result.Type()), 0, 0)
101 }
102 getFindings = func() reflect.Value {
103 return f
104 }
105 foundAt = func(i int) {
106 f = reflect.Append(f, value.Index(i))
107 }
108 }
109 }
110
111 var lastError error
112 for i := 0; i < value.Len(); i++ {
113 elem := valueAt(i)
114 success, err := elemMatcher.Match(elem)
115 if err != nil {
116 lastError = err
117 continue
118 }
119 if success {
120 if result.Kind() == reflect.Invalid {
121 return true, nil
122 }
123 foundAt(i)
124 }
125 }
126
127
128
129
130 if result.Kind() == reflect.Invalid {
131 return false, lastError
132 }
133
134
135
136
137
138 findings := getFindings()
139 if findings.Len() == 0 {
140 return false, lastError
141 }
142
143
144
145
146 if findings.Len() == 1 && !isArrayOrSlice(result.Interface()) && !isMap(result.Interface()) {
147 if isMap(actual) {
148 miter := findings.MapRange()
149 miter.Next()
150 result.Set(miter.Value())
151 } else {
152 result.Set(findings.Index(0))
153 }
154 return true, nil
155 }
156
157
158
159
160 if !findings.Type().AssignableTo(result.Type()) {
161 return false, fmt.Errorf("ContainElement cannot return multiple findings. Need *%s, got *%s",
162 findings.Type().String(), result.Type().String())
163 }
164 result.Set(findings)
165 return true, nil
166 }
167
168 func (matcher *ContainElementMatcher) FailureMessage(actual interface{}) (message string) {
169 return format.Message(actual, "to contain element matching", matcher.Element)
170 }
171
172 func (matcher *ContainElementMatcher) NegatedFailureMessage(actual interface{}) (message string) {
173 return format.Message(actual, "not to contain element matching", matcher.Element)
174 }
175
View as plain text