1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32 package cmp
33
34 import (
35 "fmt"
36 "reflect"
37 "strings"
38
39 "github.com/google/go-cmp/cmp/internal/diff"
40 "github.com/google/go-cmp/cmp/internal/function"
41 "github.com/google/go-cmp/cmp/internal/value"
42 )
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95 func Equal(x, y interface{}, opts ...Option) bool {
96 s := newState(opts)
97 s.compareAny(rootStep(x, y))
98 return s.result.Equal()
99 }
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115 func Diff(x, y interface{}, opts ...Option) string {
116 s := newState(opts)
117
118
119
120
121 if len(s.reporters) == 0 {
122 s.compareAny(rootStep(x, y))
123 if s.result.Equal() {
124 return ""
125 }
126 s.result = diff.Result{}
127 }
128
129 r := new(defaultReporter)
130 s.reporters = append(s.reporters, reporter{r})
131 s.compareAny(rootStep(x, y))
132 d := r.String()
133 if (d == "") != s.result.Equal() {
134 panic("inconsistent difference and equality results")
135 }
136 return d
137 }
138
139
140
141 func rootStep(x, y interface{}) PathStep {
142 vx := reflect.ValueOf(x)
143 vy := reflect.ValueOf(y)
144
145
146
147 var t reflect.Type
148 if !vx.IsValid() || !vy.IsValid() || vx.Type() != vy.Type() {
149 t = anyType
150 if vx.IsValid() {
151 vvx := reflect.New(t).Elem()
152 vvx.Set(vx)
153 vx = vvx
154 }
155 if vy.IsValid() {
156 vvy := reflect.New(t).Elem()
157 vvy.Set(vy)
158 vy = vvy
159 }
160 } else {
161 t = vx.Type()
162 }
163
164 return &pathStep{t, vx, vy}
165 }
166
167 type state struct {
168
169
170 result diff.Result
171 curPath Path
172 curPtrs pointerPath
173 reporters []reporter
174
175
176
177 recChecker recChecker
178
179
180
181 dynChecker dynChecker
182
183
184 exporters []exporter
185 opts Options
186 }
187
188 func newState(opts []Option) *state {
189
190 s := &state{opts: Options{validator{}}}
191 s.curPtrs.Init()
192 s.processOption(Options(opts))
193 return s
194 }
195
196 func (s *state) processOption(opt Option) {
197 switch opt := opt.(type) {
198 case nil:
199 case Options:
200 for _, o := range opt {
201 s.processOption(o)
202 }
203 case coreOption:
204 type filtered interface {
205 isFiltered() bool
206 }
207 if fopt, ok := opt.(filtered); ok && !fopt.isFiltered() {
208 panic(fmt.Sprintf("cannot use an unfiltered option: %v", opt))
209 }
210 s.opts = append(s.opts, opt)
211 case exporter:
212 s.exporters = append(s.exporters, opt)
213 case reporter:
214 s.reporters = append(s.reporters, opt)
215 default:
216 panic(fmt.Sprintf("unknown option %T", opt))
217 }
218 }
219
220
221
222
223 func (s *state) statelessCompare(step PathStep) diff.Result {
224
225
226
227
228
229 oldResult, oldReporters := s.result, s.reporters
230 s.result = diff.Result{}
231 s.reporters = nil
232 s.compareAny(step)
233 res := s.result
234 s.result, s.reporters = oldResult, oldReporters
235 return res
236 }
237
238 func (s *state) compareAny(step PathStep) {
239
240 s.curPath.push(step)
241 defer s.curPath.pop()
242 for _, r := range s.reporters {
243 r.PushStep(step)
244 defer r.PopStep()
245 }
246 s.recChecker.Check(s.curPath)
247
248
249 t := step.Type()
250 vx, vy := step.Values()
251 if si, ok := step.(SliceIndex); ok && si.isSlice && vx.IsValid() && vy.IsValid() {
252 px, py := vx.Addr(), vy.Addr()
253 if eq, visited := s.curPtrs.Push(px, py); visited {
254 s.report(eq, reportByCycle)
255 return
256 }
257 defer s.curPtrs.Pop(px, py)
258 }
259
260
261 if s.tryOptions(t, vx, vy) {
262 return
263 }
264
265
266 if s.tryMethod(t, vx, vy) {
267 return
268 }
269
270
271 switch t.Kind() {
272 case reflect.Bool:
273 s.report(vx.Bool() == vy.Bool(), 0)
274 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
275 s.report(vx.Int() == vy.Int(), 0)
276 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
277 s.report(vx.Uint() == vy.Uint(), 0)
278 case reflect.Float32, reflect.Float64:
279 s.report(vx.Float() == vy.Float(), 0)
280 case reflect.Complex64, reflect.Complex128:
281 s.report(vx.Complex() == vy.Complex(), 0)
282 case reflect.String:
283 s.report(vx.String() == vy.String(), 0)
284 case reflect.Chan, reflect.UnsafePointer:
285 s.report(vx.Pointer() == vy.Pointer(), 0)
286 case reflect.Func:
287 s.report(vx.IsNil() && vy.IsNil(), 0)
288 case reflect.Struct:
289 s.compareStruct(t, vx, vy)
290 case reflect.Slice, reflect.Array:
291 s.compareSlice(t, vx, vy)
292 case reflect.Map:
293 s.compareMap(t, vx, vy)
294 case reflect.Ptr:
295 s.comparePtr(t, vx, vy)
296 case reflect.Interface:
297 s.compareInterface(t, vx, vy)
298 default:
299 panic(fmt.Sprintf("%v kind not handled", t.Kind()))
300 }
301 }
302
303 func (s *state) tryOptions(t reflect.Type, vx, vy reflect.Value) bool {
304
305 if opt := s.opts.filter(s, t, vx, vy); opt != nil {
306 opt.apply(s, vx, vy)
307 return true
308 }
309 return false
310 }
311
312 func (s *state) tryMethod(t reflect.Type, vx, vy reflect.Value) bool {
313
314 m, ok := t.MethodByName("Equal")
315 if !ok || !function.IsType(m.Type, function.EqualAssignable) {
316 return false
317 }
318
319 eq := s.callTTBFunc(m.Func, vx, vy)
320 s.report(eq, reportByMethod)
321 return true
322 }
323
324 func (s *state) callTRFunc(f, v reflect.Value, step Transform) reflect.Value {
325 if !s.dynChecker.Next() {
326 return f.Call([]reflect.Value{v})[0]
327 }
328
329
330
331
332 c := make(chan reflect.Value)
333 go detectRaces(c, f, v)
334 got := <-c
335 want := f.Call([]reflect.Value{v})[0]
336 if step.vx, step.vy = got, want; !s.statelessCompare(step).Equal() {
337
338
339 if step.vx, step.vy = want, want; !s.statelessCompare(step).Equal() {
340 return want
341 }
342 panic(fmt.Sprintf("non-deterministic function detected: %s", function.NameOf(f)))
343 }
344 return want
345 }
346
347 func (s *state) callTTBFunc(f, x, y reflect.Value) bool {
348 if !s.dynChecker.Next() {
349 return f.Call([]reflect.Value{x, y})[0].Bool()
350 }
351
352
353
354
355
356 c := make(chan reflect.Value)
357 go detectRaces(c, f, y, x)
358 got := <-c
359 want := f.Call([]reflect.Value{x, y})[0].Bool()
360 if !got.IsValid() || got.Bool() != want {
361 panic(fmt.Sprintf("non-deterministic or non-symmetric function detected: %s", function.NameOf(f)))
362 }
363 return want
364 }
365
366 func detectRaces(c chan<- reflect.Value, f reflect.Value, vs ...reflect.Value) {
367 var ret reflect.Value
368 defer func() {
369 recover()
370 c <- ret
371 }()
372 ret = f.Call(vs)[0]
373 }
374
375 func (s *state) compareStruct(t reflect.Type, vx, vy reflect.Value) {
376 var addr bool
377 var vax, vay reflect.Value
378
379 var mayForce, mayForceInit bool
380 step := StructField{&structField{}}
381 for i := 0; i < t.NumField(); i++ {
382 step.typ = t.Field(i).Type
383 step.vx = vx.Field(i)
384 step.vy = vy.Field(i)
385 step.name = t.Field(i).Name
386 step.idx = i
387 step.unexported = !isExported(step.name)
388 if step.unexported {
389 if step.name == "_" {
390 continue
391 }
392
393
394 if !vax.IsValid() || !vay.IsValid() {
395
396
397
398 addr = vx.CanAddr() || vy.CanAddr()
399 vax = makeAddressable(vx)
400 vay = makeAddressable(vy)
401 }
402 if !mayForceInit {
403 for _, xf := range s.exporters {
404 mayForce = mayForce || xf(t)
405 }
406 mayForceInit = true
407 }
408 step.mayForce = mayForce
409 step.paddr = addr
410 step.pvx = vax
411 step.pvy = vay
412 step.field = t.Field(i)
413 }
414 s.compareAny(step)
415 }
416 }
417
418 func (s *state) compareSlice(t reflect.Type, vx, vy reflect.Value) {
419 isSlice := t.Kind() == reflect.Slice
420 if isSlice && (vx.IsNil() || vy.IsNil()) {
421 s.report(vx.IsNil() && vy.IsNil(), 0)
422 return
423 }
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439 step := SliceIndex{&sliceIndex{pathStep: pathStep{typ: t.Elem()}, isSlice: isSlice}}
440 withIndexes := func(ix, iy int) SliceIndex {
441 if ix >= 0 {
442 step.vx, step.xkey = vx.Index(ix), ix
443 } else {
444 step.vx, step.xkey = reflect.Value{}, -1
445 }
446 if iy >= 0 {
447 step.vy, step.ykey = vy.Index(iy), iy
448 } else {
449 step.vy, step.ykey = reflect.Value{}, -1
450 }
451 return step
452 }
453
454
455
456
457
458
459
460
461 var indexesX, indexesY []int
462 var ignoredX, ignoredY []bool
463 for ix := 0; ix < vx.Len(); ix++ {
464 ignored := s.statelessCompare(withIndexes(ix, -1)).NumDiff == 0
465 if !ignored {
466 indexesX = append(indexesX, ix)
467 }
468 ignoredX = append(ignoredX, ignored)
469 }
470 for iy := 0; iy < vy.Len(); iy++ {
471 ignored := s.statelessCompare(withIndexes(-1, iy)).NumDiff == 0
472 if !ignored {
473 indexesY = append(indexesY, iy)
474 }
475 ignoredY = append(ignoredY, ignored)
476 }
477
478
479 edits := diff.Difference(len(indexesX), len(indexesY), func(ix, iy int) diff.Result {
480 return s.statelessCompare(withIndexes(indexesX[ix], indexesY[iy]))
481 })
482
483
484 var ix, iy int
485 for ix < vx.Len() || iy < vy.Len() {
486 var e diff.EditType
487 switch {
488 case ix < len(ignoredX) && ignoredX[ix]:
489 e = diff.UniqueX
490 case iy < len(ignoredY) && ignoredY[iy]:
491 e = diff.UniqueY
492 default:
493 e, edits = edits[0], edits[1:]
494 }
495 switch e {
496 case diff.UniqueX:
497 s.compareAny(withIndexes(ix, -1))
498 ix++
499 case diff.UniqueY:
500 s.compareAny(withIndexes(-1, iy))
501 iy++
502 default:
503 s.compareAny(withIndexes(ix, iy))
504 ix++
505 iy++
506 }
507 }
508 }
509
510 func (s *state) compareMap(t reflect.Type, vx, vy reflect.Value) {
511 if vx.IsNil() || vy.IsNil() {
512 s.report(vx.IsNil() && vy.IsNil(), 0)
513 return
514 }
515
516
517 if eq, visited := s.curPtrs.Push(vx, vy); visited {
518 s.report(eq, reportByCycle)
519 return
520 }
521 defer s.curPtrs.Pop(vx, vy)
522
523
524
525 step := MapIndex{&mapIndex{pathStep: pathStep{typ: t.Elem()}}}
526 for _, k := range value.SortKeys(append(vx.MapKeys(), vy.MapKeys()...)) {
527 step.vx = vx.MapIndex(k)
528 step.vy = vy.MapIndex(k)
529 step.key = k
530 if !step.vx.IsValid() && !step.vy.IsValid() {
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545 const help = "consider providing a Comparer to compare the map"
546 panic(fmt.Sprintf("%#v has map key with NaNs\n%s", s.curPath, help))
547 }
548 s.compareAny(step)
549 }
550 }
551
552 func (s *state) comparePtr(t reflect.Type, vx, vy reflect.Value) {
553 if vx.IsNil() || vy.IsNil() {
554 s.report(vx.IsNil() && vy.IsNil(), 0)
555 return
556 }
557
558
559 if eq, visited := s.curPtrs.Push(vx, vy); visited {
560 s.report(eq, reportByCycle)
561 return
562 }
563 defer s.curPtrs.Pop(vx, vy)
564
565 vx, vy = vx.Elem(), vy.Elem()
566 s.compareAny(Indirect{&indirect{pathStep{t.Elem(), vx, vy}}})
567 }
568
569 func (s *state) compareInterface(t reflect.Type, vx, vy reflect.Value) {
570 if vx.IsNil() || vy.IsNil() {
571 s.report(vx.IsNil() && vy.IsNil(), 0)
572 return
573 }
574 vx, vy = vx.Elem(), vy.Elem()
575 if vx.Type() != vy.Type() {
576 s.report(false, 0)
577 return
578 }
579 s.compareAny(TypeAssertion{&typeAssertion{pathStep{vx.Type(), vx, vy}}})
580 }
581
582 func (s *state) report(eq bool, rf resultFlags) {
583 if rf&reportByIgnore == 0 {
584 if eq {
585 s.result.NumSame++
586 rf |= reportEqual
587 } else {
588 s.result.NumDiff++
589 rf |= reportUnequal
590 }
591 }
592 for _, r := range s.reporters {
593 r.Report(Result{flags: rf})
594 }
595 }
596
597
598
599 type recChecker struct{ next int }
600
601
602
603
604
605 func (rc *recChecker) Check(p Path) {
606 const minLen = 1 << 16
607 if rc.next == 0 {
608 rc.next = minLen
609 }
610 if len(p) < rc.next {
611 return
612 }
613 rc.next <<= 1
614
615
616 var ss []string
617 m := map[Option]int{}
618 for _, ps := range p {
619 if t, ok := ps.(Transform); ok {
620 t := t.Option()
621 if m[t] == 1 {
622 tf := t.(*transformer).fnc.Type()
623 ss = append(ss, fmt.Sprintf("%v: %v => %v", t, tf.In(0), tf.Out(0)))
624 }
625 m[t]++
626 }
627 }
628 if len(ss) > 0 {
629 const warning = "recursive set of Transformers detected"
630 const help = "consider using cmpopts.AcyclicTransformer"
631 set := strings.Join(ss, "\n\t")
632 panic(fmt.Sprintf("%s:\n\t%s\n%s", warning, set, help))
633 }
634 }
635
636
637
638
639 type dynChecker struct{ curr, next int }
640
641
642
643
644
645
646
647
648
649
650
651 func (dc *dynChecker) Next() bool {
652 ok := dc.curr == dc.next
653 if ok {
654 dc.curr = 0
655 dc.next++
656 }
657 dc.curr++
658 return ok
659 }
660
661
662
663
664 func makeAddressable(v reflect.Value) reflect.Value {
665 if v.CanAddr() {
666 return v
667 }
668 vc := reflect.New(v.Type()).Elem()
669 vc.Set(v)
670 return vc
671 }
672
View as plain text