package processmanager import ( "context" "errors" "fmt" "os" "os/exec" "strings" "time" "github.com/go-logr/logr" "golang.org/x/sys/unix" ) type ReadyCheckFunc func(context.Context) (bool, error) // Process handles the running of a process. // // Starting the process with Start() runs it asynchronously. Calling // Stop() will send a SIGTERM to the process. // // If the process exits unexpectedly (e.g. they reach an error state), // the result will be sent to the channel returned by Result(). type Process interface { // The PID of the process (nil if process is not running). PID() *int // WithArgs sets the command arguments. Has no affect until // the process is restarted. WithArgs(args ...string) // WithExpectNoExit tells the process manager that its process // is not expected to stop. If the process does stop, an error // should be returned, even if the process returned no error. WithExpectNoExit() // Whether the process is expected to exit. See WithExpectNoExit(). ExpectsExit() bool ProcessManager } type process struct { processManager // Path to the process binary. path string // Arguments the process will be ran with. args []string // The os/exec instance of the running command. cmd *exec.Cmd // Used internally to receive the result of the os/exec command. procExitChan chan error // Whether the process is expected to exit. If true, an error // will be returned to the result channel on exit, even if the // process exited cleanly. expectNoExit bool } // Create a new Process, given arguments and path to the process. func NewProcess(name string, path string, args ...string) (Process, error) { if _, err := os.ReadFile(path); err != nil { return nil, fmt.Errorf("unable to find command %s: %w", path, err) } return &process{ processManager: processManager{ name: name, resultChan: make(chan error, 1), log: logr.Discard(), vlog: logr.Discard(), }, path: path, args: args, procExitChan: make(chan error), }, nil } func (proc *process) Start(ctx context.Context) (err error) { proc.Mutex.Lock() defer proc.Mutex.Unlock() if proc.isRunning { return nil } // cleanup process and long-running threads if we fail defer func() { err = proc.cleanupOnFailure(ctx, err) }() args := strings.Join(proc.args, " ") proc.log.Info("running process", "args", args) // create child context so we can cancel long-running threads // when Stop() is called procCtx, cancel := context.WithCancel(ctx) proc.cancel = cancel // start context handler to call Stop() if ctx is cancelled proc.startContextHandler(procCtx, ctx) if err := proc.executeHooks(procCtx, preStart); err != nil { return err } if err := proc.startProcess(); err != nil { return fmt.Errorf("unable to start %s process: %w", proc.Name(), err) } proc.vlog.Info("process is running", "PID", proc.PID(), "args", args) // start exit-handler to handle process exiting unexpectedly proc.startExitHandler(procCtx) if err := proc.executeHooks(procCtx, postStart); err != nil { return err } if err := proc.waitUntilReadyWithTimeout(ctx); err != nil { return fmt.Errorf("%s process is not ready: %w", proc.Name(), err) } proc.vlog.Info("process is ready", "PID", proc.PID(), "args", args) proc.isRunning = true return nil } func (proc *process) cleanupOnFailure(ctx context.Context, err error) error { if err == nil { return nil } proc.vlog.Info("starting process failed, cleaning up") return errors.Join(err, proc.stop(ctx)) } func (proc *process) startContextHandler(ctx, startCtx context.Context) { if proc.skipContextHandling { return } go func() { if err := contextHandler(ctx, startCtx, proc, proc.log); err != nil { proc.log.Error(err, "failed to shutdown") } }() } // Starts the process and sets stdout and stderr. Sends the result of // the process to the exit channel once it completes. func (proc *process) startProcess() error { proc.cmd = exec.Command(proc.path, proc.args...) //#nosec G204 proc.cmd.Stdout = os.Stdout proc.cmd.Stderr = os.Stderr if err := proc.cmd.Start(); err != nil { return err } // send the result of the command to the exit channel once it exits go func() { proc.procExitChan <- proc.cmd.Wait() }() return nil } // If an exit signal is received from the process, perform cleanup and // return the result to the result channel for consumers to evaluate. // // If the context is cancelled, Stop() has been called, so the exit-handler // can stop running. func (proc *process) startExitHandler(ctx context.Context) { go func() { select { case result := <-proc.procExitChan: proc.vlog.Info("process has exited", "PID", proc.PID()) proc.resultChan <- errors.Join( resultError(proc.Name(), result, proc.expectNoExit), proc.Stop(ctx), ) case <-ctx.Done(): return } }() } func (proc *process) Stop(ctx context.Context) error { proc.Mutex.Lock() defer proc.Mutex.Unlock() if !proc.isRunning { return nil } return proc.stop(ctx) } func (proc *process) stop(ctx context.Context) error { pid := proc.PID() proc.log.Info("stopping process", "PID", pid) // cancel long-running threads proc.cancel() if err := proc.executeHooks(ctx, preStop); err != nil { return err } if err := proc.stopProcess(); err != nil { return fmt.Errorf("unable to stop process: %w", err) } if err := proc.executeHooks(ctx, postStop); err != nil { return err } proc.vlog.Info("process has stopped", "PID", pid) proc.isRunning = false return nil } // Sends a SIGTERM signal to the process, with a 10 second wait for // the process to complete. func (proc *process) stopProcess() error { pid := proc.PID() if pid == nil { return nil } // If the process is already complete, we are done. if err := proc.cmd.Process.Signal(unix.SIGTERM); err == os.ErrProcessDone { return nil } else if err != nil { return fmt.Errorf("failed to send SIGTERM to process %s with PID=%d: %w", proc.path, pid, err) } if err := proc.waitForProcessExit(); err != nil { return fmt.Errorf("process with PID=%d did not exit: %w", pid, err) } proc.cmd = nil return nil } // Waits for 10 seconds for the process to exit, ignoring any exit errors. func (proc *process) waitForProcessExit() error { select { case err := <-proc.procExitChan: if _, ok := err.(*exec.ExitError); ok || err == nil { return nil } return fmt.Errorf("error received whilst exiting: %w", err) case <-time.After(exitTimeout): return fmt.Errorf("timeout reached") } } func (proc *process) Restart(ctx context.Context) error { if err := proc.Stop(ctx); err != nil { return err } return proc.Start(ctx) } func (proc *process) WithLogger(log logr.Logger, verbose bool) { proc.log = log.WithName(fmt.Sprintf("%s-process", proc.Name())).WithValues("process", proc.Name(), "path", proc.path) if verbose { proc.vlog = proc.log } } func (proc *process) PID() *int { if proc.cmd != nil && proc.cmd.Process != nil { return &proc.cmd.Process.Pid } return nil } func (proc *process) WithArgs(args ...string) { proc.args = args } func (proc *process) WithExpectNoExit() { proc.expectNoExit = true } func (proc *process) ExpectsExit() bool { return !proc.expectNoExit }