1
2
3
4
5 package testinggoroutine
6
7 import (
8 _ "embed"
9 "fmt"
10 "go/ast"
11 "go/token"
12 "go/types"
13
14 "golang.org/x/tools/go/analysis"
15 "golang.org/x/tools/go/analysis/passes/inspect"
16 "golang.org/x/tools/go/analysis/passes/internal/analysisutil"
17 "golang.org/x/tools/go/ast/astutil"
18 "golang.org/x/tools/go/ast/inspector"
19 "golang.org/x/tools/go/types/typeutil"
20 "golang.org/x/tools/internal/aliases"
21 )
22
23
24 var doc string
25
26 var reportSubtest bool
27
28 func init() {
29 Analyzer.Flags.BoolVar(&reportSubtest, "subtest", false, "whether to check if t.Run subtest is terminated correctly; experimental")
30 }
31
32 var Analyzer = &analysis.Analyzer{
33 Name: "testinggoroutine",
34 Doc: analysisutil.MustExtractDoc(doc, "testinggoroutine"),
35 URL: "https://pkg.go.dev/golang.org/x/tools/go/analysis/passes/testinggoroutine",
36 Requires: []*analysis.Analyzer{inspect.Analyzer},
37 Run: run,
38 }
39
40 func run(pass *analysis.Pass) (interface{}, error) {
41 inspect := pass.ResultOf[inspect.Analyzer].(*inspector.Inspector)
42
43 if !analysisutil.Imports(pass.Pkg, "testing") {
44 return nil, nil
45 }
46
47 toDecl := localFunctionDecls(pass.TypesInfo, pass.Files)
48
49
50
51
52
53 asyncs := make(map[ast.Node][]*asyncCall)
54 var regions []ast.Node
55 addCall := func(c *asyncCall) {
56 if c != nil {
57 r := c.region
58 if asyncs[r] == nil {
59 regions = append(regions, r)
60 }
61 asyncs[r] = append(asyncs[r], c)
62 }
63 }
64
65
66 inspect.Nodes([]ast.Node{
67 (*ast.FuncDecl)(nil),
68 (*ast.GoStmt)(nil),
69 (*ast.CallExpr)(nil),
70 }, func(node ast.Node, push bool) bool {
71 if !push {
72 return false
73 }
74 switch node := node.(type) {
75 case *ast.FuncDecl:
76 return hasBenchmarkOrTestParams(node)
77
78 case *ast.GoStmt:
79 c := goAsyncCall(pass.TypesInfo, node, toDecl)
80 addCall(c)
81
82 case *ast.CallExpr:
83 c := tRunAsyncCall(pass.TypesInfo, node)
84 addCall(c)
85 }
86 return true
87 })
88
89
90
91
92
93 for _, region := range regions {
94 ast.Inspect(region, func(n ast.Node) bool {
95 if n == region {
96 return true
97 } else if asyncs[n] != nil {
98 return false
99 }
100
101 call, ok := n.(*ast.CallExpr)
102 if !ok {
103 return true
104 }
105 x, sel, fn := forbiddenMethod(pass.TypesInfo, call)
106 if x == nil {
107 return true
108 }
109
110 for _, e := range asyncs[region] {
111 if !withinScope(e.scope, x) {
112 forbidden := formatMethod(sel, fn)
113
114 var context string
115 var where analysis.Range = e.async
116 if _, local := e.fun.(*ast.FuncLit); local {
117 where = call
118 } else if id, ok := e.fun.(*ast.Ident); ok {
119 context = fmt.Sprintf(" (%s calls %s)", id.Name, forbidden)
120 }
121 if _, ok := e.async.(*ast.GoStmt); ok {
122 pass.ReportRangef(where, "call to %s from a non-test goroutine%s", forbidden, context)
123 } else if reportSubtest {
124 pass.ReportRangef(where, "call to %s on %s defined outside of the subtest%s", forbidden, x.Name(), context)
125 }
126 }
127 }
128 return true
129 })
130 }
131
132 return nil, nil
133 }
134
135 func hasBenchmarkOrTestParams(fnDecl *ast.FuncDecl) bool {
136
137 params := fnDecl.Type.Params.List
138
139 for _, param := range params {
140 if _, ok := typeIsTestingDotTOrB(param.Type); ok {
141 return true
142 }
143 }
144
145 return false
146 }
147
148 func typeIsTestingDotTOrB(expr ast.Expr) (string, bool) {
149 starExpr, ok := expr.(*ast.StarExpr)
150 if !ok {
151 return "", false
152 }
153 selExpr, ok := starExpr.X.(*ast.SelectorExpr)
154 if !ok {
155 return "", false
156 }
157 varPkg := selExpr.X.(*ast.Ident)
158 if varPkg.Name != "testing" {
159 return "", false
160 }
161
162 varTypeName := selExpr.Sel.Name
163 ok = varTypeName == "B" || varTypeName == "T"
164 return varTypeName, ok
165 }
166
167
168
169
170 type asyncCall struct {
171 region ast.Node
172 async ast.Node
173 scope ast.Node
174 fun ast.Expr
175 }
176
177
178 func withinScope(scope ast.Node, x *types.Var) bool {
179 if scope != nil {
180 return x.Pos() != token.NoPos && scope.Pos() <= x.Pos() && x.Pos() <= scope.End()
181 }
182 return false
183 }
184
185
186 func goAsyncCall(info *types.Info, goStmt *ast.GoStmt, toDecl func(*types.Func) *ast.FuncDecl) *asyncCall {
187 call := goStmt.Call
188
189 fun := astutil.Unparen(call.Fun)
190 if id := funcIdent(fun); id != nil {
191 if lit := funcLitInScope(id); lit != nil {
192 return &asyncCall{region: lit, async: goStmt, scope: nil, fun: fun}
193 }
194 }
195
196 if fn := typeutil.StaticCallee(info, call); fn != nil {
197 if decl := toDecl(fn); decl != nil {
198 return &asyncCall{region: decl, async: goStmt, scope: nil, fun: fun}
199 }
200 }
201
202
203 return &asyncCall{region: goStmt, async: goStmt, scope: nil, fun: fun}
204 }
205
206
207 func tRunAsyncCall(info *types.Info, call *ast.CallExpr) *asyncCall {
208 if len(call.Args) != 2 {
209 return nil
210 }
211 run := typeutil.Callee(info, call)
212 if run, ok := run.(*types.Func); !ok || !isMethodNamed(run, "testing", "Run") {
213 return nil
214 }
215
216 fun := astutil.Unparen(call.Args[1])
217 if lit, ok := fun.(*ast.FuncLit); ok {
218 return &asyncCall{region: lit, async: call, scope: lit, fun: fun}
219 }
220
221 if id := funcIdent(fun); id != nil {
222 if lit := funcLitInScope(id); lit != nil {
223 return &asyncCall{region: lit, async: call, scope: lit, fun: fun}
224 }
225 }
226
227
228
229 return &asyncCall{region: call, async: call, scope: fun, fun: fun}
230 }
231
232 var forbidden = []string{
233 "FailNow",
234 "Fatal",
235 "Fatalf",
236 "Skip",
237 "Skipf",
238 "SkipNow",
239 }
240
241
242
243
244 func forbiddenMethod(info *types.Info, call *ast.CallExpr) (*types.Var, *types.Selection, *types.Func) {
245
246 fun := astutil.Unparen(call.Fun)
247 selExpr, ok := fun.(*ast.SelectorExpr)
248 if !ok {
249 return nil, nil, nil
250 }
251 sel := info.Selections[selExpr]
252 if sel == nil {
253 return nil, nil, nil
254 }
255
256 var x *types.Var
257 if id, ok := astutil.Unparen(selExpr.X).(*ast.Ident); ok {
258 x, _ = info.Uses[id].(*types.Var)
259 }
260 if x == nil {
261 return nil, nil, nil
262 }
263
264 fn, _ := sel.Obj().(*types.Func)
265 if fn == nil || !isMethodNamed(fn, "testing", forbidden...) {
266 return nil, nil, nil
267 }
268 return x, sel, fn
269 }
270
271 func formatMethod(sel *types.Selection, fn *types.Func) string {
272 var ptr string
273 rtype := sel.Recv()
274 if p, ok := aliases.Unalias(rtype).(*types.Pointer); ok {
275 ptr = "*"
276 rtype = p.Elem()
277 }
278 return fmt.Sprintf("(%s%s).%s", ptr, rtype.String(), fn.Name())
279 }
280
View as plain text