...

Source file src/gotest.tools/v3/internal/assert/assert.go

Documentation: gotest.tools/v3/internal/assert

     1  // Package assert provides internal utilties for assertions.
     2  package assert
     3  
     4  import (
     5  	"fmt"
     6  	"go/ast"
     7  	"go/token"
     8  	"reflect"
     9  
    10  	"gotest.tools/v3/assert/cmp"
    11  	"gotest.tools/v3/internal/format"
    12  	"gotest.tools/v3/internal/source"
    13  )
    14  
    15  // LogT is the subset of testing.T used by the assert package.
    16  type LogT interface {
    17  	Log(args ...interface{})
    18  }
    19  
    20  type helperT interface {
    21  	Helper()
    22  }
    23  
    24  const failureMessage = "assertion failed: "
    25  
    26  // Eval the comparison and print a failure messages if the comparison has failed.
    27  func Eval(
    28  	t LogT,
    29  	argSelector argSelector,
    30  	comparison interface{},
    31  	msgAndArgs ...interface{},
    32  ) bool {
    33  	if ht, ok := t.(helperT); ok {
    34  		ht.Helper()
    35  	}
    36  	var success bool
    37  	switch check := comparison.(type) {
    38  	case bool:
    39  		if check {
    40  			return true
    41  		}
    42  		logFailureFromBool(t, msgAndArgs...)
    43  
    44  	// Undocumented legacy comparison without Result type
    45  	case func() (success bool, message string):
    46  		success = runCompareFunc(t, check, msgAndArgs...)
    47  
    48  	case nil:
    49  		return true
    50  
    51  	case error:
    52  		msg := failureMsgFromError(check)
    53  		t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
    54  
    55  	case cmp.Comparison:
    56  		success = RunComparison(t, argSelector, check, msgAndArgs...)
    57  
    58  	case func() cmp.Result:
    59  		success = RunComparison(t, argSelector, check, msgAndArgs...)
    60  
    61  	default:
    62  		t.Log(fmt.Sprintf("invalid Comparison: %v (%T)", check, check))
    63  	}
    64  	return success
    65  }
    66  
    67  func runCompareFunc(
    68  	t LogT,
    69  	f func() (success bool, message string),
    70  	msgAndArgs ...interface{},
    71  ) bool {
    72  	if ht, ok := t.(helperT); ok {
    73  		ht.Helper()
    74  	}
    75  	if success, message := f(); !success {
    76  		t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
    77  		return false
    78  	}
    79  	return true
    80  }
    81  
    82  func logFailureFromBool(t LogT, msgAndArgs ...interface{}) {
    83  	if ht, ok := t.(helperT); ok {
    84  		ht.Helper()
    85  	}
    86  	const stackIndex = 3 // Assert()/Check(), assert(), logFailureFromBool()
    87  	args, err := source.CallExprArgs(stackIndex)
    88  	if err != nil {
    89  		t.Log(err.Error())
    90  	}
    91  
    92  	var msg string
    93  	const comparisonArgIndex = 1 // Assert(t, comparison)
    94  	if len(args) <= comparisonArgIndex {
    95  		msg = "but assert failed to find the expression to print"
    96  	} else {
    97  		msg, err = boolFailureMessage(args[comparisonArgIndex])
    98  		if err != nil {
    99  			t.Log(err.Error())
   100  			msg = "expression is false"
   101  		}
   102  	}
   103  
   104  	t.Log(format.WithCustomMessage(failureMessage+msg, msgAndArgs...))
   105  }
   106  
   107  func failureMsgFromError(err error) string {
   108  	// Handle errors with non-nil types
   109  	v := reflect.ValueOf(err)
   110  	if v.Kind() == reflect.Ptr && v.IsNil() {
   111  		return fmt.Sprintf("error is not nil: error has type %T", err)
   112  	}
   113  	return "error is not nil: " + err.Error()
   114  }
   115  
   116  func boolFailureMessage(expr ast.Expr) (string, error) {
   117  	if binaryExpr, ok := expr.(*ast.BinaryExpr); ok {
   118  		x, err := source.FormatNode(binaryExpr.X)
   119  		if err != nil {
   120  			return "", err
   121  		}
   122  		y, err := source.FormatNode(binaryExpr.Y)
   123  		if err != nil {
   124  			return "", err
   125  		}
   126  
   127  		switch binaryExpr.Op {
   128  		case token.NEQ:
   129  			return x + " is " + y, nil
   130  		case token.EQL:
   131  			return x + " is not " + y, nil
   132  		case token.GTR:
   133  			return x + " is <= " + y, nil
   134  		case token.LSS:
   135  			return x + " is >= " + y, nil
   136  		case token.GEQ:
   137  			return x + " is less than " + y, nil
   138  		case token.LEQ:
   139  			return x + " is greater than " + y, nil
   140  		}
   141  	}
   142  
   143  	if unaryExpr, ok := expr.(*ast.UnaryExpr); ok && unaryExpr.Op == token.NOT {
   144  		x, err := source.FormatNode(unaryExpr.X)
   145  		if err != nil {
   146  			return "", err
   147  		}
   148  		return x + " is true", nil
   149  	}
   150  
   151  	if ident, ok := expr.(*ast.Ident); ok {
   152  		return ident.Name + " is false", nil
   153  	}
   154  
   155  	formatted, err := source.FormatNode(expr)
   156  	if err != nil {
   157  		return "", err
   158  	}
   159  	return "expression is false: " + formatted, nil
   160  }
   161  

View as plain text