1
2
3
4
5
6
7 package reflect
8
9 import (
10 "fmt"
11 "reflect"
12 "strings"
13 )
14
15
16
17 type Equalities map[reflect.Type]reflect.Value
18
19
20 func EqualitiesOrDie(funcs ...interface{}) Equalities {
21 e := Equalities{}
22 if err := e.AddFuncs(funcs...); err != nil {
23 panic(err)
24 }
25 return e
26 }
27
28
29 func (e Equalities) AddFuncs(funcs ...interface{}) error {
30 for _, f := range funcs {
31 if err := e.AddFunc(f); err != nil {
32 return err
33 }
34 }
35 return nil
36 }
37
38
39
40 func (e Equalities) AddFunc(eqFunc interface{}) error {
41 fv := reflect.ValueOf(eqFunc)
42 ft := fv.Type()
43 if ft.Kind() != reflect.Func {
44 return fmt.Errorf("expected func, got: %v", ft)
45 }
46 if ft.NumIn() != 2 {
47 return fmt.Errorf("expected two 'in' params, got: %v", ft)
48 }
49 if ft.NumOut() != 1 {
50 return fmt.Errorf("expected one 'out' param, got: %v", ft)
51 }
52 if ft.In(0) != ft.In(1) {
53 return fmt.Errorf("expected arg 1 and 2 to have same type, but got %v", ft)
54 }
55 var forReturnType bool
56 boolType := reflect.TypeOf(forReturnType)
57 if ft.Out(0) != boolType {
58 return fmt.Errorf("expected bool return, got: %v", ft)
59 }
60 e[ft.In(0)] = fv
61 return nil
62 }
63
64
65
66
67
68
69
70 type visit struct {
71 a1 uintptr
72 a2 uintptr
73 typ reflect.Type
74 }
75
76
77
78
79 type unexportedTypePanic []reflect.Type
80
81 func (u unexportedTypePanic) Error() string { return u.String() }
82 func (u unexportedTypePanic) String() string {
83 strs := make([]string, len(u))
84 for i, t := range u {
85 strs[i] = fmt.Sprintf("%v", t)
86 }
87 return "an unexported field was encountered, nested like this: " + strings.Join(strs, " -> ")
88 }
89
90 func makeUsefulPanic(v reflect.Value) {
91 if x := recover(); x != nil {
92 if u, ok := x.(unexportedTypePanic); ok {
93 u = append(unexportedTypePanic{v.Type()}, u...)
94 x = u
95 }
96 panic(x)
97 }
98 }
99
100
101
102
103 func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
104 defer makeUsefulPanic(v1)
105
106 if !v1.IsValid() || !v2.IsValid() {
107 return v1.IsValid() == v2.IsValid()
108 }
109 if v1.Type() != v2.Type() {
110 return false
111 }
112 if fv, ok := e[v1.Type()]; ok {
113 return fv.Call([]reflect.Value{v1, v2})[0].Bool()
114 }
115 if v1.CanAddr() {
116 if fv, ok := e[v1.Addr().Type()]; ok {
117 return fv.Call([]reflect.Value{v1.Addr(), v2.Addr()})[0].Bool()
118 }
119 }
120
121 hard := func(k reflect.Kind) bool {
122 switch k {
123 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
124 return true
125 }
126 return false
127 }
128
129 if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
130 addr1 := v1.UnsafeAddr()
131 addr2 := v2.UnsafeAddr()
132 if addr1 > addr2 {
133
134 addr1, addr2 = addr2, addr1
135 }
136
137
138 if addr1 == addr2 {
139 return true
140 }
141
142
143 typ := v1.Type()
144 v := visit{addr1, addr2, typ}
145 if visited[v] {
146 return true
147 }
148
149
150 visited[v] = true
151 }
152
153 switch v1.Kind() {
154 case reflect.Array:
155
156
157 for i := 0; i < v1.Len(); i++ {
158 if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
159 return false
160 }
161 }
162 return true
163 case reflect.Slice:
164 if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
165 return false
166 }
167 if v1.IsNil() || v1.Len() == 0 {
168 return true
169 }
170 if v1.Len() != v2.Len() {
171 return false
172 }
173 if v1.Pointer() == v2.Pointer() {
174 return true
175 }
176 for i := 0; i < v1.Len(); i++ {
177 if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, depth+1) {
178 return false
179 }
180 }
181 return true
182 case reflect.Interface:
183 if v1.IsNil() || v2.IsNil() {
184 return v1.IsNil() == v2.IsNil()
185 }
186 return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
187 case reflect.Ptr:
188 return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, depth+1)
189 case reflect.Struct:
190 for i, n := 0, v1.NumField(); i < n; i++ {
191 if !e.deepValueEqual(v1.Field(i), v2.Field(i), visited, depth+1) {
192 return false
193 }
194 }
195 return true
196 case reflect.Map:
197 if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
198 return false
199 }
200 if v1.IsNil() || v1.Len() == 0 {
201 return true
202 }
203 if v1.Len() != v2.Len() {
204 return false
205 }
206 if v1.Pointer() == v2.Pointer() {
207 return true
208 }
209 for _, k := range v1.MapKeys() {
210 if !e.deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) {
211 return false
212 }
213 }
214 return true
215 case reflect.Func:
216 if v1.IsNil() && v2.IsNil() {
217 return true
218 }
219
220 return false
221 default:
222
223 if !v1.CanInterface() || !v2.CanInterface() {
224 panic(unexportedTypePanic{})
225 }
226 return v1.Interface() == v2.Interface()
227 }
228 }
229
230
231
232
233
234
235
236
237
238
239 func (e Equalities) DeepEqual(a1, a2 interface{}) bool {
240 if a1 == nil || a2 == nil {
241 return a1 == a2
242 }
243 v1 := reflect.ValueOf(a1)
244 v2 := reflect.ValueOf(a2)
245 if v1.Type() != v2.Type() {
246 return false
247 }
248 return e.deepValueEqual(v1, v2, make(map[visit]bool), 0)
249 }
250
251 func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
252 defer makeUsefulPanic(v1)
253
254 if !v1.IsValid() || !v2.IsValid() {
255 return v1.IsValid() == v2.IsValid()
256 }
257 if v1.Type() != v2.Type() {
258 return false
259 }
260 if fv, ok := e[v1.Type()]; ok {
261 return fv.Call([]reflect.Value{v1, v2})[0].Bool()
262 }
263 if v1.CanAddr() {
264 if fv, ok := e[v1.Addr().Type()]; ok {
265 return fv.Call([]reflect.Value{v1.Addr(), v2.Addr()})[0].Bool()
266 }
267 }
268
269 hard := func(k reflect.Kind) bool {
270 switch k {
271 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
272 return true
273 }
274 return false
275 }
276
277 if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
278 addr1 := v1.UnsafeAddr()
279 addr2 := v2.UnsafeAddr()
280 if addr1 > addr2 {
281
282 addr1, addr2 = addr2, addr1
283 }
284
285
286 if addr1 == addr2 {
287 return true
288 }
289
290
291 typ := v1.Type()
292 v := visit{addr1, addr2, typ}
293 if visited[v] {
294 return true
295 }
296
297
298 visited[v] = true
299 }
300
301 switch v1.Kind() {
302 case reflect.Array:
303
304
305 for i := 0; i < v1.Len(); i++ {
306 if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
307 return false
308 }
309 }
310 return true
311 case reflect.Slice:
312 if v1.IsNil() || v1.Len() == 0 {
313 return true
314 }
315 if v1.Len() > v2.Len() {
316 return false
317 }
318 if v1.Pointer() == v2.Pointer() {
319 return true
320 }
321 for i := 0; i < v1.Len(); i++ {
322 if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
323 return false
324 }
325 }
326 return true
327 case reflect.String:
328 if v1.Len() == 0 {
329 return true
330 }
331 if v1.Len() > v2.Len() {
332 return false
333 }
334 return v1.String() == v2.String()
335 case reflect.Interface:
336 if v1.IsNil() {
337 return true
338 }
339 return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1)
340 case reflect.Ptr:
341 if v1.IsNil() {
342 return true
343 }
344 return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1)
345 case reflect.Struct:
346 for i, n := 0, v1.NumField(); i < n; i++ {
347 if !e.deepValueDerive(v1.Field(i), v2.Field(i), visited, depth+1) {
348 return false
349 }
350 }
351 return true
352 case reflect.Map:
353 if v1.IsNil() || v1.Len() == 0 {
354 return true
355 }
356 if v1.Len() > v2.Len() {
357 return false
358 }
359 if v1.Pointer() == v2.Pointer() {
360 return true
361 }
362 for _, k := range v1.MapKeys() {
363 if !e.deepValueDerive(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) {
364 return false
365 }
366 }
367 return true
368 case reflect.Func:
369 if v1.IsNil() && v2.IsNil() {
370 return true
371 }
372
373 return false
374 default:
375
376 if !v1.CanInterface() || !v2.CanInterface() {
377 panic(unexportedTypePanic{})
378 }
379 return v1.Interface() == v2.Interface()
380 }
381 }
382
383
384
385
386
387
388 func (e Equalities) DeepDerivative(a1, a2 interface{}) bool {
389 if a1 == nil {
390 return true
391 }
392 v1 := reflect.ValueOf(a1)
393 v2 := reflect.ValueOf(a2)
394 if v1.Type() != v2.Type() {
395 return false
396 }
397 return e.deepValueDerive(v1, v2, make(map[visit]bool), 0)
398 }
399
View as plain text