...

Source file src/gotest.tools/v3/internal/assert/result.go

Documentation: gotest.tools/v3/internal/assert

     1  package assert
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"go/ast"
     7  
     8  	"gotest.tools/v3/assert/cmp"
     9  	"gotest.tools/v3/internal/format"
    10  	"gotest.tools/v3/internal/source"
    11  )
    12  
    13  // RunComparison and return Comparison.Success. If the comparison fails a messages
    14  // will be printed using t.Log.
    15  func RunComparison(
    16  	t LogT,
    17  	argSelector argSelector,
    18  	f cmp.Comparison,
    19  	msgAndArgs ...interface{},
    20  ) bool {
    21  	if ht, ok := t.(helperT); ok {
    22  		ht.Helper()
    23  	}
    24  	result := f()
    25  	if result.Success() {
    26  		return true
    27  	}
    28  
    29  	if source.IsUpdate() {
    30  		if updater, ok := result.(updateExpected); ok {
    31  			const stackIndex = 3 // Assert/Check, assert, RunComparison
    32  			err := updater.UpdatedExpected(stackIndex)
    33  			switch {
    34  			case err == nil:
    35  				return true
    36  			case errors.Is(err, source.ErrNotFound):
    37  				// do nothing, fallthrough to regular failure message
    38  			default:
    39  				t.Log("failed to update source", err)
    40  				return false
    41  			}
    42  		}
    43  	}
    44  
    45  	var message string
    46  	switch typed := result.(type) {
    47  	case resultWithComparisonArgs:
    48  		const stackIndex = 3 // Assert/Check, assert, RunComparison
    49  		args, err := source.CallExprArgs(stackIndex)
    50  		if err != nil {
    51  			t.Log(err.Error())
    52  		}
    53  		message = typed.FailureMessage(filterPrintableExpr(argSelector(args)))
    54  	case resultBasic:
    55  		message = typed.FailureMessage()
    56  	default:
    57  		message = fmt.Sprintf("comparison returned invalid Result type: %T", result)
    58  	}
    59  
    60  	t.Log(format.WithCustomMessage(failureMessage+message, msgAndArgs...))
    61  	return false
    62  }
    63  
    64  type resultWithComparisonArgs interface {
    65  	FailureMessage(args []ast.Expr) string
    66  }
    67  
    68  type resultBasic interface {
    69  	FailureMessage() string
    70  }
    71  
    72  type updateExpected interface {
    73  	UpdatedExpected(stackIndex int) error
    74  }
    75  
    76  // filterPrintableExpr filters the ast.Expr slice to only include Expr that are
    77  // easy to read when printed and contain relevant information to an assertion.
    78  //
    79  // Ident and SelectorExpr are included because they print nicely and the variable
    80  // names may provide additional context to their values.
    81  // BasicLit and CompositeLit are excluded because their source is equivalent to
    82  // their value, which is already available.
    83  // Other types are ignored for now, but could be added if they are relevant.
    84  func filterPrintableExpr(args []ast.Expr) []ast.Expr {
    85  	result := make([]ast.Expr, len(args))
    86  	for i, arg := range args {
    87  		if isShortPrintableExpr(arg) {
    88  			result[i] = arg
    89  			continue
    90  		}
    91  
    92  		if starExpr, ok := arg.(*ast.StarExpr); ok {
    93  			result[i] = starExpr.X
    94  			continue
    95  		}
    96  	}
    97  	return result
    98  }
    99  
   100  func isShortPrintableExpr(expr ast.Expr) bool {
   101  	switch expr.(type) {
   102  	case *ast.Ident, *ast.SelectorExpr, *ast.IndexExpr, *ast.SliceExpr:
   103  		return true
   104  	case *ast.BinaryExpr, *ast.UnaryExpr:
   105  		return true
   106  	default:
   107  		// CallExpr, ParenExpr, TypeAssertExpr, KeyValueExpr, StarExpr
   108  		return false
   109  	}
   110  }
   111  
   112  type argSelector func([]ast.Expr) []ast.Expr
   113  
   114  // ArgsAfterT selects args starting at position 1. Used when the caller has a
   115  // testing.T as the first argument, and the args to select should follow it.
   116  func ArgsAfterT(args []ast.Expr) []ast.Expr {
   117  	if len(args) < 1 {
   118  		return nil
   119  	}
   120  	return args[1:]
   121  }
   122  
   123  // ArgsFromComparisonCall selects args from the CallExpression at position 1.
   124  // Used when the caller has a testing.T as the first argument, and the args to
   125  // select are passed to the cmp.Comparison at position 1.
   126  func ArgsFromComparisonCall(args []ast.Expr) []ast.Expr {
   127  	if len(args) <= 1 {
   128  		return nil
   129  	}
   130  	if callExpr, ok := args[1].(*ast.CallExpr); ok {
   131  		return callExpr.Args
   132  	}
   133  	return nil
   134  }
   135  
   136  // ArgsAtZeroIndex selects args from the CallExpression at position 1.
   137  // Used when the caller accepts a single cmp.Comparison argument.
   138  func ArgsAtZeroIndex(args []ast.Expr) []ast.Expr {
   139  	if len(args) == 0 {
   140  		return nil
   141  	}
   142  	if callExpr, ok := args[0].(*ast.CallExpr); ok {
   143  		return callExpr.Args
   144  	}
   145  	return nil
   146  }
   147  

View as plain text