...

Source file src/oss.terrastruct.com/util-go/xmain/xmain.go

Documentation: oss.terrastruct.com/util-go/xmain

     1  // Package xmain provides a standard stub for the main of a command handling logging,
     2  // flags, signals and shutdown.
     3  package xmain
     4  
     5  import (
     6  	"context"
     7  	"errors"
     8  	"fmt"
     9  	"io"
    10  	"os"
    11  	"os/signal"
    12  	"path/filepath"
    13  	"strings"
    14  	"syscall"
    15  	"time"
    16  
    17  	"oss.terrastruct.com/util-go/cmdlog"
    18  	"oss.terrastruct.com/util-go/xos"
    19  )
    20  
    21  type RunFunc func(context.Context, *State) error
    22  
    23  func Main(run RunFunc) {
    24  	name := ""
    25  	args := []string(nil)
    26  	if len(os.Args) > 0 {
    27  		name = os.Args[0]
    28  		args = os.Args[1:]
    29  	}
    30  
    31  	ms := &State{
    32  		Name: name,
    33  
    34  		Stdin:  os.Stdin,
    35  		Stdout: os.Stdout,
    36  		Stderr: os.Stderr,
    37  
    38  		Env: xos.NewEnv(os.Environ()),
    39  	}
    40  	ms.Log = cmdlog.New(ms.Env, ms.Stderr)
    41  	ms.Opts = NewOpts(ms.Env, args)
    42  
    43  	wd, err := os.Getwd()
    44  	if err != nil {
    45  		ms.mainFatal(err)
    46  	}
    47  	ms.PWD = wd
    48  
    49  	sigs := make(chan os.Signal, 1)
    50  	signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
    51  
    52  	err = ms.Main(context.Background(), sigs, run)
    53  	if err != nil {
    54  		ms.mainFatal(err)
    55  	}
    56  }
    57  
    58  func (ms *State) mainFatal(err error) {
    59  	code := 1
    60  	msg := ""
    61  	usage := false
    62  
    63  	var eerr ExitError
    64  	var uerr UsageError
    65  	if errors.As(err, &eerr) {
    66  		code = eerr.Code
    67  		msg = eerr.Message
    68  	} else if errors.As(err, &uerr) {
    69  		msg = err.Error()
    70  		usage = true
    71  	} else {
    72  		msg = err.Error()
    73  	}
    74  
    75  	if msg != "" {
    76  		ms.Log.Error.Print(msg)
    77  		if usage {
    78  			ms.Log.Error.Print("Run with --help to see usage.")
    79  		}
    80  	}
    81  	os.Exit(code)
    82  }
    83  
    84  type State struct {
    85  	Name string
    86  
    87  	Stdin  io.Reader
    88  	Stdout io.WriteCloser
    89  	Stderr io.WriteCloser
    90  
    91  	Log  *cmdlog.Logger
    92  	Env  *xos.Env
    93  	Opts *Opts
    94  
    95  	PWD string
    96  }
    97  
    98  func (ms *State) Main(ctx context.Context, sigs <-chan os.Signal, run func(context.Context, *State) error) error {
    99  	ctx, cancel := context.WithCancel(ctx)
   100  	defer cancel()
   101  
   102  	done := make(chan error, 1)
   103  	go func() {
   104  		defer close(done)
   105  		done <- run(ctx, ms)
   106  	}()
   107  
   108  	select {
   109  	case err := <-done:
   110  		return err
   111  	case sig := <-sigs:
   112  		ms.Log.Warn.Printf("received signal %v: shutting down...", sig)
   113  		cancel()
   114  		select {
   115  		case err := <-done:
   116  			if err != nil && !errors.Is(err, context.Canceled) {
   117  				return fmt.Errorf("failed to shutdown: %w", err)
   118  			}
   119  			if sig == syscall.SIGTERM {
   120  				// We successfully shutdown.
   121  				return nil
   122  			}
   123  			return ExitError{Code: 1}
   124  		case <-time.After(time.Minute):
   125  			return ExitError{
   126  				Code:    1,
   127  				Message: "took longer than 1 minute to shutdown: exiting forcefully",
   128  			}
   129  		}
   130  	}
   131  }
   132  
   133  type ExitError struct {
   134  	Code    int    `json:"code"`
   135  	Message string `json:"message"`
   136  }
   137  
   138  func ExitErrorf(code int, msg string, v ...interface{}) ExitError {
   139  	return ExitError{
   140  		Code:    code,
   141  		Message: fmt.Sprintf(msg, v...),
   142  	}
   143  }
   144  
   145  func (ee ExitError) Error() string {
   146  	s := fmt.Sprintf("exiting with code %d", ee.Code)
   147  	if ee.Message != "" {
   148  		s += ": " + ee.Message
   149  	}
   150  	return s
   151  }
   152  
   153  type UsageError struct {
   154  	Message string `json:"message"`
   155  }
   156  
   157  func UsageErrorf(msg string, v ...interface{}) UsageError {
   158  	return UsageError{
   159  		Message: fmt.Sprintf(msg, v...),
   160  	}
   161  }
   162  
   163  func (ue UsageError) Error() string {
   164  	return fmt.Sprintf("bad usage: %s", ue.Message)
   165  }
   166  
   167  func (ms *State) ReadPath(fp string) ([]byte, error) {
   168  	if fp == "-" {
   169  		return io.ReadAll(ms.Stdin)
   170  	}
   171  	return os.ReadFile(fp)
   172  }
   173  
   174  func (ms *State) WritePath(fp string, p []byte) error {
   175  	if fp == "-" {
   176  		_, err := ms.Stdout.Write(p)
   177  		if err != nil {
   178  			return err
   179  		}
   180  		return ms.Stdout.Close()
   181  	}
   182  	return os.WriteFile(fp, p, 0644)
   183  }
   184  
   185  // AbsPath joins the PWD with fp to give the absolute path to fp.
   186  func (ms *State) AbsPath(fp string) string {
   187  	if fp == "-" || filepath.IsAbs(fp) {
   188  		return fp
   189  	}
   190  	return filepath.Join(ms.PWD, fp)
   191  }
   192  
   193  // HumanPath makes absolute path fp more suitable for human consumption
   194  // by replacing $HOME in fp with ~ and making it relative to the current PWD.
   195  func (ms *State) HumanPath(fp string) string {
   196  	if fp == "-" {
   197  		return fp
   198  	}
   199  	fp = ms.AbsPath(fp)
   200  
   201  	if strings.HasPrefix(fp, ms.Env.Getenv("HOME")) {
   202  		fp = filepath.Join("~", strings.TrimPrefix(fp, ms.Env.Getenv("HOME")))
   203  	}
   204  	pwd := ms.PWD
   205  	if strings.HasPrefix(pwd, ms.Env.Getenv("HOME")) {
   206  		pwd = filepath.Join("~", strings.TrimPrefix(pwd, ms.Env.Getenv("HOME")))
   207  	}
   208  
   209  	rel, err := filepath.Rel(pwd, fp)
   210  	if err != nil {
   211  		return fp
   212  	}
   213  	return rel
   214  }
   215  

View as plain text