1
2 package source
3
4 import (
5 "bytes"
6 "errors"
7 "fmt"
8 "go/ast"
9 "go/format"
10 "go/parser"
11 "go/token"
12 "os"
13 "path/filepath"
14 "runtime"
15 )
16
17
18
19 func FormattedCallExprArg(stackIndex int, argPos int) (string, error) {
20 args, err := CallExprArgs(stackIndex + 1)
21 if err != nil {
22 return "", err
23 }
24 if argPos >= len(args) {
25 return "", errors.New("failed to find expression")
26 }
27 return FormatNode(args[argPos])
28 }
29
30
31
32 func CallExprArgs(stackIndex int) ([]ast.Expr, error) {
33 _, filename, line, ok := runtime.Caller(stackIndex + 1)
34 if !ok {
35 return nil, errors.New("failed to get call stack")
36 }
37 debug("call stack position: %s:%d", filename, line)
38
39
40
41
42
43
44 if inBazelTest && !filepath.IsAbs(filename) {
45 var err error
46 filename, err = bazelSourcePath(filename)
47 if err != nil {
48 return nil, err
49 }
50 }
51
52 fileset := token.NewFileSet()
53 astFile, err := parser.ParseFile(fileset, filename, nil, parser.AllErrors)
54 if err != nil {
55 return nil, fmt.Errorf("failed to parse source file %s: %w", filename, err)
56 }
57
58 expr, err := getCallExprArgs(fileset, astFile, line)
59 if err != nil {
60 return nil, fmt.Errorf("call from %s:%d: %w", filename, line, err)
61 }
62 return expr, nil
63 }
64
65 func getNodeAtLine(fileset *token.FileSet, astFile ast.Node, lineNum int) (ast.Node, error) {
66 if node := scanToLine(fileset, astFile, lineNum); node != nil {
67 return node, nil
68 }
69 if node := scanToDeferLine(fileset, astFile, lineNum); node != nil {
70 node, err := guessDefer(node)
71 if err != nil || node != nil {
72 return node, err
73 }
74 }
75 return nil, errors.New("failed to find expression")
76 }
77
78 func scanToLine(fileset *token.FileSet, node ast.Node, lineNum int) ast.Node {
79 var matchedNode ast.Node
80 ast.Inspect(node, func(node ast.Node) bool {
81 switch {
82 case node == nil || matchedNode != nil:
83 return false
84 case fileset.Position(node.Pos()).Line == lineNum:
85 matchedNode = node
86 return false
87 }
88 return true
89 })
90 return matchedNode
91 }
92
93 func getCallExprArgs(fileset *token.FileSet, astFile ast.Node, line int) ([]ast.Expr, error) {
94 node, err := getNodeAtLine(fileset, astFile, line)
95 if err != nil {
96 return nil, err
97 }
98
99 debug("found node: %s", debugFormatNode{node})
100
101 visitor := &callExprVisitor{}
102 ast.Walk(visitor, node)
103 if visitor.expr == nil {
104 return nil, errors.New("failed to find an expression")
105 }
106 debug("callExpr: %s", debugFormatNode{visitor.expr})
107 return visitor.expr.Args, nil
108 }
109
110 type callExprVisitor struct {
111 expr *ast.CallExpr
112 }
113
114 func (v *callExprVisitor) Visit(node ast.Node) ast.Visitor {
115 if v.expr != nil || node == nil {
116 return nil
117 }
118 debug("visit: %s", debugFormatNode{node})
119
120 switch typed := node.(type) {
121 case *ast.CallExpr:
122 v.expr = typed
123 return nil
124 case *ast.DeferStmt:
125 ast.Walk(v, typed.Call.Fun)
126 return nil
127 }
128 return v
129 }
130
131
132 func FormatNode(node ast.Node) (string, error) {
133 buf := new(bytes.Buffer)
134 err := format.Node(buf, token.NewFileSet(), node)
135 return buf.String(), err
136 }
137
138 var debugEnabled = os.Getenv("GOTESTTOOLS_DEBUG") != ""
139
140 func debug(format string, args ...interface{}) {
141 if debugEnabled {
142 fmt.Fprintf(os.Stderr, "DEBUG: "+format+"\n", args...)
143 }
144 }
145
146 type debugFormatNode struct {
147 ast.Node
148 }
149
150 func (n debugFormatNode) String() string {
151 if n.Node == nil {
152 return "none"
153 }
154 out, err := FormatNode(n.Node)
155 if err != nil {
156 return fmt.Sprintf("failed to format %s: %s", n.Node, err)
157 }
158 return fmt.Sprintf("(%T) %s", n.Node, out)
159 }
160
View as plain text