...

Source file src/github.com/spf13/cobra/flag_groups.go

Documentation: github.com/spf13/cobra

     1  // Copyright 2013-2023 The Cobra Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //      http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package cobra
    16  
    17  import (
    18  	"fmt"
    19  	"sort"
    20  	"strings"
    21  
    22  	flag "github.com/spf13/pflag"
    23  )
    24  
    25  const (
    26  	requiredAsGroup   = "cobra_annotation_required_if_others_set"
    27  	oneRequired       = "cobra_annotation_one_required"
    28  	mutuallyExclusive = "cobra_annotation_mutually_exclusive"
    29  )
    30  
    31  // MarkFlagsRequiredTogether marks the given flags with annotations so that Cobra errors
    32  // if the command is invoked with a subset (but not all) of the given flags.
    33  func (c *Command) MarkFlagsRequiredTogether(flagNames ...string) {
    34  	c.mergePersistentFlags()
    35  	for _, v := range flagNames {
    36  		f := c.Flags().Lookup(v)
    37  		if f == nil {
    38  			panic(fmt.Sprintf("Failed to find flag %q and mark it as being required in a flag group", v))
    39  		}
    40  		if err := c.Flags().SetAnnotation(v, requiredAsGroup, append(f.Annotations[requiredAsGroup], strings.Join(flagNames, " "))); err != nil {
    41  			// Only errs if the flag isn't found.
    42  			panic(err)
    43  		}
    44  	}
    45  }
    46  
    47  // MarkFlagsOneRequired marks the given flags with annotations so that Cobra errors
    48  // if the command is invoked without at least one flag from the given set of flags.
    49  func (c *Command) MarkFlagsOneRequired(flagNames ...string) {
    50  	c.mergePersistentFlags()
    51  	for _, v := range flagNames {
    52  		f := c.Flags().Lookup(v)
    53  		if f == nil {
    54  			panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a one-required flag group", v))
    55  		}
    56  		if err := c.Flags().SetAnnotation(v, oneRequired, append(f.Annotations[oneRequired], strings.Join(flagNames, " "))); err != nil {
    57  			// Only errs if the flag isn't found.
    58  			panic(err)
    59  		}
    60  	}
    61  }
    62  
    63  // MarkFlagsMutuallyExclusive marks the given flags with annotations so that Cobra errors
    64  // if the command is invoked with more than one flag from the given set of flags.
    65  func (c *Command) MarkFlagsMutuallyExclusive(flagNames ...string) {
    66  	c.mergePersistentFlags()
    67  	for _, v := range flagNames {
    68  		f := c.Flags().Lookup(v)
    69  		if f == nil {
    70  			panic(fmt.Sprintf("Failed to find flag %q and mark it as being in a mutually exclusive flag group", v))
    71  		}
    72  		// Each time this is called is a single new entry; this allows it to be a member of multiple groups if needed.
    73  		if err := c.Flags().SetAnnotation(v, mutuallyExclusive, append(f.Annotations[mutuallyExclusive], strings.Join(flagNames, " "))); err != nil {
    74  			panic(err)
    75  		}
    76  	}
    77  }
    78  
    79  // ValidateFlagGroups validates the mutuallyExclusive/oneRequired/requiredAsGroup logic and returns the
    80  // first error encountered.
    81  func (c *Command) ValidateFlagGroups() error {
    82  	if c.DisableFlagParsing {
    83  		return nil
    84  	}
    85  
    86  	flags := c.Flags()
    87  
    88  	// groupStatus format is the list of flags as a unique ID,
    89  	// then a map of each flag name and whether it is set or not.
    90  	groupStatus := map[string]map[string]bool{}
    91  	oneRequiredGroupStatus := map[string]map[string]bool{}
    92  	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
    93  	flags.VisitAll(func(pflag *flag.Flag) {
    94  		processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
    95  		processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
    96  		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
    97  	})
    98  
    99  	if err := validateRequiredFlagGroups(groupStatus); err != nil {
   100  		return err
   101  	}
   102  	if err := validateOneRequiredFlagGroups(oneRequiredGroupStatus); err != nil {
   103  		return err
   104  	}
   105  	if err := validateExclusiveFlagGroups(mutuallyExclusiveGroupStatus); err != nil {
   106  		return err
   107  	}
   108  	return nil
   109  }
   110  
   111  func hasAllFlags(fs *flag.FlagSet, flagnames ...string) bool {
   112  	for _, fname := range flagnames {
   113  		f := fs.Lookup(fname)
   114  		if f == nil {
   115  			return false
   116  		}
   117  	}
   118  	return true
   119  }
   120  
   121  func processFlagForGroupAnnotation(flags *flag.FlagSet, pflag *flag.Flag, annotation string, groupStatus map[string]map[string]bool) {
   122  	groupInfo, found := pflag.Annotations[annotation]
   123  	if found {
   124  		for _, group := range groupInfo {
   125  			if groupStatus[group] == nil {
   126  				flagnames := strings.Split(group, " ")
   127  
   128  				// Only consider this flag group at all if all the flags are defined.
   129  				if !hasAllFlags(flags, flagnames...) {
   130  					continue
   131  				}
   132  
   133  				groupStatus[group] = map[string]bool{}
   134  				for _, name := range flagnames {
   135  					groupStatus[group][name] = false
   136  				}
   137  			}
   138  
   139  			groupStatus[group][pflag.Name] = pflag.Changed
   140  		}
   141  	}
   142  }
   143  
   144  func validateRequiredFlagGroups(data map[string]map[string]bool) error {
   145  	keys := sortedKeys(data)
   146  	for _, flagList := range keys {
   147  		flagnameAndStatus := data[flagList]
   148  
   149  		unset := []string{}
   150  		for flagname, isSet := range flagnameAndStatus {
   151  			if !isSet {
   152  				unset = append(unset, flagname)
   153  			}
   154  		}
   155  		if len(unset) == len(flagnameAndStatus) || len(unset) == 0 {
   156  			continue
   157  		}
   158  
   159  		// Sort values, so they can be tested/scripted against consistently.
   160  		sort.Strings(unset)
   161  		return fmt.Errorf("if any flags in the group [%v] are set they must all be set; missing %v", flagList, unset)
   162  	}
   163  
   164  	return nil
   165  }
   166  
   167  func validateOneRequiredFlagGroups(data map[string]map[string]bool) error {
   168  	keys := sortedKeys(data)
   169  	for _, flagList := range keys {
   170  		flagnameAndStatus := data[flagList]
   171  		var set []string
   172  		for flagname, isSet := range flagnameAndStatus {
   173  			if isSet {
   174  				set = append(set, flagname)
   175  			}
   176  		}
   177  		if len(set) >= 1 {
   178  			continue
   179  		}
   180  
   181  		// Sort values, so they can be tested/scripted against consistently.
   182  		sort.Strings(set)
   183  		return fmt.Errorf("at least one of the flags in the group [%v] is required", flagList)
   184  	}
   185  	return nil
   186  }
   187  
   188  func validateExclusiveFlagGroups(data map[string]map[string]bool) error {
   189  	keys := sortedKeys(data)
   190  	for _, flagList := range keys {
   191  		flagnameAndStatus := data[flagList]
   192  		var set []string
   193  		for flagname, isSet := range flagnameAndStatus {
   194  			if isSet {
   195  				set = append(set, flagname)
   196  			}
   197  		}
   198  		if len(set) == 0 || len(set) == 1 {
   199  			continue
   200  		}
   201  
   202  		// Sort values, so they can be tested/scripted against consistently.
   203  		sort.Strings(set)
   204  		return fmt.Errorf("if any flags in the group [%v] are set none of the others can be; %v were all set", flagList, set)
   205  	}
   206  	return nil
   207  }
   208  
   209  func sortedKeys(m map[string]map[string]bool) []string {
   210  	keys := make([]string, len(m))
   211  	i := 0
   212  	for k := range m {
   213  		keys[i] = k
   214  		i++
   215  	}
   216  	sort.Strings(keys)
   217  	return keys
   218  }
   219  
   220  // enforceFlagGroupsForCompletion will do the following:
   221  // - when a flag in a group is present, other flags in the group will be marked required
   222  // - when none of the flags in a one-required group are present, all flags in the group will be marked required
   223  // - when a flag in a mutually exclusive group is present, other flags in the group will be marked as hidden
   224  // This allows the standard completion logic to behave appropriately for flag groups
   225  func (c *Command) enforceFlagGroupsForCompletion() {
   226  	if c.DisableFlagParsing {
   227  		return
   228  	}
   229  
   230  	flags := c.Flags()
   231  	groupStatus := map[string]map[string]bool{}
   232  	oneRequiredGroupStatus := map[string]map[string]bool{}
   233  	mutuallyExclusiveGroupStatus := map[string]map[string]bool{}
   234  	c.Flags().VisitAll(func(pflag *flag.Flag) {
   235  		processFlagForGroupAnnotation(flags, pflag, requiredAsGroup, groupStatus)
   236  		processFlagForGroupAnnotation(flags, pflag, oneRequired, oneRequiredGroupStatus)
   237  		processFlagForGroupAnnotation(flags, pflag, mutuallyExclusive, mutuallyExclusiveGroupStatus)
   238  	})
   239  
   240  	// If a flag that is part of a group is present, we make all the other flags
   241  	// of that group required so that the shell completion suggests them automatically
   242  	for flagList, flagnameAndStatus := range groupStatus {
   243  		for _, isSet := range flagnameAndStatus {
   244  			if isSet {
   245  				// One of the flags of the group is set, mark the other ones as required
   246  				for _, fName := range strings.Split(flagList, " ") {
   247  					_ = c.MarkFlagRequired(fName)
   248  				}
   249  			}
   250  		}
   251  	}
   252  
   253  	// If none of the flags of a one-required group are present, we make all the flags
   254  	// of that group required so that the shell completion suggests them automatically
   255  	for flagList, flagnameAndStatus := range oneRequiredGroupStatus {
   256  		set := 0
   257  
   258  		for _, isSet := range flagnameAndStatus {
   259  			if isSet {
   260  				set++
   261  			}
   262  		}
   263  
   264  		// None of the flags of the group are set, mark all flags in the group
   265  		// as required
   266  		if set == 0 {
   267  			for _, fName := range strings.Split(flagList, " ") {
   268  				_ = c.MarkFlagRequired(fName)
   269  			}
   270  		}
   271  	}
   272  
   273  	// If a flag that is mutually exclusive to others is present, we hide the other
   274  	// flags of that group so the shell completion does not suggest them
   275  	for flagList, flagnameAndStatus := range mutuallyExclusiveGroupStatus {
   276  		for flagName, isSet := range flagnameAndStatus {
   277  			if isSet {
   278  				// One of the flags of the mutually exclusive group is set, mark the other ones as hidden
   279  				// Don't mark the flag that is already set as hidden because it may be an
   280  				// array or slice flag and therefore must continue being suggested
   281  				for _, fName := range strings.Split(flagList, " ") {
   282  					if fName != flagName {
   283  						flag := c.Flags().Lookup(fName)
   284  						flag.Hidden = true
   285  					}
   286  				}
   287  			}
   288  		}
   289  	}
   290  }
   291  

View as plain text