...

Source file src/gotest.tools/v3/assert/cmp/compare.go

Documentation: gotest.tools/v3/assert/cmp

     1  /*Package cmp provides Comparisons for Assert and Check*/
     2  package cmp // import "gotest.tools/v3/assert/cmp"
     3  
     4  import (
     5  	"errors"
     6  	"fmt"
     7  	"reflect"
     8  	"regexp"
     9  	"strings"
    10  
    11  	"github.com/google/go-cmp/cmp"
    12  	"gotest.tools/v3/internal/format"
    13  )
    14  
    15  // Comparison is a function which compares values and returns [ResultSuccess] if
    16  // the actual value matches the expected value. If the values do not match the
    17  // [Result] will contain a message about why it failed.
    18  type Comparison func() Result
    19  
    20  // DeepEqual compares two values using [github.com/google/go-cmp/cmp]
    21  // and succeeds if the values are equal.
    22  //
    23  // The comparison can be customized using comparison Options.
    24  // Package [gotest.tools/v3/assert/opt] provides some additional
    25  // commonly used Options.
    26  func DeepEqual(x, y interface{}, opts ...cmp.Option) Comparison {
    27  	return func() (result Result) {
    28  		defer func() {
    29  			if panicmsg, handled := handleCmpPanic(recover()); handled {
    30  				result = ResultFailure(panicmsg)
    31  			}
    32  		}()
    33  		diff := cmp.Diff(x, y, opts...)
    34  		if diff == "" {
    35  			return ResultSuccess
    36  		}
    37  		return multiLineDiffResult(diff, x, y)
    38  	}
    39  }
    40  
    41  func handleCmpPanic(r interface{}) (string, bool) {
    42  	if r == nil {
    43  		return "", false
    44  	}
    45  	panicmsg, ok := r.(string)
    46  	if !ok {
    47  		panic(r)
    48  	}
    49  	switch {
    50  	case strings.HasPrefix(panicmsg, "cannot handle unexported field"):
    51  		return panicmsg, true
    52  	}
    53  	panic(r)
    54  }
    55  
    56  func toResult(success bool, msg string) Result {
    57  	if success {
    58  		return ResultSuccess
    59  	}
    60  	return ResultFailure(msg)
    61  }
    62  
    63  // RegexOrPattern may be either a [*regexp.Regexp] or a string that is a valid
    64  // regexp pattern.
    65  type RegexOrPattern interface{}
    66  
    67  // Regexp succeeds if value v matches regular expression re.
    68  //
    69  // Example:
    70  //
    71  //	assert.Assert(t, cmp.Regexp("^[0-9a-f]{32}$", str))
    72  //	r := regexp.MustCompile("^[0-9a-f]{32}$")
    73  //	assert.Assert(t, cmp.Regexp(r, str))
    74  func Regexp(re RegexOrPattern, v string) Comparison {
    75  	match := func(re *regexp.Regexp) Result {
    76  		return toResult(
    77  			re.MatchString(v),
    78  			fmt.Sprintf("value %q does not match regexp %q", v, re.String()))
    79  	}
    80  
    81  	return func() Result {
    82  		switch regex := re.(type) {
    83  		case *regexp.Regexp:
    84  			return match(regex)
    85  		case string:
    86  			re, err := regexp.Compile(regex)
    87  			if err != nil {
    88  				return ResultFailure(err.Error())
    89  			}
    90  			return match(re)
    91  		default:
    92  			return ResultFailure(fmt.Sprintf("invalid type %T for regex pattern", regex))
    93  		}
    94  	}
    95  }
    96  
    97  // Equal succeeds if x == y. See [gotest.tools/v3/assert.Equal] for full documentation.
    98  func Equal(x, y interface{}) Comparison {
    99  	return func() Result {
   100  		switch {
   101  		case x == y:
   102  			return ResultSuccess
   103  		case isMultiLineStringCompare(x, y):
   104  			diff := format.UnifiedDiff(format.DiffConfig{A: x.(string), B: y.(string)})
   105  			return multiLineDiffResult(diff, x, y)
   106  		}
   107  		return ResultFailureTemplate(`
   108  			{{- printf "%v" .Data.x}} (
   109  				{{- with callArg 0 }}{{ formatNode . }} {{end -}}
   110  				{{- printf "%T" .Data.x -}}
   111  			) != {{ printf "%v" .Data.y}} (
   112  				{{- with callArg 1 }}{{ formatNode . }} {{end -}}
   113  				{{- printf "%T" .Data.y -}}
   114  			)`,
   115  			map[string]interface{}{"x": x, "y": y})
   116  	}
   117  }
   118  
   119  func isMultiLineStringCompare(x, y interface{}) bool {
   120  	strX, ok := x.(string)
   121  	if !ok {
   122  		return false
   123  	}
   124  	strY, ok := y.(string)
   125  	if !ok {
   126  		return false
   127  	}
   128  	return strings.Contains(strX, "\n") || strings.Contains(strY, "\n")
   129  }
   130  
   131  func multiLineDiffResult(diff string, x, y interface{}) Result {
   132  	return ResultFailureTemplate(`
   133  --- {{ with callArg 0 }}{{ formatNode . }}{{else}}←{{end}}
   134  +++ {{ with callArg 1 }}{{ formatNode . }}{{else}}→{{end}}
   135  {{ .Data.diff }}`,
   136  		map[string]interface{}{"diff": diff, "x": x, "y": y})
   137  }
   138  
   139  // Len succeeds if the sequence has the expected length.
   140  func Len(seq interface{}, expected int) Comparison {
   141  	return func() (result Result) {
   142  		defer func() {
   143  			if e := recover(); e != nil {
   144  				result = ResultFailure(fmt.Sprintf("type %T does not have a length", seq))
   145  			}
   146  		}()
   147  		value := reflect.ValueOf(seq)
   148  		length := value.Len()
   149  		if length == expected {
   150  			return ResultSuccess
   151  		}
   152  		msg := fmt.Sprintf("expected %s (length %d) to have length %d", seq, length, expected)
   153  		return ResultFailure(msg)
   154  	}
   155  }
   156  
   157  // Contains succeeds if item is in collection. Collection may be a string, map,
   158  // slice, or array.
   159  //
   160  // If collection is a string, item must also be a string, and is compared using
   161  // [strings.Contains].
   162  // If collection is a Map, contains will succeed if item is a key in the map.
   163  // If collection is a slice or array, item is compared to each item in the
   164  // sequence using [reflect.DeepEqual].
   165  func Contains(collection interface{}, item interface{}) Comparison {
   166  	return func() Result {
   167  		colValue := reflect.ValueOf(collection)
   168  		if !colValue.IsValid() {
   169  			return ResultFailure("nil does not contain items")
   170  		}
   171  		msg := fmt.Sprintf("%v does not contain %v", collection, item)
   172  
   173  		itemValue := reflect.ValueOf(item)
   174  		switch colValue.Type().Kind() {
   175  		case reflect.String:
   176  			if itemValue.Type().Kind() != reflect.String {
   177  				return ResultFailure("string may only contain strings")
   178  			}
   179  			return toResult(
   180  				strings.Contains(colValue.String(), itemValue.String()),
   181  				fmt.Sprintf("string %q does not contain %q", collection, item))
   182  
   183  		case reflect.Map:
   184  			if itemValue.Type() != colValue.Type().Key() {
   185  				return ResultFailure(fmt.Sprintf(
   186  					"%v can not contain a %v key", colValue.Type(), itemValue.Type()))
   187  			}
   188  			return toResult(colValue.MapIndex(itemValue).IsValid(), msg)
   189  
   190  		case reflect.Slice, reflect.Array:
   191  			for i := 0; i < colValue.Len(); i++ {
   192  				if reflect.DeepEqual(colValue.Index(i).Interface(), item) {
   193  					return ResultSuccess
   194  				}
   195  			}
   196  			return ResultFailure(msg)
   197  		default:
   198  			return ResultFailure(fmt.Sprintf("type %T does not contain items", collection))
   199  		}
   200  	}
   201  }
   202  
   203  // Panics succeeds if f() panics.
   204  func Panics(f func()) Comparison {
   205  	return func() (result Result) {
   206  		defer func() {
   207  			if err := recover(); err != nil {
   208  				result = ResultSuccess
   209  			}
   210  		}()
   211  		f()
   212  		return ResultFailure("did not panic")
   213  	}
   214  }
   215  
   216  // Error succeeds if err is a non-nil error, and the error message equals the
   217  // expected message.
   218  func Error(err error, message string) Comparison {
   219  	return func() Result {
   220  		switch {
   221  		case err == nil:
   222  			return ResultFailure("expected an error, got nil")
   223  		case err.Error() != message:
   224  			return ResultFailure(fmt.Sprintf(
   225  				"expected error %q, got %s", message, formatErrorMessage(err)))
   226  		}
   227  		return ResultSuccess
   228  	}
   229  }
   230  
   231  // ErrorContains succeeds if err is a non-nil error, and the error message contains
   232  // the expected substring.
   233  func ErrorContains(err error, substring string) Comparison {
   234  	return func() Result {
   235  		switch {
   236  		case err == nil:
   237  			return ResultFailure("expected an error, got nil")
   238  		case !strings.Contains(err.Error(), substring):
   239  			return ResultFailure(fmt.Sprintf(
   240  				"expected error to contain %q, got %s", substring, formatErrorMessage(err)))
   241  		}
   242  		return ResultSuccess
   243  	}
   244  }
   245  
   246  type causer interface {
   247  	Cause() error
   248  }
   249  
   250  func formatErrorMessage(err error) string {
   251  	//nolint:errorlint,nolintlint // unwrapping is not appropriate here
   252  	if _, ok := err.(causer); ok {
   253  		return fmt.Sprintf("%q\n%+v", err, err)
   254  	}
   255  	// This error was not wrapped with github.com/pkg/errors
   256  	return fmt.Sprintf("%q", err)
   257  }
   258  
   259  // Nil succeeds if obj is a nil interface, pointer, or function.
   260  //
   261  // Use [gotest.tools/v3/assert.NilError] for comparing errors. Use Len(obj, 0) for comparing slices,
   262  // maps, and channels.
   263  func Nil(obj interface{}) Comparison {
   264  	msgFunc := func(value reflect.Value) string {
   265  		return fmt.Sprintf("%v (type %s) is not nil", reflect.Indirect(value), value.Type())
   266  	}
   267  	return isNil(obj, msgFunc)
   268  }
   269  
   270  func isNil(obj interface{}, msgFunc func(reflect.Value) string) Comparison {
   271  	return func() Result {
   272  		if obj == nil {
   273  			return ResultSuccess
   274  		}
   275  		value := reflect.ValueOf(obj)
   276  		kind := value.Type().Kind()
   277  		if kind >= reflect.Chan && kind <= reflect.Slice {
   278  			if value.IsNil() {
   279  				return ResultSuccess
   280  			}
   281  			return ResultFailure(msgFunc(value))
   282  		}
   283  
   284  		return ResultFailure(fmt.Sprintf("%v (type %s) can not be nil", value, value.Type()))
   285  	}
   286  }
   287  
   288  // ErrorType succeeds if err is not nil and is of the expected type.
   289  // New code should use [ErrorIs] instead.
   290  //
   291  // Expected can be one of:
   292  //
   293  //	func(error) bool
   294  //
   295  // Function should return true if the error is the expected type.
   296  //
   297  //	type struct{}, type &struct{}
   298  //
   299  // A struct or a pointer to a struct.
   300  // Fails if the error is not of the same type as expected.
   301  //
   302  //	type &interface{}
   303  //
   304  // A pointer to an interface type.
   305  // Fails if err does not implement the interface.
   306  //
   307  //	reflect.Type
   308  //
   309  // Fails if err does not implement the [reflect.Type].
   310  func ErrorType(err error, expected interface{}) Comparison {
   311  	return func() Result {
   312  		switch expectedType := expected.(type) {
   313  		case func(error) bool:
   314  			return cmpErrorTypeFunc(err, expectedType)
   315  		case reflect.Type:
   316  			if expectedType.Kind() == reflect.Interface {
   317  				return cmpErrorTypeImplementsType(err, expectedType)
   318  			}
   319  			return cmpErrorTypeEqualType(err, expectedType)
   320  		case nil:
   321  			return ResultFailure("invalid type for expected: nil")
   322  		}
   323  
   324  		expectedType := reflect.TypeOf(expected)
   325  		switch {
   326  		case expectedType.Kind() == reflect.Struct, isPtrToStruct(expectedType):
   327  			return cmpErrorTypeEqualType(err, expectedType)
   328  		case isPtrToInterface(expectedType):
   329  			return cmpErrorTypeImplementsType(err, expectedType.Elem())
   330  		}
   331  		return ResultFailure(fmt.Sprintf("invalid type for expected: %T", expected))
   332  	}
   333  }
   334  
   335  func cmpErrorTypeFunc(err error, f func(error) bool) Result {
   336  	if f(err) {
   337  		return ResultSuccess
   338  	}
   339  	actual := "nil"
   340  	if err != nil {
   341  		actual = fmt.Sprintf("%s (%T)", err, err)
   342  	}
   343  	return ResultFailureTemplate(`error is {{ .Data.actual }}
   344  		{{- with callArg 1 }}, not {{ formatNode . }}{{end -}}`,
   345  		map[string]interface{}{"actual": actual})
   346  }
   347  
   348  func cmpErrorTypeEqualType(err error, expectedType reflect.Type) Result {
   349  	if err == nil {
   350  		return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
   351  	}
   352  	errValue := reflect.ValueOf(err)
   353  	if errValue.Type() == expectedType {
   354  		return ResultSuccess
   355  	}
   356  	return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
   357  }
   358  
   359  func cmpErrorTypeImplementsType(err error, expectedType reflect.Type) Result {
   360  	if err == nil {
   361  		return ResultFailure(fmt.Sprintf("error is nil, not %s", expectedType))
   362  	}
   363  	errValue := reflect.ValueOf(err)
   364  	if errValue.Type().Implements(expectedType) {
   365  		return ResultSuccess
   366  	}
   367  	return ResultFailure(fmt.Sprintf("error is %s (%T), not %s", err, err, expectedType))
   368  }
   369  
   370  func isPtrToInterface(typ reflect.Type) bool {
   371  	return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Interface
   372  }
   373  
   374  func isPtrToStruct(typ reflect.Type) bool {
   375  	return typ.Kind() == reflect.Ptr && typ.Elem().Kind() == reflect.Struct
   376  }
   377  
   378  var (
   379  	stdlibErrorNewType = reflect.TypeOf(errors.New(""))
   380  	stdlibFmtErrorType = reflect.TypeOf(fmt.Errorf("%w", fmt.Errorf("")))
   381  )
   382  
   383  // ErrorIs succeeds if errors.Is(actual, expected) returns true. See
   384  // [errors.Is] for accepted argument values.
   385  func ErrorIs(actual error, expected error) Comparison {
   386  	return func() Result {
   387  		if errors.Is(actual, expected) {
   388  			return ResultSuccess
   389  		}
   390  
   391  		// The type of stdlib errors is excluded because the type is not relevant
   392  		// in those cases. The type is only important when it is a user defined
   393  		// custom error type.
   394  		return ResultFailureTemplate(`error is
   395  			{{- if not .Data.a }} nil,{{ else }}
   396  				{{- printf " \"%v\"" .Data.a }}
   397  				{{- if notStdlibErrorType .Data.a }} ({{ printf "%T" .Data.a }}){{ end }},
   398  			{{- end }} not {{ printf "\"%v\"" .Data.x }} (
   399  			{{- with callArg 1 }}{{ formatNode . }}{{ end }}
   400  			{{- if notStdlibErrorType .Data.x }}{{ printf " %T" .Data.x }}{{ end }})`,
   401  			map[string]interface{}{"a": actual, "x": expected})
   402  	}
   403  }
   404  

View as plain text