1
16
17 package remotecommand
18
19 import (
20 "bytes"
21 "context"
22 "crypto/rand"
23 "encoding/json"
24 "fmt"
25 "io"
26 "math"
27 mrand "math/rand"
28 "net/http"
29 "net/http/httptest"
30 "net/url"
31 "reflect"
32 "strings"
33 "sync"
34 "testing"
35 "time"
36
37 gwebsocket "github.com/gorilla/websocket"
38
39 v1 "k8s.io/api/core/v1"
40 apierrors "k8s.io/apimachinery/pkg/api/errors"
41 metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
42 "k8s.io/apimachinery/pkg/util/httpstream/wsstream"
43 "k8s.io/apimachinery/pkg/util/remotecommand"
44 "k8s.io/apimachinery/pkg/util/wait"
45 "k8s.io/client-go/rest"
46 clientcmdapi "k8s.io/client-go/tools/clientcmd/api"
47 )
48
49
50
51
52
53
54 func TestWebSocketClient_LoopbackStdinToStdout(t *testing.T) {
55
56 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
57 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
58 if err != nil {
59 t.Fatalf("error on webSocketServerStreams: %v", err)
60 }
61 defer conns.conn.Close()
62
63 _, err = io.Copy(conns.stdoutStream, conns.stdinStream)
64 if err != nil {
65 t.Fatalf("error copying STDIN to STDOUT: %v", err)
66 }
67 }))
68 defer websocketServer.Close()
69
70
71
72 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
73 websocketLocation, err := url.Parse(websocketServer.URL)
74 if err != nil {
75 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
76 }
77 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
78 if err != nil {
79 t.Errorf("unexpected error creating websocket executor: %v", err)
80 }
81
82
83 randomSize := 1024 * 1024
84 randomData := make([]byte, randomSize)
85 if _, err := rand.Read(randomData); err != nil {
86 t.Errorf("unexpected error reading random data: %v", err)
87 }
88 var stdout bytes.Buffer
89 options := &StreamOptions{
90 Stdin: bytes.NewReader(randomData),
91 Stdout: &stdout,
92 }
93 errorChan := make(chan error)
94 go func() {
95
96 errorChan <- exec.StreamWithContext(context.Background(), *options)
97 }()
98
99 select {
100 case <-time.After(wait.ForeverTestTimeout):
101 t.Fatalf("expect stream to be closed after connection is closed.")
102 case err := <-errorChan:
103 if err != nil {
104 t.Errorf("unexpected error")
105 }
106
107 streamExec := exec.(*wsStreamExecutor)
108 if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
109 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
110 }
111 }
112 data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
113 if err != nil {
114 t.Fatalf("error reading the stream: %v", err)
115 }
116
117 if !bytes.Equal(randomData, data) {
118 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
119 }
120 }
121
122
123
124 func TestWebSocketClient_DifferentBufferSizes(t *testing.T) {
125
126
127 bufferSizes := []int{1 * 1024, 4 * 1024, 64 * 1024, 128 * 1024}
128 for _, bufferSize := range bufferSizes {
129
130 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
131 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
132 if err != nil {
133 t.Fatalf("error on webSocketServerStreams: %v", err)
134 }
135 defer conns.conn.Close()
136
137 buffer := make([]byte, bufferSize)
138 _, err = io.CopyBuffer(conns.stdoutStream, conns.stdinStream, buffer)
139 if err != nil {
140 t.Fatalf("error copying STDIN to STDOUT: %v", err)
141 }
142 }))
143 defer websocketServer.Close()
144
145
146
147 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
148 websocketLocation, err := url.Parse(websocketServer.URL)
149 if err != nil {
150 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
151 }
152 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
153 if err != nil {
154 t.Errorf("unexpected error creating websocket executor: %v", err)
155 }
156
157
158 randomSize := 1024 * 1024
159 randomData := make([]byte, randomSize)
160 if _, err := rand.Read(randomData); err != nil {
161 t.Errorf("unexpected error reading random data: %v", err)
162 }
163 var stdout bytes.Buffer
164 options := &StreamOptions{
165 Stdin: bytes.NewReader(randomData),
166 Stdout: &stdout,
167 }
168 errorChan := make(chan error)
169 go func() {
170
171 errorChan <- exec.StreamWithContext(context.Background(), *options)
172 }()
173
174 select {
175 case <-time.After(wait.ForeverTestTimeout):
176 t.Fatalf("expect stream to be closed after connection is closed.")
177 case err := <-errorChan:
178 if err != nil {
179 t.Errorf("unexpected error")
180 }
181
182 streamExec := exec.(*wsStreamExecutor)
183 if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
184 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
185 }
186 }
187 data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
188 if err != nil {
189 t.Errorf("error reading the stream: %v", err)
190 return
191 }
192
193 if !bytes.Equal(randomData, data) {
194 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
195 }
196 }
197 }
198
199
200
201
202
203 func TestWebSocketClient_LoopbackStdinAsPipe(t *testing.T) {
204
205 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
206 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
207 if err != nil {
208 t.Fatalf("error on webSocketServerStreams: %v", err)
209 }
210 defer conns.conn.Close()
211
212 _, err = io.Copy(conns.stdoutStream, conns.stdinStream)
213 if err != nil {
214 t.Fatalf("error copying STDIN to STDOUT: %v", err)
215 }
216 }))
217 defer websocketServer.Close()
218
219
220
221 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true"
222 websocketLocation, err := url.Parse(websocketServer.URL)
223 if err != nil {
224 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
225 }
226 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
227 if err != nil {
228 t.Errorf("unexpected error creating websocket executor: %v", err)
229 }
230
231
232 randomSize := 1024 * 1024
233 randomData := make([]byte, randomSize)
234 if _, err := rand.Read(randomData); err != nil {
235 t.Errorf("unexpected error reading random data: %v", err)
236 }
237 reader, writer := io.Pipe()
238 var stdout bytes.Buffer
239 options := &StreamOptions{
240 Stdin: reader,
241 Stdout: &stdout,
242 }
243 errorChan := make(chan error)
244 go func() {
245
246 errorChan <- exec.StreamWithContext(context.Background(), *options)
247 }()
248
249 _, err = writer.Write(randomData)
250 if err != nil {
251 t.Fatalf("unable to write random data to STDIN pipe: %v", err)
252 }
253 writer.Close()
254
255 select {
256 case <-time.After(wait.ForeverTestTimeout):
257 t.Fatalf("expect stream to be closed after connection is closed.")
258 case err := <-errorChan:
259 if err != nil {
260 t.Errorf("unexpected error")
261 }
262
263 streamExec := exec.(*wsStreamExecutor)
264 if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
265 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
266 }
267 }
268 data, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
269 if err != nil {
270 t.Errorf("error reading the stream: %v", err)
271 return
272 }
273
274 if !bytes.Equal(randomData, data) {
275 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
276 }
277 }
278
279
280
281
282
283
284 func TestWebSocketClient_LoopbackStdinToStderr(t *testing.T) {
285
286 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
287 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
288 if err != nil {
289 t.Fatalf("error on webSocketServerStreams: %v", err)
290 }
291 defer conns.conn.Close()
292
293 _, err = io.Copy(conns.stderrStream, conns.stdinStream)
294 if err != nil {
295 t.Fatalf("error copying STDIN to STDERR: %v", err)
296 }
297 }))
298 defer websocketServer.Close()
299
300
301
302 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stderr=true"
303 websocketLocation, err := url.Parse(websocketServer.URL)
304 if err != nil {
305 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
306 }
307 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
308 if err != nil {
309 t.Errorf("unexpected error creating websocket executor: %v", err)
310 }
311
312
313 randomSize := 1024 * 1024
314 randomData := make([]byte, randomSize)
315 if _, err := rand.Read(randomData); err != nil {
316 t.Errorf("unexpected error reading random data: %v", err)
317 }
318 var stderr bytes.Buffer
319 options := &StreamOptions{
320 Stdin: bytes.NewReader(randomData),
321 Stderr: &stderr,
322 }
323 errorChan := make(chan error)
324 go func() {
325
326 errorChan <- exec.StreamWithContext(context.Background(), *options)
327 }()
328
329 select {
330 case <-time.After(wait.ForeverTestTimeout):
331 t.Fatalf("expect stream to be closed after connection is closed.")
332 case err := <-errorChan:
333 if err != nil {
334 t.Errorf("unexpected error")
335 }
336
337 streamExec := exec.(*wsStreamExecutor)
338 if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
339 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
340 }
341 }
342 data, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
343 if err != nil {
344 t.Errorf("error reading the stream: %v", err)
345 return
346 }
347
348 if !bytes.Equal(randomData, data) {
349 t.Errorf("unexpected data received: %d sent: %d", len(data), len(randomData))
350 }
351 }
352
353
354
355 func TestWebSocketClient_MultipleReadChannels(t *testing.T) {
356
357
358 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
359 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
360 if err != nil {
361 t.Fatalf("error on webSocketServerStreams: %v", err)
362 }
363 defer conns.conn.Close()
364
365 stdinReader := io.TeeReader(conns.stdinStream, conns.stderrStream)
366
367 _, err = io.Copy(conns.stdoutStream, stdinReader)
368 if err != nil {
369 t.Errorf("error copying STDIN to STDOUT: %v", err)
370 }
371 }))
372 defer websocketServer.Close()
373
374
375 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stdout=true" + "&" + "stderr=true"
376 websocketLocation, err := url.Parse(websocketServer.URL)
377 if err != nil {
378 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
379 }
380 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
381 if err != nil {
382 t.Errorf("unexpected error creating websocket executor: %v", err)
383 }
384
385
386 randomSize := 1024 * 1024
387 randomData := make([]byte, randomSize)
388 if _, err := rand.Read(randomData); err != nil {
389 t.Errorf("unexpected error reading random data: %v", err)
390 }
391 var stdout, stderr bytes.Buffer
392 options := &StreamOptions{
393 Stdin: bytes.NewReader(randomData),
394 Stdout: &stdout,
395 Stderr: &stderr,
396 }
397 errorChan := make(chan error)
398 go func() {
399 errorChan <- exec.StreamWithContext(context.Background(), *options)
400 }()
401
402 select {
403 case <-time.After(wait.ForeverTestTimeout):
404 t.Fatalf("expect stream to be closed after connection is closed.")
405 case err := <-errorChan:
406 if err != nil {
407 t.Errorf("unexpected error: %v", err)
408 }
409
410 streamExec := exec.(*wsStreamExecutor)
411 if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
412 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
413 }
414 }
415
416 stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
417 if err != nil {
418 t.Fatalf("error reading the stream: %v", err)
419 }
420 if !bytes.Equal(stdoutBytes, randomData) {
421 t.Errorf("unexpected data received (%d) sent (%d)", len(stdoutBytes), len(randomData))
422 }
423
424 stderrBytes, err := io.ReadAll(bytes.NewReader(stderr.Bytes()))
425 if err != nil {
426 t.Fatalf("error reading the stream: %v", err)
427 }
428 if !bytes.Equal(stderrBytes, randomData) {
429 t.Errorf("unexpected data received (%d) sent (%d)", len(stderrBytes), len(randomData))
430 }
431 }
432
433
434 func randomExitCode() int {
435 errorCode := mrand.Intn(128)
436 if errorCode == 0 {
437 errorCode = 1
438 }
439 return errorCode
440 }
441
442
443
444 func TestWebSocketClient_ErrorStream(t *testing.T) {
445 expectedExitCode := randomExitCode()
446
447 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
448 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
449 if err != nil {
450 t.Fatalf("error on webSocketServerStreams: %v", err)
451 }
452 defer conns.conn.Close()
453 _, err = io.Copy(conns.stderrStream, conns.stdinStream)
454 if err != nil {
455 t.Fatalf("error copying STDIN to STDERR: %v", err)
456 }
457
458 err = conns.writeStatus(&apierrors.StatusError{ErrStatus: metav1.Status{
459 Status: metav1.StatusFailure,
460 Reason: remotecommand.NonZeroExitCodeReason,
461 Details: &metav1.StatusDetails{
462 Causes: []metav1.StatusCause{
463 {
464 Type: remotecommand.ExitCodeCauseType,
465 Message: fmt.Sprintf("%d", expectedExitCode),
466 },
467 },
468 },
469 }})
470 if err != nil {
471 t.Fatalf("error writing status: %v", err)
472 }
473 }))
474 defer websocketServer.Close()
475
476
477 websocketServer.URL = websocketServer.URL + "?" + "stdin=true" + "&" + "stderr=true"
478 websocketLocation, err := url.Parse(websocketServer.URL)
479 if err != nil {
480 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
481 }
482 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
483 if err != nil {
484 t.Errorf("unexpected error creating websocket executor: %v", err)
485 }
486 randomData := make([]byte, 256)
487 if _, err := rand.Read(randomData); err != nil {
488 t.Errorf("unexpected error reading random data: %v", err)
489 }
490 var stderr bytes.Buffer
491 options := &StreamOptions{
492 Stdin: bytes.NewReader(randomData),
493 Stderr: &stderr,
494 }
495 errorChan := make(chan error)
496 go func() {
497
498 errorChan <- exec.StreamWithContext(context.Background(), *options)
499 }()
500
501 select {
502 case <-time.After(wait.ForeverTestTimeout):
503 t.Fatalf("expect stream to be closed after connection is closed.")
504 case err := <-errorChan:
505
506 streamExec := exec.(*wsStreamExecutor)
507 if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
508 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
509 }
510
511 if err == nil {
512 t.Errorf("expected error, but received none")
513 }
514 expectedError := fmt.Sprintf("command terminated with exit code %d", expectedExitCode)
515
516 if expectedError != err.Error() {
517 t.Errorf("expected error (%s), got (%s)", expectedError, err)
518 }
519 }
520 }
521
522
523
524 type fakeTerminalSizeQueue struct {
525 maxSizes int
526 terminalSizes []TerminalSize
527 }
528
529
530
531 func newTerminalSizeQueue(max int) *fakeTerminalSizeQueue {
532 return &fakeTerminalSizeQueue{
533 maxSizes: max,
534 terminalSizes: make([]TerminalSize, 0, max),
535 }
536 }
537
538
539
540
541 func (f *fakeTerminalSizeQueue) Next() *TerminalSize {
542 if len(f.terminalSizes) >= f.maxSizes {
543 return nil
544 }
545 size := randomTerminalSize()
546 f.terminalSizes = append(f.terminalSizes, size)
547 return &size
548 }
549
550
551
552 func randomTerminalSize() TerminalSize {
553 randWidth := uint16(mrand.Intn(int(math.Pow(2, 16))))
554 randHeight := uint16(mrand.Intn(int(math.Pow(2, 16))))
555 return TerminalSize{
556 Width: randWidth,
557 Height: randHeight,
558 }
559 }
560
561
562
563
564 type randReader struct {
565 randBytes []byte
566 closed bool
567 lock sync.Mutex
568 }
569
570
571
572
573 func (r *randReader) Read(b []byte) (int, error) {
574 r.lock.Lock()
575 defer r.lock.Unlock()
576 if r.closed {
577 return 0, io.EOF
578 }
579 n, err := rand.Read(b)
580 c := bytes.Clone(b)
581 r.randBytes = append(r.randBytes, c...)
582 return n, err
583 }
584
585
586
587
588 func (r *randReader) Close() (err error) {
589 r.lock.Lock()
590 defer r.lock.Unlock()
591 r.closed = true
592 return nil
593 }
594
595
596
597 func TestWebSocketClient_MultipleWriteChannels(t *testing.T) {
598
599
600 numSizeQueue := 10000
601 sizeQueue := newTerminalSizeQueue(numSizeQueue)
602 actualTerminalSizes := make([]TerminalSize, 0, numSizeQueue)
603
604 stdinReader := randReader{randBytes: []byte{}, closed: false}
605
606
607
608 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
609 var wg sync.WaitGroup
610 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
611 if err != nil {
612 t.Fatalf("error on webSocketServerStreams: %v", err)
613 }
614 defer conns.conn.Close()
615
616 wg.Add(1)
617 go func() {
618 _, err := io.Copy(conns.stdoutStream, conns.stdinStream)
619 if err != nil {
620 t.Errorf("error copying STDIN to STDOUT: %v", err)
621 }
622 wg.Done()
623 }()
624
625 for i := 0; i < numSizeQueue; i++ {
626 actualTerminalSize := <-conns.resizeChan
627 actualTerminalSizes = append(actualTerminalSizes, actualTerminalSize)
628 }
629 stdinReader.Close()
630 wg.Wait()
631 }))
632 defer websocketServer.Close()
633
634
635 websocketServer.URL = websocketServer.URL + "?" + "tty=true" + "&" + "stdin=true" + "&" + "stdout=true"
636 websocketLocation, err := url.Parse(websocketServer.URL)
637 if err != nil {
638 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
639 }
640 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
641 if err != nil {
642 t.Errorf("unexpected error creating websocket executor: %v", err)
643 }
644 var stdout bytes.Buffer
645 options := &StreamOptions{
646 Stdin: &stdinReader,
647 Stdout: &stdout,
648 Tty: true,
649 TerminalSizeQueue: sizeQueue,
650 }
651 errorChan := make(chan error)
652 go func() {
653 errorChan <- exec.StreamWithContext(context.Background(), *options)
654 }()
655
656 select {
657 case <-time.After(wait.ForeverTestTimeout):
658 t.Fatalf("expect stream to be closed after connection is closed.")
659 case err := <-errorChan:
660 if err != nil {
661 t.Errorf("unexpected error: %v", err)
662 }
663
664 streamExec := exec.(*wsStreamExecutor)
665 if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
666 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
667 }
668 }
669
670
671 stdoutBytes, err := io.ReadAll(bytes.NewReader(stdout.Bytes()))
672 if err != nil {
673 t.Fatalf("error reading the stream: %v", err)
674 }
675 if len(stdoutBytes) == 0 {
676 t.Errorf("No STDOUT bytes processed before resize stream finished: %d", len(stdoutBytes))
677 }
678 if !bytes.Equal(stdoutBytes, stdinReader.randBytes) {
679 t.Errorf("unexpected data received (%d) sent (%d)", len(stdoutBytes), len(stdinReader.randBytes))
680 }
681
682
683 if len(actualTerminalSizes) != numSizeQueue {
684 t.Errorf("expected received terminal size window (%d), got (%d)",
685 numSizeQueue, len(actualTerminalSizes))
686 }
687 for i, actual := range actualTerminalSizes {
688 expected := sizeQueue.terminalSizes[i]
689 if !reflect.DeepEqual(expected, actual) {
690 t.Errorf("expected terminal resize window %v, got %v", expected, actual)
691 }
692 }
693 }
694
695
696
697 func TestWebSocketClient_ProtocolVersions(t *testing.T) {
698
699
700 var upgrader = gwebsocket.Upgrader{
701 CheckOrigin: func(r *http.Request) bool {
702 return true
703 },
704 Subprotocols: []string{
705 remotecommand.StreamProtocolV4Name,
706 remotecommand.StreamProtocolV3Name,
707 remotecommand.StreamProtocolV2Name,
708 },
709 }
710
711 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
712 conn, err := upgrader.Upgrade(w, req, nil)
713 if err != nil {
714 t.Fatalf("unable to upgrade to create websocket connection: %v", err)
715 }
716 defer conn.Close()
717 }))
718 defer websocketServer.Close()
719
720
721 websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
722 websocketLocation, err := url.Parse(websocketServer.URL)
723 if err != nil {
724 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
725 }
726 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
727 if err != nil {
728 t.Errorf("unexpected error creating websocket executor: %v", err)
729 }
730
731
732 versions := []string{
733 remotecommand.StreamProtocolV4Name,
734 remotecommand.StreamProtocolV3Name,
735 remotecommand.StreamProtocolV2Name,
736 }
737 for _, requestedVersion := range versions {
738 streamExec := exec.(*wsStreamExecutor)
739 streamExec.protocols = []string{requestedVersion}
740 var stdout bytes.Buffer
741 options := &StreamOptions{
742 Stdout: &stdout,
743 }
744 errorChan := make(chan error)
745 go func() {
746
747 errorChan <- exec.StreamWithContext(context.Background(), *options)
748 }()
749
750 select {
751 case <-time.After(wait.ForeverTestTimeout):
752 t.Fatalf("expect stream to be closed after connection is closed.")
753 case <-errorChan:
754
755 streamExec := exec.(*wsStreamExecutor)
756 if requestedVersion != streamExec.negotiated {
757 t.Fatalf("expected protocol version (%s), got (%s)", requestedVersion, streamExec.negotiated)
758 }
759 }
760 }
761 }
762
763
764
765
766 func TestWebSocketClient_BadHandshake(t *testing.T) {
767
768 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
769
770 _, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
771 if err == nil {
772 t.Fatalf("expected error, but received none.")
773 }
774 if !strings.Contains(err.Error(), "websocket server finished before becoming ready") {
775 t.Errorf("expected websocket server error, but got: %v", err)
776 }
777 }))
778 defer websocketServer.Close()
779
780 websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
781 websocketLocation, err := url.Parse(websocketServer.URL)
782 if err != nil {
783 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
784 }
785 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
786 if err != nil {
787 t.Errorf("unexpected error creating websocket executor: %v", err)
788 }
789 streamExec := exec.(*wsStreamExecutor)
790
791 streamExec.protocols = []string{remotecommand.StreamProtocolV4Name}
792
793 var stdout bytes.Buffer
794 options := &StreamOptions{
795 Stdout: &stdout,
796 }
797 errorChan := make(chan error)
798 go func() {
799
800 errorChan <- streamExec.StreamWithContext(context.Background(), *options)
801 }()
802
803 select {
804 case <-time.After(wait.ForeverTestTimeout):
805 t.Fatalf("expect stream to be closed after connection is closed.")
806 case err := <-errorChan:
807
808 if err == nil {
809 t.Errorf("expected error but received none")
810 }
811 if !strings.Contains(err.Error(), "bad handshake") {
812 t.Errorf("expected bad handshake error, got (%s)", err)
813 }
814 }
815 }
816
817
818
819 func TestWebSocketClient_HeartbeatTimeout(t *testing.T) {
820 blockRequestCtx, unblockRequest := context.WithCancel(context.Background())
821 defer unblockRequest()
822
823 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
824 conns, err := webSocketServerStreams(req, w, streamOptionsFromRequest(req))
825 if err != nil {
826 t.Fatalf("error on webSocketServerStreams: %v", err)
827 }
828 defer conns.conn.Close()
829 <-blockRequestCtx.Done()
830 }))
831 defer websocketServer.Close()
832
833 websocketServer.URL = websocketServer.URL + "?" + "stdin=true"
834 websocketLocation, err := url.Parse(websocketServer.URL)
835 if err != nil {
836 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
837 }
838 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
839 if err != nil {
840 t.Errorf("unexpected error creating websocket executor: %v", err)
841 }
842 streamExec := exec.(*wsStreamExecutor)
843
844 pingPeriod := wait.ForeverTestTimeout
845 pingDeadline := time.Second
846 streamExec.heartbeatPeriod = pingPeriod
847 streamExec.heartbeatDeadline = pingDeadline
848
849 randomData := make([]byte, 128)
850 if _, err := rand.Read(randomData); err != nil {
851 t.Errorf("unexpected error reading random data: %v", err)
852 }
853 options := &StreamOptions{
854 Stdin: bytes.NewReader(randomData),
855 }
856 errorChan := make(chan error)
857 go func() {
858
859 errorChan <- streamExec.StreamWithContext(context.Background(), *options)
860 }()
861
862 select {
863 case <-time.After(wait.ForeverTestTimeout):
864 t.Fatalf("expected heartbeat timeout, got none.")
865 case err := <-errorChan:
866
867 if err == nil {
868 t.Fatalf("expected error but received none")
869 }
870 if !strings.Contains(err.Error(), "i/o timeout") {
871 t.Errorf("expected heartbeat timeout error, got (%s)", err)
872 }
873
874 streamExec := exec.(*wsStreamExecutor)
875 if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
876 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
877 }
878 }
879 }
880
881
882
883
884 func TestWebSocketClient_TextMessageTypeError(t *testing.T) {
885 var upgrader = gwebsocket.Upgrader{
886 CheckOrigin: func(r *http.Request) bool {
887 return true
888 },
889 Subprotocols: []string{remotecommand.StreamProtocolV5Name},
890 }
891
892 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
893 conn, err := upgrader.Upgrade(w, req, nil)
894 if err != nil {
895 t.Fatalf("unable to upgrade to create websocket connection: %v", err)
896 }
897 defer conn.Close()
898 msg := []byte("test message with wrong message type.")
899 stdOutMsg := append([]byte{remotecommand.StreamStdOut}, msg...)
900
901 err = conn.WriteMessage(gwebsocket.TextMessage, stdOutMsg)
902 if err != nil {
903 t.Fatalf("error writing text message to websocket: %v", err)
904 }
905
906 }))
907 defer websocketServer.Close()
908
909
910 websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
911 websocketLocation, err := url.Parse(websocketServer.URL)
912 if err != nil {
913 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
914 }
915 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
916 if err != nil {
917 t.Errorf("unexpected error creating websocket executor: %v", err)
918 }
919 var stdout bytes.Buffer
920 options := &StreamOptions{
921 Stdout: &stdout,
922 }
923 errorChan := make(chan error)
924 go func() {
925
926 errorChan <- exec.StreamWithContext(context.Background(), *options)
927 }()
928
929 select {
930 case <-time.After(wait.ForeverTestTimeout):
931 t.Fatalf("expect stream to be closed after connection is closed.")
932 case err := <-errorChan:
933
934 if err == nil {
935 t.Fatalf("expected error but received none")
936 }
937 if !strings.Contains(err.Error(), "unexpected message type") {
938 t.Errorf("expected bad message type error, got (%s)", err)
939 }
940
941 streamExec := exec.(*wsStreamExecutor)
942 if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
943 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
944 }
945 }
946 }
947
948
949
950
951 func TestWebSocketClient_EmptyMessageHandled(t *testing.T) {
952 var upgrader = gwebsocket.Upgrader{
953 CheckOrigin: func(r *http.Request) bool {
954 return true
955 },
956 Subprotocols: []string{remotecommand.StreamProtocolV5Name},
957 }
958
959 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
960 conn, err := upgrader.Upgrade(w, req, nil)
961 if err != nil {
962 t.Fatalf("unable to upgrade to create websocket connection: %v", err)
963 }
964 defer conn.Close()
965
966 conn.WriteMessage(gwebsocket.BinaryMessage, []byte{})
967 }))
968 defer websocketServer.Close()
969
970
971 websocketServer.URL = websocketServer.URL + "?" + "stdout=true"
972 websocketLocation, err := url.Parse(websocketServer.URL)
973 if err != nil {
974 t.Fatalf("Unable to parse WebSocket server URL: %s", websocketServer.URL)
975 }
976 exec, err := NewWebSocketExecutor(&rest.Config{Host: websocketLocation.Host}, "GET", websocketServer.URL)
977 if err != nil {
978 t.Errorf("unexpected error creating websocket executor: %v", err)
979 }
980 var stdout bytes.Buffer
981 options := &StreamOptions{
982 Stdout: &stdout,
983 }
984 errorChan := make(chan error)
985 go func() {
986
987 errorChan <- exec.StreamWithContext(context.Background(), *options)
988 }()
989
990 select {
991 case <-time.After(wait.ForeverTestTimeout):
992 t.Fatalf("expect stream to be closed after connection is closed.")
993 case err := <-errorChan:
994
995 if err == nil {
996 t.Fatalf("expected error but received none")
997 }
998 if !strings.Contains(err.Error(), "read stream id") {
999 t.Errorf("expected error reading stream id, got (%s)", err)
1000 }
1001
1002 streamExec := exec.(*wsStreamExecutor)
1003 if remotecommand.StreamProtocolV5Name != streamExec.negotiated {
1004 t.Fatalf("expected remote command v5 protocol, got (%s)", streamExec.negotiated)
1005 }
1006 }
1007 }
1008
1009 func TestWebSocketClient_ExecutorErrors(t *testing.T) {
1010
1011 config := rest.Config{
1012 ExecProvider: &clientcmdapi.ExecConfig{},
1013 AuthProvider: &clientcmdapi.AuthProviderConfig{},
1014 }
1015 _, err := NewWebSocketExecutor(&config, "GET", "http://localhost")
1016 if err == nil {
1017 t.Errorf("expecting executor constructor error, but received none.")
1018 } else if !strings.Contains(err.Error(), "error creating websocket transports") {
1019 t.Errorf("expecting error creating transports, got (%s)", err.Error())
1020 }
1021
1022 exec, err := NewWebSocketExecutor(&rest.Config{}, "GET", "http://localhost")
1023 if err != nil {
1024 t.Errorf("unexpected error creating websocket executor: %v", err)
1025 }
1026 errorChan := make(chan error)
1027 go func() {
1028
1029 var ctx context.Context
1030 errorChan <- exec.StreamWithContext(ctx, StreamOptions{})
1031 }()
1032
1033 select {
1034 case <-time.After(wait.ForeverTestTimeout):
1035 t.Fatalf("expect stream to be closed after connection is closed.")
1036 case err := <-errorChan:
1037
1038 if err == nil {
1039 t.Fatalf("expected error but received none")
1040 }
1041 if !strings.Contains(err.Error(), "nil Context") {
1042 t.Errorf("expected nil context error, got (%s)", err)
1043 }
1044 }
1045 }
1046
1047 func TestWebSocketClient_HeartbeatSucceeds(t *testing.T) {
1048 var upgrader = gwebsocket.Upgrader{
1049 CheckOrigin: func(r *http.Request) bool {
1050 return true
1051 },
1052 }
1053
1054 websocketServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
1055 conn, err := upgrader.Upgrade(w, req, nil)
1056 if err != nil {
1057 t.Fatalf("unable to upgrade to create websocket connection: %v", err)
1058 }
1059 defer conn.Close()
1060 for {
1061 _, _, err := conn.ReadMessage()
1062 if err != nil {
1063 break
1064 }
1065 }
1066 }))
1067 defer websocketServer.Close()
1068
1069 url := strings.ReplaceAll(websocketServer.URL, "http", "ws")
1070 client, _, err := gwebsocket.DefaultDialer.Dial(url, nil)
1071 if err != nil {
1072 t.Fatalf("dial: %v", err)
1073 }
1074 defer client.Close()
1075
1076
1077 var expectedMsg = "test heartbeat message"
1078 var period = 100 * time.Millisecond
1079 var deadline = 200 * time.Millisecond
1080 heartbeat := newHeartbeat(client, period, deadline)
1081 heartbeat.setMessage(expectedMsg)
1082
1083 pongMsgCh := make(chan string)
1084 pongHandler := heartbeat.conn.PongHandler()
1085 heartbeat.conn.SetPongHandler(func(msg string) error {
1086 pongMsgCh <- msg
1087 return pongHandler(msg)
1088 })
1089 go heartbeat.start()
1090
1091 var wg sync.WaitGroup
1092 wg.Add(1)
1093 go func() {
1094 defer wg.Done()
1095 for {
1096 _, _, err := client.ReadMessage()
1097 if err != nil {
1098 t.Logf("client err reading message: %v", err)
1099 return
1100 }
1101 }
1102 }()
1103
1104 select {
1105 case actualMsg := <-pongMsgCh:
1106 close(heartbeat.closer)
1107
1108 if expectedMsg != actualMsg {
1109 t.Errorf("expected received pong message (%s), got (%s)", expectedMsg, actualMsg)
1110 }
1111 case <-time.After(period * 4):
1112
1113 close(heartbeat.closer)
1114 t.Errorf("unexpected heartbeat timeout")
1115 }
1116 wg.Wait()
1117 }
1118
1119 func TestLateStreamCreation(t *testing.T) {
1120 c := newWSStreamCreator(nil)
1121 c.closeAllStreamReaders(nil)
1122 if err := c.setStream(0, nil); err == nil {
1123 t.Fatal("expected error adding stream after closeAllStreamReaders")
1124 }
1125 }
1126
1127 func TestWebSocketClient_StreamsAndExpectedErrors(t *testing.T) {
1128
1129 c := newWSStreamCreator(nil)
1130 headers := http.Header{}
1131 headers.Set(v1.StreamType, v1.StreamTypeStdin)
1132 s, err := c.CreateStream(headers)
1133 if err != nil {
1134 t.Errorf("unexpected stream creation error: %v", err)
1135 }
1136 expectedStreamID := uint32(remotecommand.StreamStdIn)
1137 actualStreamID := s.Identifier()
1138 if expectedStreamID != actualStreamID {
1139 t.Errorf("expecting stream id (%d), got (%d)", expectedStreamID, actualStreamID)
1140 }
1141 actualHeaders := s.Headers()
1142 if !reflect.DeepEqual(headers, actualHeaders) {
1143 t.Errorf("expecting stream headers (%v), got (%v)", headers, actualHeaders)
1144 }
1145
1146 err = s.Reset()
1147 if err != nil {
1148 t.Errorf("unexpected error in stream reset: %v", err)
1149 }
1150
1151 err = s.Close()
1152 if err == nil {
1153 t.Errorf("expecting stream Close error, but received none")
1154 }
1155 if !strings.Contains(err.Error(), "Close() on already closed stream") {
1156 t.Errorf("expected stream close error, got (%s)", err)
1157 }
1158
1159 n, err := s.Write([]byte("not written"))
1160 if n != 0 {
1161 t.Errorf("expected zero bytes written, wrote (%d) instead", n)
1162 }
1163 if err == nil {
1164 t.Errorf("expecting stream Write error, but received none")
1165 }
1166 if !strings.Contains(err.Error(), "write on closed stream") {
1167 t.Errorf("expected stream write error, got (%s)", err)
1168 }
1169
1170 headers = http.Header{}
1171 headers.Set(v1.StreamType, "UNKNOWN")
1172 _, err = c.CreateStream(headers)
1173 if err == nil {
1174 t.Errorf("expecting CreateStream error, but received none")
1175 } else if !strings.Contains(err.Error(), "unknown stream type") {
1176 t.Errorf("expecting unknown stream type error, got (%s)", err.Error())
1177 }
1178
1179 headers.Set(v1.StreamType, v1.StreamTypeError)
1180 c.streams[remotecommand.StreamErr] = &stream{}
1181 _, err = c.CreateStream(headers)
1182 if err == nil {
1183 t.Errorf("expecting CreateStream error, but received none")
1184 } else if !strings.Contains(err.Error(), "duplicate stream") {
1185 t.Errorf("expecting duplicate stream error, got (%s)", err.Error())
1186 }
1187 }
1188
1189
1190
1191 type options struct {
1192 stdin bool
1193 stdout bool
1194 stderr bool
1195 tty bool
1196 }
1197
1198
1199 func streamOptionsFromRequest(req *http.Request) *options {
1200 query := req.URL.Query()
1201 tty := query.Get("tty") == "true"
1202 stdin := query.Get("stdin") == "true"
1203 stdout := query.Get("stdout") == "true"
1204 stderr := query.Get("stderr") == "true"
1205 return &options{
1206 stdin: stdin,
1207 stdout: stdout,
1208 stderr: stderr,
1209 tty: tty,
1210 }
1211 }
1212
1213
1214 type websocketStreams struct {
1215 conn io.Closer
1216 stdinStream io.ReadCloser
1217 stdoutStream io.WriteCloser
1218 stderrStream io.WriteCloser
1219 writeStatus func(status *apierrors.StatusError) error
1220 resizeStream io.ReadCloser
1221 resizeChan chan TerminalSize
1222 tty bool
1223 }
1224
1225
1226
1227 func webSocketServerStreams(req *http.Request, w http.ResponseWriter, opts *options) (*websocketStreams, error) {
1228 conn, err := createWebSocketStreams(req, w, opts)
1229 if err != nil {
1230 return nil, err
1231 }
1232
1233 if conn.resizeStream != nil {
1234 conn.resizeChan = make(chan TerminalSize)
1235 go handleResizeEvents(req.Context(), conn.resizeStream, conn.resizeChan)
1236 }
1237
1238 return conn, nil
1239 }
1240
1241
1242 func handleResizeEvents(ctx context.Context, stream io.Reader, channel chan<- TerminalSize) {
1243 defer close(channel)
1244
1245 decoder := json.NewDecoder(stream)
1246 for {
1247 size := TerminalSize{}
1248 if err := decoder.Decode(&size); err != nil {
1249 break
1250 }
1251
1252 select {
1253 case channel <- size:
1254 case <-ctx.Done():
1255
1256
1257
1258 return
1259 }
1260 }
1261 }
1262
1263
1264
1265 func createChannels(opts *options) []wsstream.ChannelType {
1266
1267 channels := make([]wsstream.ChannelType, 5)
1268 channels[remotecommand.StreamStdIn] = readChannel(opts.stdin)
1269 channels[remotecommand.StreamStdOut] = writeChannel(opts.stdout)
1270 channels[remotecommand.StreamStdErr] = writeChannel(opts.stderr)
1271 channels[remotecommand.StreamErr] = wsstream.WriteChannel
1272 channels[remotecommand.StreamResize] = wsstream.ReadChannel
1273 return channels
1274 }
1275
1276
1277 func readChannel(real bool) wsstream.ChannelType {
1278 if real {
1279 return wsstream.ReadChannel
1280 }
1281 return wsstream.IgnoreChannel
1282 }
1283
1284
1285 func writeChannel(real bool) wsstream.ChannelType {
1286 if real {
1287 return wsstream.WriteChannel
1288 }
1289 return wsstream.IgnoreChannel
1290 }
1291
1292
1293
1294 func createWebSocketStreams(req *http.Request, w http.ResponseWriter, opts *options) (*websocketStreams, error) {
1295 channels := createChannels(opts)
1296 conn := wsstream.NewConn(map[string]wsstream.ChannelProtocolConfig{
1297 remotecommand.StreamProtocolV5Name: {
1298 Binary: true,
1299 Channels: channels,
1300 },
1301 })
1302 conn.SetIdleTimeout(4 * time.Hour)
1303
1304
1305 _, streams, err := conn.Open(w, req)
1306 if err != nil {
1307 return nil, err
1308 }
1309
1310
1311
1312 switch {
1313 case opts.stdout:
1314 streams[remotecommand.StreamStdOut].Write([]byte{})
1315 case opts.stderr:
1316 streams[remotecommand.StreamStdErr].Write([]byte{})
1317 default:
1318 streams[remotecommand.StreamErr].Write([]byte{})
1319 }
1320
1321 wsStreams := &websocketStreams{
1322 conn: conn,
1323 stdinStream: streams[remotecommand.StreamStdIn],
1324 stdoutStream: streams[remotecommand.StreamStdOut],
1325 stderrStream: streams[remotecommand.StreamStdErr],
1326 tty: opts.tty,
1327 resizeStream: streams[remotecommand.StreamResize],
1328 }
1329
1330 wsStreams.writeStatus = func(stream io.Writer) func(status *apierrors.StatusError) error {
1331 return func(status *apierrors.StatusError) error {
1332 bs, err := json.Marshal(status.Status())
1333 if err != nil {
1334 return err
1335 }
1336 _, err = stream.Write(bs)
1337 return err
1338 }
1339 }(streams[remotecommand.StreamErr])
1340
1341 return wsStreams, nil
1342 }
1343
View as plain text