1 package xmain 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "os" 9 "sync" 10 "testing" 11 "time" 12 13 "oss.terrastruct.com/util-go/assert" 14 "oss.terrastruct.com/util-go/cmdlog" 15 "oss.terrastruct.com/util-go/xdefer" 16 "oss.terrastruct.com/util-go/xos" 17 ) 18 19 type TestState struct { 20 Run func(context.Context, *State) error 21 Env *xos.Env 22 Args []string 23 PWD string 24 25 Stdin io.Reader 26 Stdout io.Writer 27 Stderr io.Writer 28 29 ms *State 30 sigs chan os.Signal 31 done chan struct{} 32 doneErr *error 33 } 34 35 func (ts *TestState) StdinPipe() (pw io.WriteCloser) { 36 ts.Stdin, pw = io.Pipe() 37 return pw 38 } 39 40 func (ts *TestState) StdoutPipe() (pr io.Reader) { 41 pr, ts.Stdout = io.Pipe() 42 return pr 43 } 44 45 func (ts *TestState) StderrPipe() (pr io.Reader) { 46 pr, ts.Stderr = io.Pipe() 47 return pr 48 } 49 50 func (ts *TestState) Start(tb testing.TB, ctx context.Context) { 51 tb.Helper() 52 53 if ts.done != nil { 54 tb.Fatal("xmain.TestState.Start cannot be called twice") 55 } 56 57 if ts.Env == nil { 58 ts.Env = xos.NewEnv(nil) 59 } 60 var tempDirCleanup func() 61 if ts.PWD == "" { 62 ts.PWD, tempDirCleanup = assert.TempDir(tb) 63 } 64 65 ts.sigs = make(chan os.Signal, 1) 66 ts.done = make(chan struct{}) 67 68 name := "" 69 args := []string(nil) 70 if len(ts.Args) > 0 { 71 name = ts.Args[0] 72 args = ts.Args[1:] 73 } 74 log := cmdlog.NewTB(ts.Env, tb) 75 ts.ms = &State{ 76 Name: name, 77 78 Log: log, 79 Env: ts.Env, 80 Opts: NewOpts(ts.Env, args), 81 PWD: ts.PWD, 82 } 83 84 if ts.Stdin == nil { 85 ts.ms.Stdin = io.LimitReader(nil, 0) 86 } else if rc, ok := ts.Stdin.(io.ReadCloser); ok { 87 ts.ms.Stdin = rc 88 } else { 89 var pw io.WriteCloser 90 ts.ms.Stdin, pw = io.Pipe() 91 go func() { 92 defer pw.Close() 93 io.Copy(pw, ts.Stdin) 94 }() 95 } 96 97 var pipeWG sync.WaitGroup 98 if ts.Stdout == nil { 99 ts.ms.Stdout = nopWriterCloser{io.Discard} 100 } else if wc, ok := ts.Stdout.(io.WriteCloser); ok { 101 ts.ms.Stdout = wc 102 } else { 103 var pr io.Reader 104 pr, ts.ms.Stdout = io.Pipe() 105 pipeWG.Add(1) 106 go func() { 107 defer pipeWG.Done() 108 io.Copy(ts.Stdout, pr) 109 }() 110 } 111 if ts.Stderr == nil { 112 ts.ms.Stderr = nopWriterCloser{&prefixSuffixSaver{N: 1 << 25}} 113 } else if wc, ok := ts.Stderr.(io.WriteCloser); ok { 114 ts.ms.Stderr = wc 115 } else { 116 var pr io.Reader 117 pr, ts.ms.Stderr = io.Pipe() 118 pipeWG.Add(1) 119 go func() { 120 defer pipeWG.Done() 121 io.Copy(ts.Stderr, pr) 122 }() 123 } 124 ts.ms.Log = cmdlog.New(ts.ms.Env, ts.ms.Stderr) 125 126 go func() { 127 var err error 128 defer func() { 129 ts.closeStdin() 130 ts.ms.Stdout.Close() 131 ts.ms.Stderr.Close() 132 pipeWG.Wait() 133 if tempDirCleanup != nil { 134 tempDirCleanup() 135 } 136 ts.doneErr = &err 137 close(ts.done) 138 }() 139 err = ts.ms.Main(ctx, ts.sigs, ts.Run) 140 if err != nil { 141 if ts.Stderr == nil { 142 stderr := ts.ms.Stderr.(nopWriterCloser).Writer.(*prefixSuffixSaver).Bytes() 143 if len(stderr) > 0 { 144 err = fmt.Errorf("%w; stderr: %s", err, stderr) 145 } 146 } 147 } 148 }() 149 } 150 151 func (ts *TestState) closeStdin() { 152 if rc, ok := ts.ms.Stdin.(io.ReadCloser); ok { 153 rc.Close() 154 } 155 } 156 157 func (ts *TestState) Cleanup(tb testing.TB) { 158 tb.Helper() 159 160 select { 161 case <-ts.done: 162 // Already exited. 163 return 164 default: 165 } 166 167 ts.closeStdin() 168 169 ctx, cancel := context.WithTimeout(context.Background(), time.Minute) 170 defer cancel() 171 err := ts.Signal(ctx, os.Interrupt) 172 if err != nil { 173 tb.Errorf("failed to os.Interrupt xmain test: %v", err) 174 } 175 err = ts.Wait(ctx) 176 if errors.Is(err, context.DeadlineExceeded) { 177 err = ts.Signal(ctx, os.Kill) 178 if err != nil { 179 tb.Errorf("failed to kill xmain test: %v", err) 180 } 181 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 182 defer cancel() 183 err = ts.Wait(ctx) 184 } 185 assert.Success(tb, err) 186 } 187 188 func (ts *TestState) Signal(ctx context.Context, sig os.Signal) (err error) { 189 defer xdefer.Errorf(&err, "failed to signal xmain test: %v", ts.ms.Name) 190 191 select { 192 case <-ctx.Done(): 193 return ctx.Err() 194 case <-ts.done: 195 return fmt.Errorf("xmain test exited: %w", *ts.doneErr) 196 case ts.sigs <- sig: 197 return nil 198 } 199 } 200 201 func (ts *TestState) Wait(ctx context.Context) (err error) { 202 defer xdefer.Errorf(&err, "failed to wait xmain test: %v", ts.ms.Name) 203 204 select { 205 case <-ctx.Done(): 206 return ctx.Err() 207 case <-ts.done: 208 return *ts.doneErr 209 } 210 } 211 212 type nopWriterCloser struct { 213 io.Writer 214 } 215 216 func (c nopWriterCloser) Close() error { 217 return nil 218 } 219