...

Source file src/gotest.tools/v3/internal/source/source.go

Documentation: gotest.tools/v3/internal/source

     1  // Package source provides utilities for handling source-code.
     2  package source // import "gotest.tools/v3/internal/source"
     3  
     4  import (
     5  	"bytes"
     6  	"errors"
     7  	"fmt"
     8  	"go/ast"
     9  	"go/format"
    10  	"go/parser"
    11  	"go/token"
    12  	"os"
    13  	"path/filepath"
    14  	"runtime"
    15  )
    16  
    17  // FormattedCallExprArg returns the argument from an ast.CallExpr at the
    18  // index in the call stack. The argument is formatted using FormatNode.
    19  func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
    20  	args, err := CallExprArgs(stackIndex + 1)
    21  	if err != nil {
    22  		return "", err
    23  	}
    24  	if argPos >= len(args) {
    25  		return "", errors.New("failed to find expression")
    26  	}
    27  	return FormatNode(args[argPos])
    28  }
    29  
    30  // CallExprArgs returns the ast.Expr slice for the args of an ast.CallExpr at
    31  // the index in the call stack.
    32  func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
    33  	_, filename, line, ok := runtime.Caller(stackIndex + 1)
    34  	if !ok {
    35  		return nil, errors.New("failed to get call stack")
    36  	}
    37  	debug("call stack position: %s:%d", filename, line)
    38  
    39  	// Normally, `go` will compile programs with absolute paths in
    40  	// the debug metadata. However, in the name of reproducibility,
    41  	// Bazel uses a compilation strategy that results in relative paths
    42  	// (otherwise, since Bazel uses a random tmp dir for compile and sandboxing,
    43  	// the resulting binaries would change across compiles/test runs).
    44  	if inBazelTest && !filepath.IsAbs(filename) {
    45  		var err error
    46  		filename, err = bazelSourcePath(filename)
    47  		if err != nil {
    48  			return nil, err
    49  		}
    50  	}
    51  
    52  	fileset := token.NewFileSet()
    53  	astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
    54  	if err != nil {
    55  		return nil, fmt.Errorf("failed to parse source file %s: %w", filename, err)
    56  	}
    57  
    58  	expr, err := getCallExprArgs(fileset, astFile, line)
    59  	if err != nil {
    60  		return nil, fmt.Errorf("call from %s:%d: %w", filename, line, err)
    61  	}
    62  	return expr, nil
    63  }
    64  
    65  func getNodeAtLine(fileset *token.FileSet, astFile ast.Node, lineNum int) (ast.Node, error) {
    66  	if node := scanToLine(fileset, astFile, lineNum); node != nil {
    67  		return node, nil
    68  	}
    69  	if node := scanToDeferLine(fileset, astFile, lineNum); node != nil {
    70  		node, err := guessDefer(node)
    71  		if err != nil || node != nil {
    72  			return node, err
    73  		}
    74  	}
    75  	return nil, errors.New("failed to find expression")
    76  }
    77  
    78  func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
    79  	var matchedNode ast.Node
    80  	ast.Inspect(node, func(node ast.Node) bool {
    81  		switch {
    82  		case node == nil || matchedNode != nil:
    83  			return false
    84  		case fileset.Position(node.Pos()).Line == lineNum:
    85  			matchedNode = node
    86  			return false
    87  		}
    88  		return true
    89  	})
    90  	return matchedNode
    91  }
    92  
    93  func getCallExprArgs(fileset *token.FileSet, astFile ast.Node, line int) ([]ast.Expr, error) {
    94  	node, err := getNodeAtLine(fileset, astFile, line)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	debug("found node: %s", debugFormatNode{node})
   100  
   101  	visitor := &callExprVisitor{}
   102  	ast.Walk(visitor, node)
   103  	if visitor.expr == nil {
   104  		return nil, errors.New("failed to find an expression")
   105  	}
   106  	debug("callExpr: %s", debugFormatNode{visitor.expr})
   107  	return visitor.expr.Args, nil
   108  }
   109  
   110  type callExprVisitor struct {
   111  	expr *ast.CallExpr
   112  }
   113  
   114  func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
   115  	if v.expr != nil || node == nil {
   116  		return nil
   117  	}
   118  	debug("visit: %s", debugFormatNode{node})
   119  
   120  	switch typed := node.(type) {
   121  	case *ast.CallExpr:
   122  		v.expr = typed
   123  		return nil
   124  	case *ast.DeferStmt:
   125  		ast.Walk(v, typed.Call.Fun)
   126  		return nil
   127  	}
   128  	return v
   129  }
   130  
   131  // FormatNode using go/format.Node and return the result as a string
   132  func FormatNode(node ast.Node) (string, error) {
   133  	buf := new(bytes.Buffer)
   134  	err := format.Node(buf, token.NewFileSet(), node)
   135  	return buf.String(), err
   136  }
   137  
   138  var debugEnabled = os.Getenv("GOTESTTOOLS_DEBUG") != ""
   139  
   140  func debug(format string, args ...interface{}) {
   141  	if debugEnabled {
   142  		fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
   143  	}
   144  }
   145  
   146  type debugFormatNode struct {
   147  	ast.Node
   148  }
   149  
   150  func (n debugFormatNode) String() string {
   151  	if n.Node == nil {
   152  		return "none"
   153  	}
   154  	out, err := FormatNode(n.Node)
   155  	if err != nil {
   156  		return fmt.Sprintf("failed to format %s: %s", n.Node, err)
   157  	}
   158  	return fmt.Sprintf("(%T) %s", n.Node, out)
   159  }
   160  

View as plain text