...

Source file src/github.com/onsi/ginkgo/v2/types/flags.go

Documentation: github.com/onsi/ginkgo/v2/types

     1  package types
     2  
     3  import (
     4  	"flag"
     5  	"fmt"
     6  	"io"
     7  	"reflect"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/onsi/ginkgo/v2/formatter"
    12  )
    13  
    14  type GinkgoFlag struct {
    15  	Name       string
    16  	KeyPath    string
    17  	SectionKey string
    18  
    19  	Usage             string
    20  	UsageArgument     string
    21  	UsageDefaultValue string
    22  
    23  	DeprecatedName    string
    24  	DeprecatedDocLink string
    25  	DeprecatedVersion string
    26  
    27  	ExportAs     string
    28  	AlwaysExport bool
    29  }
    30  
    31  type GinkgoFlags []GinkgoFlag
    32  
    33  func (f GinkgoFlags) CopyAppend(flags ...GinkgoFlag) GinkgoFlags {
    34  	out := GinkgoFlags{}
    35  	out = append(out, f...)
    36  	out = append(out, flags...)
    37  	return out
    38  }
    39  
    40  func (f GinkgoFlags) WithPrefix(prefix string) GinkgoFlags {
    41  	if prefix == "" {
    42  		return f
    43  	}
    44  	out := GinkgoFlags{}
    45  	for _, flag := range f {
    46  		if flag.Name != "" {
    47  			flag.Name = prefix + "." + flag.Name
    48  		}
    49  		if flag.DeprecatedName != "" {
    50  			flag.DeprecatedName = prefix + "." + flag.DeprecatedName
    51  		}
    52  		if flag.ExportAs != "" {
    53  			flag.ExportAs = prefix + "." + flag.ExportAs
    54  		}
    55  		out = append(out, flag)
    56  	}
    57  	return out
    58  }
    59  
    60  func (f GinkgoFlags) SubsetWithNames(names ...string) GinkgoFlags {
    61  	out := GinkgoFlags{}
    62  	for _, flag := range f {
    63  		for _, name := range names {
    64  			if flag.Name == name {
    65  				out = append(out, flag)
    66  				break
    67  			}
    68  		}
    69  	}
    70  	return out
    71  }
    72  
    73  type GinkgoFlagSection struct {
    74  	Key         string
    75  	Style       string
    76  	Succinct    bool
    77  	Heading     string
    78  	Description string
    79  }
    80  
    81  type GinkgoFlagSections []GinkgoFlagSection
    82  
    83  func (gfs GinkgoFlagSections) Lookup(key string) (GinkgoFlagSection, bool) {
    84  	for _, section := range gfs {
    85  		if section.Key == key {
    86  			return section, true
    87  		}
    88  	}
    89  
    90  	return GinkgoFlagSection{}, false
    91  }
    92  
    93  type GinkgoFlagSet struct {
    94  	flags    GinkgoFlags
    95  	bindings interface{}
    96  
    97  	sections            GinkgoFlagSections
    98  	extraGoFlagsSection GinkgoFlagSection
    99  
   100  	flagSet *flag.FlagSet
   101  }
   102  
   103  // Call NewGinkgoFlagSet to create GinkgoFlagSet that creates and binds to it's own *flag.FlagSet
   104  func NewGinkgoFlagSet(flags GinkgoFlags, bindings interface{}, sections GinkgoFlagSections) (GinkgoFlagSet, error) {
   105  	return bindFlagSet(GinkgoFlagSet{
   106  		flags:    flags,
   107  		bindings: bindings,
   108  		sections: sections,
   109  	}, nil)
   110  }
   111  
   112  // Call NewGinkgoFlagSet to create GinkgoFlagSet that extends an existing *flag.FlagSet
   113  func NewAttachedGinkgoFlagSet(flagSet *flag.FlagSet, flags GinkgoFlags, bindings interface{}, sections GinkgoFlagSections, extraGoFlagsSection GinkgoFlagSection) (GinkgoFlagSet, error) {
   114  	return bindFlagSet(GinkgoFlagSet{
   115  		flags:               flags,
   116  		bindings:            bindings,
   117  		sections:            sections,
   118  		extraGoFlagsSection: extraGoFlagsSection,
   119  	}, flagSet)
   120  }
   121  
   122  func bindFlagSet(f GinkgoFlagSet, flagSet *flag.FlagSet) (GinkgoFlagSet, error) {
   123  	if flagSet == nil {
   124  		f.flagSet = flag.NewFlagSet("", flag.ContinueOnError)
   125  		//suppress all output as Ginkgo is responsible for formatting usage
   126  		f.flagSet.SetOutput(io.Discard)
   127  	} else {
   128  		f.flagSet = flagSet
   129  		//we're piggybacking on an existing flagset (typically go test) so we have limited control
   130  		//on user feedback
   131  		f.flagSet.Usage = f.substituteUsage
   132  	}
   133  
   134  	for _, flag := range f.flags {
   135  		name := flag.Name
   136  
   137  		deprecatedUsage := "[DEPRECATED]"
   138  		deprecatedName := flag.DeprecatedName
   139  		if name != "" {
   140  			deprecatedUsage = fmt.Sprintf("[DEPRECATED] use --%s instead", name)
   141  		} else if flag.Usage != "" {
   142  			deprecatedUsage += " " + flag.Usage
   143  		}
   144  
   145  		value, ok := valueAtKeyPath(f.bindings, flag.KeyPath)
   146  		if !ok {
   147  			return GinkgoFlagSet{}, fmt.Errorf("could not load KeyPath: %s", flag.KeyPath)
   148  		}
   149  
   150  		iface, addr := value.Interface(), value.Addr().Interface()
   151  
   152  		switch value.Type() {
   153  		case reflect.TypeOf(string("")):
   154  			if name != "" {
   155  				f.flagSet.StringVar(addr.(*string), name, iface.(string), flag.Usage)
   156  			}
   157  			if deprecatedName != "" {
   158  				f.flagSet.StringVar(addr.(*string), deprecatedName, iface.(string), deprecatedUsage)
   159  			}
   160  		case reflect.TypeOf(int64(0)):
   161  			if name != "" {
   162  				f.flagSet.Int64Var(addr.(*int64), name, iface.(int64), flag.Usage)
   163  			}
   164  			if deprecatedName != "" {
   165  				f.flagSet.Int64Var(addr.(*int64), deprecatedName, iface.(int64), deprecatedUsage)
   166  			}
   167  		case reflect.TypeOf(float64(0)):
   168  			if name != "" {
   169  				f.flagSet.Float64Var(addr.(*float64), name, iface.(float64), flag.Usage)
   170  			}
   171  			if deprecatedName != "" {
   172  				f.flagSet.Float64Var(addr.(*float64), deprecatedName, iface.(float64), deprecatedUsage)
   173  			}
   174  		case reflect.TypeOf(int(0)):
   175  			if name != "" {
   176  				f.flagSet.IntVar(addr.(*int), name, iface.(int), flag.Usage)
   177  			}
   178  			if deprecatedName != "" {
   179  				f.flagSet.IntVar(addr.(*int), deprecatedName, iface.(int), deprecatedUsage)
   180  			}
   181  		case reflect.TypeOf(bool(true)):
   182  			if name != "" {
   183  				f.flagSet.BoolVar(addr.(*bool), name, iface.(bool), flag.Usage)
   184  			}
   185  			if deprecatedName != "" {
   186  				f.flagSet.BoolVar(addr.(*bool), deprecatedName, iface.(bool), deprecatedUsage)
   187  			}
   188  		case reflect.TypeOf(time.Duration(0)):
   189  			if name != "" {
   190  				f.flagSet.DurationVar(addr.(*time.Duration), name, iface.(time.Duration), flag.Usage)
   191  			}
   192  			if deprecatedName != "" {
   193  				f.flagSet.DurationVar(addr.(*time.Duration), deprecatedName, iface.(time.Duration), deprecatedUsage)
   194  			}
   195  
   196  		case reflect.TypeOf([]string{}):
   197  			if name != "" {
   198  				f.flagSet.Var(stringSliceVar{value}, name, flag.Usage)
   199  			}
   200  			if deprecatedName != "" {
   201  				f.flagSet.Var(stringSliceVar{value}, deprecatedName, deprecatedUsage)
   202  			}
   203  		default:
   204  			return GinkgoFlagSet{}, fmt.Errorf("unsupported type %T", iface)
   205  		}
   206  	}
   207  
   208  	return f, nil
   209  }
   210  
   211  func (f GinkgoFlagSet) IsZero() bool {
   212  	return f.flagSet == nil
   213  }
   214  
   215  func (f GinkgoFlagSet) WasSet(name string) bool {
   216  	found := false
   217  	f.flagSet.Visit(func(f *flag.Flag) {
   218  		if f.Name == name {
   219  			found = true
   220  		}
   221  	})
   222  
   223  	return found
   224  }
   225  
   226  func (f GinkgoFlagSet) Lookup(name string) *flag.Flag {
   227  	return f.flagSet.Lookup(name)
   228  }
   229  
   230  func (f GinkgoFlagSet) Parse(args []string) ([]string, error) {
   231  	if f.IsZero() {
   232  		return args, nil
   233  	}
   234  	err := f.flagSet.Parse(args)
   235  	if err != nil {
   236  		return []string{}, err
   237  	}
   238  	return f.flagSet.Args(), nil
   239  }
   240  
   241  func (f GinkgoFlagSet) ValidateDeprecations(deprecationTracker *DeprecationTracker) {
   242  	if f.IsZero() {
   243  		return
   244  	}
   245  	f.flagSet.Visit(func(flag *flag.Flag) {
   246  		for _, ginkgoFlag := range f.flags {
   247  			if ginkgoFlag.DeprecatedName != "" && strings.HasSuffix(flag.Name, ginkgoFlag.DeprecatedName) {
   248  				message := fmt.Sprintf("--%s is deprecated", ginkgoFlag.DeprecatedName)
   249  				if ginkgoFlag.Name != "" {
   250  					message = fmt.Sprintf("--%s is deprecated, use --%s instead", ginkgoFlag.DeprecatedName, ginkgoFlag.Name)
   251  				} else if ginkgoFlag.Usage != "" {
   252  					message += " " + ginkgoFlag.Usage
   253  				}
   254  
   255  				deprecationTracker.TrackDeprecation(Deprecation{
   256  					Message: message,
   257  					DocLink: ginkgoFlag.DeprecatedDocLink,
   258  					Version: ginkgoFlag.DeprecatedVersion,
   259  				})
   260  			}
   261  		}
   262  	})
   263  }
   264  
   265  func (f GinkgoFlagSet) Usage() string {
   266  	if f.IsZero() {
   267  		return ""
   268  	}
   269  	groupedFlags := map[GinkgoFlagSection]GinkgoFlags{}
   270  	ungroupedFlags := GinkgoFlags{}
   271  	managedFlags := map[string]bool{}
   272  	extraGoFlags := []*flag.Flag{}
   273  
   274  	for _, flag := range f.flags {
   275  		managedFlags[flag.Name] = true
   276  		managedFlags[flag.DeprecatedName] = true
   277  
   278  		if flag.Name == "" {
   279  			continue
   280  		}
   281  
   282  		section, ok := f.sections.Lookup(flag.SectionKey)
   283  		if ok {
   284  			groupedFlags[section] = append(groupedFlags[section], flag)
   285  		} else {
   286  			ungroupedFlags = append(ungroupedFlags, flag)
   287  		}
   288  	}
   289  
   290  	f.flagSet.VisitAll(func(flag *flag.Flag) {
   291  		if !managedFlags[flag.Name] {
   292  			extraGoFlags = append(extraGoFlags, flag)
   293  		}
   294  	})
   295  
   296  	out := ""
   297  	for _, section := range f.sections {
   298  		flags := groupedFlags[section]
   299  		if len(flags) == 0 {
   300  			continue
   301  		}
   302  		out += f.usageForSection(section)
   303  		if section.Succinct {
   304  			succinctFlags := []string{}
   305  			for _, flag := range flags {
   306  				if flag.Name != "" {
   307  					succinctFlags = append(succinctFlags, fmt.Sprintf("--%s", flag.Name))
   308  				}
   309  			}
   310  			out += formatter.Fiw(1, formatter.COLS, section.Style+strings.Join(succinctFlags, ", ")+"{{/}}\n")
   311  		} else {
   312  			for _, flag := range flags {
   313  				out += f.usageForFlag(flag, section.Style)
   314  			}
   315  		}
   316  		out += "\n"
   317  	}
   318  	if len(ungroupedFlags) > 0 {
   319  		for _, flag := range ungroupedFlags {
   320  			out += f.usageForFlag(flag, "")
   321  		}
   322  		out += "\n"
   323  	}
   324  	if len(extraGoFlags) > 0 {
   325  		out += f.usageForSection(f.extraGoFlagsSection)
   326  		for _, goFlag := range extraGoFlags {
   327  			out += f.usageForGoFlag(goFlag)
   328  		}
   329  	}
   330  
   331  	return out
   332  }
   333  
   334  func (f GinkgoFlagSet) substituteUsage() {
   335  	fmt.Fprintln(f.flagSet.Output(), f.Usage())
   336  }
   337  
   338  func valueAtKeyPath(root interface{}, keyPath string) (reflect.Value, bool) {
   339  	if len(keyPath) == 0 {
   340  		return reflect.Value{}, false
   341  	}
   342  
   343  	val := reflect.ValueOf(root)
   344  	components := strings.Split(keyPath, ".")
   345  	for _, component := range components {
   346  		val = reflect.Indirect(val)
   347  		switch val.Kind() {
   348  		case reflect.Map:
   349  			val = val.MapIndex(reflect.ValueOf(component))
   350  			if val.Kind() == reflect.Interface {
   351  				val = reflect.ValueOf(val.Interface())
   352  			}
   353  		case reflect.Struct:
   354  			val = val.FieldByName(component)
   355  		default:
   356  			return reflect.Value{}, false
   357  		}
   358  		if (val == reflect.Value{}) {
   359  			return reflect.Value{}, false
   360  		}
   361  	}
   362  
   363  	return val, true
   364  }
   365  
   366  func (f GinkgoFlagSet) usageForSection(section GinkgoFlagSection) string {
   367  	out := formatter.F(section.Style + "{{bold}}{{underline}}" + section.Heading + "{{/}}\n")
   368  	if section.Description != "" {
   369  		out += formatter.Fiw(0, formatter.COLS, section.Description+"\n")
   370  	}
   371  	return out
   372  }
   373  
   374  func (f GinkgoFlagSet) usageForFlag(flag GinkgoFlag, style string) string {
   375  	argument := flag.UsageArgument
   376  	defValue := flag.UsageDefaultValue
   377  	if argument == "" {
   378  		value, _ := valueAtKeyPath(f.bindings, flag.KeyPath)
   379  		switch value.Type() {
   380  		case reflect.TypeOf(string("")):
   381  			argument = "string"
   382  		case reflect.TypeOf(int64(0)), reflect.TypeOf(int(0)):
   383  			argument = "int"
   384  		case reflect.TypeOf(time.Duration(0)):
   385  			argument = "duration"
   386  		case reflect.TypeOf(float64(0)):
   387  			argument = "float"
   388  		case reflect.TypeOf([]string{}):
   389  			argument = "string"
   390  		}
   391  	}
   392  	if argument != "" {
   393  		argument = "[" + argument + "] "
   394  	}
   395  	if defValue != "" {
   396  		defValue = fmt.Sprintf("(default: %s)", defValue)
   397  	}
   398  	hyphens := "--"
   399  	if len(flag.Name) == 1 {
   400  		hyphens = "-"
   401  	}
   402  
   403  	out := formatter.Fi(1, style+"%s%s{{/}} %s{{gray}}%s{{/}}\n", hyphens, flag.Name, argument, defValue)
   404  	out += formatter.Fiw(2, formatter.COLS, "{{light-gray}}%s{{/}}\n", flag.Usage)
   405  	return out
   406  }
   407  
   408  func (f GinkgoFlagSet) usageForGoFlag(goFlag *flag.Flag) string {
   409  	//Taken directly from the flag package
   410  	out := fmt.Sprintf("  -%s", goFlag.Name)
   411  	name, usage := flag.UnquoteUsage(goFlag)
   412  	if len(name) > 0 {
   413  		out += " " + name
   414  	}
   415  	if len(out) <= 4 {
   416  		out += "\t"
   417  	} else {
   418  		out += "\n    \t"
   419  	}
   420  	out += strings.ReplaceAll(usage, "\n", "\n    \t")
   421  	out += "\n"
   422  	return out
   423  }
   424  
   425  type stringSliceVar struct {
   426  	slice reflect.Value
   427  }
   428  
   429  func (ssv stringSliceVar) String() string { return "" }
   430  func (ssv stringSliceVar) Set(s string) error {
   431  	ssv.slice.Set(reflect.AppendSlice(ssv.slice, reflect.ValueOf([]string{s})))
   432  	return nil
   433  }
   434  
   435  // given a set of GinkgoFlags and bindings, generate flag arguments suitable to be passed to an application with that set of flags configured.
   436  func GenerateFlagArgs(flags GinkgoFlags, bindings interface{}) ([]string, error) {
   437  	result := []string{}
   438  	for _, flag := range flags {
   439  		name := flag.ExportAs
   440  		if name == "" {
   441  			name = flag.Name
   442  		}
   443  		if name == "" {
   444  			continue
   445  		}
   446  
   447  		value, ok := valueAtKeyPath(bindings, flag.KeyPath)
   448  		if !ok {
   449  			return []string{}, fmt.Errorf("could not load KeyPath: %s", flag.KeyPath)
   450  		}
   451  
   452  		iface := value.Interface()
   453  		switch value.Type() {
   454  		case reflect.TypeOf(string("")):
   455  			if iface.(string) != "" || flag.AlwaysExport {
   456  				result = append(result, fmt.Sprintf("--%s=%s", name, iface))
   457  			}
   458  		case reflect.TypeOf(int64(0)):
   459  			if iface.(int64) != 0 || flag.AlwaysExport {
   460  				result = append(result, fmt.Sprintf("--%s=%d", name, iface))
   461  			}
   462  		case reflect.TypeOf(float64(0)):
   463  			if iface.(float64) != 0 || flag.AlwaysExport {
   464  				result = append(result, fmt.Sprintf("--%s=%f", name, iface))
   465  			}
   466  		case reflect.TypeOf(int(0)):
   467  			if iface.(int) != 0 || flag.AlwaysExport {
   468  				result = append(result, fmt.Sprintf("--%s=%d", name, iface))
   469  			}
   470  		case reflect.TypeOf(bool(true)):
   471  			if iface.(bool) {
   472  				result = append(result, fmt.Sprintf("--%s", name))
   473  			}
   474  		case reflect.TypeOf(time.Duration(0)):
   475  			if iface.(time.Duration) != time.Duration(0) || flag.AlwaysExport {
   476  				result = append(result, fmt.Sprintf("--%s=%s", name, iface))
   477  			}
   478  
   479  		case reflect.TypeOf([]string{}):
   480  			strings := iface.([]string)
   481  			for _, s := range strings {
   482  				result = append(result, fmt.Sprintf("--%s=%s", name, s))
   483  			}
   484  		default:
   485  			return []string{}, fmt.Errorf("unsupported type %T", iface)
   486  		}
   487  	}
   488  
   489  	return result, nil
   490  }
   491  

View as plain text