1 package types
2
3 import (
4 "flag"
5 "fmt"
6 "io"
7 "reflect"
8 "strings"
9 "time"
10
11 "github.com/onsi/ginkgo/v2/formatter"
12 )
13
14 type GinkgoFlag struct {
15 Name string
16 KeyPath string
17 SectionKey string
18
19 Usage string
20 UsageArgument string
21 UsageDefaultValue string
22
23 DeprecatedName string
24 DeprecatedDocLink string
25 DeprecatedVersion string
26
27 ExportAs string
28 AlwaysExport bool
29 }
30
31 type GinkgoFlags []GinkgoFlag
32
33 func (f GinkgoFlags) CopyAppend(flags ...GinkgoFlag) GinkgoFlags {
34 out := GinkgoFlags{}
35 out = append(out, f...)
36 out = append(out, flags...)
37 return out
38 }
39
40 func (f GinkgoFlags) WithPrefix(prefix string) GinkgoFlags {
41 if prefix == "" {
42 return f
43 }
44 out := GinkgoFlags{}
45 for _, flag := range f {
46 if flag.Name != "" {
47 flag.Name = prefix + "." + flag.Name
48 }
49 if flag.DeprecatedName != "" {
50 flag.DeprecatedName = prefix + "." + flag.DeprecatedName
51 }
52 if flag.ExportAs != "" {
53 flag.ExportAs = prefix + "." + flag.ExportAs
54 }
55 out = append(out, flag)
56 }
57 return out
58 }
59
60 func (f GinkgoFlags) SubsetWithNames(names ...string) GinkgoFlags {
61 out := GinkgoFlags{}
62 for _, flag := range f {
63 for _, name := range names {
64 if flag.Name == name {
65 out = append(out, flag)
66 break
67 }
68 }
69 }
70 return out
71 }
72
73 type GinkgoFlagSection struct {
74 Key string
75 Style string
76 Succinct bool
77 Heading string
78 Description string
79 }
80
81 type GinkgoFlagSections []GinkgoFlagSection
82
83 func (gfs GinkgoFlagSections) Lookup(key string) (GinkgoFlagSection, bool) {
84 for _, section := range gfs {
85 if section.Key == key {
86 return section, true
87 }
88 }
89
90 return GinkgoFlagSection{}, false
91 }
92
93 type GinkgoFlagSet struct {
94 flags GinkgoFlags
95 bindings interface{}
96
97 sections GinkgoFlagSections
98 extraGoFlagsSection GinkgoFlagSection
99
100 flagSet *flag.FlagSet
101 }
102
103
104 func NewGinkgoFlagSet(flags GinkgoFlags, bindings interface{}, sections GinkgoFlagSections) (GinkgoFlagSet, error) {
105 return bindFlagSet(GinkgoFlagSet{
106 flags: flags,
107 bindings: bindings,
108 sections: sections,
109 }, nil)
110 }
111
112
113 func NewAttachedGinkgoFlagSet(flagSet *flag.FlagSet, flags GinkgoFlags, bindings interface{}, sections GinkgoFlagSections, extraGoFlagsSection GinkgoFlagSection) (GinkgoFlagSet, error) {
114 return bindFlagSet(GinkgoFlagSet{
115 flags: flags,
116 bindings: bindings,
117 sections: sections,
118 extraGoFlagsSection: extraGoFlagsSection,
119 }, flagSet)
120 }
121
122 func bindFlagSet(f GinkgoFlagSet, flagSet *flag.FlagSet) (GinkgoFlagSet, error) {
123 if flagSet == nil {
124 f.flagSet = flag.NewFlagSet("", flag.ContinueOnError)
125
126 f.flagSet.SetOutput(io.Discard)
127 } else {
128 f.flagSet = flagSet
129
130
131 f.flagSet.Usage = f.substituteUsage
132 }
133
134 for _, flag := range f.flags {
135 name := flag.Name
136
137 deprecatedUsage := "[DEPRECATED]"
138 deprecatedName := flag.DeprecatedName
139 if name != "" {
140 deprecatedUsage = fmt.Sprintf("[DEPRECATED] use --%s instead", name)
141 } else if flag.Usage != "" {
142 deprecatedUsage += " " + flag.Usage
143 }
144
145 value, ok := valueAtKeyPath(f.bindings, flag.KeyPath)
146 if !ok {
147 return GinkgoFlagSet{}, fmt.Errorf("could not load KeyPath: %s", flag.KeyPath)
148 }
149
150 iface, addr := value.Interface(), value.Addr().Interface()
151
152 switch value.Type() {
153 case reflect.TypeOf(string("")):
154 if name != "" {
155 f.flagSet.StringVar(addr.(*string), name, iface.(string), flag.Usage)
156 }
157 if deprecatedName != "" {
158 f.flagSet.StringVar(addr.(*string), deprecatedName, iface.(string), deprecatedUsage)
159 }
160 case reflect.TypeOf(int64(0)):
161 if name != "" {
162 f.flagSet.Int64Var(addr.(*int64), name, iface.(int64), flag.Usage)
163 }
164 if deprecatedName != "" {
165 f.flagSet.Int64Var(addr.(*int64), deprecatedName, iface.(int64), deprecatedUsage)
166 }
167 case reflect.TypeOf(float64(0)):
168 if name != "" {
169 f.flagSet.Float64Var(addr.(*float64), name, iface.(float64), flag.Usage)
170 }
171 if deprecatedName != "" {
172 f.flagSet.Float64Var(addr.(*float64), deprecatedName, iface.(float64), deprecatedUsage)
173 }
174 case reflect.TypeOf(int(0)):
175 if name != "" {
176 f.flagSet.IntVar(addr.(*int), name, iface.(int), flag.Usage)
177 }
178 if deprecatedName != "" {
179 f.flagSet.IntVar(addr.(*int), deprecatedName, iface.(int), deprecatedUsage)
180 }
181 case reflect.TypeOf(bool(true)):
182 if name != "" {
183 f.flagSet.BoolVar(addr.(*bool), name, iface.(bool), flag.Usage)
184 }
185 if deprecatedName != "" {
186 f.flagSet.BoolVar(addr.(*bool), deprecatedName, iface.(bool), deprecatedUsage)
187 }
188 case reflect.TypeOf(time.Duration(0)):
189 if name != "" {
190 f.flagSet.DurationVar(addr.(*time.Duration), name, iface.(time.Duration), flag.Usage)
191 }
192 if deprecatedName != "" {
193 f.flagSet.DurationVar(addr.(*time.Duration), deprecatedName, iface.(time.Duration), deprecatedUsage)
194 }
195
196 case reflect.TypeOf([]string{}):
197 if name != "" {
198 f.flagSet.Var(stringSliceVar{value}, name, flag.Usage)
199 }
200 if deprecatedName != "" {
201 f.flagSet.Var(stringSliceVar{value}, deprecatedName, deprecatedUsage)
202 }
203 default:
204 return GinkgoFlagSet{}, fmt.Errorf("unsupported type %T", iface)
205 }
206 }
207
208 return f, nil
209 }
210
211 func (f GinkgoFlagSet) IsZero() bool {
212 return f.flagSet == nil
213 }
214
215 func (f GinkgoFlagSet) WasSet(name string) bool {
216 found := false
217 f.flagSet.Visit(func(f *flag.Flag) {
218 if f.Name == name {
219 found = true
220 }
221 })
222
223 return found
224 }
225
226 func (f GinkgoFlagSet) Lookup(name string) *flag.Flag {
227 return f.flagSet.Lookup(name)
228 }
229
230 func (f GinkgoFlagSet) Parse(args []string) ([]string, error) {
231 if f.IsZero() {
232 return args, nil
233 }
234 err := f.flagSet.Parse(args)
235 if err != nil {
236 return []string{}, err
237 }
238 return f.flagSet.Args(), nil
239 }
240
241 func (f GinkgoFlagSet) ValidateDeprecations(deprecationTracker *DeprecationTracker) {
242 if f.IsZero() {
243 return
244 }
245 f.flagSet.Visit(func(flag *flag.Flag) {
246 for _, ginkgoFlag := range f.flags {
247 if ginkgoFlag.DeprecatedName != "" && strings.HasSuffix(flag.Name, ginkgoFlag.DeprecatedName) {
248 message := fmt.Sprintf("--%s is deprecated", ginkgoFlag.DeprecatedName)
249 if ginkgoFlag.Name != "" {
250 message = fmt.Sprintf("--%s is deprecated, use --%s instead", ginkgoFlag.DeprecatedName, ginkgoFlag.Name)
251 } else if ginkgoFlag.Usage != "" {
252 message += " " + ginkgoFlag.Usage
253 }
254
255 deprecationTracker.TrackDeprecation(Deprecation{
256 Message: message,
257 DocLink: ginkgoFlag.DeprecatedDocLink,
258 Version: ginkgoFlag.DeprecatedVersion,
259 })
260 }
261 }
262 })
263 }
264
265 func (f GinkgoFlagSet) Usage() string {
266 if f.IsZero() {
267 return ""
268 }
269 groupedFlags := map[GinkgoFlagSection]GinkgoFlags{}
270 ungroupedFlags := GinkgoFlags{}
271 managedFlags := map[string]bool{}
272 extraGoFlags := []*flag.Flag{}
273
274 for _, flag := range f.flags {
275 managedFlags[flag.Name] = true
276 managedFlags[flag.DeprecatedName] = true
277
278 if flag.Name == "" {
279 continue
280 }
281
282 section, ok := f.sections.Lookup(flag.SectionKey)
283 if ok {
284 groupedFlags[section] = append(groupedFlags[section], flag)
285 } else {
286 ungroupedFlags = append(ungroupedFlags, flag)
287 }
288 }
289
290 f.flagSet.VisitAll(func(flag *flag.Flag) {
291 if !managedFlags[flag.Name] {
292 extraGoFlags = append(extraGoFlags, flag)
293 }
294 })
295
296 out := ""
297 for _, section := range f.sections {
298 flags := groupedFlags[section]
299 if len(flags) == 0 {
300 continue
301 }
302 out += f.usageForSection(section)
303 if section.Succinct {
304 succinctFlags := []string{}
305 for _, flag := range flags {
306 if flag.Name != "" {
307 succinctFlags = append(succinctFlags, fmt.Sprintf("--%s", flag.Name))
308 }
309 }
310 out += formatter.Fiw(1, formatter.COLS, section.Style+strings.Join(succinctFlags, ", ")+"{{/}}\n")
311 } else {
312 for _, flag := range flags {
313 out += f.usageForFlag(flag, section.Style)
314 }
315 }
316 out += "\n"
317 }
318 if len(ungroupedFlags) > 0 {
319 for _, flag := range ungroupedFlags {
320 out += f.usageForFlag(flag, "")
321 }
322 out += "\n"
323 }
324 if len(extraGoFlags) > 0 {
325 out += f.usageForSection(f.extraGoFlagsSection)
326 for _, goFlag := range extraGoFlags {
327 out += f.usageForGoFlag(goFlag)
328 }
329 }
330
331 return out
332 }
333
334 func (f GinkgoFlagSet) substituteUsage() {
335 fmt.Fprintln(f.flagSet.Output(), f.Usage())
336 }
337
338 func valueAtKeyPath(root interface{}, keyPath string) (reflect.Value, bool) {
339 if len(keyPath) == 0 {
340 return reflect.Value{}, false
341 }
342
343 val := reflect.ValueOf(root)
344 components := strings.Split(keyPath, ".")
345 for _, component := range components {
346 val = reflect.Indirect(val)
347 switch val.Kind() {
348 case reflect.Map:
349 val = val.MapIndex(reflect.ValueOf(component))
350 if val.Kind() == reflect.Interface {
351 val = reflect.ValueOf(val.Interface())
352 }
353 case reflect.Struct:
354 val = val.FieldByName(component)
355 default:
356 return reflect.Value{}, false
357 }
358 if (val == reflect.Value{}) {
359 return reflect.Value{}, false
360 }
361 }
362
363 return val, true
364 }
365
366 func (f GinkgoFlagSet) usageForSection(section GinkgoFlagSection) string {
367 out := formatter.F(section.Style + "{{bold}}{{underline}}" + section.Heading + "{{/}}\n")
368 if section.Description != "" {
369 out += formatter.Fiw(0, formatter.COLS, section.Description+"\n")
370 }
371 return out
372 }
373
374 func (f GinkgoFlagSet) usageForFlag(flag GinkgoFlag, style string) string {
375 argument := flag.UsageArgument
376 defValue := flag.UsageDefaultValue
377 if argument == "" {
378 value, _ := valueAtKeyPath(f.bindings, flag.KeyPath)
379 switch value.Type() {
380 case reflect.TypeOf(string("")):
381 argument = "string"
382 case reflect.TypeOf(int64(0)), reflect.TypeOf(int(0)):
383 argument = "int"
384 case reflect.TypeOf(time.Duration(0)):
385 argument = "duration"
386 case reflect.TypeOf(float64(0)):
387 argument = "float"
388 case reflect.TypeOf([]string{}):
389 argument = "string"
390 }
391 }
392 if argument != "" {
393 argument = "[" + argument + "] "
394 }
395 if defValue != "" {
396 defValue = fmt.Sprintf("(default: %s)", defValue)
397 }
398 hyphens := "--"
399 if len(flag.Name) == 1 {
400 hyphens = "-"
401 }
402
403 out := formatter.Fi(1, style+"%s%s{{/}} %s{{gray}}%s{{/}}\n", hyphens, flag.Name, argument, defValue)
404 out += formatter.Fiw(2, formatter.COLS, "{{light-gray}}%s{{/}}\n", flag.Usage)
405 return out
406 }
407
408 func (f GinkgoFlagSet) usageForGoFlag(goFlag *flag.Flag) string {
409
410 out := fmt.Sprintf(" -%s", goFlag.Name)
411 name, usage := flag.UnquoteUsage(goFlag)
412 if len(name) > 0 {
413 out += " " + name
414 }
415 if len(out) <= 4 {
416 out += "\t"
417 } else {
418 out += "\n \t"
419 }
420 out += strings.ReplaceAll(usage, "\n", "\n \t")
421 out += "\n"
422 return out
423 }
424
425 type stringSliceVar struct {
426 slice reflect.Value
427 }
428
429 func (ssv stringSliceVar) String() string { return "" }
430 func (ssv stringSliceVar) Set(s string) error {
431 ssv.slice.Set(reflect.AppendSlice(ssv.slice, reflect.ValueOf([]string{s})))
432 return nil
433 }
434
435
436 func GenerateFlagArgs(flags GinkgoFlags, bindings interface{}) ([]string, error) {
437 result := []string{}
438 for _, flag := range flags {
439 name := flag.ExportAs
440 if name == "" {
441 name = flag.Name
442 }
443 if name == "" {
444 continue
445 }
446
447 value, ok := valueAtKeyPath(bindings, flag.KeyPath)
448 if !ok {
449 return []string{}, fmt.Errorf("could not load KeyPath: %s", flag.KeyPath)
450 }
451
452 iface := value.Interface()
453 switch value.Type() {
454 case reflect.TypeOf(string("")):
455 if iface.(string) != "" || flag.AlwaysExport {
456 result = append(result, fmt.Sprintf("--%s=%s", name, iface))
457 }
458 case reflect.TypeOf(int64(0)):
459 if iface.(int64) != 0 || flag.AlwaysExport {
460 result = append(result, fmt.Sprintf("--%s=%d", name, iface))
461 }
462 case reflect.TypeOf(float64(0)):
463 if iface.(float64) != 0 || flag.AlwaysExport {
464 result = append(result, fmt.Sprintf("--%s=%f", name, iface))
465 }
466 case reflect.TypeOf(int(0)):
467 if iface.(int) != 0 || flag.AlwaysExport {
468 result = append(result, fmt.Sprintf("--%s=%d", name, iface))
469 }
470 case reflect.TypeOf(bool(true)):
471 if iface.(bool) {
472 result = append(result, fmt.Sprintf("--%s", name))
473 }
474 case reflect.TypeOf(time.Duration(0)):
475 if iface.(time.Duration) != time.Duration(0) || flag.AlwaysExport {
476 result = append(result, fmt.Sprintf("--%s=%s", name, iface))
477 }
478
479 case reflect.TypeOf([]string{}):
480 strings := iface.([]string)
481 for _, s := range strings {
482 result = append(result, fmt.Sprintf("--%s=%s", name, s))
483 }
484 default:
485 return []string{}, fmt.Errorf("unsupported type %T", iface)
486 }
487 }
488
489 return result, nil
490 }
491
View as plain text