...
1
16
17 package remotecommand
18
19 import (
20 "fmt"
21 "io"
22 "net/http"
23 "sync"
24
25 "k8s.io/api/core/v1"
26 "k8s.io/apimachinery/pkg/util/runtime"
27 )
28
29
30
31
32
33 type streamProtocolV2 struct {
34 StreamOptions
35
36 errorStream io.Reader
37 remoteStdin io.ReadWriteCloser
38 remoteStdout io.Reader
39 remoteStderr io.Reader
40 }
41
42 var _ streamProtocolHandler = &streamProtocolV2{}
43
44 func newStreamProtocolV2(options StreamOptions) streamProtocolHandler {
45 return &streamProtocolV2{
46 StreamOptions: options,
47 }
48 }
49
50 func (p *streamProtocolV2) createStreams(conn streamCreator) error {
51 var err error
52 headers := http.Header{}
53
54
55 headers.Set(v1.StreamType, v1.StreamTypeError)
56 p.errorStream, err = conn.CreateStream(headers)
57 if err != nil {
58 return err
59 }
60
61
62 if p.Stdin != nil {
63 headers.Set(v1.StreamType, v1.StreamTypeStdin)
64 p.remoteStdin, err = conn.CreateStream(headers)
65 if err != nil {
66 return err
67 }
68 }
69
70
71 if p.Stdout != nil {
72 headers.Set(v1.StreamType, v1.StreamTypeStdout)
73 p.remoteStdout, err = conn.CreateStream(headers)
74 if err != nil {
75 return err
76 }
77 }
78
79
80 if p.Stderr != nil && !p.Tty {
81 headers.Set(v1.StreamType, v1.StreamTypeStderr)
82 p.remoteStderr, err = conn.CreateStream(headers)
83 if err != nil {
84 return err
85 }
86 }
87 return nil
88 }
89
90 func (p *streamProtocolV2) copyStdin() {
91 if p.Stdin != nil {
92 var once sync.Once
93
94
95 go func() {
96 defer runtime.HandleCrash()
97
98
99
100
101 defer once.Do(func() { p.remoteStdin.Close() })
102
103 if _, err := io.Copy(p.remoteStdin, readerWrapper{p.Stdin}); err != nil {
104 runtime.HandleError(err)
105 }
106 }()
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122 go func() {
123 defer runtime.HandleCrash()
124 defer once.Do(func() { p.remoteStdin.Close() })
125
126
127
128 if _, err := io.Copy(io.Discard, p.remoteStdin); err != nil {
129 runtime.HandleError(err)
130 }
131 }()
132 }
133 }
134
135 func (p *streamProtocolV2) copyStdout(wg *sync.WaitGroup) {
136 if p.Stdout == nil {
137 return
138 }
139
140 wg.Add(1)
141 go func() {
142 defer runtime.HandleCrash()
143 defer wg.Done()
144
145
146
147 defer io.Copy(io.Discard, p.remoteStdout)
148
149 if _, err := io.Copy(p.Stdout, p.remoteStdout); err != nil {
150 runtime.HandleError(err)
151 }
152 }()
153 }
154
155 func (p *streamProtocolV2) copyStderr(wg *sync.WaitGroup) {
156 if p.Stderr == nil || p.Tty {
157 return
158 }
159
160 wg.Add(1)
161 go func() {
162 defer runtime.HandleCrash()
163 defer wg.Done()
164 defer io.Copy(io.Discard, p.remoteStderr)
165
166 if _, err := io.Copy(p.Stderr, p.remoteStderr); err != nil {
167 runtime.HandleError(err)
168 }
169 }()
170 }
171
172 func (p *streamProtocolV2) stream(conn streamCreator) error {
173 if err := p.createStreams(conn); err != nil {
174 return err
175 }
176
177
178
179 errorChan := watchErrorStream(p.errorStream, &errorDecoderV2{})
180
181 p.copyStdin()
182
183 var wg sync.WaitGroup
184 p.copyStdout(&wg)
185 p.copyStderr(&wg)
186
187
188 wg.Wait()
189
190
191 return <-errorChan
192 }
193
194
195 type errorDecoderV2 struct{}
196
197 func (d *errorDecoderV2) decode(message []byte) error {
198 return fmt.Errorf("error executing remote command: %s", message)
199 }
200
View as plain text