1
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31 package main
32
33 import (
34 "flag"
35 "fmt"
36 "go/ast"
37 "go/build"
38 "go/doc"
39 "go/parser"
40 "go/token"
41 "os"
42 "sort"
43 "strings"
44 "text/template"
45 )
46
47 type Import struct {
48 Name string
49 Path string
50 }
51
52 type TestCase struct {
53 Package string
54 Name string
55 }
56
57 type Example struct {
58 Package string
59 Name string
60 Output string
61 Unordered bool
62 }
63
64
65 type Cases struct {
66 Imports []*Import
67 Tests []TestCase
68 Benchmarks []TestCase
69 FuzzTargets []TestCase
70 Examples []Example
71 TestMain string
72 CoverMode string
73 CoverFormat string
74 Pkgname string
75 }
76
77
78 func (c *Cases) Version(v string) bool {
79 for _, r := range build.Default.ReleaseTags {
80 if v == r {
81 return true
82 }
83 }
84 return false
85 }
86
87 const testMainTpl = `
88 package main
89
90 // bzltestutil may change the current directory in its init function to emulate
91 // 'go test' behavior. It must be initialized before user packages.
92 // In Go 1.20 and earlier, this import declaration must appear before
93 // imports of user packages. See comment in bzltestutil/init.go.
94 import "github.com/bazelbuild/rules_go/go/tools/bzltestutil"
95
96 import (
97 "flag"
98 "log"
99 "os"
100 "os/exec"
101 {{if .TestMain}}
102 "reflect"
103 {{end}}
104 "strconv"
105 "strings"
106 "testing"
107 "testing/internal/testdeps"
108
109 {{if ne .CoverMode ""}}
110 "github.com/bazelbuild/rules_go/go/tools/coverdata"
111 {{end}}
112
113 {{range $p := .Imports}}
114 {{$p.Name}} "{{$p.Path}}"
115 {{end}}
116 )
117
118 var allTests = []testing.InternalTest{
119 {{range .Tests}}
120 {"{{.Name}}", {{.Package}}.{{.Name}} },
121 {{end}}
122 }
123
124 var benchmarks = []testing.InternalBenchmark{
125 {{range .Benchmarks}}
126 {"{{.Name}}", {{.Package}}.{{.Name}} },
127 {{end}}
128 }
129
130 {{if .Version "go1.18"}}
131 var fuzzTargets = []testing.InternalFuzzTarget{
132 {{range .FuzzTargets}}
133 {"{{.Name}}", {{.Package}}.{{.Name}} },
134 {{end}}
135 }
136 {{end}}
137
138 var examples = []testing.InternalExample{
139 {{range .Examples}}
140 {Name: "{{.Name}}", F: {{.Package}}.{{.Name}}, Output: {{printf "%q" .Output}}, Unordered: {{.Unordered}} },
141 {{end}}
142 }
143
144 func testsInShard() []testing.InternalTest {
145 totalShards, err := strconv.Atoi(os.Getenv("TEST_TOTAL_SHARDS"))
146 if err != nil || totalShards <= 1 {
147 return allTests
148 }
149 file, err := os.Create(os.Getenv("TEST_SHARD_STATUS_FILE"))
150 if err != nil {
151 log.Fatalf("Failed to touch TEST_SHARD_STATUS_FILE: %v", err)
152 }
153 _ = file.Close()
154 shardIndex, err := strconv.Atoi(os.Getenv("TEST_SHARD_INDEX"))
155 if err != nil || shardIndex < 0 {
156 return allTests
157 }
158 tests := []testing.InternalTest{}
159 for i, t := range allTests {
160 if i % totalShards == shardIndex {
161 tests = append(tests, t)
162 }
163 }
164 return tests
165 }
166
167 func main() {
168 if bzltestutil.ShouldWrap() {
169 err := bzltestutil.Wrap("{{.Pkgname}}")
170 if xerr, ok := err.(*exec.ExitError); ok {
171 os.Exit(xerr.ExitCode())
172 } else if err != nil {
173 log.Print(err)
174 os.Exit(bzltestutil.TestWrapperAbnormalExit)
175 } else {
176 os.Exit(0)
177 }
178 }
179
180 testDeps :=
181 {{if eq .CoverFormat "lcov"}}
182 bzltestutil.LcovTestDeps{TestDeps: testdeps.TestDeps{}}
183 {{else}}
184 testdeps.TestDeps{}
185 {{end}}
186 {{if .Version "go1.18"}}
187 m := testing.MainStart(testDeps, testsInShard(), benchmarks, fuzzTargets, examples)
188 {{else}}
189 m := testing.MainStart(testDeps, testsInShard(), benchmarks, examples)
190 {{end}}
191
192 if filter := os.Getenv("TESTBRIDGE_TEST_ONLY"); filter != "" {
193 filters := strings.Split(filter, ",")
194 var runTests []string
195 var skipTests []string
196
197 for _, f := range filters {
198 if strings.HasPrefix(f, "-") {
199 skipTests = append(skipTests, f[1:])
200 } else {
201 runTests = append(runTests, f)
202 }
203 }
204 if len(runTests) > 0 {
205 flag.Lookup("test.run").Value.Set(strings.Join(runTests, "|"))
206 }
207 if len(skipTests) > 0 {
208 flag.Lookup("test.skip").Value.Set(strings.Join(skipTests, "|"))
209 }
210 }
211
212 if failfast := os.Getenv("TESTBRIDGE_TEST_RUNNER_FAIL_FAST"); failfast != "" {
213 flag.Lookup("test.failfast").Value.Set("true")
214 }
215 {{if eq .CoverFormat "lcov"}}
216 panicOnExit0Flag := flag.Lookup("test.paniconexit0").Value
217 testDeps.OriginalPanicOnExit = panicOnExit0Flag.(flag.Getter).Get().(bool)
218 // Setting this flag provides a way to run hooks right before testing.M.Run() returns.
219 panicOnExit0Flag.Set("true")
220 {{end}}
221 {{if ne .CoverMode ""}}
222 if len(coverdata.Counters) > 0 {
223 testing.RegisterCover(testing.Cover{
224 Mode: "{{ .CoverMode }}",
225 Counters: coverdata.Counters,
226 Blocks: coverdata.Blocks,
227 })
228
229 if coverageDat, ok := os.LookupEnv("COVERAGE_OUTPUT_FILE"); ok {
230 {{if eq .CoverFormat "lcov"}}
231 flag.Lookup("test.coverprofile").Value.Set(coverageDat+".cover")
232 {{else}}
233 flag.Lookup("test.coverprofile").Value.Set(coverageDat)
234 {{end}}
235 }
236 }
237 {{end}}
238 bzltestutil.RegisterTimeoutHandler()
239 {{if not .TestMain}}
240 res := m.Run()
241 {{else}}
242 {{.TestMain}}(m)
243 {{/* See golang.org/issue/34129 and golang.org/cl/219639 */}}
244 res := int(reflect.ValueOf(m).Elem().FieldByName("exitCode").Int())
245 {{end}}
246 os.Exit(res)
247 }
248 `
249
250 func genTestMain(args []string) error {
251
252 args, _, err := expandParamsFiles(args)
253 if err != nil {
254 return err
255 }
256 imports := multiFlag{}
257 sources := multiFlag{}
258 flags := flag.NewFlagSet("GoTestGenTest", flag.ExitOnError)
259 goenv := envFlags(flags)
260 out := flags.String("output", "", "output file to write. Defaults to stdout.")
261 coverMode := flags.String("cover_mode", "", "the coverage mode to use")
262 coverFormat := flags.String("cover_format", "", "the coverage report type to generate (go_cover or lcov)")
263 pkgname := flags.String("pkgname", "", "package name of test")
264 flags.Var(&imports, "import", "Packages to import")
265 flags.Var(&sources, "src", "Sources to process for tests")
266 if err := flags.Parse(args); err != nil {
267 return err
268 }
269 if err := goenv.checkFlags(); err != nil {
270 return err
271 }
272
273 importMap := map[string]*Import{}
274 for _, imp := range imports {
275 parts := strings.Split(imp, "=")
276 if len(parts) != 2 {
277 return fmt.Errorf("Invalid import %q specified", imp)
278 }
279 i := &Import{Name: parts[0], Path: parts[1]}
280 importMap[i.Name] = i
281 }
282
283 sourceList := []string{}
284 sourceMap := map[string]string{}
285 for _, s := range sources {
286 parts := strings.Split(s, "=")
287 if len(parts) != 2 {
288 return fmt.Errorf("Invalid source %q specified", s)
289 }
290 sourceList = append(sourceList, parts[1])
291 sourceMap[parts[1]] = parts[0]
292 }
293
294
295 filteredSrcs, err := filterAndSplitFiles(sourceList)
296 if err != nil {
297 return err
298 }
299 goSrcs := filteredSrcs.goSrcs
300
301 outFile := os.Stdout
302 if *out != "" {
303 var err error
304 outFile, err = os.Create(*out)
305 if err != nil {
306 return fmt.Errorf("os.Create(%q): %v", *out, err)
307 }
308 defer outFile.Close()
309 }
310
311 cases := Cases{
312 CoverFormat: *coverFormat,
313 CoverMode: *coverMode,
314 Pkgname: *pkgname,
315 }
316
317 testFileSet := token.NewFileSet()
318 pkgs := map[string]bool{}
319 for _, f := range goSrcs {
320 parse, err := parser.ParseFile(testFileSet, f.filename, nil, parser.ParseComments)
321 if err != nil {
322 return fmt.Errorf("ParseFile(%q): %v", f.filename, err)
323 }
324 pkg := sourceMap[f.filename]
325 if strings.HasSuffix(parse.Name.String(), "_test") {
326 pkg += "_test"
327 }
328 for _, e := range doc.Examples(parse) {
329 if e.Output == "" && !e.EmptyOutput {
330 continue
331 }
332 cases.Examples = append(cases.Examples, Example{
333 Name: "Example" + e.Name,
334 Package: pkg,
335 Output: e.Output,
336 Unordered: e.Unordered,
337 })
338 pkgs[pkg] = true
339 }
340 for _, d := range parse.Decls {
341 fn, ok := d.(*ast.FuncDecl)
342 if !ok {
343 continue
344 }
345 if fn.Recv != nil {
346 continue
347 }
348 if fn.Name.Name == "TestMain" {
349
350 pkgs[pkg] = true
351 cases.TestMain = fmt.Sprintf("%s.%s", pkg, fn.Name.Name)
352 continue
353 }
354
355
356
357
358
359 if len(fn.Type.Params.List) != 1 {
360 continue
361 }
362
363
364 if fn.Type.Results != nil {
365 continue
366 }
367
368
369
370 starExpr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
371 if !ok {
372 continue
373 }
374 selExpr, ok := starExpr.X.(*ast.SelectorExpr)
375 if !ok {
376 continue
377 }
378
379
380
381
382
383
384 if strings.HasPrefix(fn.Name.Name, "Test") {
385 if selExpr.Sel.Name != "T" {
386 continue
387 }
388 pkgs[pkg] = true
389 cases.Tests = append(cases.Tests, TestCase{
390 Package: pkg,
391 Name: fn.Name.Name,
392 })
393 }
394 if strings.HasPrefix(fn.Name.Name, "Benchmark") {
395 if selExpr.Sel.Name != "B" {
396 continue
397 }
398 pkgs[pkg] = true
399 cases.Benchmarks = append(cases.Benchmarks, TestCase{
400 Package: pkg,
401 Name: fn.Name.Name,
402 })
403 }
404 if strings.HasPrefix(fn.Name.Name, "Fuzz") {
405 if selExpr.Sel.Name != "F" {
406 continue
407 }
408 pkgs[pkg] = true
409 cases.FuzzTargets = append(cases.FuzzTargets, TestCase{
410 Package: pkg,
411 Name: fn.Name.Name,
412 })
413 }
414 }
415 }
416
417 for name := range importMap {
418
419 if !pkgs[name] {
420 importMap[name].Name = "_"
421 }
422 cases.Imports = append(cases.Imports, importMap[name])
423 }
424 sort.Slice(cases.Imports, func(i, j int) bool {
425 return cases.Imports[i].Name < cases.Imports[j].Name
426 })
427 tpl := template.Must(template.New("source").Parse(testMainTpl))
428 if err := tpl.Execute(outFile, &cases); err != nil {
429 return fmt.Errorf("template.Execute(%v): %v", cases, err)
430 }
431 return nil
432 }
433
View as plain text