1
16
17 package remotecommand
18
19 import (
20 "bytes"
21 "context"
22 "crypto/rand"
23 "encoding/json"
24 "errors"
25 "io"
26 "net/http"
27 "net/http/httptest"
28 "net/url"
29 "strings"
30 "testing"
31 "time"
32
33 v1 "k8s.io/api/core/v1"
34 apierrors "k8s.io/apimachinery/pkg/api/errors"
35 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
36 "k8s.io/apimachinery/pkg/util/httpstream"
37 "k8s.io/apimachinery/pkg/util/httpstream/spdy"
38 remotecommandconsts "k8s.io/apimachinery/pkg/util/remotecommand"
39 "k8s.io/apimachinery/pkg/util/wait"
40 "k8s.io/client-go/rest"
41 )
42
43 type AttachFunc func(in io.Reader, out, err io.WriteCloser, tty bool, resize <-chan TerminalSize) error
44 type streamContext struct {
45 conn io.Closer
46 stdinStream io.ReadCloser
47 stdoutStream io.WriteCloser
48 stderrStream io.WriteCloser
49 writeStatus func(status *apierrors.StatusError) error
50 }
51
52 type streamAndReply struct {
53 httpstream.Stream
54 replySent <-chan struct{}
55 }
56
57 type fakeEmptyDataPty struct {
58 }
59
60 func (s *fakeEmptyDataPty) Read(p []byte) (int, error) {
61 return len(p), nil
62 }
63
64 func (s *fakeEmptyDataPty) Write(p []byte) (int, error) {
65 return len(p), nil
66 }
67
68 type fakeMassiveDataPty struct{}
69
70 func (s *fakeMassiveDataPty) Read(p []byte) (int, error) {
71 time.Sleep(time.Duration(1) * time.Second)
72 return copy(p, []byte{}), errors.New("client crashed after 1 second")
73 }
74
75 func (s *fakeMassiveDataPty) Write(p []byte) (int, error) {
76 time.Sleep(time.Duration(1) * time.Second)
77 return len(p), errors.New("return err")
78 }
79
80 func fakeMassiveDataAttacher(stdin io.Reader, stdout, stderr io.WriteCloser, tty bool, resize <-chan TerminalSize) error {
81
82 copyDone := make(chan struct{}, 3)
83
84 if stdin == nil {
85 return errors.New("stdin is requested")
86 }
87
88 go func() {
89 io.Copy(io.Discard, stdin)
90 copyDone <- struct{}{}
91 }()
92
93 go func() {
94 if stdout == nil {
95 return
96 }
97 copyDone <- writeMassiveData(stdout)
98 }()
99
100 go func() {
101 if stderr == nil {
102 return
103 }
104 copyDone <- writeMassiveData(stderr)
105 }()
106
107 select {
108 case <-copyDone:
109 return nil
110 }
111 }
112
113 func writeMassiveData(stdStream io.Writer) struct{} {
114 for {
115 _, err := io.Copy(stdStream, strings.NewReader("something"))
116 if err != nil && err.Error() != "EOF" {
117 break
118 }
119 }
120 return struct{}{}
121 }
122
123 func TestSPDYExecutorStream(t *testing.T) {
124 tests := []struct {
125 timeout time.Duration
126 name string
127 options StreamOptions
128 expectError string
129 attacher AttachFunc
130 }{
131 {
132 name: "stdoutBlockTest",
133 options: StreamOptions{
134 Stdin: &fakeMassiveDataPty{},
135 Stdout: &fakeMassiveDataPty{},
136 },
137 expectError: "",
138 attacher: fakeMassiveDataAttacher,
139 },
140 {
141 name: "stderrBlockTest",
142 options: StreamOptions{
143 Stdin: &fakeMassiveDataPty{},
144 Stderr: &fakeMassiveDataPty{},
145 },
146 expectError: "",
147 attacher: fakeMassiveDataAttacher,
148 },
149 {
150 timeout: 500 * time.Millisecond,
151 name: "timeoutTest",
152 options: StreamOptions{
153 Stdin: &fakeMassiveDataPty{},
154 Stderr: &fakeMassiveDataPty{},
155 },
156 expectError: context.DeadlineExceeded.Error(),
157 attacher: fakeMassiveDataAttacher,
158 },
159 }
160
161 for _, test := range tests {
162 t.Run(test.name, func(t *testing.T) {
163 server := newTestHTTPServer(test.attacher, &test.options)
164 defer server.Close()
165
166 ctx, cancel := context.Background(), func() {}
167 if test.timeout > 0 {
168 ctx, cancel = context.WithTimeout(ctx, test.timeout)
169 }
170 defer cancel()
171
172 err := attach2Server(ctx, server.URL, test.options)
173
174 gotError := ""
175 if err != nil {
176 gotError = err.Error()
177 }
178 if test.expectError != gotError {
179 t.Errorf("%s: expected [%v], got [%v]", test.name, test.expectError, gotError)
180 }
181 })
182 }
183 }
184
185 func newTestHTTPServer(f AttachFunc, options *StreamOptions) *httptest.Server {
186
187 server := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
188 ctx, err := createHTTPStreams(writer, request, options)
189 if err != nil {
190 return
191 }
192 defer ctx.conn.Close()
193
194
195 err = f(ctx.stdinStream, ctx.stdoutStream, ctx.stderrStream, false, nil)
196 if err != nil {
197 ctx.writeStatus(apierrors.NewInternalError(err))
198 } else {
199 ctx.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{
200 Status: metav1.StatusSuccess,
201 }})
202 }
203 }))
204 return server
205 }
206
207 func attach2Server(ctx context.Context, rawURL string, options StreamOptions) error {
208 uri, _ := url.Parse(rawURL)
209 exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri)
210 if err != nil {
211 return err
212 }
213
214 e := make(chan error, 1)
215 go func(e chan error) {
216 e <- exec.StreamWithContext(ctx, options)
217 }(e)
218 select {
219 case err := <-e:
220 return err
221 case <-time.After(wait.ForeverTestTimeout):
222 return errors.New("execute timeout")
223 }
224 }
225
226
227 func createHTTPStreams(w http.ResponseWriter, req *http.Request, opts *StreamOptions) (*streamContext, error) {
228 _, err := httpstream.Handshake(req, w, []string{remotecommandconsts.StreamProtocolV4Name})
229 if err != nil {
230 return nil, err
231 }
232
233 upgrader := spdy.NewResponseUpgrader()
234 streamCh := make(chan streamAndReply)
235 conn := upgrader.UpgradeResponse(w, req, func(stream httpstream.Stream, replySent <-chan struct{}) error {
236 streamCh <- streamAndReply{Stream: stream, replySent: replySent}
237 return nil
238 })
239 ctx := &streamContext{
240 conn: conn,
241 }
242
243
244 replyChan := make(chan struct{}, 4)
245 defer close(replyChan)
246 receivedStreams := 0
247 expectedStreams := 1
248 if opts.Stdout != nil {
249 expectedStreams++
250 }
251 if opts.Stdin != nil {
252 expectedStreams++
253 }
254 if opts.Stderr != nil {
255 expectedStreams++
256 }
257 WaitForStreams:
258 for {
259 select {
260 case stream := <-streamCh:
261 streamType := stream.Headers().Get(v1.StreamType)
262 switch streamType {
263 case v1.StreamTypeError:
264 replyChan <- struct{}{}
265 ctx.writeStatus = v4WriteStatusFunc(stream)
266 case v1.StreamTypeStdout:
267 replyChan <- struct{}{}
268 ctx.stdoutStream = stream
269 case v1.StreamTypeStdin:
270 replyChan <- struct{}{}
271 ctx.stdinStream = stream
272 case v1.StreamTypeStderr:
273 replyChan <- struct{}{}
274 ctx.stderrStream = stream
275 default:
276
277 return nil, errors.New("unimplemented stream type")
278 }
279 case <-replyChan:
280 receivedStreams++
281 if receivedStreams == expectedStreams {
282 break WaitForStreams
283 }
284 }
285 }
286
287 return ctx, nil
288 }
289
290 func v4WriteStatusFunc(stream io.Writer) func(status *apierrors.StatusError) error {
291 return func(status *apierrors.StatusError) error {
292 bs, err := json.Marshal(status.Status())
293 if err != nil {
294 return err
295 }
296 _, err = stream.Write(bs)
297 return err
298 }
299 }
300
301
302 type writeDetector struct {
303 written chan bool
304 closed bool
305 io.Writer
306 }
307
308 func newWriterDetector(w io.Writer) *writeDetector {
309 return &writeDetector{
310 written: make(chan bool),
311 Writer: w,
312 }
313 }
314
315 func (w *writeDetector) BlockUntilWritten() {
316 <-w.written
317 }
318
319 func (w *writeDetector) Write(p []byte) (n int, err error) {
320 if !w.closed {
321 close(w.written)
322 w.closed = true
323 }
324 return w.Writer.Write(p)
325 }
326
327
328
329
330 func TestStreamExitsAfterConnectionIsClosed(t *testing.T) {
331 writeDetector := newWriterDetector(&fakeEmptyDataPty{})
332 options := StreamOptions{
333 Stdin: &fakeEmptyDataPty{},
334 Stdout: writeDetector,
335 }
336 server := newTestHTTPServer(fakeMassiveDataAttacher, &options)
337
338 ctx, cancelFn := context.WithTimeout(context.Background(), 500*time.Millisecond)
339 defer cancelFn()
340
341 uri, _ := url.Parse(server.URL)
342 exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri)
343 if err != nil {
344 t.Fatal(err)
345 }
346 streamExec := exec.(*spdyStreamExecutor)
347
348 conn, streamer, err := streamExec.newConnectionAndStream(ctx, options)
349 if err != nil {
350 t.Fatal(err)
351 }
352
353 errorChan := make(chan error)
354 go func() {
355 errorChan <- streamer.stream(conn)
356 }()
357
358
359 writeDetector.BlockUntilWritten()
360
361
362 conn.Close()
363
364 select {
365 case <-time.After(1 * time.Second):
366 t.Fatalf("expect stream to be closed after connection is closed.")
367 case <-errorChan:
368 return
369 }
370 }
371
372 func TestStreamRandomData(t *testing.T) {
373 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
374 var stdin, stdout bytes.Buffer
375 ctx, err := createHTTPStreams(w, req, &StreamOptions{
376 Stdin: &stdin,
377 Stdout: &stdout,
378 })
379 if err != nil {
380 t.Errorf("error on createHTTPStreams: %v", err)
381 return
382 }
383 defer ctx.conn.Close()
384
385 io.Copy(ctx.stdoutStream, ctx.stdinStream)
386 }))
387
388 defer server.Close()
389
390 uri, _ := url.Parse(server.URL)
391 exec, err := NewSPDYExecutor(&rest.Config{Host: uri.Host}, "POST", uri)
392 if err != nil {
393 t.Fatal(err)
394 }
395
396 randomData := make([]byte, 1024*1024)
397 if _, err := rand.Read(randomData); err != nil {
398 t.Errorf("unexpected error reading random data: %v", err)
399 }
400 var stdout bytes.Buffer
401 options := &StreamOptions{
402 Stdin: bytes.NewReader(randomData),
403 Stdout: &stdout,
404 }
405 errorChan := make(chan error)
406 go func() {
407 errorChan <- exec.StreamWithContext(context.Background(), *options)
408 }()
409
410 select {
411 case <-time.After(wait.ForeverTestTimeout):
412 t.Fatalf("expect stream to be closed after connection is closed.")
413 case err := <-errorChan:
414 if err != nil {
415 t.Errorf("unexpected error")
416 }
417 }
418
419 data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
420 if err != nil {
421 t.Errorf("error reading the stream: %v", err)
422 return
423 }
424 if !bytes.Equal(randomData, data) {
425 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
426 }
427
428 }
429
View as plain text