1
3 package icmd
4
5 import (
6 "bytes"
7 "fmt"
8 "io"
9 "os"
10 "os/exec"
11 "strings"
12 "sync"
13 "time"
14
15 "gotest.tools/v3/assert"
16 "gotest.tools/v3/assert/cmp"
17 )
18
19 type helperT interface {
20 Helper()
21 }
22
23
24 const None = "[NOTHING]"
25
26 type lockedBuffer struct {
27 m sync.RWMutex
28 buf bytes.Buffer
29 }
30
31 func (buf *lockedBuffer) Write(b []byte) (int, error) {
32 buf.m.Lock()
33 defer buf.m.Unlock()
34 return buf.buf.Write(b)
35 }
36
37 func (buf *lockedBuffer) String() string {
38 buf.m.RLock()
39 defer buf.m.RUnlock()
40 return buf.buf.String()
41 }
42
43
44 type Result struct {
45 Cmd *exec.Cmd
46 ExitCode int
47 Error error
48
49 Timeout bool
50 outBuffer *lockedBuffer
51 errBuffer *lockedBuffer
52 }
53
54
55
56
57
58 func (r *Result) Assert(t assert.TestingT, exp Expected) *Result {
59 if ht, ok := t.(helperT); ok {
60 ht.Helper()
61 }
62 assert.Assert(t, r.Equal(exp))
63 return r
64 }
65
66
67
68
69 func (r *Result) Equal(exp Expected) cmp.Comparison {
70 return func() cmp.Result {
71 return cmp.ResultFromError(r.match(exp))
72 }
73 }
74
75
76 func (r *Result) Compare(exp Expected) error {
77 return r.match(exp)
78 }
79
80 func (r *Result) match(exp Expected) error {
81 errors := []string{}
82 add := func(format string, args ...interface{}) {
83 errors = append(errors, fmt.Sprintf(format, args...))
84 }
85
86 if exp.ExitCode != r.ExitCode {
87 add("ExitCode was %d expected %d", r.ExitCode, exp.ExitCode)
88 }
89 if exp.Timeout != r.Timeout {
90 if exp.Timeout {
91 add("Expected command to timeout")
92 } else {
93 add("Expected command to finish, but it hit the timeout")
94 }
95 }
96 if !matchOutput(exp.Out, r.Stdout()) {
97 add("Expected stdout to contain %q", exp.Out)
98 }
99 if !matchOutput(exp.Err, r.Stderr()) {
100 add("Expected stderr to contain %q", exp.Err)
101 }
102 switch {
103
104
105
106 case exp.Error == "" && exp.ExitCode != 0:
107 case exp.Error == "" && r.Error != nil:
108 add("Expected no error")
109 case exp.Error != "" && r.Error == nil:
110 add("Expected error to contain %q, but there was no error", exp.Error)
111 case exp.Error != "" && !strings.Contains(r.Error.Error(), exp.Error):
112 add("Expected error to contain %q", exp.Error)
113 }
114
115 if len(errors) == 0 {
116 return nil
117 }
118 return fmt.Errorf("%s\nFailures:\n%s", r, strings.Join(errors, "\n"))
119 }
120
121 func matchOutput(expected string, actual string) bool {
122 switch expected {
123 case None:
124 return actual == ""
125 default:
126 return strings.Contains(actual, expected)
127 }
128 }
129
130 func (r *Result) String() string {
131 var timeout string
132 if r.Timeout {
133 timeout = " (timeout)"
134 }
135 var errString string
136 if r.Error != nil {
137 errString = "\nError: " + r.Error.Error()
138 }
139
140 return fmt.Sprintf(`
141 Command: %s
142 ExitCode: %d%s%s
143 Stdout: %v
144 Stderr: %v
145 `,
146 strings.Join(r.Cmd.Args, " "),
147 r.ExitCode,
148 timeout,
149 errString,
150 r.Stdout(),
151 r.Stderr())
152 }
153
154
155
156 type Expected struct {
157 ExitCode int
158 Timeout bool
159 Error string
160 Out string
161 Err string
162 }
163
164
165
166 var Success = Expected{}
167
168
169 func (r *Result) Stdout() string {
170 return r.outBuffer.String()
171 }
172
173
174 func (r *Result) Stderr() string {
175 return r.errBuffer.String()
176 }
177
178
179 func (r *Result) Combined() string {
180 return r.outBuffer.String() + r.errBuffer.String()
181 }
182
183 func (r *Result) setExitError(err error) {
184 if err == nil {
185 return
186 }
187 r.Error = err
188 r.ExitCode = processExitCode(err)
189 }
190
191
192
193 type Cmd struct {
194 Command []string
195 Timeout time.Duration
196 Stdin io.Reader
197 Stdout io.Writer
198 Stderr io.Writer
199 Dir string
200 Env []string
201 ExtraFiles []*os.File
202 }
203
204
205 func Command(command string, args ...string) Cmd {
206 return Cmd{Command: append([]string{command}, args...)}
207 }
208
209
210 func RunCmd(cmd Cmd, cmdOperators ...CmdOp) *Result {
211 result := StartCmd(cmd, cmdOperators...)
212 if result.Error != nil {
213 return result
214 }
215 return WaitOnCmd(cmd.Timeout, result)
216 }
217
218
219 func RunCommand(command string, args ...string) *Result {
220 return RunCmd(Command(command, args...))
221 }
222
223
224 func StartCmd(cmd Cmd, cmdOperators ...CmdOp) *Result {
225 for _, op := range cmdOperators {
226 op(&cmd)
227 }
228 result := buildCmd(cmd)
229 if result.Error != nil {
230 return result
231 }
232 result.setExitError(result.Cmd.Start())
233 return result
234 }
235
236
237 func buildCmd(cmd Cmd) *Result {
238 var execCmd *exec.Cmd
239 switch len(cmd.Command) {
240 case 1:
241 execCmd = exec.Command(cmd.Command[0])
242 default:
243 execCmd = exec.Command(cmd.Command[0], cmd.Command[1:]...)
244 }
245 outBuffer := new(lockedBuffer)
246 errBuffer := new(lockedBuffer)
247
248 execCmd.Stdin = cmd.Stdin
249 execCmd.Dir = cmd.Dir
250 execCmd.Env = cmd.Env
251 if cmd.Stdout != nil {
252 execCmd.Stdout = io.MultiWriter(outBuffer, cmd.Stdout)
253 } else {
254 execCmd.Stdout = outBuffer
255 }
256 if cmd.Stderr != nil {
257 execCmd.Stderr = io.MultiWriter(errBuffer, cmd.Stderr)
258 } else {
259 execCmd.Stderr = errBuffer
260 }
261 execCmd.ExtraFiles = cmd.ExtraFiles
262
263 return &Result{
264 Cmd: execCmd,
265 outBuffer: outBuffer,
266 errBuffer: errBuffer,
267 }
268 }
269
270
271
272 func WaitOnCmd(timeout time.Duration, result *Result) *Result {
273 if timeout == time.Duration(0) {
274 result.setExitError(result.Cmd.Wait())
275 return result
276 }
277
278 done := make(chan error, 1)
279
280 go func() {
281 done <- result.Cmd.Wait()
282 }()
283
284 select {
285 case <-time.After(timeout):
286 killErr := result.Cmd.Process.Kill()
287 if killErr != nil {
288 fmt.Printf("failed to kill (pid=%d): %v\n", result.Cmd.Process.Pid, killErr)
289 }
290 result.Timeout = true
291 case err := <-done:
292 result.setExitError(err)
293 }
294 return result
295 }
296
View as plain text