1
2
3
4 package deep
5
6 import (
7 "errors"
8 "fmt"
9 "log"
10 "reflect"
11 "strings"
12 )
13
14 var (
15
16
17 FloatPrecision = 10
18
19
20 MaxDiff = 10
21
22
23
24 MaxDepth = 0
25
26
27 LogErrors = false
28
29
30
31
32
33 CompareUnexportedFields = false
34
35
36
37
38
39 CompareFunctions = false
40
41
42 NilSlicesAreEmpty = false
43
44
45 NilMapsAreEmpty = false
46 )
47
48 var (
49
50 ErrMaxRecursion = errors.New("recursed to MaxDepth")
51
52
53 ErrTypeMismatch = errors.New("variables are different reflect.Type")
54
55
56 ErrNotHandled = errors.New("cannot compare the reflect.Kind")
57 )
58
59 const (
60
61
62 FLAG_NONE byte = iota
63
64
65
66
67
68
69 FLAG_IGNORE_SLICE_ORDER
70 )
71
72 type cmp struct {
73 diff []string
74 buff []string
75 floatFormat string
76 flag map[byte]bool
77 }
78
79 var errorType = reflect.TypeOf((*error)(nil)).Elem()
80
81
82
83
84
85
86
87
88
89
90
91 func Equal(a, b interface{}, flags ...interface{}) []string {
92 aVal := reflect.ValueOf(a)
93 bVal := reflect.ValueOf(b)
94 c := &cmp{
95 diff: []string{},
96 buff: []string{},
97 floatFormat: fmt.Sprintf("%%.%df", FloatPrecision),
98 flag: map[byte]bool{},
99 }
100 for i := range flags {
101 c.flag[flags[i].(byte)] = true
102 }
103 if a == nil && b == nil {
104 return nil
105 } else if a == nil && b != nil {
106 c.saveDiff("<nil pointer>", b)
107 } else if a != nil && b == nil {
108 c.saveDiff(a, "<nil pointer>")
109 }
110 if len(c.diff) > 0 {
111 return c.diff
112 }
113
114 c.equals(aVal, bVal, 0)
115 if len(c.diff) > 0 {
116 return c.diff
117 }
118 return nil
119 }
120
121 func (c *cmp) equals(a, b reflect.Value, level int) {
122 if MaxDepth > 0 && level > MaxDepth {
123 logError(ErrMaxRecursion)
124 return
125 }
126
127
128 if !a.IsValid() || !b.IsValid() {
129 if a.IsValid() && !b.IsValid() {
130 c.saveDiff(a.Type(), "<nil pointer>")
131 } else if !a.IsValid() && b.IsValid() {
132 c.saveDiff("<nil pointer>", b.Type())
133 }
134 return
135 }
136
137
138 aType := a.Type()
139 bType := b.Type()
140 if aType != bType {
141
142 if aType.Name() == "" || aType.Name() != bType.Name() {
143 c.saveDiff(aType, bType)
144 } else {
145
146
147
148
149 aFullType := aType.PkgPath() + "." + aType.Name()
150 bFullType := bType.PkgPath() + "." + bType.Name()
151 c.saveDiff(aFullType, bFullType)
152 }
153 logError(ErrTypeMismatch)
154 return
155 }
156
157
158 aKind := a.Kind()
159 bKind := b.Kind()
160
161
162 aElem := aKind == reflect.Ptr || aKind == reflect.Interface
163 bElem := bKind == reflect.Ptr || bKind == reflect.Interface
164
165
166
167
168
169
170
171
172
173
174 if (aType.Implements(errorType) && bType.Implements(errorType)) &&
175 ((!aElem || !a.IsNil()) && (!bElem || !b.IsNil())) &&
176 (a.CanInterface() && b.CanInterface()) {
177 aString := a.MethodByName("Error").Call(nil)[0].String()
178 bString := b.MethodByName("Error").Call(nil)[0].String()
179 if aString != bString {
180 c.saveDiff(aString, bString)
181 }
182 return
183 }
184
185
186 if aElem || bElem {
187 if aElem {
188 a = a.Elem()
189 }
190 if bElem {
191 b = b.Elem()
192 }
193 c.equals(a, b, level+1)
194 return
195 }
196
197 switch aKind {
198
199
200
201
202
203 case reflect.Struct:
204
214
215
216
217 if eqFunc := a.MethodByName("Equal"); eqFunc.IsValid() && eqFunc.CanInterface() {
218
219
220
221
222
223
224
225
226
227 funcType := eqFunc.Type()
228 if funcType.NumIn() == 1 && funcType.In(0) == bType {
229 retVals := eqFunc.Call([]reflect.Value{b})
230 if !retVals[0].Bool() {
231 c.saveDiff(a, b)
232 }
233 return
234 }
235 }
236
237 for i := 0; i < a.NumField(); i++ {
238 if aType.Field(i).PkgPath != "" && !CompareUnexportedFields {
239 continue
240 }
241
242 if aType.Field(i).Tag.Get("deep") == "-" {
243 continue
244 }
245
246 c.push(aType.Field(i).Name)
247
248
249
250 af := a.Field(i)
251 bf := b.Field(i)
252
253
254 c.equals(af, bf, level+1)
255
256 c.pop()
257
258 if len(c.diff) >= MaxDiff {
259 break
260 }
261 }
262 case reflect.Map:
263
277
278 if a.IsNil() || b.IsNil() {
279 if NilMapsAreEmpty {
280 if a.IsNil() && b.Len() != 0 {
281 c.saveDiff("<nil map>", b)
282 return
283 } else if a.Len() != 0 && b.IsNil() {
284 c.saveDiff(a, "<nil map>")
285 return
286 }
287 } else {
288 if a.IsNil() && !b.IsNil() {
289 c.saveDiff("<nil map>", b)
290 } else if !a.IsNil() && b.IsNil() {
291 c.saveDiff(a, "<nil map>")
292 }
293 }
294 return
295 }
296
297 if a.Pointer() == b.Pointer() {
298 return
299 }
300
301 for _, key := range a.MapKeys() {
302 c.push(fmt.Sprintf("map[%v]", key))
303
304 aVal := a.MapIndex(key)
305 bVal := b.MapIndex(key)
306 if bVal.IsValid() {
307 c.equals(aVal, bVal, level+1)
308 } else {
309 c.saveDiff(aVal, "<does not have key>")
310 }
311
312 c.pop()
313
314 if len(c.diff) >= MaxDiff {
315 return
316 }
317 }
318
319 for _, key := range b.MapKeys() {
320 if aVal := a.MapIndex(key); aVal.IsValid() {
321 continue
322 }
323
324 c.push(fmt.Sprintf("map[%v]", key))
325 c.saveDiff("<does not have key>", b.MapIndex(key))
326 c.pop()
327 if len(c.diff) >= MaxDiff {
328 return
329 }
330 }
331 case reflect.Array:
332 n := a.Len()
333 for i := 0; i < n; i++ {
334 c.push(fmt.Sprintf("array[%d]", i))
335 c.equals(a.Index(i), b.Index(i), level+1)
336 c.pop()
337 if len(c.diff) >= MaxDiff {
338 break
339 }
340 }
341 case reflect.Slice:
342 if NilSlicesAreEmpty {
343 if a.IsNil() && b.Len() != 0 {
344 c.saveDiff("<nil slice>", b)
345 return
346 } else if a.Len() != 0 && b.IsNil() {
347 c.saveDiff(a, "<nil slice>")
348 return
349 }
350 } else {
351 if a.IsNil() && !b.IsNil() {
352 c.saveDiff("<nil slice>", b)
353 return
354 } else if !a.IsNil() && b.IsNil() {
355 c.saveDiff(a, "<nil slice>")
356 return
357 }
358 }
359
360
361
362
363
364
365
366 aLen := a.Len()
367 bLen := b.Len()
368 if a.Pointer() == b.Pointer() && aLen == bLen {
369 return
370 }
371
372 if c.flag[FLAG_IGNORE_SLICE_ORDER] {
373
374
375
376
377
378
379 am := map[interface{}]int{}
380 for i := 0; i < a.Len(); i++ {
381 am[a.Index(i).Interface()] += 1
382 }
383 bm := map[interface{}]int{}
384 for i := 0; i < b.Len(); i++ {
385 bm[b.Index(i).Interface()] += 1
386 }
387 c.cmpMapValueCounts(a, b, am, bm, true)
388 c.cmpMapValueCounts(b, a, bm, am, false)
389 } else {
390
391 n := aLen
392 if bLen > aLen {
393 n = bLen
394 }
395 for i := 0; i < n; i++ {
396 c.push(fmt.Sprintf("slice[%d]", i))
397 if i < aLen && i < bLen {
398 c.equals(a.Index(i), b.Index(i), level+1)
399 } else if i < aLen {
400 c.saveDiff(a.Index(i), "<no value>")
401 } else {
402 c.saveDiff("<no value>", b.Index(i))
403 }
404 c.pop()
405 if len(c.diff) >= MaxDiff {
406 break
407 }
408 }
409 }
410
411
412
413
414
415 case reflect.Float32, reflect.Float64:
416
417
418
419
420
421
422
423 aval := fmt.Sprintf(c.floatFormat, a.Float())
424 bval := fmt.Sprintf(c.floatFormat, b.Float())
425 if aval != bval {
426 c.saveDiff(a.Float(), b.Float())
427 }
428 case reflect.Bool:
429 if a.Bool() != b.Bool() {
430 c.saveDiff(a.Bool(), b.Bool())
431 }
432 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
433 if a.Int() != b.Int() {
434 c.saveDiff(a.Int(), b.Int())
435 }
436 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
437 if a.Uint() != b.Uint() {
438 c.saveDiff(a.Uint(), b.Uint())
439 }
440 case reflect.String:
441 if a.String() != b.String() {
442 c.saveDiff(a.String(), b.String())
443 }
444 case reflect.Func:
445 if CompareFunctions {
446 if !a.IsNil() || !b.IsNil() {
447 aVal, bVal := "nil func", "nil func"
448 if !a.IsNil() {
449 aVal = "func"
450 }
451 if !b.IsNil() {
452 bVal = "func"
453 }
454 c.saveDiff(aVal, bVal)
455 }
456 }
457 default:
458 logError(ErrNotHandled)
459 }
460 }
461
462 func (c *cmp) push(name string) {
463 c.buff = append(c.buff, name)
464 }
465
466 func (c *cmp) pop() {
467 if len(c.buff) > 0 {
468 c.buff = c.buff[0 : len(c.buff)-1]
469 }
470 }
471
472 func (c *cmp) saveDiff(aval, bval interface{}) {
473 if len(c.buff) > 0 {
474 varName := strings.Join(c.buff, ".")
475 c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval))
476 } else {
477 c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval))
478 }
479 }
480
481 func (c *cmp) cmpMapValueCounts(a, b reflect.Value, am, bm map[interface{}]int, a2b bool) {
482 for v := range am {
483 aCount, _ := am[v]
484 bCount, _ := bm[v]
485
486 if aCount != bCount {
487 c.push(fmt.Sprintf("(unordered) slice[]=%v: value count", v))
488 if a2b {
489 c.saveDiff(fmt.Sprintf("%d", aCount), fmt.Sprintf("%d", bCount))
490 } else {
491 c.saveDiff(fmt.Sprintf("%d", bCount), fmt.Sprintf("%d", aCount))
492 }
493 c.pop()
494 }
495 delete(am, v)
496 delete(bm, v)
497 }
498 }
499
500 func logError(err error) {
501 if LogErrors {
502 log.Println(err)
503 }
504 }
505
View as plain text