package flagutil import ( "context" "errors" "flag" "fmt" "net" "os" "regexp" "strconv" "time" "github.com/shurcooL/graphql" "golang.org/x/term" "k8s.io/client-go/tools/clientcmd" "edge-infra.dev/pkg/edge/api/graph/model" clustertype "edge-infra.dev/pkg/edge/constants/api/cluster" "edge-infra.dev/pkg/edge/constants/api/fleet" "edge-infra.dev/pkg/edge/edgecli" "edge-infra.dev/pkg/lib/cli/rags" "edge-infra.dev/pkg/lib/networkvalidator" ) func ValidateRequiredFlags(rs *rags.RagSet) error { for _, rag := range rs.Rags() { if rag.Required { switch rag.Value.(type) { case *rags.Value[string]: if rag.Value.String() == "" { return fmt.Errorf("Flag '%s' is required", rag.Name) } case *rags.Value[int], *rags.Value[float64]: if rag.Value.String() == "0" { return fmt.Errorf("Flag '%s' is required", rag.Name) } } } } return nil } func ValidateConnectionFlags(rs *rags.RagSet, cfg *edgecli.Config) error { // if token passed or defaults to non empty context, no errors if GetStringFlag(rs, BearerTokenFlag) != "" { now := time.Now() expiration, err := time.Parse(time.RFC3339, cfg.CurrentBanner().TokenTime) if err != nil { return err } if now.Before(expiration) { return nil } } var input string var err error // if username neither passed nor found in context if GetStringFlag(rs, UsernameFlag) == "" { fmt.Println("Please enter BSL username.") _, _ = fmt.Scanln(&input) for input == "" { fmt.Println("Please enter a valid username.") _, _ = fmt.Scanln(&input) } if err = SetFlag(rs, UsernameFlag, input); err != nil { return err } } // prompt for password if GetStringFlag(rs, PasswordFlag) == "" { password, err := PromptPassword() if err != nil { return err } if err = SetFlag(rs, PasswordFlag, password); err != nil { return err } } // if organization neither passed nor found in context if GetStringFlag(rs, OrganizationFlag) == "" { fmt.Println("Please enter BSL organization.") _, _ = fmt.Scanln(&input) for input == "" { fmt.Println("Please enter a valid organization.") _, _ = fmt.Scanln(&input) } if err = SetFlag(rs, OrganizationFlag, input); err != nil { return err } } if GetStringFlag(rs, Endpoint) == "" && GetStringFlag(rs, "bff-endpoint") == "" { fmt.Println("Please enter a API Endpoint.") _, _ = fmt.Scanln(&input) for input == "" { fmt.Println("Please enter a valid API Endpoint.") _, _ = fmt.Scanln(&input) } if err = SetFlag(rs, Endpoint, input); err != nil { return err } } else if GetStringFlag(rs, Endpoint) == "" && GetStringFlag(rs, "bff-endpoint") != "" { fmt.Println("WARNING: bff-endpoint flag is being deprecated in Edge Version 0.22, please consider switching over to api-endpoint") errSettingAlias := SetFlag(rs, Endpoint, GetStringFlag(rs, "bff-endpoint")) if errSettingAlias != nil { return errSettingAlias } } err = UpdateBearerToken( cfg, GetStringFlag(rs, Endpoint), GetStringFlag(rs, UsernameFlag), GetStringFlag(rs, PasswordFlag), GetStringFlag(rs, OrganizationFlag), ) return err } func ValidateFleetType(rs *rags.RagSet) error { flag := rs.FlagSet().Lookup(FleetFlag) if flag == nil { return errors.New("Fleet flag not found") } fleetType := flag.Value if fleetType.String() != "" { fleetTypeError := fleet.IsValid(fleetType.String()) if fleetTypeError != nil { return fleetTypeError } } return nil } func ValidateClusterType(rs *rags.RagSet) error { flag := rs.FlagSet().Lookup(ClusterTypeFlag) if flag == nil { return errors.New("Cluster type flag not found") } clusterType := flag.Value if clusterType.String() != "" { clusterTypeError := clustertype.Type(clusterType.String()).IsValid() if clusterTypeError != nil { return clusterTypeError } } return nil } func CheckClusterIdentifierValue(rs *rags.RagSet) error { if GetStringFlag(rs, KubeConfigContextFlag) == "" { clientCfg, _ := clientcmd.NewDefaultClientConfigLoadingRules().Load() if clientCfg.CurrentContext == "" { return fmt.Errorf("failed to determine current config context: no context is set, or there is no kubeconfig") } context := clientCfg.CurrentContext if err := rs.FlagSet().Lookup(KubeConfigContextFlag).Value.Set(context); err != nil { return fmt.Errorf("failed to apply current config context: %w", err) } fmt.Printf("using default kubeconfig-context value: %v\n", GetStringFlag(rs, KubeConfigContextFlag)) } if GetStringFlag(rs, KubeConfigFlag) == "" { clientCfg, err := clientcmd.NewDefaultClientConfigLoadingRules().Load() if err != nil { return fmt.Errorf("failed to find kubeconfig: %w", err) } ctx, ok := clientCfg.Contexts[clientCfg.CurrentContext] if !ok { return fmt.Errorf("current context %s not found in contexts", clientCfg.CurrentContext) } configLoc := ctx.LocationOfOrigin if err := rs.FlagSet().Lookup(KubeConfigFlag).Value.Set(configLoc); err != nil { return fmt.Errorf("failed to apply current kubeconfig value: %w", err) } fmt.Printf("kubeconfig input is empty, using default kubeconfig value: %v\n", GetStringFlag(rs, KubeConfigFlag)) } return nil } func PromptPassword() (string, error) { fmt.Println("Please enter BSL password.") bytePassword, err := term.ReadPassword(int(os.Stdin.Fd())) if err != nil { return "", err } for string(bytePassword) == "" { fmt.Println("Please enter a valid password.") bytePassword, err = term.ReadPassword(int(os.Stdin.Fd())) if err != nil { return "", err } } return string(bytePassword), nil } func GetBearerToken(url, username, password, organization string) (string, error) { client := graphql.NewClient(url, nil) var mutation struct { Login struct { Token graphql.String CredentialsExpired graphql.Boolean } `graphql:"login(username: $username, password: $password, organization: $organization)"` } variables := map[string]interface{}{ "username": graphql.String(username), "password": graphql.String(password), "organization": graphql.String(organization), } err := client.Mutate(context.Background(), &mutation, variables) if mutation.Login.CredentialsExpired { return "", errors.New("credentials have expired. Please navigate to the UI to reset your credentials") } return string(mutation.Login.Token), err } func UpdateBearerToken(cfg *edgecli.Config, url, username, password, organization string) error { token, err := GetBearerToken(url, username, password, organization) if err != nil { return err } cfg.CurrentBanner().Token = token cfg.CurrentBanner().TokenTime = time.Now().Add(15 * time.Minute).Format(time.RFC3339) return edgecli.WriteConfig(cfg) } // This avoids an apparent bug whereby cli.Context.IsSet(myflag) always returns true for a // flag that is defined, even if you didn't pass it when you ran the CLI command. func FlagWasPassed(rs *rags.RagSet, flagName string) bool { value := Value(rs, flagName) var result bool // Boolean flags are different from flags of any other type, since their "value" is // indicated by their presence/absence, not by a value that follows the flag. switch val := value.(type) { case bool: { result = val } case int: { result = val != 0 } default: { result = val != "" } } return result } func GetOptionalFlagValue[T any](rs *rags.RagSet, flagName string) *T { var flagValue *T if FlagWasPassed(rs, flagName) { value := Value(rs, flagName).(T) flagValue = &value } return flagValue } func Value(rs *rags.RagSet, name string) interface{} { return rs.FlagSet().Lookup(name).Value.(flag.Getter).Get() } // Check that all strings successfully parse as the things they're supposed to represent. Return only the parsed value(s) // that we currently need, or an error on any validation failure. func ValidateTerminalFlags( role model.TerminalRoleType, class *model.TerminalClassType, _ string, mac string, ipv4addr *string, ipv6addr *string, prefixLen4 string, prefixLen6 string, hostname *string, dhcp6 bool, dhcp4 bool) (int, int, error) { var err error if !(role == model.TerminalRoleTypeControlplane || role == model.TerminalRoleTypeWorker) { return 0, 0, fmt.Errorf("role must be one of '%v' or '%v'", model.TerminalRoleTypeControlplane, model.TerminalRoleTypeWorker) } if !isValidTerminalClass(class) { return 0, 0, fmt.Errorf("class must be one of '%v' or '%v'", model.TerminalClassTypeServer, model.TerminalClassTypeTouchpoint) } _, err = net.ParseMAC(mac) if err != nil { return 0, 0, fmt.Errorf("invalid MAC address '%v'", mac) } if ipv4addr != nil && !networkvalidator.ValidateIP(*ipv4addr) { return 0, 0, fmt.Errorf("invalid IPv4 address %s", *ipv4addr) } if ipv6addr != nil && !networkvalidator.ValidateIP(*ipv6addr) { return 0, 0, fmt.Errorf("invalid IPv6 address %s", *ipv6addr) } if ipv6addr != nil && dhcp6 { return 0, 0, fmt.Errorf("must not provide IPv6 address if DHCP6 is enabled") } if ipv4addr != nil && dhcp4 { return 0, 0, fmt.Errorf("must not provide IPv4 address if DHCP4 is enabled") } prefLen4, err := strconv.Atoi(prefixLen4) if err != nil || prefLen4 < 0 || prefLen4 > 32 { return 0, 0, fmt.Errorf("invalid ipv4 network prefix length '%v'", prefixLen4) } prefLen6, err := strconv.Atoi(prefixLen6) if err != nil || prefLen6 < 0 || prefLen6 > 128 { return 0, 0, fmt.Errorf("invalid ipv6 network prefix length '%v'", prefixLen6) } if hostname != nil { if err = networkvalidator.IsValidHostname(*hostname); err != nil { return 0, 0, err } } return prefLen4, prefLen6, nil } func IsValidUUID(uuidInput string) bool { pattern := `^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$` matchedString, err := regexp.MatchString(pattern, uuidInput) if err != nil { return false } return matchedString } func IsValidServiceType(serviceType string) bool { return serviceType == "ntp" || serviceType == "dns" || serviceType == "kube-vip" }