1
2
3 package exec
4
5 import (
6 "errors"
7 "fmt"
8 "os"
9 "strings"
10 "syscall"
11 "unicode/utf16"
12 "unsafe"
13
14 "golang.org/x/sys/windows"
15 )
16
17 var (
18 errProcNotStarted = errors.New("process has not started yet")
19 errProcNotFinished = errors.New("process has not finished yet")
20 )
21
22
23
24
25
26
27
28
29
30
31
32
33 type Exec struct {
34 path string
35 cmdline string
36
37 process *os.Process
38
39 procState *os.ProcessState
40 waitCalled bool
41
42
43 stdioPipesOurSide [3]*os.File
44
45
46 stdioPipesProcSide [3]*os.File
47 attrList *windows.ProcThreadAttributeListContainer
48 *execConfig
49 }
50
51
52
53 func New(path, cmdLine string, opts ...ExecOpts) (*Exec, error) {
54
55 if path == "" {
56 return nil, errors.New("path cannot be empty")
57 }
58
59
60 eopts := &execConfig{}
61 for _, o := range opts {
62 if err := o(eopts); err != nil {
63 return nil, err
64 }
65 }
66
67 e := &Exec{
68 path: path,
69 cmdline: cmdLine,
70 execConfig: eopts,
71 }
72
73 if err := e.setupStdio(); err != nil {
74 return nil, err
75 }
76 return e, nil
77 }
78
79
80
81 func (e *Exec) Start() error {
82 argv0 := e.path
83 if len(e.dir) != 0 {
84
85
86
87
88 var err error
89 argv0, err = joinExeDirAndFName(e.dir, e.path)
90 if err != nil {
91 return err
92 }
93 }
94
95 argv0p, err := windows.UTF16PtrFromString(argv0)
96 if err != nil {
97 return err
98 }
99
100 argvp, err := windows.UTF16PtrFromString(e.cmdline)
101 if err != nil {
102 return err
103 }
104
105 var dirp *uint16
106 if len(e.dir) != 0 {
107 dirp, err = windows.UTF16PtrFromString(e.dir)
108 if err != nil {
109 return err
110 }
111 }
112
113 siEx := new(windows.StartupInfoEx)
114 siEx.Flags = windows.STARTF_USESTDHANDLES
115 pi := new(windows.ProcessInformation)
116
117
118 flags := uint32(windows.CREATE_UNICODE_ENVIRONMENT) | windows.EXTENDED_STARTUPINFO_PRESENT | e.execConfig.processFlags
119
120
121
122
123
124
125 e.attrList, err = windows.NewProcThreadAttributeList(3)
126 if err != nil {
127 return fmt.Errorf("failed to initialize process thread attribute list: %w", err)
128 }
129
130
131
132 inheritHandles := e.stdioPipesProcSide[0] != nil || e.stdioPipesProcSide[1] != nil || e.stdioPipesProcSide[2] != nil
133 if inheritHandles {
134 var handles []uintptr
135 for _, file := range e.stdioPipesProcSide {
136 if file.Fd() != uintptr(syscall.InvalidHandle) {
137 handles = append(handles, file.Fd())
138 }
139 }
140
141
142 err := e.attrList.Update(
143 windows.PROC_THREAD_ATTRIBUTE_HANDLE_LIST,
144 unsafe.Pointer(&handles[0]),
145 uintptr(len(handles))*unsafe.Sizeof(handles[0]),
146 )
147 if err != nil {
148 return err
149 }
150
151
152 if e.stdioPipesProcSide[0] != nil {
153 siEx.StdInput = windows.Handle(e.stdioPipesProcSide[0].Fd())
154 }
155 if e.stdioPipesProcSide[1] != nil {
156 siEx.StdOutput = windows.Handle(e.stdioPipesProcSide[1].Fd())
157 }
158 if e.stdioPipesProcSide[2] != nil {
159 siEx.StdErr = windows.Handle(e.stdioPipesProcSide[2].Fd())
160 }
161 }
162
163 if e.job != nil {
164 if err := e.job.UpdateProcThreadAttribute(e.attrList); err != nil {
165 return err
166 }
167 }
168
169 if e.cpty != nil {
170 if err := e.cpty.UpdateProcThreadAttribute(e.attrList); err != nil {
171 return err
172 }
173 }
174
175 var zeroSec windows.SecurityAttributes
176 pSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1}
177 tSec := &windows.SecurityAttributes{Length: uint32(unsafe.Sizeof(zeroSec)), InheritHandle: 1}
178
179 siEx.ProcThreadAttributeList = e.attrList.List()
180 siEx.Cb = uint32(unsafe.Sizeof(*siEx))
181 if e.execConfig.token != 0 {
182 err = windows.CreateProcessAsUser(
183 e.execConfig.token,
184 argv0p,
185 argvp,
186 pSec,
187 tSec,
188 inheritHandles,
189 flags,
190 createEnvBlock(addCriticalEnv(dedupEnvCase(true, e.env))),
191 dirp,
192 &siEx.StartupInfo,
193 pi,
194 )
195 } else {
196 err = windows.CreateProcess(
197 argv0p,
198 argvp,
199 pSec,
200 tSec,
201 inheritHandles,
202 flags,
203 createEnvBlock(addCriticalEnv(dedupEnvCase(true, e.env))),
204 dirp,
205 &siEx.StartupInfo,
206 pi,
207 )
208 }
209 if err != nil {
210 return fmt.Errorf("failed to create process: %w", err)
211 }
212
213 defer func() {
214 _ = windows.CloseHandle(windows.Handle(pi.Thread))
215 }()
216
217
218
219 e.process, err = os.FindProcess(int(pi.ProcessId))
220 if err != nil {
221
222
223 if tErr := windows.TerminateProcess(pi.Process, 1); tErr != nil {
224 return fmt.Errorf("failed to terminate process after process not found: %w", tErr)
225 }
226 return fmt.Errorf("failed to find process after starting: %w", err)
227 }
228 return nil
229 }
230
231
232 func (e *Exec) Run() error {
233 if err := e.Start(); err != nil {
234 return err
235 }
236 return e.Wait()
237 }
238
239
240 func (e *Exec) close() error {
241 if e.procState == nil {
242 return errProcNotFinished
243 }
244 e.attrList.Delete()
245 e.closeStdio()
246 return nil
247 }
248
249
250 func (e *Exec) Pid() int {
251 if e.process == nil {
252 return -1
253 }
254 return e.process.Pid
255 }
256
257
258 func (e *Exec) Exited() bool {
259 if e.procState == nil {
260 return false
261 }
262 return e.procState.Exited()
263 }
264
265
266 func (e *Exec) ExitCode() int {
267 if e.procState == nil {
268 return -1
269 }
270 return e.procState.ExitCode()
271 }
272
273
274
275 func (e *Exec) Wait() (err error) {
276 if e.process == nil {
277 return errProcNotStarted
278 }
279 if e.waitCalled {
280 return errors.New("exec: Wait was already called")
281 }
282 e.waitCalled = true
283 e.procState, err = e.process.Wait()
284 if err != nil {
285 return err
286 }
287 return e.close()
288 }
289
290
291 func (e *Exec) Kill() error {
292 if e.process == nil {
293 return errProcNotStarted
294 }
295 return e.process.Kill()
296 }
297
298
299 func (e *Exec) Stdin() *os.File {
300 if e.cpty != nil {
301 return e.cpty.InPipe()
302 }
303 return e.stdioPipesOurSide[0]
304 }
305
306
307
308 func (e *Exec) Stdout() *os.File {
309 if e.cpty != nil {
310 return e.cpty.OutPipe()
311 }
312 return e.stdioPipesOurSide[1]
313 }
314
315
316
317 func (e *Exec) Stderr() *os.File {
318 if e.cpty != nil {
319 return e.cpty.OutPipe()
320 }
321 return e.stdioPipesOurSide[2]
322 }
323
324
325 func (e *Exec) setupStdio() error {
326 stdioRequested := e.stdin || e.stderr || e.stdout
327
328
329 if e.cpty != nil && stdioRequested {
330 return nil
331 }
332
333
334
335
336 if e.stdin {
337 pr, pw, err := os.Pipe()
338 if err != nil {
339 return err
340 }
341 e.stdioPipesOurSide[0] = pw
342
343 if err := windows.SetHandleInformation(
344 windows.Handle(pr.Fd()),
345 windows.HANDLE_FLAG_INHERIT,
346 windows.HANDLE_FLAG_INHERIT,
347 ); err != nil {
348 return fmt.Errorf("failed to make stdin pipe inheritable: %w", err)
349 }
350 e.stdioPipesProcSide[0] = pr
351 }
352
353 if e.stdout {
354 pr, pw, err := os.Pipe()
355 if err != nil {
356 return err
357 }
358 e.stdioPipesOurSide[1] = pr
359
360 if err := windows.SetHandleInformation(
361 windows.Handle(pw.Fd()),
362 windows.HANDLE_FLAG_INHERIT,
363 windows.HANDLE_FLAG_INHERIT,
364 ); err != nil {
365 return fmt.Errorf("failed to make stdout pipe inheritable: %w", err)
366 }
367 e.stdioPipesProcSide[1] = pw
368 }
369
370 if e.stderr {
371 pr, pw, err := os.Pipe()
372 if err != nil {
373 return err
374 }
375 e.stdioPipesOurSide[2] = pr
376
377 if err := windows.SetHandleInformation(
378 windows.Handle(pw.Fd()),
379 windows.HANDLE_FLAG_INHERIT,
380 windows.HANDLE_FLAG_INHERIT,
381 ); err != nil {
382 return fmt.Errorf("failed to make stderr pipe inheritable: %w", err)
383 }
384 e.stdioPipesProcSide[2] = pw
385 }
386 return nil
387 }
388
389 func (e *Exec) closeStdio() {
390 for i, file := range e.stdioPipesOurSide {
391 if file != nil {
392 file.Close()
393 }
394 e.stdioPipesOurSide[i] = nil
395 }
396 for i, file := range e.stdioPipesProcSide {
397 if file != nil {
398 file.Close()
399 }
400 e.stdioPipesProcSide[i] = nil
401 }
402 }
403
404
405
406
407
408
409 func isSlash(c uint8) bool {
410 return c == '\\' || c == '/'
411 }
412
413 func normalizeDir(dir string) (name string, err error) {
414 ndir, err := syscall.FullPath(dir)
415 if err != nil {
416 return "", err
417 }
418 if len(ndir) > 2 && isSlash(ndir[0]) && isSlash(ndir[1]) {
419
420 return "", syscall.EINVAL
421 }
422 return ndir, nil
423 }
424
425 func volToUpper(ch int) int {
426 if 'a' <= ch && ch <= 'z' {
427 ch += 'A' - 'a'
428 }
429 return ch
430 }
431
432 func joinExeDirAndFName(dir, p string) (name string, err error) {
433 if len(p) == 0 {
434 return "", syscall.EINVAL
435 }
436 if len(p) > 2 && isSlash(p[0]) && isSlash(p[1]) {
437
438 return p, nil
439 }
440 if len(p) > 1 && p[1] == ':' {
441
442 if len(p) == 2 {
443 return "", syscall.EINVAL
444 }
445 if isSlash(p[2]) {
446 return p, nil
447 } else {
448 d, err := normalizeDir(dir)
449 if err != nil {
450 return "", err
451 }
452 if volToUpper(int(p[0])) == volToUpper(int(d[0])) {
453 return syscall.FullPath(d + "\\" + p[2:])
454 } else {
455 return syscall.FullPath(p)
456 }
457 }
458 } else {
459
460 d, err := normalizeDir(dir)
461 if err != nil {
462 return "", err
463 }
464 if isSlash(p[0]) {
465 return windows.FullPath(d[:2] + p)
466 } else {
467 return windows.FullPath(d + "\\" + p)
468 }
469 }
470 }
471
472
473
474
475
476 func createEnvBlock(envv []string) *uint16 {
477 if len(envv) == 0 {
478 return &utf16.Encode([]rune("\x00\x00"))[0]
479 }
480 length := 0
481 for _, s := range envv {
482 length += len(s) + 1
483 }
484 length += 1
485
486 b := make([]byte, length)
487 i := 0
488 for _, s := range envv {
489 l := len(s)
490 copy(b[i:i+l], []byte(s))
491 copy(b[i+l:i+l+1], []byte{0})
492 i = i + l + 1
493 }
494 copy(b[i:i+1], []byte{0})
495
496 return &utf16.Encode([]rune(string(b)))[0]
497 }
498
499
500
501 func dedupEnvCase(caseInsensitive bool, env []string) []string {
502 out := make([]string, 0, len(env))
503 saw := make(map[string]int, len(env))
504 for _, kv := range env {
505 eq := strings.Index(kv, "=")
506 if eq < 0 {
507 out = append(out, kv)
508 continue
509 }
510 k := kv[:eq]
511 if caseInsensitive {
512 k = strings.ToLower(k)
513 }
514 if dupIdx, isDup := saw[k]; isDup {
515 out[dupIdx] = kv
516 continue
517 }
518 saw[k] = len(out)
519 out = append(out, kv)
520 }
521 return out
522 }
523
524
525
526
527 func addCriticalEnv(env []string) []string {
528 for _, kv := range env {
529 eq := strings.Index(kv, "=")
530 if eq < 0 {
531 continue
532 }
533 k := kv[:eq]
534 if strings.EqualFold(k, "SYSTEMROOT") {
535
536 return env
537 }
538 }
539 return append(env, "SYSTEMROOT="+os.Getenv("SYSTEMROOT"))
540 }
541
View as plain text