1 package flagutil
2
3 import (
4 "context"
5 "errors"
6 "flag"
7 "fmt"
8 "net"
9 "os"
10 "regexp"
11 "strconv"
12 "time"
13
14 "github.com/shurcooL/graphql"
15 "golang.org/x/term"
16 "k8s.io/client-go/tools/clientcmd"
17
18 "edge-infra.dev/pkg/edge/api/graph/model"
19 clustertype "edge-infra.dev/pkg/edge/constants/api/cluster"
20 "edge-infra.dev/pkg/edge/constants/api/fleet"
21 "edge-infra.dev/pkg/edge/edgecli"
22 "edge-infra.dev/pkg/lib/cli/rags"
23 "edge-infra.dev/pkg/lib/networkvalidator"
24 )
25
26 func ValidateRequiredFlags(rs *rags.RagSet) error {
27 for _, rag := range rs.Rags() {
28 if rag.Required {
29 switch rag.Value.(type) {
30 case *rags.Value[string]:
31 if rag.Value.String() == "" {
32 return fmt.Errorf("Flag '%s' is required", rag.Name)
33 }
34 case *rags.Value[int], *rags.Value[float64]:
35 if rag.Value.String() == "0" {
36 return fmt.Errorf("Flag '%s' is required", rag.Name)
37 }
38 }
39 }
40 }
41 return nil
42 }
43
44 func ValidateConnectionFlags(rs *rags.RagSet, cfg *edgecli.Config) error {
45
46 if GetStringFlag(rs, BearerTokenFlag) != "" {
47 now := time.Now()
48 expiration, err := time.Parse(time.RFC3339, cfg.CurrentBanner().TokenTime)
49 if err != nil {
50 return err
51 }
52 if now.Before(expiration) {
53 return nil
54 }
55 }
56
57 var input string
58 var err error
59
60
61 if GetStringFlag(rs, UsernameFlag) == "" {
62 fmt.Println("Please enter BSL username.")
63 _, _ = fmt.Scanln(&input)
64 for input == "" {
65 fmt.Println("Please enter a valid username.")
66 _, _ = fmt.Scanln(&input)
67 }
68 if err = SetFlag(rs, UsernameFlag, input); err != nil {
69 return err
70 }
71 }
72
73
74 if GetStringFlag(rs, PasswordFlag) == "" {
75 password, err := PromptPassword()
76 if err != nil {
77 return err
78 }
79 if err = SetFlag(rs, PasswordFlag, password); err != nil {
80 return err
81 }
82 }
83
84
85 if GetStringFlag(rs, OrganizationFlag) == "" {
86 fmt.Println("Please enter BSL organization.")
87 _, _ = fmt.Scanln(&input)
88 for input == "" {
89 fmt.Println("Please enter a valid organization.")
90 _, _ = fmt.Scanln(&input)
91 }
92 if err = SetFlag(rs, OrganizationFlag, input); err != nil {
93 return err
94 }
95 }
96
97 if GetStringFlag(rs, Endpoint) == "" && GetStringFlag(rs, "bff-endpoint") == "" {
98 fmt.Println("Please enter a API Endpoint.")
99 _, _ = fmt.Scanln(&input)
100 for input == "" {
101 fmt.Println("Please enter a valid API Endpoint.")
102 _, _ = fmt.Scanln(&input)
103 }
104 if err = SetFlag(rs, Endpoint, input); err != nil {
105 return err
106 }
107 } else if GetStringFlag(rs, Endpoint) == "" && GetStringFlag(rs, "bff-endpoint") != "" {
108 fmt.Println("WARNING: bff-endpoint flag is being deprecated in Edge Version 0.22, please consider switching over to api-endpoint")
109 errSettingAlias := SetFlag(rs, Endpoint, GetStringFlag(rs, "bff-endpoint"))
110 if errSettingAlias != nil {
111 return errSettingAlias
112 }
113 }
114
115 err = UpdateBearerToken(
116 cfg,
117 GetStringFlag(rs, Endpoint),
118 GetStringFlag(rs, UsernameFlag),
119 GetStringFlag(rs, PasswordFlag),
120 GetStringFlag(rs, OrganizationFlag),
121 )
122 return err
123 }
124
125 func ValidateFleetType(rs *rags.RagSet) error {
126 flag := rs.FlagSet().Lookup(FleetFlag)
127 if flag == nil {
128 return errors.New("Fleet flag not found")
129 }
130 fleetType := flag.Value
131 if fleetType.String() != "" {
132 fleetTypeError := fleet.IsValid(fleetType.String())
133 if fleetTypeError != nil {
134 return fleetTypeError
135 }
136 }
137 return nil
138 }
139
140 func ValidateClusterType(rs *rags.RagSet) error {
141 flag := rs.FlagSet().Lookup(ClusterTypeFlag)
142 if flag == nil {
143 return errors.New("Cluster type flag not found")
144 }
145 clusterType := flag.Value
146 if clusterType.String() != "" {
147 clusterTypeError := clustertype.Type(clusterType.String()).IsValid()
148 if clusterTypeError != nil {
149 return clusterTypeError
150 }
151 }
152 return nil
153 }
154
155 func CheckClusterIdentifierValue(rs *rags.RagSet) error {
156 if GetStringFlag(rs, KubeConfigContextFlag) == "" {
157 clientCfg, _ := clientcmd.NewDefaultClientConfigLoadingRules().Load()
158 if clientCfg.CurrentContext == "" {
159 return fmt.Errorf("failed to determine current config context: no context is set, or there is no kubeconfig")
160 }
161 context := clientCfg.CurrentContext
162 if err := rs.FlagSet().Lookup(KubeConfigContextFlag).Value.Set(context); err != nil {
163 return fmt.Errorf("failed to apply current config context: %w", err)
164 }
165 fmt.Printf("using default kubeconfig-context value: %v\n", GetStringFlag(rs, KubeConfigContextFlag))
166 }
167 if GetStringFlag(rs, KubeConfigFlag) == "" {
168 clientCfg, err := clientcmd.NewDefaultClientConfigLoadingRules().Load()
169 if err != nil {
170 return fmt.Errorf("failed to find kubeconfig: %w", err)
171 }
172 ctx, ok := clientCfg.Contexts[clientCfg.CurrentContext]
173 if !ok {
174 return fmt.Errorf("current context %s not found in contexts", clientCfg.CurrentContext)
175 }
176 configLoc := ctx.LocationOfOrigin
177 if err := rs.FlagSet().Lookup(KubeConfigFlag).Value.Set(configLoc); err != nil {
178 return fmt.Errorf("failed to apply current kubeconfig value: %w", err)
179 }
180 fmt.Printf("kubeconfig input is empty, using default kubeconfig value: %v\n", GetStringFlag(rs, KubeConfigFlag))
181 }
182 return nil
183 }
184
185 func PromptPassword() (string, error) {
186 fmt.Println("Please enter BSL password.")
187
188 bytePassword, err := term.ReadPassword(int(os.Stdin.Fd()))
189 if err != nil {
190 return "", err
191 }
192 for string(bytePassword) == "" {
193 fmt.Println("Please enter a valid password.")
194 bytePassword, err = term.ReadPassword(int(os.Stdin.Fd()))
195 if err != nil {
196 return "", err
197 }
198 }
199 return string(bytePassword), nil
200 }
201
202 func GetBearerToken(url, username, password, organization string) (string, error) {
203 client := graphql.NewClient(url, nil)
204
205 var mutation struct {
206 Login struct {
207 Token graphql.String
208 CredentialsExpired graphql.Boolean
209 } `graphql:"login(username: $username, password: $password, organization: $organization)"`
210 }
211
212 variables := map[string]interface{}{
213 "username": graphql.String(username),
214 "password": graphql.String(password),
215 "organization": graphql.String(organization),
216 }
217
218 err := client.Mutate(context.Background(), &mutation, variables)
219 if mutation.Login.CredentialsExpired {
220 return "", errors.New("credentials have expired. Please navigate to the UI to reset your credentials")
221 }
222 return string(mutation.Login.Token), err
223 }
224
225 func UpdateBearerToken(cfg *edgecli.Config, url, username, password, organization string) error {
226 token, err := GetBearerToken(url, username, password, organization)
227 if err != nil {
228 return err
229 }
230 cfg.CurrentBanner().Token = token
231 cfg.CurrentBanner().TokenTime = time.Now().Add(15 * time.Minute).Format(time.RFC3339)
232 return edgecli.WriteConfig(cfg)
233 }
234
235
236
237 func FlagWasPassed(rs *rags.RagSet, flagName string) bool {
238 value := Value(rs, flagName)
239 var result bool
240
241
242
243 switch val := value.(type) {
244 case bool:
245 {
246 result = val
247 }
248 case int:
249 {
250 result = val != 0
251 }
252 default:
253 {
254 result = val != ""
255 }
256 }
257 return result
258 }
259
260 func GetOptionalFlagValue[T any](rs *rags.RagSet, flagName string) *T {
261 var flagValue *T
262
263 if FlagWasPassed(rs, flagName) {
264 value := Value(rs, flagName).(T)
265 flagValue = &value
266 }
267 return flagValue
268 }
269
270 func Value(rs *rags.RagSet, name string) interface{} {
271 return rs.FlagSet().Lookup(name).Value.(flag.Getter).Get()
272 }
273
274
275
276 func ValidateTerminalFlags(
277 role model.TerminalRoleType, class *model.TerminalClassType, _ string, mac string, ipv4addr *string, ipv6addr *string, prefixLen4 string, prefixLen6 string, hostname *string, dhcp6 bool, dhcp4 bool) (int, int, error) {
278 var err error
279
280 if !(role == model.TerminalRoleTypeControlplane || role == model.TerminalRoleTypeWorker) {
281 return 0, 0, fmt.Errorf("role must be one of '%v' or '%v'", model.TerminalRoleTypeControlplane, model.TerminalRoleTypeWorker)
282 }
283
284 if !isValidTerminalClass(class) {
285 return 0, 0, fmt.Errorf("class must be one of '%v' or '%v'", model.TerminalClassTypeServer, model.TerminalClassTypeTouchpoint)
286 }
287
288 _, err = net.ParseMAC(mac)
289 if err != nil {
290 return 0, 0, fmt.Errorf("invalid MAC address '%v'", mac)
291 }
292
293 if ipv4addr != nil && !networkvalidator.ValidateIP(*ipv4addr) {
294 return 0, 0, fmt.Errorf("invalid IPv4 address %s", *ipv4addr)
295 }
296
297 if ipv6addr != nil && !networkvalidator.ValidateIP(*ipv6addr) {
298 return 0, 0, fmt.Errorf("invalid IPv6 address %s", *ipv6addr)
299 }
300
301 if ipv6addr != nil && dhcp6 {
302 return 0, 0, fmt.Errorf("must not provide IPv6 address if DHCP6 is enabled")
303 }
304
305 if ipv4addr != nil && dhcp4 {
306 return 0, 0, fmt.Errorf("must not provide IPv4 address if DHCP4 is enabled")
307 }
308
309 prefLen4, err := strconv.Atoi(prefixLen4)
310 if err != nil || prefLen4 < 0 || prefLen4 > 32 {
311 return 0, 0, fmt.Errorf("invalid ipv4 network prefix length '%v'", prefixLen4)
312 }
313
314 prefLen6, err := strconv.Atoi(prefixLen6)
315 if err != nil || prefLen6 < 0 || prefLen6 > 128 {
316 return 0, 0, fmt.Errorf("invalid ipv6 network prefix length '%v'", prefixLen6)
317 }
318
319 if hostname != nil {
320 if err = networkvalidator.IsValidHostname(*hostname); err != nil {
321 return 0, 0, err
322 }
323 }
324
325 return prefLen4, prefLen6, nil
326 }
327
328 func IsValidUUID(uuidInput string) bool {
329 pattern := `^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`
330 matchedString, err := regexp.MatchString(pattern, uuidInput)
331 if err != nil {
332 return false
333 }
334 return matchedString
335 }
336
337 func IsValidServiceType(serviceType string) bool {
338 return serviceType == "ntp" || serviceType == "dns" || serviceType == "kube-vip"
339 }
340
View as plain text