...
1 package sh
2
3 import (
4 "bytes"
5 "encoding/json"
6 "encoding/xml"
7 "errors"
8 "io"
9 "os"
10 "strings"
11 "syscall"
12 "time"
13 )
14
15 var ErrExecTimeout = errors.New("execute timeout")
16
17
18 func (s *Session) UnmarshalJSON(data interface{}) (err error) {
19 bufrw := bytes.NewBuffer(nil)
20 s.Stdout = bufrw
21 if err = s.Run(); err != nil {
22 return
23 }
24 return json.NewDecoder(bufrw).Decode(data)
25 }
26
27
28 func (s *Session) UnmarshalXML(data interface{}) (err error) {
29 bufrw := bytes.NewBuffer(nil)
30 s.Stdout = bufrw
31 if err = s.Run(); err != nil {
32 return
33 }
34 return xml.NewDecoder(bufrw).Decode(data)
35 }
36
37
38 func (s *Session) Start() (err error) {
39 s.started = true
40 var rd *io.PipeReader
41 var wr *io.PipeWriter
42 var length = len(s.cmds)
43 if s.ShowCMD {
44 var cmds = make([]string, 0, 4)
45 for _, cmd := range s.cmds {
46 cmds = append(cmds, strings.Join(cmd.Args, " "))
47 }
48 s.writePrompt(strings.Join(cmds, " | "))
49 }
50 for index, cmd := range s.cmds {
51 if index == 0 {
52 cmd.Stdin = s.Stdin
53 } else {
54 cmd.Stdin = rd
55 }
56 if index != length {
57 rd, wr = io.Pipe()
58 cmd.Stdout = wr
59 if s.PipeStdErrors {
60 cmd.Stderr = s.Stderr
61 } else {
62 cmd.Stderr = os.Stderr
63 }
64 }
65 if index == length-1 {
66 cmd.Stdout = s.Stdout
67 cmd.Stderr = s.Stderr
68 }
69 err = cmd.Start()
70 if err != nil {
71 return
72 }
73 }
74 return
75 }
76
77
78
79 func (s *Session) Wait() error {
80 var pipeErr, lastErr error
81 for _, cmd := range s.cmds {
82 if lastErr = cmd.Wait(); lastErr != nil {
83 pipeErr = lastErr
84 }
85 wr, ok := cmd.Stdout.(*io.PipeWriter)
86 if ok {
87 wr.Close()
88 }
89 }
90 if s.PipeFail {
91 return pipeErr
92 }
93 return lastErr
94 }
95
96 func (s *Session) Kill(sig os.Signal) {
97 for _, cmd := range s.cmds {
98 if cmd.Process != nil {
99 cmd.Process.Signal(sig)
100 }
101 }
102 }
103
104 func (s *Session) WaitTimeout(timeout time.Duration) (err error) {
105 select {
106 case <-time.After(timeout):
107 s.Kill(syscall.SIGKILL)
108 return ErrExecTimeout
109 case err = <-Go(s.Wait):
110 return err
111 }
112 }
113
114 func Go(f func() error) chan error {
115 ch := make(chan error, 1)
116 go func() {
117 ch <- f()
118 }()
119 return ch
120 }
121
122 func (s *Session) Run() (err error) {
123 if err = s.Start(); err != nil {
124 return
125 }
126 if s.timeout != time.Duration(0) {
127 return s.WaitTimeout(s.timeout)
128 }
129 return s.Wait()
130 }
131
132 func (s *Session) Output() (out []byte, err error) {
133 oldout := s.Stdout
134 defer func() {
135 s.Stdout = oldout
136 }()
137 stdout := bytes.NewBuffer(nil)
138 s.Stdout = stdout
139 err = s.Run()
140 out = stdout.Bytes()
141 return
142 }
143
144 func (s *Session) WriteStdout(f string) error {
145 oldout := s.Stdout
146 defer func() {
147 s.Stdout = oldout
148 }()
149
150 out, err := os.Create(f)
151 if err != nil {
152 return err
153 }
154 defer out.Close()
155 s.Stdout = out
156 return s.Run()
157 }
158
159 func (s *Session) CombinedOutput() (out []byte, err error) {
160 oldout := s.Stdout
161 olderr := s.Stderr
162 defer func() {
163 s.Stdout = oldout
164 s.Stderr = olderr
165 }()
166 stdout := bytes.NewBuffer(nil)
167 s.Stdout = stdout
168 s.Stderr = stdout
169
170 err = s.Run()
171 out = stdout.Bytes()
172 return
173 }
174
View as plain text