...

Source file src/oss.terrastruct.com/util-go/xmain/xmaintest.go

Documentation: oss.terrastruct.com/util-go/xmain

     1  package xmain
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"os"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	"oss.terrastruct.com/util-go/assert"
    14  	"oss.terrastruct.com/util-go/cmdlog"
    15  	"oss.terrastruct.com/util-go/xdefer"
    16  	"oss.terrastruct.com/util-go/xos"
    17  )
    18  
    19  type TestState struct {
    20  	Run  func(context.Context, *State) error
    21  	Env  *xos.Env
    22  	Args []string
    23  	PWD  string
    24  
    25  	Stdin  io.Reader
    26  	Stdout io.Writer
    27  	Stderr io.Writer
    28  
    29  	ms      *State
    30  	sigs    chan os.Signal
    31  	done    chan struct{}
    32  	doneErr *error
    33  }
    34  
    35  func (ts *TestState) StdinPipe() (pw io.WriteCloser) {
    36  	ts.Stdin, pw = io.Pipe()
    37  	return pw
    38  }
    39  
    40  func (ts *TestState) StdoutPipe() (pr io.Reader) {
    41  	pr, ts.Stdout = io.Pipe()
    42  	return pr
    43  }
    44  
    45  func (ts *TestState) StderrPipe() (pr io.Reader) {
    46  	pr, ts.Stderr = io.Pipe()
    47  	return pr
    48  }
    49  
    50  func (ts *TestState) Start(tb testing.TB, ctx context.Context) {
    51  	tb.Helper()
    52  
    53  	if ts.done != nil {
    54  		tb.Fatal("xmain.TestState.Start cannot be called twice")
    55  	}
    56  
    57  	if ts.Env == nil {
    58  		ts.Env = xos.NewEnv(nil)
    59  	}
    60  	var tempDirCleanup func()
    61  	if ts.PWD == "" {
    62  		ts.PWD, tempDirCleanup = assert.TempDir(tb)
    63  	}
    64  
    65  	ts.sigs = make(chan os.Signal, 1)
    66  	ts.done = make(chan struct{})
    67  
    68  	name := ""
    69  	args := []string(nil)
    70  	if len(ts.Args) > 0 {
    71  		name = ts.Args[0]
    72  		args = ts.Args[1:]
    73  	}
    74  	log := cmdlog.NewTB(ts.Env, tb)
    75  	ts.ms = &State{
    76  		Name: name,
    77  
    78  		Log:  log,
    79  		Env:  ts.Env,
    80  		Opts: NewOpts(ts.Env, args),
    81  		PWD:  ts.PWD,
    82  	}
    83  
    84  	if ts.Stdin == nil {
    85  		ts.ms.Stdin = io.LimitReader(nil, 0)
    86  	} else if rc, ok := ts.Stdin.(io.ReadCloser); ok {
    87  		ts.ms.Stdin = rc
    88  	} else {
    89  		var pw io.WriteCloser
    90  		ts.ms.Stdin, pw = io.Pipe()
    91  		go func() {
    92  			defer pw.Close()
    93  			io.Copy(pw, ts.Stdin)
    94  		}()
    95  	}
    96  
    97  	var pipeWG sync.WaitGroup
    98  	if ts.Stdout == nil {
    99  		ts.ms.Stdout = nopWriterCloser{io.Discard}
   100  	} else if wc, ok := ts.Stdout.(io.WriteCloser); ok {
   101  		ts.ms.Stdout = wc
   102  	} else {
   103  		var pr io.Reader
   104  		pr, ts.ms.Stdout = io.Pipe()
   105  		pipeWG.Add(1)
   106  		go func() {
   107  			defer pipeWG.Done()
   108  			io.Copy(ts.Stdout, pr)
   109  		}()
   110  	}
   111  	if ts.Stderr == nil {
   112  		ts.ms.Stderr = nopWriterCloser{&prefixSuffixSaver{N: 1 << 25}}
   113  	} else if wc, ok := ts.Stderr.(io.WriteCloser); ok {
   114  		ts.ms.Stderr = wc
   115  	} else {
   116  		var pr io.Reader
   117  		pr, ts.ms.Stderr = io.Pipe()
   118  		pipeWG.Add(1)
   119  		go func() {
   120  			defer pipeWG.Done()
   121  			io.Copy(ts.Stderr, pr)
   122  		}()
   123  	}
   124  	ts.ms.Log = cmdlog.New(ts.ms.Env, ts.ms.Stderr)
   125  
   126  	go func() {
   127  		var err error
   128  		defer func() {
   129  			ts.closeStdin()
   130  			ts.ms.Stdout.Close()
   131  			ts.ms.Stderr.Close()
   132  			pipeWG.Wait()
   133  			if tempDirCleanup != nil {
   134  				tempDirCleanup()
   135  			}
   136  			ts.doneErr = &err
   137  			close(ts.done)
   138  		}()
   139  		err = ts.ms.Main(ctx, ts.sigs, ts.Run)
   140  		if err != nil {
   141  			if ts.Stderr == nil {
   142  				stderr := ts.ms.Stderr.(nopWriterCloser).Writer.(*prefixSuffixSaver).Bytes()
   143  				if len(stderr) > 0 {
   144  					err = fmt.Errorf("%w; stderr: %s", err, stderr)
   145  				}
   146  			}
   147  		}
   148  	}()
   149  }
   150  
   151  func (ts *TestState) closeStdin() {
   152  	if rc, ok := ts.ms.Stdin.(io.ReadCloser); ok {
   153  		rc.Close()
   154  	}
   155  }
   156  
   157  func (ts *TestState) Cleanup(tb testing.TB) {
   158  	tb.Helper()
   159  
   160  	select {
   161  	case <-ts.done:
   162  		// Already exited.
   163  		return
   164  	default:
   165  	}
   166  
   167  	ts.closeStdin()
   168  
   169  	ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
   170  	defer cancel()
   171  	err := ts.Signal(ctx, os.Interrupt)
   172  	if err != nil {
   173  		tb.Errorf("failed to os.Interrupt xmain test: %v", err)
   174  	}
   175  	err = ts.Wait(ctx)
   176  	if errors.Is(err, context.DeadlineExceeded) {
   177  		err = ts.Signal(ctx, os.Kill)
   178  		if err != nil {
   179  			tb.Errorf("failed to kill xmain test: %v", err)
   180  		}
   181  		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
   182  		defer cancel()
   183  		err = ts.Wait(ctx)
   184  	}
   185  	assert.Success(tb, err)
   186  }
   187  
   188  func (ts *TestState) Signal(ctx context.Context, sig os.Signal) (err error) {
   189  	defer xdefer.Errorf(&err, "failed to signal xmain test: %v", ts.ms.Name)
   190  
   191  	select {
   192  	case <-ctx.Done():
   193  		return ctx.Err()
   194  	case <-ts.done:
   195  		return fmt.Errorf("xmain test exited: %w", *ts.doneErr)
   196  	case ts.sigs <- sig:
   197  		return nil
   198  	}
   199  }
   200  
   201  func (ts *TestState) Wait(ctx context.Context) (err error) {
   202  	defer xdefer.Errorf(&err, "failed to wait xmain test: %v", ts.ms.Name)
   203  
   204  	select {
   205  	case <-ctx.Done():
   206  		return ctx.Err()
   207  	case <-ts.done:
   208  		return *ts.doneErr
   209  	}
   210  }
   211  
   212  type nopWriterCloser struct {
   213  	io.Writer
   214  }
   215  
   216  func (c nopWriterCloser) Close() error {
   217  	return nil
   218  }
   219  

View as plain text