...

Source file src/edge-infra.dev/test/f2/flags.go

Documentation: edge-infra.dev/test/f2

     1  package f2
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"os"
     7  	"path/filepath"
     8  	"strings"
     9  
    10  	"github.com/bazelbuild/rules_go/go/runfiles"
    11  	"github.com/peterbourgon/ff/v3"
    12  
    13  	"edge-infra.dev/pkg/lib/build/bazel"
    14  )
    15  
    16  // Flags is the flag set parsed and used by the [Framework]. Test authors should
    17  // bind custom flags to this instead of directly adding to the global command
    18  // line.
    19  var Flags = flag.NewFlagSet("", flag.ContinueOnError)
    20  
    21  var (
    22  	CfgFlagName = "test-config"
    23  	CfgPath     = "test/config.json"
    24  	cfgFlag     string
    25  
    26  	Labels     map[string]string
    27  	SkipLabels map[string]string
    28  )
    29  
    30  // handleFlags copies all flags registered with [Flags] to the command line,
    31  // and then parses them with ff for environment variable and config file
    32  // support
    33  func handleFlags() error {
    34  	// register flag for config file path, so ff can find it
    35  	Flags.StringVar(&cfgFlag, CfgFlagName, resolveCfgPath(CfgPath), "Path to test configuration file")
    36  	// copy all the flags from the global flagset each test suite registers
    37  	// config options with to the command line for parsing
    38  	CopyFlags(Flags, flag.CommandLine)
    39  	RegisterCommonFlags(flag.CommandLine)
    40  
    41  	// parse test configuration
    42  	err := ff.Parse(flag.CommandLine, os.Args[1:],
    43  		ff.WithConfigFileFlag(CfgFlagName),
    44  		ff.WithConfigFileParser(ff.JSONParser),
    45  		ff.WithAllowMissingConfigFile(true),
    46  		ff.WithIgnoreUndefined(true),
    47  	)
    48  	if err != nil {
    49  		return err
    50  	}
    51  
    52  	return Validate()
    53  }
    54  
    55  // resolveCfgPath resolves the default config path value. if the test is being
    56  // executed by Bazel, it attempts to resolve the path for the test config.json
    57  // provided via runfiles. otherwise, it returns the default provided and we
    58  // assume that test/config.json will be provided explicitly.
    59  func resolveCfgPath(d string) string {
    60  	if ws := os.Getenv(bazel.TestWorkspace); ws != "" {
    61  		path, err := runfiles.Rlocation(filepath.Join(ws, d))
    62  		if err == nil {
    63  			return path
    64  		}
    65  	}
    66  
    67  	return d
    68  }
    69  
    70  // CopyFlags ensures that all flags that are defined in the source flag
    71  // set appear in the target flag set as if they had been defined there
    72  // directly. From the flag package it inherits the behavior that there
    73  // is a panic if the target already contains a flag from the source.
    74  func CopyFlags(source *flag.FlagSet, target *flag.FlagSet) {
    75  	source.VisitAll(func(flag *flag.Flag) {
    76  		// We don't need to copy flag.DefValue. The original
    77  		// default (from, say, flag.String) was stored in
    78  		// the value and gets extracted by Var for the help
    79  		// message.
    80  		target.Var(flag.Value, flag.Name, flag.Usage)
    81  	})
    82  }
    83  
    84  func RegisterCommonFlags(flags *flag.FlagSet) {
    85  	flags.Func("labels", "Only run tests with the provided comma separated labels (eg: foo=bar,bar=baz,bar=boo)", func(s string) error {
    86  		Labels = commaSepValues(s)
    87  		return nil
    88  	})
    89  
    90  	flags.Func("skip-labels", "Run all tests except those with the provided comma separated labels (eg: foo=bar,bar=baz,bar=boo)", func(s string) error {
    91  		SkipLabels = commaSepValues(s)
    92  		return nil
    93  	})
    94  }
    95  
    96  // parses comma separated map values
    97  func commaSepValues(s string) map[string]string {
    98  	a := strings.Split(s, ",")
    99  	r := map[string]string{}
   100  	for _, v := range a {
   101  		if v != "" {
   102  			kv := strings.Split(v, "=")
   103  			// drop malformed input
   104  			if len(kv) > 1 {
   105  				// if theres already an entry append
   106  				if len(r[kv[0]]) != 0 {
   107  					r[kv[0]] = fmt.Sprintf("%s,%s", r[kv[0]], kv[1])
   108  				} else {
   109  					r[kv[0]] = kv[1]
   110  				}
   111  			}
   112  		}
   113  	}
   114  	return r
   115  }
   116  
   117  // Validate checks that the base test context was provided valid values via config
   118  func Validate() error {
   119  	// check if theres any matching labels and error if so
   120  	for key := range Labels {
   121  		if _, ok := SkipLabels[key]; ok {
   122  			return fmt.Errorf("-labels and -skip-labels cannot contain the same label %s", key)
   123  		}
   124  	}
   125  	return nil
   126  }
   127  

View as plain text