...
1 package f2
2
3 import (
4 "flag"
5 "fmt"
6 "os"
7 "path/filepath"
8 "strings"
9
10 "github.com/bazelbuild/rules_go/go/runfiles"
11 "github.com/peterbourgon/ff/v3"
12
13 "edge-infra.dev/pkg/lib/build/bazel"
14 )
15
16
17
18
19 var Flags = flag.NewFlagSet("", flag.ContinueOnError)
20
21 var (
22 CfgFlagName = "test-config"
23 CfgPath = "test/config.json"
24 cfgFlag string
25
26 Labels map[string]string
27 SkipLabels map[string]string
28 )
29
30
31
32
33 func handleFlags() error {
34
35 Flags.StringVar(&cfgFlag, CfgFlagName, resolveCfgPath(CfgPath), "Path to test configuration file")
36
37
38 CopyFlags(Flags, flag.CommandLine)
39 RegisterCommonFlags(flag.CommandLine)
40
41
42 err := ff.Parse(flag.CommandLine, os.Args[1:],
43 ff.WithConfigFileFlag(CfgFlagName),
44 ff.WithConfigFileParser(ff.JSONParser),
45 ff.WithAllowMissingConfigFile(true),
46 ff.WithIgnoreUndefined(true),
47 )
48 if err != nil {
49 return err
50 }
51
52 return Validate()
53 }
54
55
56
57
58
59 func resolveCfgPath(d string) string {
60 if ws := os.Getenv(bazel.TestWorkspace); ws != "" {
61 path, err := runfiles.Rlocation(filepath.Join(ws, d))
62 if err == nil {
63 return path
64 }
65 }
66
67 return d
68 }
69
70
71
72
73
74 func CopyFlags(source *flag.FlagSet, target *flag.FlagSet) {
75 source.VisitAll(func(flag *flag.Flag) {
76
77
78
79
80 target.Var(flag.Value, flag.Name, flag.Usage)
81 })
82 }
83
84 func RegisterCommonFlags(flags *flag.FlagSet) {
85 flags.Func("labels", "Only run tests with the provided comma separated labels (eg: foo=bar,bar=baz,bar=boo)", func(s string) error {
86 Labels = commaSepValues(s)
87 return nil
88 })
89
90 flags.Func("skip-labels", "Run all tests except those with the provided comma separated labels (eg: foo=bar,bar=baz,bar=boo)", func(s string) error {
91 SkipLabels = commaSepValues(s)
92 return nil
93 })
94 }
95
96
97 func commaSepValues(s string) map[string]string {
98 a := strings.Split(s, ",")
99 r := map[string]string{}
100 for _, v := range a {
101 if v != "" {
102 kv := strings.Split(v, "=")
103
104 if len(kv) > 1 {
105
106 if len(r[kv[0]]) != 0 {
107 r[kv[0]] = fmt.Sprintf("%s,%s", r[kv[0]], kv[1])
108 } else {
109 r[kv[0]] = kv[1]
110 }
111 }
112 }
113 }
114 return r
115 }
116
117
118 func Validate() error {
119
120 for key := range Labels {
121 if _, ok := SkipLabels[key]; ok {
122 return fmt.Errorf("-labels and -skip-labels cannot contain the same label %s", key)
123 }
124 }
125 return nil
126 }
127
View as plain text