1 package sink
2
3 import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "os"
9 "strings"
10 "text/tabwriter"
11
12 "github.com/peterbourgon/ff/v3"
13
14 "edge-infra.dev/pkg/lib/cli/rags"
15 )
16
17 var (
18 defaultOpts = []ff.Option{ff.WithEnvVarNoPrefix()}
19 )
20
21
22
23 type Command struct {
24
25
26
27
28
29
30
31
32
33
34
35
36 Use string
37
38
39
40
41 Short string
42
43
44
45 Long string
46
47
48 Flags []*rags.Rag
49
50
51 Exec func(ctx context.Context, r Run) error
52
53
54 Commands []*Command
55
56
57
58 Extensions []Extension
59
60
61
62
63 Options []ff.Option
64
65
66
67
68
69 UsageFn func(*Command) string
70
71 out io.Writer
72 err io.Writer
73
74
75 rs *rags.RagSet
76 selected *Command
77 args []string
78 computed bool
79 logLvl int
80 logJSON bool
81 help bool
82 parent *Command
83 }
84
85
86
87
88 type Extension interface {
89 RegisterFlags(rs *rags.RagSet)
90 }
91
92
93
94
95
96 type BeforeRunner interface {
97 BeforeRun(context.Context, Run) (context.Context, Run, error)
98 }
99
100
101
102
103 type AfterRunner interface {
104 AfterRun(context.Context, Run) (context.Context, Run, error)
105 }
106
107
108
109 func (c *Command) Name() string {
110 n := c.Use
111 if i := strings.Index(n, " "); i >= 0 {
112 n = n[:i]
113 }
114 return n
115 }
116
117
118
119 func (c *Command) LongName() string {
120 if c.HasParent() {
121 return c.Parent().LongName() + " " + c.Name()
122 }
123 return c.Name()
124 }
125
126
127 func (c *Command) HasParent() bool {
128 return c.parent != nil
129 }
130
131
132 func (c *Command) Parent() *Command {
133 return c.parent
134 }
135
136
137
138 func (c *Command) AllParsingOptions() []ff.Option {
139 if c.HasParent() {
140 return append(c.Parent().AllParsingOptions(), c.Options...)
141 }
142 return c.Options
143 }
144
145
146 func (c *Command) Usage() string {
147 if c.UsageFn != nil {
148 return c.UsageFn(c)
149 }
150 return defaultUsageFn(c)
151 }
152
153
154
155
156 func (c *Command) Parse(args []string) error {
157 if c.selected != nil {
158 return nil
159 }
160
161 if err := c.compute(); err != nil {
162 return fmt.Errorf("failed to initialize CLI: %w", err)
163 }
164
165 parsingOpts := c.AllParsingOptions()
166 if len(parsingOpts) == 0 {
167 parsingOpts = defaultOpts
168 }
169 if err := ff.Parse(c.rs.FlagSet(), args, parsingOpts...); err != nil {
170 return fmt.Errorf("failed to parse options: %w", err)
171 }
172
173 c.args = c.rs.FlagSet().Args()
174 if len(c.args) > 0 {
175 for _, scmd := range c.Commands {
176 if strings.EqualFold(c.args[0], scmd.Name()) {
177 c.selected = scmd
178 return scmd.Parse(c.args[1:])
179 }
180 }
181 }
182
183 c.selected = c
184 return nil
185 }
186
187
188 func (c *Command) compute() error {
189 if c.computed {
190 return nil
191 }
192
193 for i := range c.Commands {
194
195 c.Commands[i].parent = c
196
197 if c.Commands[i] == c {
198 return fmt.Errorf("command %s cannot be child of itself", c.Name())
199 }
200
201 for x := range c.Commands {
202 if x != i && c.Commands[i].Name() == c.Commands[x].Name() {
203 return fmt.Errorf("command %s defined twice", c.Commands[i].Name())
204 }
205 }
206 }
207
208
209 if c.rs == nil {
210 c.rs = rags.New(c.Name(), flag.ContinueOnError, c.globalFlags()...)
211 }
212 c.rs.Add(c.Flags...)
213 for _, e := range c.Extensions {
214 e.RegisterFlags(c.rs)
215 }
216
217
218 if c.Exec == nil {
219 c.Exec = usageCmd
220 }
221
222 for i := range c.Commands {
223 if err := c.Commands[i].compute(); err != nil {
224 return fmt.Errorf("command '%s' is invalid: %w", c.Commands[i].Name(), err)
225 }
226 }
227
228 c.computed = true
229 return nil
230 }
231
232
233 func (c *Command) Run(ctx context.Context) error {
234 switch {
235 case c.selected == nil:
236 return fmt.Errorf("Run() called without calling Parse()")
237 case c.selected == c:
238 return c.execute(ctx, newRun(c))
239 default:
240 return c.selected.Run(ctx)
241 }
242 }
243
244 func (c *Command) execute(ctx context.Context, r Run) (err error) {
245 defer func() {
246
247
248 if err != nil {
249 r.Log.Error(err, "command failed")
250 }
251 }()
252
253 if ctx, r, err = c.beforeRun(ctx, r); err != nil {
254 return err
255 }
256
257 if c.help {
258 return usageCmd(ctx, r)
259 }
260
261 if err = c.Exec(ctx, r); err != nil {
262 return
263 }
264
265 _, r, err = c.afterRun(ctx, r)
266 return
267 }
268
269 func (c *Command) beforeRun(ctx context.Context, r Run) (context.Context, Run, error) {
270 for _, e := range c.Extensions {
271 if b, ok := e.(BeforeRunner); ok {
272 var err error
273 ctx, r, err = b.BeforeRun(ctx, r)
274 if err != nil {
275 return ctx, r, err
276 }
277 }
278 }
279 return ctx, r, nil
280 }
281
282 func (c *Command) afterRun(ctx context.Context, r Run) (context.Context, Run, error) {
283 for _, e := range c.Extensions {
284 if b, ok := e.(AfterRunner); ok {
285 var err error
286 ctx, r, err = b.AfterRun(ctx, r)
287 if err != nil {
288 return ctx, r, err
289 }
290 }
291 }
292 return ctx, r, nil
293 }
294
295
296
297 func (c *Command) ParseAndRun(ctx context.Context, args []string) error {
298 if err := c.Parse(args); err != nil {
299 fmt.Fprintln(c.getErr(), err)
300 return err
301 }
302 return c.Run(ctx)
303 }
304
305 func (c *Command) globalFlags() []*rags.Rag {
306 return []*rags.Rag{
307 {
308 Name: "help",
309 Short: "h",
310 Usage: "Display help information",
311 Value: &rags.Bool{Var: &c.help},
312 Category: rags.Global,
313 },
314 {
315 Name: "log-level",
316 Short: "v",
317 Usage: "Control logging verbosity. A higher number means chattier logs",
318 Value: &rags.Int{Var: &c.logLvl},
319 Category: rags.Global,
320 },
321 {
322 Name: "log-json",
323 Usage: "Emit JSON logs",
324 Value: &rags.Bool{Var: &c.logJSON},
325 Category: rags.Global,
326 },
327 }
328 }
329
330
331 func (c *Command) SetOut(w io.Writer) {
332 c.out = w
333 }
334
335
336
337 func (c *Command) SetErr(w io.Writer) {
338 c.err = w
339 }
340
341 func (c *Command) getOut() io.Writer {
342 switch {
343 case c.out != nil:
344 return c.out
345 case c.HasParent() && c.Parent().out != nil:
346 return c.Parent().out
347 default:
348 return os.Stdout
349 }
350 }
351
352 func (c *Command) getErr() io.Writer {
353 switch {
354 case c.err != nil:
355 return c.err
356 case c.HasParent() && c.Parent().err != nil:
357 return c.Parent().err
358 default:
359 return os.Stderr
360 }
361 }
362
363
364
365
366
367
368 func usageCmd(_ context.Context, r Run) error {
369 fmt.Fprintln(r.Err(), r.Cmd().Usage())
370 return nil
371 }
372
373
374 func defaultUsageFn(c *Command) string {
375 var b strings.Builder
376
377
378 switch {
379 case c.Long != "":
380 fmt.Fprintln(&b, c.Long)
381 fmt.Fprintln(&b)
382 case c.Short != "":
383 fmt.Fprintln(&b, c.Short)
384 fmt.Fprintln(&b)
385 }
386
387 fmt.Fprintln(&b, "Usage:")
388 tw := tabwriter.NewWriter(&b, 2, 0, 2, ' ', 0)
389 defer tw.Flush()
390 fmt.Fprintf(tw, "\t%s\t\t\n\n", useline(c))
391
392 if len(c.Commands) > 0 {
393 fmt.Fprintln(tw, "Commands:")
394 for _, subcommand := range c.Commands {
395 fmt.Fprintf(tw, "\t%s\t%s\t\n", subcommand.Name(), subcommand.Short)
396 }
397 fmt.Fprintln(tw)
398 }
399
400 if c.rs != nil && len(c.rs.Rags()) > 0 {
401 c.rs.SetOutput(&b)
402 c.rs.Usage()
403 }
404
405 return b.String()
406 }
407
408
409 func useline(c *Command) string {
410 if c.HasParent() {
411 return c.Parent().LongName() + " " + c.Use
412 }
413 return c.Use
414 }
415
View as plain text