...

Source file src/github.com/bazelbuild/rules_go/go/tools/builders/generate_test_main.go

Documentation: github.com/bazelbuild/rules_go/go/tools/builders

     1  /* Copyright 2016 The Bazel Authors. All rights reserved.
     2  
     3  Licensed under the Apache License, Version 2.0 (the "License");
     4  you may not use this file except in compliance with the License.
     5  You may obtain a copy of the License at
     6  
     7     http://www.apache.org/licenses/LICENSE-2.0
     8  
     9  Unless required by applicable law or agreed to in writing, software
    10  distributed under the License is distributed on an "AS IS" BASIS,
    11  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  See the License for the specific language governing permissions and
    13  limitations under the License.
    14  */
    15  
    16  // Go testing support for Bazel.
    17  //
    18  // A Go test comprises three packages:
    19  //
    20  // 1. An internal test package, compiled from the sources of the library being
    21  //    tested and any _test.go files with the same package name.
    22  // 2. An external test package, compiled from _test.go files with a package
    23  //    name ending with "_test".
    24  // 3. A generated main package that imports both packages and initializes the
    25  //    test framework with a list of tests, benchmarks, examples, and fuzz
    26  //    targets read from source files.
    27  //
    28  // This action generates the source code for (3). The equivalent code for
    29  // 'go test' is in $GOROOT/src/cmd/go/internal/load/test.go.
    30  
    31  package main
    32  
    33  import (
    34  	"flag"
    35  	"fmt"
    36  	"go/ast"
    37  	"go/build"
    38  	"go/doc"
    39  	"go/parser"
    40  	"go/token"
    41  	"os"
    42  	"sort"
    43  	"strings"
    44  	"text/template"
    45  )
    46  
    47  type Import struct {
    48  	Name string
    49  	Path string
    50  }
    51  
    52  type TestCase struct {
    53  	Package string
    54  	Name    string
    55  }
    56  
    57  type Example struct {
    58  	Package   string
    59  	Name      string
    60  	Output    string
    61  	Unordered bool
    62  }
    63  
    64  // Cases holds template data.
    65  type Cases struct {
    66  	Imports     []*Import
    67  	Tests       []TestCase
    68  	Benchmarks  []TestCase
    69  	FuzzTargets []TestCase
    70  	Examples    []Example
    71  	TestMain    string
    72  	CoverMode   string
    73  	CoverFormat string
    74  	Pkgname     string
    75  }
    76  
    77  // Version returns whether v is a supported Go version (like "go1.18").
    78  func (c *Cases) Version(v string) bool {
    79  	for _, r := range build.Default.ReleaseTags {
    80  		if v == r {
    81  			return true
    82  		}
    83  	}
    84  	return false
    85  }
    86  
    87  const testMainTpl = `
    88  package main
    89  
    90  // bzltestutil may change the current directory in its init function to emulate
    91  // 'go test' behavior. It must be initialized before user packages.
    92  // In Go 1.20 and earlier, this import declaration must appear before
    93  // imports of user packages. See comment in bzltestutil/init.go.
    94  import "github.com/bazelbuild/rules_go/go/tools/bzltestutil"
    95  
    96  import (
    97  	"flag"
    98  	"log"
    99  	"os"
   100  	"os/exec"
   101  {{if .TestMain}}
   102  	"reflect"
   103  {{end}}
   104  	"strconv"
   105  	"strings"
   106  	"testing"
   107  	"testing/internal/testdeps"
   108  
   109  {{if ne .CoverMode ""}}
   110  	"github.com/bazelbuild/rules_go/go/tools/coverdata"
   111  {{end}}
   112  
   113  {{range $p := .Imports}}
   114  	{{$p.Name}} "{{$p.Path}}"
   115  {{end}}
   116  )
   117  
   118  var allTests = []testing.InternalTest{
   119  {{range .Tests}}
   120  	{"{{.Name}}", {{.Package}}.{{.Name}} },
   121  {{end}}
   122  }
   123  
   124  var benchmarks = []testing.InternalBenchmark{
   125  {{range .Benchmarks}}
   126  	{"{{.Name}}", {{.Package}}.{{.Name}} },
   127  {{end}}
   128  }
   129  
   130  {{if .Version "go1.18"}}
   131  var fuzzTargets = []testing.InternalFuzzTarget{
   132  {{range .FuzzTargets}}
   133    {"{{.Name}}", {{.Package}}.{{.Name}} },
   134  {{end}}
   135  }
   136  {{end}}
   137  
   138  var examples = []testing.InternalExample{
   139  {{range .Examples}}
   140  	{Name: "{{.Name}}", F: {{.Package}}.{{.Name}}, Output: {{printf "%q" .Output}}, Unordered: {{.Unordered}} },
   141  {{end}}
   142  }
   143  
   144  func testsInShard() []testing.InternalTest {
   145  	totalShards, err := strconv.Atoi(os.Getenv("TEST_TOTAL_SHARDS"))
   146  	if err != nil || totalShards <= 1 {
   147  		return allTests
   148  	}
   149  	file, err := os.Create(os.Getenv("TEST_SHARD_STATUS_FILE"))
   150  	if err != nil {
   151  		log.Fatalf("Failed to touch TEST_SHARD_STATUS_FILE: %v", err)
   152  	}
   153  	_ = file.Close()
   154  	shardIndex, err := strconv.Atoi(os.Getenv("TEST_SHARD_INDEX"))
   155  	if err != nil || shardIndex < 0 {
   156  		return allTests
   157  	}
   158  	tests := []testing.InternalTest{}
   159  	for i, t := range allTests {
   160  		if i % totalShards == shardIndex {
   161  			tests = append(tests, t)
   162  		}
   163  	}
   164  	return tests
   165  }
   166  
   167  func main() {
   168  	if bzltestutil.ShouldWrap() {
   169  		err := bzltestutil.Wrap("{{.Pkgname}}")
   170  		if xerr, ok := err.(*exec.ExitError); ok {
   171  			os.Exit(xerr.ExitCode())
   172  		} else if err != nil {
   173  			log.Print(err)
   174  			os.Exit(bzltestutil.TestWrapperAbnormalExit)
   175  		} else {
   176  			os.Exit(0)
   177  		}
   178  	}
   179  
   180  	testDeps :=
   181    {{if eq .CoverFormat "lcov"}}
   182  		bzltestutil.LcovTestDeps{TestDeps: testdeps.TestDeps{}}
   183    {{else}}
   184  		testdeps.TestDeps{}
   185    {{end}}
   186    {{if .Version "go1.18"}}
   187  	m := testing.MainStart(testDeps, testsInShard(), benchmarks, fuzzTargets, examples)
   188    {{else}}
   189  	m := testing.MainStart(testDeps, testsInShard(), benchmarks, examples)
   190    {{end}}
   191  
   192  	if filter := os.Getenv("TESTBRIDGE_TEST_ONLY"); filter != "" {
   193  		filters := strings.Split(filter, ",")
   194  		var runTests []string
   195  		var skipTests []string
   196  
   197  		for _, f := range filters {
   198  			if strings.HasPrefix(f, "-") {
   199  				skipTests = append(skipTests, f[1:])
   200  			} else {
   201  				runTests = append(runTests, f)
   202  			}
   203  		}
   204  		if len(runTests) > 0 {
   205  			flag.Lookup("test.run").Value.Set(strings.Join(runTests, "|"))
   206  		}
   207  		if len(skipTests) > 0 {
   208  			flag.Lookup("test.skip").Value.Set(strings.Join(skipTests, "|"))
   209  		}
   210  	}
   211  
   212  	if failfast := os.Getenv("TESTBRIDGE_TEST_RUNNER_FAIL_FAST"); failfast != "" {
   213  		flag.Lookup("test.failfast").Value.Set("true")
   214  	}
   215  {{if eq .CoverFormat "lcov"}}
   216  	panicOnExit0Flag := flag.Lookup("test.paniconexit0").Value
   217  	testDeps.OriginalPanicOnExit = panicOnExit0Flag.(flag.Getter).Get().(bool)
   218  	// Setting this flag provides a way to run hooks right before testing.M.Run() returns.
   219  	panicOnExit0Flag.Set("true")
   220  {{end}}
   221  {{if ne .CoverMode ""}}
   222  	if len(coverdata.Counters) > 0 {
   223  		testing.RegisterCover(testing.Cover{
   224  			Mode: "{{ .CoverMode }}",
   225  			Counters: coverdata.Counters,
   226  			Blocks: coverdata.Blocks,
   227  		})
   228  
   229  		if coverageDat, ok := os.LookupEnv("COVERAGE_OUTPUT_FILE"); ok {
   230  			{{if eq .CoverFormat "lcov"}}
   231  			flag.Lookup("test.coverprofile").Value.Set(coverageDat+".cover")
   232  			{{else}}
   233  			flag.Lookup("test.coverprofile").Value.Set(coverageDat)
   234  			{{end}}
   235  		}
   236  	}
   237  	{{end}}
   238  	bzltestutil.RegisterTimeoutHandler()
   239  	{{if not .TestMain}}
   240  	res := m.Run()
   241  	{{else}}
   242  	{{.TestMain}}(m)
   243  	{{/* See golang.org/issue/34129 and golang.org/cl/219639 */}}
   244  	res := int(reflect.ValueOf(m).Elem().FieldByName("exitCode").Int())
   245  	{{end}}
   246  	os.Exit(res)
   247  }
   248  `
   249  
   250  func genTestMain(args []string) error {
   251  	// Prepare our flags
   252  	args, _, err := expandParamsFiles(args)
   253  	if err != nil {
   254  		return err
   255  	}
   256  	imports := multiFlag{}
   257  	sources := multiFlag{}
   258  	flags := flag.NewFlagSet("GoTestGenTest", flag.ExitOnError)
   259  	goenv := envFlags(flags)
   260  	out := flags.String("output", "", "output file to write. Defaults to stdout.")
   261  	coverMode := flags.String("cover_mode", "", "the coverage mode to use")
   262  	coverFormat := flags.String("cover_format", "", "the coverage report type to generate (go_cover or lcov)")
   263  	pkgname := flags.String("pkgname", "", "package name of test")
   264  	flags.Var(&imports, "import", "Packages to import")
   265  	flags.Var(&sources, "src", "Sources to process for tests")
   266  	if err := flags.Parse(args); err != nil {
   267  		return err
   268  	}
   269  	if err := goenv.checkFlags(); err != nil {
   270  		return err
   271  	}
   272  	// Process import args
   273  	importMap := map[string]*Import{}
   274  	for _, imp := range imports {
   275  		parts := strings.Split(imp, "=")
   276  		if len(parts) != 2 {
   277  			return fmt.Errorf("Invalid import %q specified", imp)
   278  		}
   279  		i := &Import{Name: parts[0], Path: parts[1]}
   280  		importMap[i.Name] = i
   281  	}
   282  	// Process source args
   283  	sourceList := []string{}
   284  	sourceMap := map[string]string{}
   285  	for _, s := range sources {
   286  		parts := strings.Split(s, "=")
   287  		if len(parts) != 2 {
   288  			return fmt.Errorf("Invalid source %q specified", s)
   289  		}
   290  		sourceList = append(sourceList, parts[1])
   291  		sourceMap[parts[1]] = parts[0]
   292  	}
   293  
   294  	// filter our input file list
   295  	filteredSrcs, err := filterAndSplitFiles(sourceList)
   296  	if err != nil {
   297  		return err
   298  	}
   299  	goSrcs := filteredSrcs.goSrcs
   300  
   301  	outFile := os.Stdout
   302  	if *out != "" {
   303  		var err error
   304  		outFile, err = os.Create(*out)
   305  		if err != nil {
   306  			return fmt.Errorf("os.Create(%q): %v", *out, err)
   307  		}
   308  		defer outFile.Close()
   309  	}
   310  
   311  	cases := Cases{
   312  		CoverFormat: *coverFormat,
   313  		CoverMode:   *coverMode,
   314  		Pkgname:     *pkgname,
   315  	}
   316  
   317  	testFileSet := token.NewFileSet()
   318  	pkgs := map[string]bool{}
   319  	for _, f := range goSrcs {
   320  		parse, err := parser.ParseFile(testFileSet, f.filename, nil, parser.ParseComments)
   321  		if err != nil {
   322  			return fmt.Errorf("ParseFile(%q): %v", f.filename, err)
   323  		}
   324  		pkg := sourceMap[f.filename]
   325  		if strings.HasSuffix(parse.Name.String(), "_test") {
   326  			pkg += "_test"
   327  		}
   328  		for _, e := range doc.Examples(parse) {
   329  			if e.Output == "" && !e.EmptyOutput {
   330  				continue
   331  			}
   332  			cases.Examples = append(cases.Examples, Example{
   333  				Name:      "Example" + e.Name,
   334  				Package:   pkg,
   335  				Output:    e.Output,
   336  				Unordered: e.Unordered,
   337  			})
   338  			pkgs[pkg] = true
   339  		}
   340  		for _, d := range parse.Decls {
   341  			fn, ok := d.(*ast.FuncDecl)
   342  			if !ok {
   343  				continue
   344  			}
   345  			if fn.Recv != nil {
   346  				continue
   347  			}
   348  			if fn.Name.Name == "TestMain" {
   349  				// TestMain is not, itself, a test
   350  				pkgs[pkg] = true
   351  				cases.TestMain = fmt.Sprintf("%s.%s", pkg, fn.Name.Name)
   352  				continue
   353  			}
   354  
   355  			// Here we check the signature of the Test* function. To
   356  			// be considered a test:
   357  
   358  			// 1. The function should have a single argument.
   359  			if len(fn.Type.Params.List) != 1 {
   360  				continue
   361  			}
   362  
   363  			// 2. The function should return nothing.
   364  			if fn.Type.Results != nil {
   365  				continue
   366  			}
   367  
   368  			// 3. The only parameter should have a type identified as
   369  			//    *<something>.T
   370  			starExpr, ok := fn.Type.Params.List[0].Type.(*ast.StarExpr)
   371  			if !ok {
   372  				continue
   373  			}
   374  			selExpr, ok := starExpr.X.(*ast.SelectorExpr)
   375  			if !ok {
   376  				continue
   377  			}
   378  
   379  			// We do not descriminate on the referenced type of the
   380  			// parameter being *testing.T. Instead we assert that it
   381  			// should be *<something>.T. This is because the import
   382  			// could have been aliased as a different identifier.
   383  
   384  			if strings.HasPrefix(fn.Name.Name, "Test") {
   385  				if selExpr.Sel.Name != "T" {
   386  					continue
   387  				}
   388  				pkgs[pkg] = true
   389  				cases.Tests = append(cases.Tests, TestCase{
   390  					Package: pkg,
   391  					Name:    fn.Name.Name,
   392  				})
   393  			}
   394  			if strings.HasPrefix(fn.Name.Name, "Benchmark") {
   395  				if selExpr.Sel.Name != "B" {
   396  					continue
   397  				}
   398  				pkgs[pkg] = true
   399  				cases.Benchmarks = append(cases.Benchmarks, TestCase{
   400  					Package: pkg,
   401  					Name:    fn.Name.Name,
   402  				})
   403  			}
   404  			if strings.HasPrefix(fn.Name.Name, "Fuzz") {
   405  				if selExpr.Sel.Name != "F" {
   406  					continue
   407  				}
   408  				pkgs[pkg] = true
   409  				cases.FuzzTargets = append(cases.FuzzTargets, TestCase{
   410  					Package: pkg,
   411  					Name:    fn.Name.Name,
   412  				})
   413  			}
   414  		}
   415  	}
   416  
   417  	for name := range importMap {
   418  		// Set the names for all unused imports to "_"
   419  		if !pkgs[name] {
   420  			importMap[name].Name = "_"
   421  		}
   422  		cases.Imports = append(cases.Imports, importMap[name])
   423  	}
   424  	sort.Slice(cases.Imports, func(i, j int) bool {
   425  		return cases.Imports[i].Name < cases.Imports[j].Name
   426  	})
   427  	tpl := template.Must(template.New("source").Parse(testMainTpl))
   428  	if err := tpl.Execute(outFile, &cases); err != nil {
   429  		return fmt.Errorf("template.Execute(%v): %v", cases, err)
   430  	}
   431  	return nil
   432  }
   433  

View as plain text