...

Source file src/github.com/go-test/deep/deep.go

Documentation: github.com/go-test/deep

     1  // Package deep provides function deep.Equal which is like reflect.DeepEqual but
     2  // returns a list of differences. This is helpful when comparing complex types
     3  // like structures and maps.
     4  package deep
     5  
     6  import (
     7  	"errors"
     8  	"fmt"
     9  	"log"
    10  	"reflect"
    11  	"strings"
    12  )
    13  
    14  var (
    15  	// FloatPrecision is the number of decimal places to round float values
    16  	// to when comparing.
    17  	FloatPrecision = 10
    18  
    19  	// MaxDiff specifies the maximum number of differences to return.
    20  	MaxDiff = 10
    21  
    22  	// MaxDepth specifies the maximum levels of a struct to recurse into,
    23  	// if greater than zero. If zero, there is no limit.
    24  	MaxDepth = 0
    25  
    26  	// LogErrors causes errors to be logged to STDERR when true.
    27  	LogErrors = false
    28  
    29  	// CompareUnexportedFields causes unexported struct fields, like s in
    30  	// T{s int}, to be compared when true. This does not work for comparing
    31  	// error or Time types on unexported fields because methods on unexported
    32  	// fields cannot be called.
    33  	CompareUnexportedFields = false
    34  
    35  	// CompareFunctions compares functions the same as reflect.DeepEqual:
    36  	// only two nil functions are equal. Every other combination is not equal.
    37  	// This is disabled by default because previous versions of this package
    38  	// ignored functions. Enabling it can possibly report new diffs.
    39  	CompareFunctions = false
    40  
    41  	// NilSlicesAreEmpty causes a nil slice to be equal to an empty slice.
    42  	NilSlicesAreEmpty = false
    43  
    44  	// NilMapsAreEmpty causes a nil map to be equal to an empty map.
    45  	NilMapsAreEmpty = false
    46  )
    47  
    48  var (
    49  	// ErrMaxRecursion is logged when MaxDepth is reached.
    50  	ErrMaxRecursion = errors.New("recursed to MaxDepth")
    51  
    52  	// ErrTypeMismatch is logged when Equal passed two different types of values.
    53  	ErrTypeMismatch = errors.New("variables are different reflect.Type")
    54  
    55  	// ErrNotHandled is logged when a primitive Go kind is not handled.
    56  	ErrNotHandled = errors.New("cannot compare the reflect.Kind")
    57  )
    58  
    59  const (
    60  	// FLAG_NONE is a placeholder for default Equal behavior. You don't have to
    61  	// pass it to Equal; if you do, it does nothing.
    62  	FLAG_NONE byte = iota
    63  
    64  	// FLAG_IGNORE_SLICE_ORDER causes Equal to ignore slice order so that
    65  	// []int{1, 2} and []int{2, 1} are equal. Only slices of primitive scalars
    66  	// like numbers and strings are supported. Slices of complex types,
    67  	// like []T where T is a struct, are undefined because Equal does not
    68  	// recurse into the slice value when this flag is enabled.
    69  	FLAG_IGNORE_SLICE_ORDER
    70  )
    71  
    72  type cmp struct {
    73  	diff        []string
    74  	buff        []string
    75  	floatFormat string
    76  	flag        map[byte]bool
    77  }
    78  
    79  var errorType = reflect.TypeOf((*error)(nil)).Elem()
    80  
    81  // Equal compares variables a and b, recursing into their structure up to
    82  // MaxDepth levels deep (if greater than zero), and returns a list of differences,
    83  // or nil if there are none. Some differences may not be found if an error is
    84  // also returned.
    85  //
    86  // If a type has an Equal method, like time.Equal, it is called to check for
    87  // equality.
    88  //
    89  // When comparing a struct, if a field has the tag `deep:"-"` then it will be
    90  // ignored.
    91  func Equal(a, b interface{}, flags ...interface{}) []string {
    92  	aVal := reflect.ValueOf(a)
    93  	bVal := reflect.ValueOf(b)
    94  	c := &cmp{
    95  		diff:        []string{},
    96  		buff:        []string{},
    97  		floatFormat: fmt.Sprintf("%%.%df", FloatPrecision),
    98  		flag:        map[byte]bool{},
    99  	}
   100  	for i := range flags {
   101  		c.flag[flags[i].(byte)] = true
   102  	}
   103  	if a == nil && b == nil {
   104  		return nil
   105  	} else if a == nil && b != nil {
   106  		c.saveDiff("<nil pointer>", b)
   107  	} else if a != nil && b == nil {
   108  		c.saveDiff(a, "<nil pointer>")
   109  	}
   110  	if len(c.diff) > 0 {
   111  		return c.diff
   112  	}
   113  
   114  	c.equals(aVal, bVal, 0)
   115  	if len(c.diff) > 0 {
   116  		return c.diff // diffs
   117  	}
   118  	return nil // no diffs
   119  }
   120  
   121  func (c *cmp) equals(a, b reflect.Value, level int) {
   122  	if MaxDepth > 0 && level > MaxDepth {
   123  		logError(ErrMaxRecursion)
   124  		return
   125  	}
   126  
   127  	// Check if one value is nil, e.g. T{x: *X} and T.x is nil
   128  	if !a.IsValid() || !b.IsValid() {
   129  		if a.IsValid() && !b.IsValid() {
   130  			c.saveDiff(a.Type(), "<nil pointer>")
   131  		} else if !a.IsValid() && b.IsValid() {
   132  			c.saveDiff("<nil pointer>", b.Type())
   133  		}
   134  		return
   135  	}
   136  
   137  	// If different types, they can't be equal
   138  	aType := a.Type()
   139  	bType := b.Type()
   140  	if aType != bType {
   141  		// Built-in types don't have a name, so don't report [3]int != [2]int as " != "
   142  		if aType.Name() == "" || aType.Name() != bType.Name() {
   143  			c.saveDiff(aType, bType)
   144  		} else {
   145  			// Type names can be the same, e.g. pkg/v1.Error and pkg/v2.Error
   146  			// are both exported as pkg, so unless we include the full pkg path
   147  			// the diff will be "pkg.Error != pkg.Error"
   148  			// https://github.com/go-test/deep/issues/39
   149  			aFullType := aType.PkgPath() + "." + aType.Name()
   150  			bFullType := bType.PkgPath() + "." + bType.Name()
   151  			c.saveDiff(aFullType, bFullType)
   152  		}
   153  		logError(ErrTypeMismatch)
   154  		return
   155  	}
   156  
   157  	// Primitive https://golang.org/pkg/reflect/#Kind
   158  	aKind := a.Kind()
   159  	bKind := b.Kind()
   160  
   161  	// Do a and b have underlying elements? Yes if they're ptr or interface.
   162  	aElem := aKind == reflect.Ptr || aKind == reflect.Interface
   163  	bElem := bKind == reflect.Ptr || bKind == reflect.Interface
   164  
   165  	// If both types implement the error interface, compare the error strings.
   166  	// This must be done before dereferencing because errors.New() returns a
   167  	// pointer to a struct that implements the interface:
   168  	//   func (e *errorString) Error() string {
   169  	// And we check CanInterface as a hack to make sure the underlying method
   170  	// is callable because https://github.com/golang/go/issues/32438
   171  	// Issues:
   172  	//   https://github.com/go-test/deep/issues/31
   173  	//   https://github.com/go-test/deep/issues/45
   174  	if (aType.Implements(errorType) && bType.Implements(errorType)) &&
   175  		((!aElem || !a.IsNil()) && (!bElem || !b.IsNil())) &&
   176  		(a.CanInterface() && b.CanInterface()) {
   177  		aString := a.MethodByName("Error").Call(nil)[0].String()
   178  		bString := b.MethodByName("Error").Call(nil)[0].String()
   179  		if aString != bString {
   180  			c.saveDiff(aString, bString)
   181  		}
   182  		return
   183  	}
   184  
   185  	// Dereference pointers and interface{}
   186  	if aElem || bElem {
   187  		if aElem {
   188  			a = a.Elem()
   189  		}
   190  		if bElem {
   191  			b = b.Elem()
   192  		}
   193  		c.equals(a, b, level+1)
   194  		return
   195  	}
   196  
   197  	switch aKind {
   198  
   199  	/////////////////////////////////////////////////////////////////////
   200  	// Iterable kinds
   201  	/////////////////////////////////////////////////////////////////////
   202  
   203  	case reflect.Struct:
   204  		/*
   205  			The variables are structs like:
   206  				type T struct {
   207  					FirstName string
   208  					LastName  string
   209  				}
   210  			Type = <pkg>.T, Kind = reflect.Struct
   211  
   212  			Iterate through the fields (FirstName, LastName), recurse into their values.
   213  		*/
   214  
   215  		// Types with an Equal() method, like time.Time, only if struct field
   216  		// is exported (CanInterface)
   217  		if eqFunc := a.MethodByName("Equal"); eqFunc.IsValid() && eqFunc.CanInterface() {
   218  			// Handle https://github.com/go-test/deep/issues/15:
   219  			// Don't call T.Equal if the method is from an embedded struct, like:
   220  			//   type Foo struct { time.Time }
   221  			// First, we'll encounter Equal(Ttime, time.Time) but if we pass b
   222  			// as the 2nd arg we'll panic: "Call using pkg.Foo as type time.Time"
   223  			// As far as I can tell, there's no way to see that the method is from
   224  			// time.Time not Foo. So we check the type of the 1st (0) arg and skip
   225  			// unless it's b type. Later, we'll encounter the time.Time anonymous/
   226  			// embedded field and then we'll have Equal(time.Time, time.Time).
   227  			funcType := eqFunc.Type()
   228  			if funcType.NumIn() == 1 && funcType.In(0) == bType {
   229  				retVals := eqFunc.Call([]reflect.Value{b})
   230  				if !retVals[0].Bool() {
   231  					c.saveDiff(a, b)
   232  				}
   233  				return
   234  			}
   235  		}
   236  
   237  		for i := 0; i < a.NumField(); i++ {
   238  			if aType.Field(i).PkgPath != "" && !CompareUnexportedFields {
   239  				continue // skip unexported field, e.g. s in type T struct {s string}
   240  			}
   241  
   242  			if aType.Field(i).Tag.Get("deep") == "-" {
   243  				continue // field wants to be ignored
   244  			}
   245  
   246  			c.push(aType.Field(i).Name) // push field name to buff
   247  
   248  			// Get the Value for each field, e.g. FirstName has Type = string,
   249  			// Kind = reflect.String.
   250  			af := a.Field(i)
   251  			bf := b.Field(i)
   252  
   253  			// Recurse to compare the field values
   254  			c.equals(af, bf, level+1)
   255  
   256  			c.pop() // pop field name from buff
   257  
   258  			if len(c.diff) >= MaxDiff {
   259  				break
   260  			}
   261  		}
   262  	case reflect.Map:
   263  		/*
   264  			The variables are maps like:
   265  				map[string]int{
   266  					"foo": 1,
   267  					"bar": 2,
   268  				}
   269  			Type = map[string]int, Kind = reflect.Map
   270  
   271  			Or:
   272  				type T map[string]int{}
   273  			Type = <pkg>.T, Kind = reflect.Map
   274  
   275  			Iterate through the map keys (foo, bar), recurse into their values.
   276  		*/
   277  
   278  		if a.IsNil() || b.IsNil() {
   279  			if NilMapsAreEmpty {
   280  				if a.IsNil() && b.Len() != 0 {
   281  					c.saveDiff("<nil map>", b)
   282  					return
   283  				} else if a.Len() != 0 && b.IsNil() {
   284  					c.saveDiff(a, "<nil map>")
   285  					return
   286  				}
   287  			} else {
   288  				if a.IsNil() && !b.IsNil() {
   289  					c.saveDiff("<nil map>", b)
   290  				} else if !a.IsNil() && b.IsNil() {
   291  					c.saveDiff(a, "<nil map>")
   292  				}
   293  			}
   294  			return
   295  		}
   296  
   297  		if a.Pointer() == b.Pointer() {
   298  			return
   299  		}
   300  
   301  		for _, key := range a.MapKeys() {
   302  			c.push(fmt.Sprintf("map[%v]", key))
   303  
   304  			aVal := a.MapIndex(key)
   305  			bVal := b.MapIndex(key)
   306  			if bVal.IsValid() {
   307  				c.equals(aVal, bVal, level+1)
   308  			} else {
   309  				c.saveDiff(aVal, "<does not have key>")
   310  			}
   311  
   312  			c.pop()
   313  
   314  			if len(c.diff) >= MaxDiff {
   315  				return
   316  			}
   317  		}
   318  
   319  		for _, key := range b.MapKeys() {
   320  			if aVal := a.MapIndex(key); aVal.IsValid() {
   321  				continue
   322  			}
   323  
   324  			c.push(fmt.Sprintf("map[%v]", key))
   325  			c.saveDiff("<does not have key>", b.MapIndex(key))
   326  			c.pop()
   327  			if len(c.diff) >= MaxDiff {
   328  				return
   329  			}
   330  		}
   331  	case reflect.Array:
   332  		n := a.Len()
   333  		for i := 0; i < n; i++ {
   334  			c.push(fmt.Sprintf("array[%d]", i))
   335  			c.equals(a.Index(i), b.Index(i), level+1)
   336  			c.pop()
   337  			if len(c.diff) >= MaxDiff {
   338  				break
   339  			}
   340  		}
   341  	case reflect.Slice:
   342  		if NilSlicesAreEmpty {
   343  			if a.IsNil() && b.Len() != 0 {
   344  				c.saveDiff("<nil slice>", b)
   345  				return
   346  			} else if a.Len() != 0 && b.IsNil() {
   347  				c.saveDiff(a, "<nil slice>")
   348  				return
   349  			}
   350  		} else {
   351  			if a.IsNil() && !b.IsNil() {
   352  				c.saveDiff("<nil slice>", b)
   353  				return
   354  			} else if !a.IsNil() && b.IsNil() {
   355  				c.saveDiff(a, "<nil slice>")
   356  				return
   357  			}
   358  		}
   359  
   360  		// Equal if same underlying pointer and same length, this latter handles
   361  		//   foo := []int{1, 2, 3, 4}
   362  		//   a := foo[0:2] // == {1,2}
   363  		//   b := foo[2:4] // == {3,4}
   364  		// a and b are same pointer but different slices (lengths) of the underlying
   365  		// array, so not equal.
   366  		aLen := a.Len()
   367  		bLen := b.Len()
   368  		if a.Pointer() == b.Pointer() && aLen == bLen {
   369  			return
   370  		}
   371  
   372  		if c.flag[FLAG_IGNORE_SLICE_ORDER] {
   373  			// Compare slices by value and value count; ignore order.
   374  			// Value equality is impliclity established by the maps:
   375  			// any value v1 will hash to the same map value if it's equal
   376  			// to another value v2. Then equality is determiend by value
   377  			// count: presuming v1==v2, then the slics are equal if there
   378  			// are equal numbers of v1 in each slice.
   379  			am := map[interface{}]int{}
   380  			for i := 0; i < a.Len(); i++ {
   381  				am[a.Index(i).Interface()] += 1
   382  			}
   383  			bm := map[interface{}]int{}
   384  			for i := 0; i < b.Len(); i++ {
   385  				bm[b.Index(i).Interface()] += 1
   386  			}
   387  			c.cmpMapValueCounts(a, b, am, bm, true)  // a cmp b
   388  			c.cmpMapValueCounts(b, a, bm, am, false) // b cmp a
   389  		} else {
   390  			// Compare slices by order
   391  			n := aLen
   392  			if bLen > aLen {
   393  				n = bLen
   394  			}
   395  			for i := 0; i < n; i++ {
   396  				c.push(fmt.Sprintf("slice[%d]", i))
   397  				if i < aLen && i < bLen {
   398  					c.equals(a.Index(i), b.Index(i), level+1)
   399  				} else if i < aLen {
   400  					c.saveDiff(a.Index(i), "<no value>")
   401  				} else {
   402  					c.saveDiff("<no value>", b.Index(i))
   403  				}
   404  				c.pop()
   405  				if len(c.diff) >= MaxDiff {
   406  					break
   407  				}
   408  			}
   409  		}
   410  
   411  	/////////////////////////////////////////////////////////////////////
   412  	// Primitive kinds
   413  	/////////////////////////////////////////////////////////////////////
   414  
   415  	case reflect.Float32, reflect.Float64:
   416  		// Round floats to FloatPrecision decimal places to compare with
   417  		// user-defined precision. As is commonly know, floats have "imprecision"
   418  		// such that 0.1 becomes 0.100000001490116119384765625. This cannot
   419  		// be avoided; it can only be handled. Issue 30 suggested that floats
   420  		// be compared using an epsilon: equal = |a-b| < epsilon.
   421  		// In many cases the result is the same, but I think epsilon is a little
   422  		// less clear for users to reason about. See issue 30 for details.
   423  		aval := fmt.Sprintf(c.floatFormat, a.Float())
   424  		bval := fmt.Sprintf(c.floatFormat, b.Float())
   425  		if aval != bval {
   426  			c.saveDiff(a.Float(), b.Float())
   427  		}
   428  	case reflect.Bool:
   429  		if a.Bool() != b.Bool() {
   430  			c.saveDiff(a.Bool(), b.Bool())
   431  		}
   432  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   433  		if a.Int() != b.Int() {
   434  			c.saveDiff(a.Int(), b.Int())
   435  		}
   436  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   437  		if a.Uint() != b.Uint() {
   438  			c.saveDiff(a.Uint(), b.Uint())
   439  		}
   440  	case reflect.String:
   441  		if a.String() != b.String() {
   442  			c.saveDiff(a.String(), b.String())
   443  		}
   444  	case reflect.Func:
   445  		if CompareFunctions {
   446  			if !a.IsNil() || !b.IsNil() {
   447  				aVal, bVal := "nil func", "nil func"
   448  				if !a.IsNil() {
   449  					aVal = "func"
   450  				}
   451  				if !b.IsNil() {
   452  					bVal = "func"
   453  				}
   454  				c.saveDiff(aVal, bVal)
   455  			}
   456  		}
   457  	default:
   458  		logError(ErrNotHandled)
   459  	}
   460  }
   461  
   462  func (c *cmp) push(name string) {
   463  	c.buff = append(c.buff, name)
   464  }
   465  
   466  func (c *cmp) pop() {
   467  	if len(c.buff) > 0 {
   468  		c.buff = c.buff[0 : len(c.buff)-1]
   469  	}
   470  }
   471  
   472  func (c *cmp) saveDiff(aval, bval interface{}) {
   473  	if len(c.buff) > 0 {
   474  		varName := strings.Join(c.buff, ".")
   475  		c.diff = append(c.diff, fmt.Sprintf("%s: %v != %v", varName, aval, bval))
   476  	} else {
   477  		c.diff = append(c.diff, fmt.Sprintf("%v != %v", aval, bval))
   478  	}
   479  }
   480  
   481  func (c *cmp) cmpMapValueCounts(a, b reflect.Value, am, bm map[interface{}]int, a2b bool) {
   482  	for v := range am {
   483  		aCount, _ := am[v]
   484  		bCount, _ := bm[v]
   485  
   486  		if aCount != bCount {
   487  			c.push(fmt.Sprintf("(unordered) slice[]=%v: value count", v))
   488  			if a2b {
   489  				c.saveDiff(fmt.Sprintf("%d", aCount), fmt.Sprintf("%d", bCount))
   490  			} else {
   491  				c.saveDiff(fmt.Sprintf("%d", bCount), fmt.Sprintf("%d", aCount))
   492  			}
   493  			c.pop()
   494  		}
   495  		delete(am, v)
   496  		delete(bm, v)
   497  	}
   498  }
   499  
   500  func logError(err error) {
   501  	if LogErrors {
   502  		log.Println(err)
   503  	}
   504  }
   505  

View as plain text