1
2
3
4 package cmd
5
6 import (
7 "bytes"
8 "context"
9 "errors"
10 "io"
11 "os"
12 "os/exec"
13 "strings"
14 "syscall"
15 "testing"
16 "time"
17
18 "github.com/Microsoft/hcsshim/internal/cow"
19 hcsschema "github.com/Microsoft/hcsshim/internal/hcs/schema2"
20 )
21
22 type localProcessHost struct {
23 }
24
25 type localProcess struct {
26 p *os.Process
27 state *os.ProcessState
28 ch chan struct{}
29 stdin, stdout, stderr *os.File
30 }
31
32 func (h *localProcessHost) OS() string {
33 return "windows"
34 }
35
36 func (h *localProcessHost) IsOCI() bool {
37 return false
38 }
39
40 func (h *localProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (_ cow.Process, err error) {
41 params := cfg.(*hcsschema.ProcessParameters)
42 lp := &localProcess{ch: make(chan struct{})}
43 defer func() {
44 if err != nil {
45 lp.Close()
46 }
47 }()
48 var stdin, stdout, stderr *os.File
49 if params.CreateStdInPipe {
50 stdin, lp.stdin, err = os.Pipe()
51 if err != nil {
52 return nil, err
53 }
54 defer stdin.Close()
55 }
56 if params.CreateStdOutPipe {
57 lp.stdout, stdout, err = os.Pipe()
58 if err != nil {
59 return nil, err
60 }
61 defer stdout.Close()
62 }
63 if params.CreateStdErrPipe {
64 lp.stderr, stderr, err = os.Pipe()
65 if err != nil {
66 return nil, err
67 }
68 defer stderr.Close()
69 }
70 path := strings.Split(params.CommandLine, " ")[0]
71 if ppath, err := exec.LookPath(path); err == nil {
72 path = ppath
73 }
74 lp.p, err = os.StartProcess(path, nil, &os.ProcAttr{
75 Files: []*os.File{stdin, stdout, stderr},
76 Sys: &syscall.SysProcAttr{
77 CmdLine: params.CommandLine,
78 },
79 })
80 if err != nil {
81 return nil, err
82 }
83 go func() {
84 lp.state, _ = lp.p.Wait()
85 close(lp.ch)
86 }()
87 return lp, nil
88 }
89
90 func (p *localProcess) Close() error {
91 if p.p != nil {
92 _ = p.p.Release()
93 }
94 if p.stdin != nil {
95 p.stdin.Close()
96 }
97 if p.stdout != nil {
98 p.stdout.Close()
99 }
100 if p.stderr != nil {
101 p.stderr.Close()
102 }
103 return nil
104 }
105
106 func (p *localProcess) CloseStdin(ctx context.Context) error {
107 return p.stdin.Close()
108 }
109
110 func (p *localProcess) CloseStdout(ctx context.Context) error {
111 return p.stdout.Close()
112 }
113
114 func (p *localProcess) CloseStderr(ctx context.Context) error {
115 return p.stderr.Close()
116 }
117
118 func (p *localProcess) ExitCode() (int, error) {
119 select {
120 case <-p.ch:
121 return p.state.ExitCode(), nil
122 default:
123 return -1, errors.New("not exited")
124 }
125 }
126
127 func (p *localProcess) Kill(ctx context.Context) (bool, error) {
128 return true, p.p.Kill()
129 }
130
131 func (p *localProcess) Signal(ctx context.Context, _ interface{}) (bool, error) {
132 return p.Kill(ctx)
133 }
134
135 func (p *localProcess) Pid() int {
136 return p.p.Pid
137 }
138
139 func (p *localProcess) ResizeConsole(ctx context.Context, x, y uint16) error {
140 return errors.New("not supported")
141 }
142
143 func (p *localProcess) Stdio() (io.Writer, io.Reader, io.Reader) {
144 return p.stdin, p.stdout, p.stderr
145 }
146
147 func (p *localProcess) Wait() error {
148 <-p.ch
149 return nil
150 }
151
152 func TestCmdExitCode(t *testing.T) {
153 cmd := Command(&localProcessHost{}, "cmd", "/c", "exit", "/b", "64")
154 err := cmd.Run()
155 if e, ok := err.(*ExitError); !ok || e.ExitCode() != 64 {
156 t.Fatal("expected exit code 64, got ", err)
157 }
158 }
159
160 func TestCmdOutput(t *testing.T) {
161 cmd := Command(&localProcessHost{}, "cmd", "/c", "echo", "hello")
162 output, err := cmd.Output()
163 if err != nil {
164 t.Fatal(err)
165 }
166 if string(output) != "hello\r\n" {
167 t.Fatalf("got %q", string(output))
168 }
169 }
170
171 func TestCmdContext(t *testing.T) {
172 ctx, cancel := context.WithTimeout(context.Background(), 250*time.Millisecond)
173 defer cancel()
174 cmd := CommandContext(ctx, &localProcessHost{}, "cmd", "/c", "pause")
175 r, w := io.Pipe()
176 cmd.Stdin = r
177 err := cmd.Start()
178 if err != nil {
179 t.Fatal(err)
180 }
181 _ = cmd.Process.Wait()
182 w.Close()
183 err = cmd.Wait()
184 if e, ok := err.(*ExitError); !ok || e.ExitCode() != 1 || ctx.Err() == nil {
185 t.Fatal(err)
186 }
187 }
188
189 func TestCmdStdin(t *testing.T) {
190 cmd := Command(&localProcessHost{}, "findstr", "x*")
191 cmd.Stdin = bytes.NewBufferString("testing 1 2 3")
192 out, err := cmd.Output()
193 if err != nil {
194 t.Fatal(err)
195 }
196 if string(out) != "testing 1 2 3\r\n" {
197 t.Fatalf("got %q", string(out))
198 }
199 }
200
201 func TestCmdStdinBlocked(t *testing.T) {
202 cmd := Command(&localProcessHost{}, "cmd", "/c", "pause")
203 r, w := io.Pipe()
204 defer r.Close()
205 go func() {
206 b := []byte{'\n'}
207 _, _ = w.Write(b)
208 }()
209 cmd.Stdin = r
210 _, err := cmd.Output()
211 if err != nil {
212 t.Fatal(err)
213 }
214 }
215
216 type stuckIoProcessHost struct {
217 cow.ProcessHost
218 }
219
220 type stuckIoProcess struct {
221 cow.Process
222 stdin, pstdout, pstderr *io.PipeWriter
223 pstdin, stdout, stderr *io.PipeReader
224 }
225
226 func (h *stuckIoProcessHost) CreateProcess(ctx context.Context, cfg interface{}) (cow.Process, error) {
227 p, err := h.ProcessHost.CreateProcess(ctx, cfg)
228 if err != nil {
229 return nil, err
230 }
231 sp := &stuckIoProcess{
232 Process: p,
233 }
234 sp.pstdin, sp.stdin = io.Pipe()
235 sp.stdout, sp.pstdout = io.Pipe()
236 sp.stderr, sp.pstderr = io.Pipe()
237 return sp, nil
238 }
239
240 func (p *stuckIoProcess) Stdio() (io.Writer, io.Reader, io.Reader) {
241 return p.stdin, p.stdout, p.stderr
242 }
243
244 func (p *stuckIoProcess) Close() error {
245 p.stdin.Close()
246 p.stdout.Close()
247 p.stderr.Close()
248 return p.Process.Close()
249 }
250
251 func TestCmdStuckIo(t *testing.T) {
252 cmd := Command(&stuckIoProcessHost{&localProcessHost{}}, "cmd", "/c", "echo", "hello")
253 cmd.CopyAfterExitTimeout = time.Millisecond * 200
254 _, err := cmd.Output()
255 if err != io.ErrClosedPipe {
256 t.Fatal(err)
257 }
258 }
259
View as plain text