1
2
3 package xmain
4
5 import (
6 "context"
7 "errors"
8 "fmt"
9 "io"
10 "os"
11 "os/signal"
12 "path/filepath"
13 "strings"
14 "syscall"
15 "time"
16
17 "oss.terrastruct.com/util-go/cmdlog"
18 "oss.terrastruct.com/util-go/xos"
19 )
20
21 type RunFunc func(context.Context, *State) error
22
23 func Main(run RunFunc) {
24 name := ""
25 args := []string(nil)
26 if len(os.Args) > 0 {
27 name = os.Args[0]
28 args = os.Args[1:]
29 }
30
31 ms := &State{
32 Name: name,
33
34 Stdin: os.Stdin,
35 Stdout: os.Stdout,
36 Stderr: os.Stderr,
37
38 Env: xos.NewEnv(os.Environ()),
39 }
40 ms.Log = cmdlog.New(ms.Env, ms.Stderr)
41 ms.Opts = NewOpts(ms.Env, args)
42
43 wd, err := os.Getwd()
44 if err != nil {
45 ms.mainFatal(err)
46 }
47 ms.PWD = wd
48
49 sigs := make(chan os.Signal, 1)
50 signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
51
52 err = ms.Main(context.Background(), sigs, run)
53 if err != nil {
54 ms.mainFatal(err)
55 }
56 }
57
58 func (ms *State) mainFatal(err error) {
59 code := 1
60 msg := ""
61 usage := false
62
63 var eerr ExitError
64 var uerr UsageError
65 if errors.As(err, &eerr) {
66 code = eerr.Code
67 msg = eerr.Message
68 } else if errors.As(err, &uerr) {
69 msg = err.Error()
70 usage = true
71 } else {
72 msg = err.Error()
73 }
74
75 if msg != "" {
76 ms.Log.Error.Print(msg)
77 if usage {
78 ms.Log.Error.Print("Run with --help to see usage.")
79 }
80 }
81 os.Exit(code)
82 }
83
84 type State struct {
85 Name string
86
87 Stdin io.Reader
88 Stdout io.WriteCloser
89 Stderr io.WriteCloser
90
91 Log *cmdlog.Logger
92 Env *xos.Env
93 Opts *Opts
94
95 PWD string
96 }
97
98 func (ms *State) Main(ctx context.Context, sigs <-chan os.Signal, run func(context.Context, *State) error) error {
99 ctx, cancel := context.WithCancel(ctx)
100 defer cancel()
101
102 done := make(chan error, 1)
103 go func() {
104 defer close(done)
105 done <- run(ctx, ms)
106 }()
107
108 select {
109 case err := <-done:
110 return err
111 case sig := <-sigs:
112 ms.Log.Warn.Printf("received signal %v: shutting down...", sig)
113 cancel()
114 select {
115 case err := <-done:
116 if err != nil && !errors.Is(err, context.Canceled) {
117 return fmt.Errorf("failed to shutdown: %w", err)
118 }
119 if sig == syscall.SIGTERM {
120
121 return nil
122 }
123 return ExitError{Code: 1}
124 case <-time.After(time.Minute):
125 return ExitError{
126 Code: 1,
127 Message: "took longer than 1 minute to shutdown: exiting forcefully",
128 }
129 }
130 }
131 }
132
133 type ExitError struct {
134 Code int `json:"code"`
135 Message string `json:"message"`
136 }
137
138 func ExitErrorf(code int, msg string, v ...interface{}) ExitError {
139 return ExitError{
140 Code: code,
141 Message: fmt.Sprintf(msg, v...),
142 }
143 }
144
145 func (ee ExitError) Error() string {
146 s := fmt.Sprintf("exiting with code %d", ee.Code)
147 if ee.Message != "" {
148 s += ": " + ee.Message
149 }
150 return s
151 }
152
153 type UsageError struct {
154 Message string `json:"message"`
155 }
156
157 func UsageErrorf(msg string, v ...interface{}) UsageError {
158 return UsageError{
159 Message: fmt.Sprintf(msg, v...),
160 }
161 }
162
163 func (ue UsageError) Error() string {
164 return fmt.Sprintf("bad usage: %s", ue.Message)
165 }
166
167 func (ms *State) ReadPath(fp string) ([]byte, error) {
168 if fp == "-" {
169 return io.ReadAll(ms.Stdin)
170 }
171 return os.ReadFile(fp)
172 }
173
174 func (ms *State) WritePath(fp string, p []byte) error {
175 if fp == "-" {
176 _, err := ms.Stdout.Write(p)
177 if err != nil {
178 return err
179 }
180 return ms.Stdout.Close()
181 }
182 return os.WriteFile(fp, p, 0644)
183 }
184
185
186 func (ms *State) AbsPath(fp string) string {
187 if fp == "-" || filepath.IsAbs(fp) {
188 return fp
189 }
190 return filepath.Join(ms.PWD, fp)
191 }
192
193
194
195 func (ms *State) HumanPath(fp string) string {
196 if fp == "-" {
197 return fp
198 }
199 fp = ms.AbsPath(fp)
200
201 if strings.HasPrefix(fp, ms.Env.Getenv("HOME")) {
202 fp = filepath.Join("~", strings.TrimPrefix(fp, ms.Env.Getenv("HOME")))
203 }
204 pwd := ms.PWD
205 if strings.HasPrefix(pwd, ms.Env.Getenv("HOME")) {
206 pwd = filepath.Join("~", strings.TrimPrefix(pwd, ms.Env.Getenv("HOME")))
207 }
208
209 rel, err := filepath.Rel(pwd, fp)
210 if err != nil {
211 return fp
212 }
213 return rel
214 }
215
View as plain text