...

Source file src/edge-infra.dev/pkg/edge/edgecli/flagutil/validation.go

Documentation: edge-infra.dev/pkg/edge/edgecli/flagutil

     1  package flagutil
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"flag"
     7  	"fmt"
     8  	"net"
     9  	"os"
    10  	"regexp"
    11  	"strconv"
    12  	"time"
    13  
    14  	"github.com/shurcooL/graphql"
    15  	"golang.org/x/term"
    16  	"k8s.io/client-go/tools/clientcmd"
    17  
    18  	"edge-infra.dev/pkg/edge/api/graph/model"
    19  	clustertype "edge-infra.dev/pkg/edge/constants/api/cluster"
    20  	"edge-infra.dev/pkg/edge/constants/api/fleet"
    21  	"edge-infra.dev/pkg/edge/edgecli"
    22  	"edge-infra.dev/pkg/lib/cli/rags"
    23  	"edge-infra.dev/pkg/lib/networkvalidator"
    24  )
    25  
    26  func ValidateRequiredFlags(rs *rags.RagSet) error {
    27  	for _, rag := range rs.Rags() {
    28  		if rag.Required {
    29  			switch rag.Value.(type) {
    30  			case *rags.Value[string]:
    31  				if rag.Value.String() == "" {
    32  					return fmt.Errorf("Flag '%s' is required", rag.Name)
    33  				}
    34  			case *rags.Value[int], *rags.Value[float64]:
    35  				if rag.Value.String() == "0" {
    36  					return fmt.Errorf("Flag '%s' is required", rag.Name)
    37  				}
    38  			}
    39  		}
    40  	}
    41  	return nil
    42  }
    43  
    44  func ValidateConnectionFlags(rs *rags.RagSet, cfg *edgecli.Config) error {
    45  	// if token passed or defaults to non empty context, no errors
    46  	if GetStringFlag(rs, BearerTokenFlag) != "" {
    47  		now := time.Now()
    48  		expiration, err := time.Parse(time.RFC3339, cfg.CurrentBanner().TokenTime)
    49  		if err != nil {
    50  			return err
    51  		}
    52  		if now.Before(expiration) {
    53  			return nil
    54  		}
    55  	}
    56  
    57  	var input string
    58  	var err error
    59  
    60  	// if username neither passed nor found in context
    61  	if GetStringFlag(rs, UsernameFlag) == "" {
    62  		fmt.Println("Please enter BSL username.")
    63  		_, _ = fmt.Scanln(&input)
    64  		for input == "" {
    65  			fmt.Println("Please enter a valid username.")
    66  			_, _ = fmt.Scanln(&input)
    67  		}
    68  		if err = SetFlag(rs, UsernameFlag, input); err != nil {
    69  			return err
    70  		}
    71  	}
    72  
    73  	// prompt for password
    74  	if GetStringFlag(rs, PasswordFlag) == "" {
    75  		password, err := PromptPassword()
    76  		if err != nil {
    77  			return err
    78  		}
    79  		if err = SetFlag(rs, PasswordFlag, password); err != nil {
    80  			return err
    81  		}
    82  	}
    83  
    84  	// if organization neither passed nor found in context
    85  	if GetStringFlag(rs, OrganizationFlag) == "" {
    86  		fmt.Println("Please enter BSL organization.")
    87  		_, _ = fmt.Scanln(&input)
    88  		for input == "" {
    89  			fmt.Println("Please enter a valid organization.")
    90  			_, _ = fmt.Scanln(&input)
    91  		}
    92  		if err = SetFlag(rs, OrganizationFlag, input); err != nil {
    93  			return err
    94  		}
    95  	}
    96  
    97  	if GetStringFlag(rs, Endpoint) == "" && GetStringFlag(rs, "bff-endpoint") == "" {
    98  		fmt.Println("Please enter a API Endpoint.")
    99  		_, _ = fmt.Scanln(&input)
   100  		for input == "" {
   101  			fmt.Println("Please enter a valid API Endpoint.")
   102  			_, _ = fmt.Scanln(&input)
   103  		}
   104  		if err = SetFlag(rs, Endpoint, input); err != nil {
   105  			return err
   106  		}
   107  	} else if GetStringFlag(rs, Endpoint) == "" && GetStringFlag(rs, "bff-endpoint") != "" {
   108  		fmt.Println("WARNING: bff-endpoint flag is being deprecated in Edge Version 0.22, please consider switching over to api-endpoint")
   109  		errSettingAlias := SetFlag(rs, Endpoint, GetStringFlag(rs, "bff-endpoint"))
   110  		if errSettingAlias != nil {
   111  			return errSettingAlias
   112  		}
   113  	}
   114  
   115  	err = UpdateBearerToken(
   116  		cfg,
   117  		GetStringFlag(rs, Endpoint),
   118  		GetStringFlag(rs, UsernameFlag),
   119  		GetStringFlag(rs, PasswordFlag),
   120  		GetStringFlag(rs, OrganizationFlag),
   121  	)
   122  	return err
   123  }
   124  
   125  func ValidateFleetType(rs *rags.RagSet) error {
   126  	flag := rs.FlagSet().Lookup(FleetFlag)
   127  	if flag == nil {
   128  		return errors.New("Fleet flag not found")
   129  	}
   130  	fleetType := flag.Value
   131  	if fleetType.String() != "" {
   132  		fleetTypeError := fleet.IsValid(fleetType.String())
   133  		if fleetTypeError != nil {
   134  			return fleetTypeError
   135  		}
   136  	}
   137  	return nil
   138  }
   139  
   140  func ValidateClusterType(rs *rags.RagSet) error {
   141  	flag := rs.FlagSet().Lookup(ClusterTypeFlag)
   142  	if flag == nil {
   143  		return errors.New("Cluster type flag not found")
   144  	}
   145  	clusterType := flag.Value
   146  	if clusterType.String() != "" {
   147  		clusterTypeError := clustertype.Type(clusterType.String()).IsValid()
   148  		if clusterTypeError != nil {
   149  			return clusterTypeError
   150  		}
   151  	}
   152  	return nil
   153  }
   154  
   155  func CheckClusterIdentifierValue(rs *rags.RagSet) error {
   156  	if GetStringFlag(rs, KubeConfigContextFlag) == "" {
   157  		clientCfg, _ := clientcmd.NewDefaultClientConfigLoadingRules().Load()
   158  		if clientCfg.CurrentContext == "" {
   159  			return fmt.Errorf("failed to determine current config context: no context is set, or there is no kubeconfig")
   160  		}
   161  		context := clientCfg.CurrentContext
   162  		if err := rs.FlagSet().Lookup(KubeConfigContextFlag).Value.Set(context); err != nil {
   163  			return fmt.Errorf("failed to apply current config context: %w", err)
   164  		}
   165  		fmt.Printf("using default kubeconfig-context value: %v\n", GetStringFlag(rs, KubeConfigContextFlag))
   166  	}
   167  	if GetStringFlag(rs, KubeConfigFlag) == "" {
   168  		clientCfg, err := clientcmd.NewDefaultClientConfigLoadingRules().Load()
   169  		if err != nil {
   170  			return fmt.Errorf("failed to find kubeconfig: %w", err)
   171  		}
   172  		ctx, ok := clientCfg.Contexts[clientCfg.CurrentContext]
   173  		if !ok {
   174  			return fmt.Errorf("current context %s not found in contexts", clientCfg.CurrentContext)
   175  		}
   176  		configLoc := ctx.LocationOfOrigin
   177  		if err := rs.FlagSet().Lookup(KubeConfigFlag).Value.Set(configLoc); err != nil {
   178  			return fmt.Errorf("failed to apply current kubeconfig value: %w", err)
   179  		}
   180  		fmt.Printf("kubeconfig input is empty, using default kubeconfig value: %v\n", GetStringFlag(rs, KubeConfigFlag))
   181  	}
   182  	return nil
   183  }
   184  
   185  func PromptPassword() (string, error) {
   186  	fmt.Println("Please enter BSL password.")
   187  
   188  	bytePassword, err := term.ReadPassword(int(os.Stdin.Fd()))
   189  	if err != nil {
   190  		return "", err
   191  	}
   192  	for string(bytePassword) == "" {
   193  		fmt.Println("Please enter a valid password.")
   194  		bytePassword, err = term.ReadPassword(int(os.Stdin.Fd()))
   195  		if err != nil {
   196  			return "", err
   197  		}
   198  	}
   199  	return string(bytePassword), nil
   200  }
   201  
   202  func GetBearerToken(url, username, password, organization string) (string, error) {
   203  	client := graphql.NewClient(url, nil)
   204  
   205  	var mutation struct {
   206  		Login struct {
   207  			Token              graphql.String
   208  			CredentialsExpired graphql.Boolean
   209  		} `graphql:"login(username: $username, password: $password, organization: $organization)"`
   210  	}
   211  
   212  	variables := map[string]interface{}{
   213  		"username":     graphql.String(username),
   214  		"password":     graphql.String(password),
   215  		"organization": graphql.String(organization),
   216  	}
   217  
   218  	err := client.Mutate(context.Background(), &mutation, variables)
   219  	if mutation.Login.CredentialsExpired {
   220  		return "", errors.New("credentials have expired. Please navigate to the UI to reset your credentials")
   221  	}
   222  	return string(mutation.Login.Token), err
   223  }
   224  
   225  func UpdateBearerToken(cfg *edgecli.Config, url, username, password, organization string) error {
   226  	token, err := GetBearerToken(url, username, password, organization)
   227  	if err != nil {
   228  		return err
   229  	}
   230  	cfg.CurrentBanner().Token = token
   231  	cfg.CurrentBanner().TokenTime = time.Now().Add(15 * time.Minute).Format(time.RFC3339)
   232  	return edgecli.WriteConfig(cfg)
   233  }
   234  
   235  // This avoids an apparent bug whereby cli.Context.IsSet(myflag) always returns true for a
   236  // flag that is defined, even if you didn't pass it when you ran the CLI command.
   237  func FlagWasPassed(rs *rags.RagSet, flagName string) bool {
   238  	value := Value(rs, flagName)
   239  	var result bool
   240  
   241  	// Boolean flags are different from flags of any other type, since their "value" is
   242  	// indicated by their presence/absence, not by a value that follows the flag.
   243  	switch val := value.(type) {
   244  	case bool:
   245  		{
   246  			result = val
   247  		}
   248  	case int:
   249  		{
   250  			result = val != 0
   251  		}
   252  	default:
   253  		{
   254  			result = val != ""
   255  		}
   256  	}
   257  	return result
   258  }
   259  
   260  func GetOptionalFlagValue[T any](rs *rags.RagSet, flagName string) *T {
   261  	var flagValue *T
   262  
   263  	if FlagWasPassed(rs, flagName) {
   264  		value := Value(rs, flagName).(T)
   265  		flagValue = &value
   266  	}
   267  	return flagValue
   268  }
   269  
   270  func Value(rs *rags.RagSet, name string) interface{} {
   271  	return rs.FlagSet().Lookup(name).Value.(flag.Getter).Get()
   272  }
   273  
   274  // Check that all strings successfully parse as the things they're supposed to represent. Return only the parsed value(s)
   275  // that we currently need, or an error on any validation failure.
   276  func ValidateTerminalFlags(
   277  	role model.TerminalRoleType, class *model.TerminalClassType, _ string, mac string, ipv4addr *string, ipv6addr *string, prefixLen4 string, prefixLen6 string, hostname *string, dhcp6 bool, dhcp4 bool) (int, int, error) {
   278  	var err error
   279  
   280  	if !(role == model.TerminalRoleTypeControlplane || role == model.TerminalRoleTypeWorker) {
   281  		return 0, 0, fmt.Errorf("role must be one of '%v' or '%v'", model.TerminalRoleTypeControlplane, model.TerminalRoleTypeWorker)
   282  	}
   283  
   284  	if !isValidTerminalClass(class) {
   285  		return 0, 0, fmt.Errorf("class must be one of '%v' or '%v'", model.TerminalClassTypeServer, model.TerminalClassTypeTouchpoint)
   286  	}
   287  
   288  	_, err = net.ParseMAC(mac)
   289  	if err != nil {
   290  		return 0, 0, fmt.Errorf("invalid MAC address '%v'", mac)
   291  	}
   292  
   293  	if ipv4addr != nil && !networkvalidator.ValidateIP(*ipv4addr) {
   294  		return 0, 0, fmt.Errorf("invalid IPv4 address %s", *ipv4addr)
   295  	}
   296  
   297  	if ipv6addr != nil && !networkvalidator.ValidateIP(*ipv6addr) {
   298  		return 0, 0, fmt.Errorf("invalid IPv6 address %s", *ipv6addr)
   299  	}
   300  
   301  	if ipv6addr != nil && dhcp6 {
   302  		return 0, 0, fmt.Errorf("must not provide IPv6 address if DHCP6 is enabled")
   303  	}
   304  
   305  	if ipv4addr != nil && dhcp4 {
   306  		return 0, 0, fmt.Errorf("must not provide IPv4 address if DHCP4 is enabled")
   307  	}
   308  
   309  	prefLen4, err := strconv.Atoi(prefixLen4)
   310  	if err != nil || prefLen4 < 0 || prefLen4 > 32 {
   311  		return 0, 0, fmt.Errorf("invalid ipv4 network prefix length '%v'", prefixLen4)
   312  	}
   313  
   314  	prefLen6, err := strconv.Atoi(prefixLen6)
   315  	if err != nil || prefLen6 < 0 || prefLen6 > 128 {
   316  		return 0, 0, fmt.Errorf("invalid ipv6 network prefix length '%v'", prefixLen6)
   317  	}
   318  
   319  	if hostname != nil {
   320  		if err = networkvalidator.IsValidHostname(*hostname); err != nil {
   321  			return 0, 0, err
   322  		}
   323  	}
   324  
   325  	return prefLen4, prefLen6, nil
   326  }
   327  
   328  func IsValidUUID(uuidInput string) bool {
   329  	pattern := `^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`
   330  	matchedString, err := regexp.MatchString(pattern, uuidInput)
   331  	if err != nil {
   332  		return false
   333  	}
   334  	return matchedString
   335  }
   336  
   337  func IsValidServiceType(serviceType string) bool {
   338  	return serviceType == "ntp" || serviceType == "dns" || serviceType == "kube-vip"
   339  }
   340  

View as plain text