1
2
3
4
5
6
7 package reflect
8
9 import (
10 "fmt"
11 "reflect"
12 "strings"
13 )
14
15
16
17 type Equalities map[reflect.Type]reflect.Value
18
19
20 func EqualitiesOrDie(funcs ...interface{}) Equalities {
21 e := Equalities{}
22 if err := e.AddFuncs(funcs...); err != nil {
23 panic(err)
24 }
25 return e
26 }
27
28
29 func (e Equalities) AddFuncs(funcs ...interface{}) error {
30 for _, f := range funcs {
31 if err := e.AddFunc(f); err != nil {
32 return err
33 }
34 }
35 return nil
36 }
37
38
39
40 func (e Equalities) AddFunc(eqFunc interface{}) error {
41 fv := reflect.ValueOf(eqFunc)
42 ft := fv.Type()
43 if ft.Kind() != reflect.Func {
44 return fmt.Errorf("expected func, got: %v", ft)
45 }
46 if ft.NumIn() != 2 {
47 return fmt.Errorf("expected two 'in' params, got: %v", ft)
48 }
49 if ft.NumOut() != 1 {
50 return fmt.Errorf("expected one 'out' param, got: %v", ft)
51 }
52 if ft.In(0) != ft.In(1) {
53 return fmt.Errorf("expected arg 1 and 2 to have same type, but got %v", ft)
54 }
55 var forReturnType bool
56 boolType := reflect.TypeOf(forReturnType)
57 if ft.Out(0) != boolType {
58 return fmt.Errorf("expected bool return, got: %v", ft)
59 }
60 e[ft.In(0)] = fv
61 return nil
62 }
63
64
65
66
67
68
69
70 type visit struct {
71 a1 uintptr
72 a2 uintptr
73 typ reflect.Type
74 }
75
76
77
78
79 type unexportedTypePanic []reflect.Type
80
81 func (u unexportedTypePanic) Error() string { return u.String() }
82 func (u unexportedTypePanic) String() string {
83 strs := make([]string, len(u))
84 for i, t := range u {
85 strs[i] = fmt.Sprintf("%v", t)
86 }
87 return "an unexported field was encountered, nested like this: " + strings.Join(strs, " -> ")
88 }
89
90 func makeUsefulPanic(v reflect.Value) {
91 if x := recover(); x != nil {
92 if u, ok := x.(unexportedTypePanic); ok {
93 u = append(unexportedTypePanic{v.Type()}, u...)
94 x = u
95 }
96 panic(x)
97 }
98 }
99
100
101
102
103
104 func (e Equalities) deepValueEqual(v1, v2 reflect.Value, visited map[visit]bool, equateNilAndEmpty bool, depth int) bool {
105 defer makeUsefulPanic(v1)
106
107 if !v1.IsValid() || !v2.IsValid() {
108 return v1.IsValid() == v2.IsValid()
109 }
110 if v1.Type() != v2.Type() {
111 return false
112 }
113 if fv, ok := e[v1.Type()]; ok {
114 return fv.Call([]reflect.Value{v1, v2})[0].Bool()
115 }
116
117 hard := func(k reflect.Kind) bool {
118 switch k {
119 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
120 return true
121 }
122 return false
123 }
124
125 if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
126 addr1 := v1.UnsafeAddr()
127 addr2 := v2.UnsafeAddr()
128 if addr1 > addr2 {
129
130 addr1, addr2 = addr2, addr1
131 }
132
133
134 if addr1 == addr2 {
135 return true
136 }
137
138
139 typ := v1.Type()
140 v := visit{addr1, addr2, typ}
141 if visited[v] {
142 return true
143 }
144
145
146 visited[v] = true
147 }
148
149 switch v1.Kind() {
150 case reflect.Array:
151
152
153 for i := 0; i < v1.Len(); i++ {
154 if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, equateNilAndEmpty, depth+1) {
155 return false
156 }
157 }
158 return true
159 case reflect.Slice:
160 if equateNilAndEmpty {
161 if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
162 return false
163 }
164
165 if v1.IsNil() || v1.Len() == 0 {
166 return true
167 }
168 } else {
169 if v1.IsNil() != v2.IsNil() {
170 return false
171 }
172
173
174
175
176 if v1.IsNil() {
177 return true
178 }
179
180
181 if v1.Len() == 0 || v2.Len() == 0 {
182 return true
183 }
184 }
185 if v1.Len() != v2.Len() {
186 return false
187 }
188 if v1.Pointer() == v2.Pointer() {
189 return true
190 }
191 for i := 0; i < v1.Len(); i++ {
192 if !e.deepValueEqual(v1.Index(i), v2.Index(i), visited, equateNilAndEmpty, depth+1) {
193 return false
194 }
195 }
196 return true
197 case reflect.Interface:
198 if v1.IsNil() || v2.IsNil() {
199 return v1.IsNil() == v2.IsNil()
200 }
201 return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, equateNilAndEmpty, depth+1)
202 case reflect.Ptr:
203 return e.deepValueEqual(v1.Elem(), v2.Elem(), visited, equateNilAndEmpty, depth+1)
204 case reflect.Struct:
205 for i, n := 0, v1.NumField(); i < n; i++ {
206 if !e.deepValueEqual(v1.Field(i), v2.Field(i), visited, equateNilAndEmpty, depth+1) {
207 return false
208 }
209 }
210 return true
211 case reflect.Map:
212 if equateNilAndEmpty {
213 if (v1.IsNil() || v1.Len() == 0) != (v2.IsNil() || v2.Len() == 0) {
214 return false
215 }
216 if v1.IsNil() || v1.Len() == 0 {
217 return true
218 }
219 } else {
220 if v1.IsNil() != v2.IsNil() {
221 return false
222 }
223
224
225
226
227 if v1.IsNil() {
228 return true
229 }
230
231
232 if v1.Len() == 0 || v2.Len() == 0 {
233 return true
234 }
235 }
236 if v1.Len() != v2.Len() {
237 return false
238 }
239 if v1.Pointer() == v2.Pointer() {
240 return true
241 }
242 for _, k := range v1.MapKeys() {
243 if !e.deepValueEqual(v1.MapIndex(k), v2.MapIndex(k), visited, equateNilAndEmpty, depth+1) {
244 return false
245 }
246 }
247 return true
248 case reflect.Func:
249 if v1.IsNil() && v2.IsNil() {
250 return true
251 }
252
253 return false
254 default:
255
256 if !v1.CanInterface() || !v2.CanInterface() {
257 panic(unexportedTypePanic{})
258 }
259 return v1.Interface() == v2.Interface()
260 }
261 }
262
263
264
265
266
267
268
269
270
271
272 func (e Equalities) DeepEqual(a1, a2 interface{}) bool {
273 return e.deepEqual(a1, a2, true)
274 }
275
276 func (e Equalities) DeepEqualWithNilDifferentFromEmpty(a1, a2 interface{}) bool {
277 return e.deepEqual(a1, a2, false)
278 }
279
280 func (e Equalities) deepEqual(a1, a2 interface{}, equateNilAndEmpty bool) bool {
281 if a1 == nil || a2 == nil {
282 return a1 == a2
283 }
284 v1 := reflect.ValueOf(a1)
285 v2 := reflect.ValueOf(a2)
286 if v1.Type() != v2.Type() {
287 return false
288 }
289 return e.deepValueEqual(v1, v2, make(map[visit]bool), equateNilAndEmpty, 0)
290 }
291
292 func (e Equalities) deepValueDerive(v1, v2 reflect.Value, visited map[visit]bool, depth int) bool {
293 defer makeUsefulPanic(v1)
294
295 if !v1.IsValid() || !v2.IsValid() {
296 return v1.IsValid() == v2.IsValid()
297 }
298 if v1.Type() != v2.Type() {
299 return false
300 }
301 if fv, ok := e[v1.Type()]; ok {
302 return fv.Call([]reflect.Value{v1, v2})[0].Bool()
303 }
304
305 hard := func(k reflect.Kind) bool {
306 switch k {
307 case reflect.Array, reflect.Map, reflect.Slice, reflect.Struct:
308 return true
309 }
310 return false
311 }
312
313 if v1.CanAddr() && v2.CanAddr() && hard(v1.Kind()) {
314 addr1 := v1.UnsafeAddr()
315 addr2 := v2.UnsafeAddr()
316 if addr1 > addr2 {
317
318 addr1, addr2 = addr2, addr1
319 }
320
321
322 if addr1 == addr2 {
323 return true
324 }
325
326
327 typ := v1.Type()
328 v := visit{addr1, addr2, typ}
329 if visited[v] {
330 return true
331 }
332
333
334 visited[v] = true
335 }
336
337 switch v1.Kind() {
338 case reflect.Array:
339
340
341 for i := 0; i < v1.Len(); i++ {
342 if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
343 return false
344 }
345 }
346 return true
347 case reflect.Slice:
348 if v1.IsNil() || v1.Len() == 0 {
349 return true
350 }
351 if v1.Len() > v2.Len() {
352 return false
353 }
354 if v1.Pointer() == v2.Pointer() {
355 return true
356 }
357 for i := 0; i < v1.Len(); i++ {
358 if !e.deepValueDerive(v1.Index(i), v2.Index(i), visited, depth+1) {
359 return false
360 }
361 }
362 return true
363 case reflect.String:
364 if v1.Len() == 0 {
365 return true
366 }
367 if v1.Len() > v2.Len() {
368 return false
369 }
370 return v1.String() == v2.String()
371 case reflect.Interface:
372 if v1.IsNil() {
373 return true
374 }
375 return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1)
376 case reflect.Pointer:
377 if v1.IsNil() {
378 return true
379 }
380 return e.deepValueDerive(v1.Elem(), v2.Elem(), visited, depth+1)
381 case reflect.Struct:
382 for i, n := 0, v1.NumField(); i < n; i++ {
383 if !e.deepValueDerive(v1.Field(i), v2.Field(i), visited, depth+1) {
384 return false
385 }
386 }
387 return true
388 case reflect.Map:
389 if v1.IsNil() || v1.Len() == 0 {
390 return true
391 }
392 if v1.Len() > v2.Len() {
393 return false
394 }
395 if v1.Pointer() == v2.Pointer() {
396 return true
397 }
398 for _, k := range v1.MapKeys() {
399 if !e.deepValueDerive(v1.MapIndex(k), v2.MapIndex(k), visited, depth+1) {
400 return false
401 }
402 }
403 return true
404 case reflect.Func:
405 if v1.IsNil() && v2.IsNil() {
406 return true
407 }
408
409 return false
410 default:
411
412 if !v1.CanInterface() || !v2.CanInterface() {
413 panic(unexportedTypePanic{})
414 }
415 return v1.Interface() == v2.Interface()
416 }
417 }
418
419
420
421
422
423
424 func (e Equalities) DeepDerivative(a1, a2 interface{}) bool {
425 if a1 == nil {
426 return true
427 }
428 v1 := reflect.ValueOf(a1)
429 v2 := reflect.ValueOf(a2)
430 if v1.Type() != v2.Type() {
431 return false
432 }
433 return e.deepValueDerive(v1, v2, make(map[visit]bool), 0)
434 }
435
View as plain text